当前位置: 首页 > news >正文

Dask心得与笔记【2】

文章目录

  • 计算
  • 参考文献

计算

数组切片如下

import numpy as np
import dask.array as dadata = np.arange(1000).reshape(10, 100)
a = da.from_array(data, chunks=(5, 20))
print(a[:,0:3])

切片结果是前3列

dask.array<getitem, shape=(10, 3), dtype=int64, chunksize=(5, 3), chunktype=numpy.ndarray>

Dask是懒惰计算,就是说,当你要求结果时,它才会计算。
调用这方法设置任务图,然后调用compute方法得到结果。

import numpy as np
import dask.array as dadata = np.arange(1000).reshape(10, 100)
a = da.from_array(data, chunks=(5, 20))
print(a[:,0:3].compute())
print(sum(a[0,0:5].compute()))#0+1+2+3+4=10
[[  0   1   2][100 101 102][200 201 202][300 301 302][400 401 402][500 501 502][600 601 602][700 701 702][800 801 802][900 901 902]]
10

按列求和

import numpy as np
import dask.array as dadata = np.arange(1000).reshape(10, 100)
a = da.from_array(data, chunks=(5, 20))
print(a[1:3,0:3].compute())
print(sum(a[1:3,0:3].compute()))
[[100 101 102][200 201 202]]
[300 302 304]

调用numpy的函数

import numpy as np
import dask.array as dadata = np.arange(1000).reshape(10, 100)
a = da.from_array(data, chunks=(5, 20))
print(a[1:3,0:3].compute())
print(a[1:3,0:3].mean().compute())
print(a[1:3,0:3].sum().compute())
print(np.cos(a[1:3,0:3]).compute())
print(a[1:3,0:3].T.compute())
PS E:\learn\learnpy> & D:/Python312/python.exe e:/learn/learnpy/l2.py
[[100 101 102][200 201 202]]
151.0
906
[[0.86231887 0.89200487 0.1015857 ][0.48718768 0.99808296 0.59134538]]
[[100 200][101 201][102 202]]

可以调用JAX的函数试下

import dask.array as da
import jax.numpy as jnp
from dask import delayed# 创建 Dask 数组
x = da.random.random((1000, 1000), chunks=(100, 100))# 定义一个使用 JAX 的函数
@delayed
def jax_computation(arr):jax_arr = jnp.array(arr)  # 转换为 JAX 数组return jnp.sum(jax_arr * 2).block_until_ready()  # 使用 JAX 计算# 应用计算
result = jax_computation(x.compute())  # 先计算 Dask 数组,再传给 JAX
from dask import compute
import jax# 在多个设备上并行运行 JAX 函数
@delayed
def jax_operation(data):device = jax.devices()[0]  # 可以使用不同设备with jax.default_device(device):return jnp.sum(data * 2)# 创建多个延迟任务
tasks = [jax_operation(jnp.ones(100)) for _ in range(10)]
results = compute(*tasks)  # 并行计算

另外,分布式 JAX 计算,可以考虑使用 JAX 的 pmap 进行多设备并行

import jax
import jax.numpy as jnp
from jax import pmap# 检查可用设备
print(jax.devices())  # 例如: [GpuDevice(id=0), GpuDevice(id=1)]# 定义一个简单的函数
def f(x):return x * 2 + 1# 创建并行化版本
parallel_f = pmap(f)# 准备输入数据 (注意: 第一维对应设备数量)
x = jnp.array([[1., 2.], [3., 4.]])  # 形状 (2, 2)# 并行执行
result = parallel_f(x)  # 在2个设备上并行计算
print(result)

TensorFlow Probability (TFP) 可以与 TensorFlow 的分布式策略结合使用,实现大规模的统计计算和概率建模。
基本概率分布计算

def run_in_distributed_environment():# 在策略范围内创建变量和计算with strategy.scope():# 创建TFP正态分布normal = tfp.distributions.Normal(loc=0., scale=1.)# 分布式计算samples = normal.sample(1000)mean = tf.reduce_mean(samples)stddev = tf.math.reduce_std(samples)return mean, stddevmean, stddev = run_in_distributed_environment()
print(f"均值: {mean.numpy()}, 标准差: {stddev.numpy()}")

参考文献

  1. https://docs.dask.org/en/stable
  2. deepseek

相关文章:

  • 《卷积神经网络到Vision Transformer:计算机视觉的十年架构革命》
  • LeetCode--38.外观数列
  • docker部署后端服务的脚本
  • 华为交换机SSH登录报错--Key exchange failed.
  • Java-Scanner类
  • 深入解析Java 内部类
  • 单电机FOC与多电机协同交叉耦合控制Simulink仿真方案
  • 深入浅出:AWS Cognito 认证机制详解
  • cf 禁止http/1.0和http/1.1的访问 是否会更安全?
  • easywechat 6.X AccessToken刷新问题
  • Linux【9】-----Linux系统编程(线程池和并发socket编程 c语言)
  • vue3中的Treeshaking特性是什么,并举例说明
  • TCP 在高速网络下的大数据量传输优化:拥塞控制、效率保障与协议演进​
  • 咨询进阶——125页麦肯锡业务流程规划方法论及流程规划案例【附全文阅读】
  • progress telerik fiddler解决微软账户登录80190001错误问题
  • docker解析
  • 函数指针与指针函数
  • 操作系统的概述之三
  • 【LeetCode 热题 100】438. 找到字符串中所有字母异位词——(解法三)不定长滑动窗口+数组
  • 【Linux】理解进程状态与优先级:操作系统中的调度原理