高级图操作
内容
高级图操作¶
在某些情况下,使用 Dask collections 进行计算可能会导致内存使用不理想(例如,整个 Dask DataFrame 被加载到内存中)。这可能发生在 Dask 的调度器没有自动延迟任务图中节点的计算,以避免其输出长时间占用内存,或者在重新计算节点比将其输出保存在内存中便宜得多的场景中。
本页重点介绍一组图操作工具,可用于帮助避免这些情况。特别是,下面描述的工具重写了 Dask collections 底层的 Dask 图,产生具有不同键集的等效 collections。
考虑以下示例
>>> import dask.array as da
>>> x = da.random.default_rng().normal(size=500_000_000, chunks=100_000)
>>> x_mean = x.mean()
>>> y = (x - x_mean).max().compute()
上述示例计算了一个分布在移除其偏差后的最大值。这涉及将 x
的分块加载到内存中,以计算 x_mean
。然而,由于计算 y
时稍后需要 x
数组,因此整个 x
数组都被保留在内存中。对于大型 Dask 数组来说,这可能会带来很大问题。
为了减轻将整个 x
数组保留在内存中的需要,可以将最后一行改写如下
>>> from dask.graph_manipulation import bind
>>> xb = bind(x, x_mean)
>>> y = (xb - x_mean).max().compute()
在这里,我们使用 bind()
创建了一个新的 Dask 数组 xb
,它产生的输出与 x
完全相同,但其底层的 Dask 图具有与 x
不同的键,并且只会在 x_mean
计算完成后才计算。
这导致 x
的分块被计算并立即由 mean
单独规约;然后重新计算并再次立即通过流水线进入减法,接着由 max
规约。这导致峰值内存使用量大大降低,因为不再将完整的 x
数组加载到内存中。然而,代价是计算时间增加了,因为 x
被计算了两次。
API¶
|
构建一个 Dask Delayed,它会等待输入 collection 的所有分块计算完成后才返回 None。 |
|
确保所有输入 collection 的所有分块都已计算完毕,然后才计算任何分块的依赖项。 |
|
使 |
|
克隆 dask collections,返回由独立计算生成的等效 collections。 |
定义¶
- dask.graph_manipulation.checkpoint(*collections, split_every: Optional[Union[float, Literal[False]]] = None) dask.delayed.Delayed [源代码]¶
构建一个 Dask Delayed,它会等待输入 collection 的所有分块计算完成后才返回 None。
- 参数
- collections
零个或多个 Dask collections 或包含零个或多个 collections 的嵌套数据结构
- split_every: int >= 2 or False, 可选
确定递归聚合的深度。如果大于输入键的数量,聚合将分多个步骤执行;聚合图的深度将为 \(log_{split_every}(input keys)\)。设置为较低的值可以减少缓存大小和网络传输,代价是消耗更多的 CPU 和更大的 dask 图。
设置为 False 可禁用。默认为 8。
- 返回
- Dask Delayed 产生 None
- dask.graph_manipulation.wait_on(*collections, split_every: Optional[Union[float, Literal[False]]] = None)[源代码]¶
确保所有输入 collection 的所有分块都已计算完毕,然后才计算任何分块的依赖项。
以下示例创建了一个 dask 数组
u
,当它用于计算时,只有在数组x
的所有分块计算完成后才会继续,否则与x
匹配>>> import dask.array as da >>> x = da.ones(10, chunks=5) >>> u = wait_on(x)
以下示例将创建两个数组
u
和v
,当它们用于计算时,只有在数组x
和y
的所有分块计算完成后才会继续,否则与x
和y
匹配>>> x = da.ones(10, chunks=5) >>> y = da.zeros(10, chunks=5) >>> u, v = wait_on(x, y)
- 参数
- collections
零个或多个 Dask collections 或 Dask collections 的嵌套结构
- split_every
参见
checkpoint()
- 返回
- 与
collections
相同 与输入类型相同的 Dask collection,其计算结果与输入相同,或与输入等效的嵌套结构,其中原始 collections 已被替换。新 collections 中重新生成节点的键将与原始节点不同,以便它们可以在同一个图中使用。
- 与
- dask.graph_manipulation.bind(children: dask.graph_manipulation.T, parents, *, omit=None, seed: collections.abc.Hashable | None = None, assume_layers: bool =True, split_every: Optional[Union[float, Literal[False]]] = None) dask.graph_manipulation.T [源代码]¶
使
children
collection(s),可选地省略子 collection,依赖于parents
collection(s)。以下是两个示例。第一个示例创建了一个数组
b2
,其计算首先完全计算数组a
,然后完全计算b
,在此过程中重新计算a
>>> import dask >>> import dask.array as da >>> a = da.ones(4, chunks=2) >>> b = a + 1 >>> b2 = bind(b, a) >>> len(b2.dask) 9 >>> b2.compute() array([2., 2., 2., 2.])
第二个示例创建了数组
b3
和c3
,其计算首先计算数组a
,然后计算加法,此时不再重新计算a
>>> c = a + 2 >>> b3, c3 = bind((b, c), a, omit=a) >>> len(b3.dask), len(c3.dask) (7, 7) >>> dask.compute(b3, c3) (array([2., 2., 2., 2.]), array([3., 3., 3., 3.]))
- 参数
- children
Dask collection 或 Dask collections 的嵌套结构
- parents
Dask collection 或 Dask collections 的嵌套结构
- omit
Dask collection 或 Dask collections 的嵌套结构
- seed
用于种子密钥重新生成的 Hashable。省略则默认为一个随机数,每次调用都会产生不同的键。
- assume_layers
- True
使用在层级别工作的快速算法,该算法假定
children
和omit
中的所有 collections使用
HighLevelGraph
,并且定义了
__dask_layers__()
方法,并且在
omit
collections 和children
collections 创建之间从未压平并重建它们的图;换句话说,如果在children
collections 的键中可以找到omit
collections 的键,那么对于层也必须如此。
- False
使用在键级别工作的较慢算法,该算法不做上述任何假设。
- split_every
参见
checkpoint()
- 返回
- 与
children
相同 与
children
等效的 Dask collection 或结构,其计算结果相同。除了omit
中的节点外,children
的所有节点都将被重新生成。紧邻omit
上方的节点,或者如果找不到omit
中的 collections,则叶节点,在parents
中的所有 collections 完全计算完毕之前,将阻止计算。重新生成节点的键将与原始节点不同,以便它们可以在同一个图中使用。
- 与
- dask.graph_manipulation.clone(*collections, omit=None, seed: collections.abc.Hashable = None, assume_layers: bool =True)[源代码]¶
克隆 dask collections,返回由独立计算生成的等效 collections。
- 参数
- 返回
- 与
collections
相同 与输入类型相同的 Dask collections,其计算结果与输入相同,或与输入等效的嵌套结构,其中原始 collections 已被替换。新 collections 中重新生成节点的键将与原始节点不同,以便它们可以在同一个图中使用。
- 与
示例
(为简洁起见,令牌已简化)
>>> import dask.array as da >>> x_i = da.asarray([1, 1, 1, 1], chunks=2) >>> y_i = x_i + 1 >>> z_i = y_i + 2 >>> dict(z_i.dask) {('array-1', 0): array([1, 1]), ('array-1', 1): array([1, 1]), ('add-2', 0): (<function operator.add>, ('array-1', 0), 1), ('add-2', 1): (<function operator.add>, ('array-1', 1), 1), ('add-3', 0): (<function operator.add>, ('add-2', 0), 1), ('add-3', 1): (<function operator.add>, ('add-2', 1), 1)} >>> w_i = clone(z_i, omit=x_i) >>> w_i.compute() array([4, 4, 4, 4]) >>> dict(w_i.dask) {('array-1', 0): array([1, 1]), ('array-1', 1): array([1, 1]), ('add-4', 0): (<function operator.add>, ('array-1', 0), 1), ('add-4', 1): (<function operator.add>, ('array-1', 1), 1), ('add-5', 0): (<function operator.add>, ('add-4', 0), 1), ('add-5', 1): (<function operator.add>, ('add-4', 1), 1)}
clone() 的典型用法模式如下
>>> x = cheap_computation_with_large_output() >>> y = expensive_and_long_computation(x) >>> z = wrap_up(clone(x), y)
在上面的代码中,x 的分块一旦被 y 的分块消耗掉就会被遗忘,然后在计算结束时完全重新生成。如果没有 clone(),x 将只计算一次,然后在整个 y 的计算过程中保留在内存中,不必要地消耗内存。