Dask心得与笔记【2】
文章目录
- 计算
- 参考文献
计算
数组切片如下
import numpy as np
import dask.array as dadata = np.arange(1000).reshape(10, 100)
a = da.from_array(data, chunks=(5, 20))
print(a[:,0:3])
切片结果是前3列
dask.array<getitem, shape=(10, 3), dtype=int64, chunksize=(5, 3), chunktype=numpy.ndarray>
Dask是懒惰计算,就是说,当你要求结果时,它才会计算。
调用这方法设置任务图,然后调用compute方法得到结果。
import numpy as np
import dask.array as dadata = np.arange(1000).reshape(10, 100)
a = da.from_array(data, chunks=(5, 20))
print(a[:,0:3].compute())
print(sum(a[0,0:5].compute()))#0+1+2+3+4=10
[[ 0 1 2][100 101 102][200 201 202][300 301 302][400 401 402][500 501 502][600 601 602][700 701 702][800 801 802][900 901 902]]
10
按列求和
import numpy as np
import dask.array as dadata = np.arange(1000).reshape(10, 100)
a = da.from_array(data, chunks=(5, 20))
print(a[1:3,0:3].compute())
print(sum(a[1:3,0:3].compute()))
[[100 101 102][200 201 202]]
[300 302 304]
调用numpy的函数
import numpy as np
import dask.array as dadata = np.arange(1000).reshape(10, 100)
a = da.from_array(data, chunks=(5, 20))
print(a[1:3,0:3].compute())
print(a[1:3,0:3].mean().compute())
print(a[1:3,0:3].sum().compute())
print(np.cos(a[1:3,0:3]).compute())
print(a[1:3,0:3].T.compute())
PS E:\learn\learnpy> & D:/Python312/python.exe e:/learn/learnpy/l2.py
[[100 101 102][200 201 202]]
151.0
906
[[0.86231887 0.89200487 0.1015857 ][0.48718768 0.99808296 0.59134538]]
[[100 200][101 201][102 202]]
可以调用JAX的函数试下
import dask.array as da
import jax.numpy as jnp
from dask import delayed# 创建 Dask 数组
x = da.random.random((1000, 1000), chunks=(100, 100))# 定义一个使用 JAX 的函数
@delayed
def jax_computation(arr):jax_arr = jnp.array(arr) # 转换为 JAX 数组return jnp.sum(jax_arr * 2).block_until_ready() # 使用 JAX 计算# 应用计算
result = jax_computation(x.compute()) # 先计算 Dask 数组,再传给 JAX
from dask import compute
import jax# 在多个设备上并行运行 JAX 函数
@delayed
def jax_operation(data):device = jax.devices()[0] # 可以使用不同设备with jax.default_device(device):return jnp.sum(data * 2)# 创建多个延迟任务
tasks = [jax_operation(jnp.ones(100)) for _ in range(10)]
results = compute(*tasks) # 并行计算
另外,分布式 JAX 计算,可以考虑使用 JAX 的 pmap 进行多设备并行
import jax
import jax.numpy as jnp
from jax import pmap# 检查可用设备
print(jax.devices()) # 例如: [GpuDevice(id=0), GpuDevice(id=1)]# 定义一个简单的函数
def f(x):return x * 2 + 1# 创建并行化版本
parallel_f = pmap(f)# 准备输入数据 (注意: 第一维对应设备数量)
x = jnp.array([[1., 2.], [3., 4.]]) # 形状 (2, 2)# 并行执行
result = parallel_f(x) # 在2个设备上并行计算
print(result)
TensorFlow Probability (TFP) 可以与 TensorFlow 的分布式策略结合使用,实现大规模的统计计算和概率建模。
基本概率分布计算
def run_in_distributed_environment():# 在策略范围内创建变量和计算with strategy.scope():# 创建TFP正态分布normal = tfp.distributions.Normal(loc=0., scale=1.)# 分布式计算samples = normal.sample(1000)mean = tf.reduce_mean(samples)stddev = tf.math.reduce_std(samples)return mean, stddevmean, stddev = run_in_distributed_environment()
print(f"均值: {mean.numpy()}, 标准差: {stddev.numpy()}")
参考文献
- https://docs.dask.org/en/stable
- deepseek