# FullyShardedDataParallel > 原文:[`pytorch.org/docs/stable/fsdp.html`](https://pytorch.org/docs/stable/fsdp.html) ```py class torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=True, use_orig_params=False, ignored_states=None, device_mesh=None) ``` 用于在数据并行工作者之间分片模块参数的包装器。 这受到[Xu 等人](https://arxiv.org/abs/2004.13336)以及[DeepSpeed](https://www.deepspeed.ai/)的 ZeRO 阶段 3 的启发。FullyShardedDataParallel 通常缩写为 FSDP。 示例: ```py >>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> torch.cuda.set_device(device_id) >>> sharded_module = FSDP(my_module) >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) >>> x = sharded_module(x, y=3, z=torch.Tensor([1])) >>> loss = x.sum() >>> loss.backward() >>> optim.step() ``` 警告 优化器必须在模块被 FSDP 包装之后初始化,因为 FSDP 将以一种可能不保留原始参数变量的方式对模块的参数进行分片和转换。因此,先前初始化的优化器可能会对参数有过时的引用。 警告 如果目标 CUDA 设备的 ID 为`dev_id`,则(1)`module`应该已经放置在该设备上,(2)可以使用`torch.cuda.set_device(dev_id)`设置设备,或者(3)应该将`dev_id`传递给`device_id`构造函数参数。此 FSDP 实例的计算设备将是该目标设备。对于(1)和(3),FSDP 初始化始终在 GPU 上进行。对于(2),FSDP 初始化发生在`module`当前的设备上,可能是 CPU。 警告 在使用 CPU 卸载时,FSDP 当前不支持在`no_sync()`之外支持梯度累积。尝试这样做会产生不正确的结果,因为 FSDP 将使用新减少的梯度而不是与任何现有梯度累积。 警告 在构造之后更改原始参数变量名称将导致未定义的行为。 警告 传递`sync_module_states=True`标志需要`module`在 GPU 上或使用`device_id`参数来指定 FSDP 将`module`移动到的 CUDA 设备。这是因为`sync_module_states=True`需要 GPU 通信。 警告 截至 PyTorch 1.12,FSDP 仅对共享参数提供有限支持(例如,将一个`Linear`层的权重设置为另一个的)。特别是,共享参数的模块必须作为同一 FSDP 单元的一部分进行包装。如果您的用例需要增强的共享参数支持,请访问[`github.com/pytorch/pytorch/issues/77724`](https://github.com/pytorch/pytorch/issues/77724) 警告 FSDP 对冻结参数(即设置`param.requires_grad=False`)有一些约束。对于`use_orig_params=False`,每个 FSDP 实例必须管理所有冻结或所有非冻结的参数。对于`use_orig_params=True`,FSDP 支持混合冻结和非冻结,但我们建议不要这样做,因为梯度内存使用量将高于预期(即等同于不冻结这些参数)。这意味着理想情况下,冻结参数应该被隔离到自己的`nn.Module`中,并分别用 FSDP 包装。 注意 尝试运行包含在 FSDP 实例中的子模块的前向传递不受支持,将导致错误。这是因为子模块的参数将被分片,但它本身不是 FSDP 实例,因此其前向传递将不会适当地聚集所有参数。当尝试仅运行编码器-解码器模型的编码器时,可能会发生这种情况,并且编码器未包装在自己的 FSDP 实例中。要解决此问题,请将子模块包装在自己的 FSDP 单元中。 注意 FSDP 将输入张量移动到 GPU 计算设备的`forward`方法中,因此用户不需要手动将它们从 CPU 移动。 警告 用户不应在前向和后向之间修改参数,而不使用`summon_full_params()`上下文,因为修改可能不会持久。此外,对于`use_orig_params=False`,在前向和后向之间访问原始参数可能会引发非法内存访问。 警告 对于`use_orig_params=True`,`ShardingStrategy.SHARD_GRAD_OP`在前向传播后暴露未分片的参数,而不是分片的参数,因为它不释放未分片的参数,不像`ShardingStrategy.FULL_SHARD`。一个注意事项是,由于梯度总是被分片或为`None`,`ShardingStrategy.SHARD_GRAD_OP`在前向传播后不会暴露带有未分片参数的分片梯度。如果要检查梯度,请尝试使用`with_grads=True`调用`summon_full_params()`。 警告 FSDP 在前向和后向计算期间用`torch.Tensor`视图替换托管模块的参数,出于自动求导相关原因。如果您的模块的前向依赖于对参数的保存引用,而不是在每次迭代中重新获取引用,则它将看不到 FSDP 新创建的视图,并且自动求导将无法正常工作。 注意 使用`limit_all_gathers=True`,您可能会看到 FSDP 在前向传播中存在一个 CPU 线程不发出任何内核的间隙。这是有意的,并显示了速率限制器的效果。以这种方式同步 CPU 线程可以防止为后续全聚合过度分配内存,并且实际上不应延迟 GPU 内核执行。 注意 当使用`sharding_strategy=ShardingStrategy.HYBRID_SHARD`,分片进程组为节点内,复制进程组为节点间时,设置`NCCL_CROSS_NIC=1`可以帮助改善某些集群设置下复制进程组的全聚合时间。 参数 + **module** (*nn.Module*) – 这是要用 FSDP 包装的模块。 + **process_group** (*可选**[**Union**[**ProcessGroup**,* *Tuple**[**ProcessGroup**,* *ProcessGroup**]**]**]*) – 这是模型被分片的进程组,因此也是 FSDP 的全聚合和减少散播集体通信所使用的进程组。如果为`None`,则 FSDP 使用默认进程组。对于混合分片策略,如`ShardingStrategy.HYBRID_SHARD`,用户可以传入一个进程组的元组,分别表示分片和复制的组。如果为`None`,则 FSDP 为用户构建进程组,以便在节点内进行分片和在节点间进行复制。(默认值:`None`) + **sharding_strategy** (*可选***[*ShardingStrategy**]*) – 这配置了分片策略,可能会权衡内存节省和通信开销。详细信息请参见`ShardingStrategy`。(默认值:`FULL_SHARD`) + **cpu_offload** (*可选***[*CPUOffload**]*) – 这配置了 CPU 卸载。如果设置为`None`,则不会发生 CPU 卸载。详细信息请参见`CPUOffload`。(默认值:`None`) + **auto_wrap_policy** (*可选****Union**[**Callable**[**[*[*nn.Module**,* [*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")*,* [*int*](https://docs.python.org/3/library/functions.html#int "(in Python v3.12)")*]**,* [*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")*]**,* *ModuleWrapPolicy**,* *CustomPolicy**]**]*) – 这指定了一个策略,将 FSDP 应用于`module`的子模块,这对通信和计算重叠至关重要,从而影响性能。如果为`None`,则 FSDP 仅应用于`module`,用户应手动将 FSDP 应用于父模块(自下而上进行)。为方便起见,这直接接受`ModuleWrapPolicy`,允许用户指定要包装的模块类(例如变换器块)。否则,这应该是一个可调用对象,接受三个参数`module: nn.Module`、`recurse: bool`和`nonwrapped_numel: int`,并应返回一个`bool`,指定是否应在`recurse=False`时应用 FSDP 到传入的`module`,或者如果`recurse=True`,遍历应继续到模块的子树。用户可以向可调用对象添加其他参数。`torch.distributed.fsdp.wrap.py`中的`size_based_auto_wrap_policy`提供了一个示例可调用对象,如果其子树中的参数超过 100M 个元素,则将 FSDP 应用于模块。我们建议在应用 FSDP 后打印模型,并根据需要进行调整。 示例: ```py >>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, >>> nonwrapped_numel: int, >>> # Additional custom arguments >>> min_num_params: int = int(1e8), >>> ) -> bool: >>> return nonwrapped_numel >= min_num_params >>> # Configure a custom `min_num_params` >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5)) ``` + **backward_prefetch** (*可选***[*BackwardPrefetch**]*) – 这配置了所有 gather 的显式向后预取。如果为`None`,则 FSDP 不进行向后预取,在向后传递中没有通信和计算重叠。详细信息请参见`BackwardPrefetch`。(默认值:`BACKWARD_PRE`) + **mixed_precision** (*可选***[*MixedPrecision**]*) – 这配置了 FSDP 的本机混合精度。如果设置为`None`,则不使用混合精度。否则,可以设置参数、缓冲区和梯度减少的数据类型。详细信息请参见`MixedPrecision`。(默认值:`None`) + **ignored_modules** (*可选****可迭代**[*[*torch.nn.Module**]**]*) – 忽略此实例的参数和子模块的参数和缓冲区的模块。`ignored_modules`中直接的模块都不应该是`FullyShardedDataParallel`实例,如果已构建的子模块是`FullyShardedDataParallel`实例,并且它们嵌套在此实例下,则不会被忽略。当使用`auto_wrap_policy`时,或者如果参数的分片不是由 FSDP 管理时,可以使用此参数避免以模块粒度分片特定参数。(默认值:`None`) + **param_init_fn** (*可选****可调用**[**[*[*nn.Module**]**,* *None**]**]*) – 一个 `Callable[torch.nn.Module] -> None`,指定了当前在元设备上的模块应该如何初始化到实际设备上。从 v1.12 开始,FSDP 通过 `is_meta` 检测在元设备上具有参数或缓冲区的模块,如果指定了 `param_init_fn`,则应用它,否则调用 `nn.Module.reset_parameters()`。对于这两种情况,实现应该 *仅* 初始化模块的参数/缓冲区,而不是其子模块的参数/缓冲区。这是为了避免重新初始化。此外,FSDP 还支持通过 torchdistX 的 `deferred_init()` API 延迟初始化,其中延迟模块通过调用 `param_init_fn`(如果指定)或 torchdistX 的默认 `materialize_module()` 进行初始化。如果指定了 `param_init_fn`,则它将应用于所有元设备模块,这意味着它可能会根据模块类型进行分类。FSDP 在参数展平和分片之前调用初始化函数。 示例: ```py >>> module = MyModule(device="meta") >>> def my_init_fn(module: nn.Module): >>> # E.g. initialize depending on the module type >>> ... >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) >>> print(next(fsdp_model.parameters()).device) # current CUDA device >>> # With torchdistX >>> module = deferred_init.deferred_init(MyModule, device="cuda") >>> # Will initialize via deferred_init.materialize_module(). >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy) ``` + **device_id**(*Optional**[**Union**[*[*int*](https://docs.python.org/3/library/functions.html#int "(in Python v3.12)")*,* *torch.device**]***) – 一个 `int` 或 `torch.device`,指定 FSDP 初始化所在的 CUDA 设备,包括如果需要的话模块初始化和参数分片。如果模块在 CPU 上,则应指定此项以提高初始化速度。如果设置了默认的 CUDA 设备(例如通过 `torch.cuda.set_device`),则用户可以将 `torch.cuda.current_device` 传递给此项。(默认值:`None`) + **sync_module_states**([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)"))- 如果为 `True`,则每个 FSDP 模块将从排名 0 广播模块参数和缓冲区,以确保它们在各个排名之间复制(为此构造函数增加通信开销)。这可以帮助以内存高效的方式通过 `load_state_dict` 加载 `state_dict` 检查点。请参阅 `FullStateDictConfig` 以获取此示例。(默认值:`False`) + **forward_prefetch**([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)"))- 如果为 `True`,则 FSDP *明确* 在当前前向计算之前预取下一个前向传递的所有聚集。这仅对 CPU 绑定的工作负载有用,在这种情况下,提前发出下一个所有聚集可能会提高重叠。这仅适用于静态图模型,因为预取遵循第一次迭代的执行顺序。(默认值:`False`) + **limit_all_gathers**([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)"))- 如果为 `True`,则 FSDP 明确同步 CPU 线程,以确保 GPU 内存使用仅来自 *两个* 连续的 FSDP 实例(当前实例运行计算和下一个实例,其所有聚集都是预取的)。如果为 `False`,则 FSDP 允许 CPU 线程发出所有聚集而无需任何额外的同步。(默认值:`True`)我们通常将此功能称为“速率限制器”。此标志应仅针对具有低内存压力的特定 CPU 绑定工作负载设置为 `False`,在这种情况下,CPU 线程可以积极发出所有内核,而不必担心 GPU 内存使用。 + **use_orig_params**([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)"))- 将此设置为`True`会使 FSDP 使用`module`的原始参数。FSDP 通过`nn.Module.named_parameters()`向用户公开这些原始参数,而不是通过 FSDP 的内部`FlatParameter`。这意味着优化器步骤在原始参数上运行,从而实现每个原始参数的超参数。FSDP 保留原始参数变量,并在未分片和分片形式之间操作它们的数据,其中它们始终是底层未分片或分片`FlatParameter`的视图。根据当前算法,分片形式始终是 1D,丢失了原始张量结构。对于给定等级,原始参数可能具有全部、部分或没有数据。在没有数据的情况下,其数据将类似于大小为 0 的空张量。用户不应编写依赖于给定原始参数在其分片形式中存在哪些数据的程序。`True`是使用`torch.compile()`所必需的。将其设置为`False`通过`nn.Module.named_parameters()`向用户公开 FSDP 的内部`FlatParameter`。 (默认值:`False`) + **ignored_states**(*可选**[**Iterable**[**torch.nn.Parameter**]**]**,* *可选****Iterable**[*[*torch.nn.Module**]**]*)- 不受此 FSDP 实例管理的被忽略的参数或模块,这意味着参数未分片,它们的梯度未在等级之间减少。此参数与现有的`ignored_modules`参数统一,我们可能很快会弃用`ignored_modules`。为了向后兼容,我们保留`ignored_states`和`ignored_modules`,但是 FSDP 只允许其中一个被指定为非`None`。 ```py apply(fn) ``` 递归地将`fn`应用于每个子模块(由`.children()`返回)以及自身。 典型用法包括初始化模型的参数(另请参阅 torch.nn.init)。 与`torch.nn.Module.apply`相比,此版本在应用`fn`之前还会收集完整的参数。不应在另一个`summon_full_params`上下文中调用它。 参数 **fn**(`Module` -> None)- 要应用于每个子模块的函数 返回 self 返回类型 Module ```py check_is_root() ``` 检查此实例是否为根 FSDP 模块。 返回类型 [bool](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)") ```py clip_grad_norm_(max_norm, norm_type=2.0) ``` 裁剪所有参数的梯度规范。 规范是计算所有参数的梯度作为单个向量的规范,并且梯度会就地修改。 参数 + **max_norm**([*float*](https://docs.python.org/3/library/functions.html#float "(in Python v3.12)") *或* [*int*](https://docs.python.org/3/library/functions.html#int "(in Python v3.12)"))- 梯度的最大规范 + **norm_type**([*float*](https://docs.python.org/3/library/functions.html#float "(in Python v3.12)") *或* [*int*](https://docs.python.org/3/library/functions.html#int "(in Python v3.12)"))- 使用的 p-范数类型。可以是`'inf'`表示无穷范数。 返回 参数的总规范(视为单个向量)。 返回类型 *Tensor* 注意 如果每个 FSDP 实例都使用`NO_SHARD`,即没有梯度在等级之间分片,则可以直接使用`torch.nn.utils.clip_grad_norm_()`。 注意 如果至少有一些 FSDP 实例使用分片策略(即`NO_SHARD`之外的策略),则应使用此方法而不是`torch.nn.utils.clip_grad_norm_()`,因为此方法处理了梯度在等级之间分片的事实。 注意 返回的总规范将具有 PyTorch 类型提升语义定义的所有参数/梯度中的“最大”dtype。例如,如果*所有*参数/梯度使用低精度 dtype,则返回的规范的 dtype 将是该低精度 dtype,但如果至少存在一个使用 FP32 的参数/梯度,则返回的规范的 dtype 将是 FP32。 警告 由于使用集体通信,因此需要在所有秩上调用此函数。 ```py static flatten_sharded_optim_state_dict(sharded_optim_state_dict, model, optim) ``` 展平分片的优化器状态字典。 API 类似于`shard_full_optim_state_dict()`。唯一的区别是输入的`sharded_optim_state_dict`应该从`sharded_optim_state_dict()`返回。因此,每个秩上都会有全聚合调用以收集`ShardedTensor` s。 参数 + **sharded_optim_state_dict**(*Dict**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*,* *Any**]*) - 对应于未展平参数并保存分片优化器状态的优化器状态字典。 + **model**(*torch.nn.Module*) - 参考`shard_full_optim_state_dict()`。 + **optim**(*torch.optim.Optimizer*) - 用于`model`的参数的优化器。 返回 参考`shard_full_optim_state_dict()`。 返回类型 [*Dict*](https://docs.python.org/3/library/typing.html#typing.Dict "(in Python v3.12)")[[str](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)"), [*Any*](https://docs.python.org/3/library/typing.html#typing.Any "(in Python v3.12)")] ```py forward(*args, **kwargs) ``` 运行包装模块的前向传递,插入 FSDP 特定的前向和后向分片逻辑。 返回类型 [*Any*](https://docs.python.org/3/library/typing.html#typing.Any "(in Python v3.12)") ```py static fsdp_modules(module, root_only=False) ``` 返回所有嵌套的 FSDP 实例。 这可能包括`module`本身,仅在`root_only=True`时包括 FSDP 根模块。 参数 + **module**(*torch.nn.Module*) - 根模块,可能是或可能不是`FSDP`模块。 + **root_only**([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")) - 是否仅返回 FSDP 根模块。 (默认值:`False`) 返回 嵌套在输入`module`中的 FSDP 模块。 返回类型 List[FullyShardedDataParallel] ```py static full_optim_state_dict(model, optim, optim_input=None, rank0_only=True, group=None) ``` 返回完整的优化器状态字典。 在秩 0 上合并完整的优化器状态,并将其作为[`dict`](https://docs.python.org/3/library/stdtypes.html#dict "(in Python v3.12)")返回,遵循`torch.optim.Optimizer.state_dict()`的约定,即具有键`"state"`和`"param_groups"`。`model`中包含的`FSDP`模块中的展平参数将映射回其未展平参数。 警告 由于使用集体通信,因此需要在所有秩上调用此函数。但是,如果`rank0_only=True`,则状态字典仅在秩为 0 时填充,并且所有其他秩返回一个空的[`dict`](https://docs.python.org/3/library/stdtypes.html#dict "(in Python v3.12)")。 警告 与`torch.optim.Optimizer.state_dict()`不同,此方法使用完整的参数名称作为键,而不是参数 ID。 注意 与`torch.optim.Optimizer.state_dict()`中一样,优化器状态字典中包含的张量不会被克隆,因此可能会出现别名意外。为了最佳实践,考虑立即保存返回的优化器状态字典,例如使用`torch.save()`。 参数 + **模型**(*torch.nn.Module*实例),其参数被传递给优化器`optim`。 + **optim**(*torch.optim.Optimizer*)- 用于`model`的参数的优化器。 + **optim_input**(*Optional**[**Union**[**List**[**Dict**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*,* *Any**]**]**,** *Iterable**[**torch.nn.Parameter**]**]**]*) - 传递给优化器`optim`的输入,表示参数组的[`list`](https://docs.python.org/3/library/stdtypes.html#list "(in Python v3.12)")或参数的可迭代对象;如果为`None`,则此方法假定输入为`model.parameters()`。此参数已被弃用,不再需要传递它。 (默认值:`None`) + **rank0_only**([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)"))- 如果为`True`,则仅在 rank 0 上保存填充的[`dict`](https://docs.python.org/3/library/stdtypes.html#dict "(in Python v3.12)");如果为`False`,则在所有 rank 上保存。 (默认值:`True`) + **group**(*dist.ProcessGroup*)- 模型的进程组或如果使用默认进程组则为`None`。 (默认值:`None`) 返回 包含`model`原始未扁平化参数的优化器状态和包括“state”和“param_groups”键的[`dict`](https://docs.python.org/3/library/stdtypes.html#dict "(in Python v3.12)")。如果`rank0_only=True`,则非零 rank 返回一个空的[`dict`](https://docs.python.org/3/library/stdtypes.html#dict "(in Python v3.12)")。 返回类型 Dict[[str](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)"), Any] ```py static get_state_dict_type(module) ``` 获取根据`module`根模块的 FSDP 模块的 state_dict_type 和相应的配置。 目标模块不必是 FSDP 模块。 返回 包含当前设置的 state_dict_type 和 state_dict / optim_state_dict 配置的`StateDictSettings`。 引发 + **如果 StateDictSettings 不同,则会引发 AssertionError** - + **FSDP 子模块不同。** - 返回类型 *StateDictSettings* ```py property module: Module ``` 返回包装的模块。 ```py named_buffers(*args, **kwargs) ``` 返回一个模块缓冲区的迭代器,同时产生缓冲区的名称和缓冲区本身。 在`summon_full_params()`上下文管理器中,拦截缓冲区名称并删除所有 FSDP 特定的扁平化缓冲区前缀的出现。 返回类型 [*Iterator*](https://docs.python.org/3/library/typing.html#typing.Iterator "(in Python v3.12)")[[*Tuple*](https://docs.python.org/3/library/typing.html#typing.Tuple "(in Python v3.12)")[[str](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)"), *Tensor*]] ```py named_parameters(*args, **kwargs) ``` 返回一个模块参数的迭代器,同时产生参数的名称和参数本身。 拦截参数名称,并在`summon_full_params()`上下文管理器内部删除所有 FSDP 特定的扁平化参数前缀的出现。 返回类型 [*Iterator*](https://docs.python.org/3/library/typing.html#typing.Iterator "(在 Python v3.12 中)")[[*Tuple*](https://docs.python.org/3/library/typing.html#typing.Tuple "(在 Python v3.12 中)")[[str](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12 中)"), *Parameter*]] ```py no_sync() ``` 禁用 FSDP 实例之间的梯度同步。 在此上下文中,梯度将在模块变量中累积,稍后在退出上下文后的第一个前向-后向传递中进行同步。这应仅用于根 FSDP 实例,并将递归应用于所有子 FSDP 实例。 注意 这可能导致更高的内存使用,因为 FSDP 将累积完整模型梯度(而不是梯度片段),直到最终同步。 注意 当与 CPU 卸载一起使用时,在上下文管理器内部梯度不会被卸载到 CPU。相反,它们只会在最终同步之后被卸载。 返回类型 [*Generator*](https://docs.python.org/3/library/typing.html#typing.Generator "(在 Python v3.12 中)") ```py static optim_state_dict(model, optim, optim_state_dict=None, group=None) ``` 转换与分片模型对应的优化器的状态字典。 给定的状态字典可以转换为三种类型之一:1)完整的优化器状态字典,2)分片的优化器状态字典,3)本地的优化器状态字典。 对于完整的优化器状态字典,所有状态都是未扁平化且未分片的。可以通过`state_dict_type()`指定仅 Rank0 和仅 CPU,以避免 OOM。 对于分片的优化器状态字典,所有状态都是未扁平化但是分片的。可以通过`state_dict_type()`指定仅 CPU,以进一步节省内存。 对于本地 state_dict,不会执行任何转换。但是,状态将从 nn.Tensor 转换为 ShardedTensor 以表示其分片性质(目前不支持)。 示例: ```py >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> from torch.distributed.fsdp import FullStateDictConfig >>> from torch.distributed.fsdp import FullOptimStateDictConfig >>> # Save a checkpoint >>> model, optim = ... >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim) >>> save_a_checkpoint(state_dict, optim_state_dict) >>> # Load a checkpoint >>> model, optim = ... >>> state_dict, optim_state_dict = load_a_checkpoint() >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> model.load_state_dict(state_dict) >>> optim_state_dict = FSDP.optim_state_dict_to_load( >>> optim_state_dict, model, optim >>> ) >>> optim.load_state_dict(optim_state_dict) ``` 参数 + **model** (*torch.nn.Module*) – 根模块(可能是或可能不是`FullyShardedDataParallel`实例)其参数被传递给优化器`optim`。 + **optim** (*torch.optim.Optimizer*) – 用于`model`参数的优化器。 + **optim_state_dict** (*Dict**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12 中)")*,* *Any**]*) – 要转换的目标优化器状态字典。如果值为 None,则将使用 optim.state_dict()。(默认值:`None`) + **group** (*dist.ProcessGroup*) – 模型的进程组,参数被分片到该组中,如果使用默认进程组则为`None`。(默认值:`None`) 返回 包含`model`的优化器状态的[`dict`](https://docs.python.org/3/library/stdtypes.html#dict "(在 Python v3.12)")。优化器状态的分片基于`state_dict_type`。 返回类型 Dict[[str](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12 中)"), Any] ```py static optim_state_dict_to_load(model, optim, optim_state_dict, is_named_optimizer=False, load_directly=False, group=None) ``` 将优化器状态字典转换为可以加载到与 FSDP 模型关联的优化器中的格式。 给定通过`optim_state_dict()`转换的`optim_state_dict`,它被转换为可以加载到`optim`的扁平化优化器 state_dict,该`optim`是`model`的优化器。`model`必须由 FullyShardedDataParallel 进行分片。 ```py >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> from torch.distributed.fsdp import FullStateDictConfig >>> from torch.distributed.fsdp import FullOptimStateDictConfig >>> # Save a checkpoint >>> model, optim = ... >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> state_dict = model.state_dict() >>> original_osd = optim.state_dict() >>> optim_state_dict = FSDP.optim_state_dict( >>> model, >>> optim, >>> optim_state_dict=original_osd >>> ) >>> save_a_checkpoint(state_dict, optim_state_dict) >>> # Load a checkpoint >>> model, optim = ... >>> state_dict, optim_state_dict = load_a_checkpoint() >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> model.load_state_dict(state_dict) >>> optim_state_dict = FSDP.optim_state_dict_to_load( >>> model, optim, optim_state_dict >>> ) >>> optim.load_state_dict(optim_state_dict) ``` 参数 + **model** (*torch.nn.Module*) – 根模块(可能是或不是`FullyShardedDataParallel`实例),其参数传递给了优化器`optim`。 + **optim** (*torch.optim.Optimizer*) – `model`的参数的优化器。 + **optim_state_dict** (*Dict**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*,* *Any**]*) – 要加载的优化器状态。 + **is_named_optimizer** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")) – 这个优化器是 NamedOptimizer 还是 KeyedOptimizer。只有在`optim`是 TorchRec 的 KeyedOptimizer 或 torch.distributed 的 NamedOptimizer 时才设置为 True。 + **load_directly** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")) – 如果设置为 True,则此 API 在返回结果之前还将调用 optim.load_state_dict(result)。否则,用户需要调用`optim.load_state_dict()`(默认值:`False`) + **group** (*dist.ProcessGroup*) – 模型的进程组,参数在其中进行分片,如果使用默认进程组,则为`None`。(默认值:`None`) 返回类型 [*Dict*](https://docs.python.org/3/library/typing.html#typing.Dict "(in Python v3.12)")[[str](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)"), [*Any*](https://docs.python.org/3/library/typing.html#typing.Any "(in Python v3.12)")] ```py register_comm_hook(state, hook) ``` 注册通信钩子。 这是一个增强功能,为用户提供了一个灵活的钩子,他们可以在其中指定 FSDP 如何在多个工作进程之间聚合梯度。这个钩子可以用于实现几种算法,如[GossipGrad](https://arxiv.org/abs/1803.05880)和涉及不同通信策略的梯度压缩,这些策略用于参数同步,同时使用`FullyShardedDataParallel`进行训练。 警告 在进行初始前向传递之前,应注册 FSDP 通信钩子,并且只注册一次。 参数 + **state** ([*object*](https://docs.python.org/3/library/functions.html#object "(in Python v3.12)")) – 传递给钩子以在训练过程中维护任何状态信息。示例包括梯度压缩中的错误反馈,与[GossipGrad](https://arxiv.org/abs/1803.05880)中下一个通信的对等体等。它由每个工作进程本地存储,并由工作进程上的所有梯度张量共享。 + **hook** (*Callable*) – 可调用函数,具有以下签名之一:1) `hook: Callable[torch.Tensor] -> None`:此函数接受一个 Python 张量,表示与此 FSDP 单元包装的模型对应的所有变量的全面、扁平化、未分片梯度(未被其他 FSDP 子单元包装)。然后执行所有必要的处理并返回`None`;2) `hook: Callable[torch.Tensor, torch.Tensor] -> None`:此函数接受两个 Python 张量,第一个表示与此 FSDP 单元包装的模型对应的所有变量的全面、扁平化、未分片梯度(未被其他 FSDP 子单元包装)。后者表示一个预先大小的张量,用于存储分片梯度的一部分。在这两种情况下,可调用函数执行所有必要的处理并返回`None`。具有签名 1 的可调用函数应处理 NO_SHARD 情况的梯度通信。具有签名 2 的可调用函数应处理分片情况的梯度通信。 ```py static rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None, optim=None) ``` 重新调整优化器状态字典`optim_state_dict`以使用键类型`optim_state_key_type`。 这可以用于实现具有 FSDP 实例和没有 FSDP 实例的模型的优化器状态字典之间的兼容性。 重新调整 FSDP 全优化器状态字典(即从`full_optim_state_dict()`)以使用参数 ID,并且可以加载到非包装模型中: ```py >>> wrapped_model, wrapped_optim = ... >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) >>> nonwrapped_model, nonwrapped_optim = ... >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) >>> nonwrapped_optim.load_state_dict(rekeyed_osd) ``` 重新调整来自非包装模型的普通优化器状态字典,以便加载到包装模型中: ```py >>> nonwrapped_model, nonwrapped_optim = ... >>> osd = nonwrapped_optim.state_dict() >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) >>> wrapped_model, wrapped_optim = ... >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) >>> wrapped_optim.load_state_dict(sharded_osd) ``` 返回 使用由`optim_state_key_type`指定的参数键重新调整的优化器状态字典。 返回类型 字典[[str](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)"), Any] ```py static scatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None, group=None) ``` 将来自排名 0 的完整优化器状态字典分散到所有其他排名。 在每个排名上返回分片的优化器状态字典。返回值与`shard_full_optim_state_dict()`相同,在排名 0 上,第一个参数应该是`full_optim_state_dict()`的返回值。 示例: ```py >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 >>> # Define new model with possibly different world size >>> new_model, new_optim, new_group = ... >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) >>> new_optim.load_state_dict(sharded_osd) ``` 注意 `shard_full_optim_state_dict()`和`scatter_full_optim_state_dict()`都可以用于获取分片优化器状态字典以进行加载。假设完整的优化器状态字典驻留在 CPU 内存中,前者要求每个排名在 CPU 内存中具有完整字典,其中每个排名单独对字典进行分片而无需任何通信,而后者只要求排名 0 在 CPU 内存中具有完整字典,其中排名 0 将每个分片移动到 GPU 内存(用于 NCCL)并适当地将其通信给排名。因此,前者具有更高的总体 CPU 内存成本,而后者具有更高的通信成本。 参数 + **full_optim_state_dict** (*Optional**[**Dict**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*,* *Any**]**]*) – 与未扁平化参数对应并保存完整非分片优化器状态的优化器状态字典(如果在排名 0 上);在非零排名上忽略该参数。 + **model**(*torch.nn.Module*)- 根模块(可能是或可能不是`FullyShardedDataParallel`实例),其参数对应于`full_optim_state_dict`中的优化器状态。 + **optim_input**(*可选**[**Union**[**List**[**Dict**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*,* *Any**]**]**,* *Iterable**[**torch.nn.Parameter**]**]**]*) - 传递给优化器的输入,表示参数组的[`list`](https://docs.python.org/3/library/stdtypes.html#list "(in Python v3.12)")或参数的可迭代对象;如果为`None`,则此方法假定输入为`model.parameters()`。此参数已被弃用,不再需要传递它。(默认值:`None`) + **optim**(*可选***[*torch.optim.Optimizer**]*)- 由此方法返回的状态字典将加载的优化器。这是优选的参数,优于`optim_input`。(默认值:`None`) + **group**(*dist.ProcessGroup*)- 模型的进程组或如果使用默认进程组则为`None`。(默认值:`None`) 返回值 现在完整的优化器状态字典已经重新映射到扁平化的参数,而不是未扁平化的参数,并且仅限于包括此排名部分的优化器状态。 返回类型 Dict[[str](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)"), Any] ```py static set_state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None) ``` 设置目标模块的所有后代 FSDP 模块的`state_dict_type`。 还接受(可选)模型和优化器状态字典的配置。目标模块不必是一个 FSDP 模块。如果目标模块是一个 FSDP 模块,它的`state_dict_type`也将被更改。 注意 此 API 应仅用于顶层(根)模块。 注意 此 API 使用户能够透明地使用传统的`state_dict` API,在根 FSDP 模块被另一个`nn.Module`包装的情况下进行模型检查点。例如,以下代码将确保在所有非 FSDP 实例上调用`state_dict`,同时将分派到 FSDP 的 sharded_state_dict 实现: 示例: ```py >>> model = DDP(FSDP(...)) >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.SHARDED_STATE_DICT, >>> state_dict_config = ShardedStateDictConfig(offload_to_cpu=True), >>> optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True), >>> ) >>> param_state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim) ``` 参数 + **module**(*torch.nn.Module*)- 根模块。 + **state_dict_type**(*StateDictType*)- 要设置的期望的`state_dict_type`。 + **state_dict_config**(*可选***[*StateDictConfig**]*)- 目标`state_dict_type`的配置。 + **optim_state_dict_config**(*可选***[*OptimStateDictConfig**]*)- 优化器状态字典的配置。 返回值 包括模块的先前 state_dict 类型和配置的 StateDictSettings。 返回类型 *StateDictSettings* ```py static shard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None) ``` 分片完整的优化器状态字典。 将`full_optim_state_dict`中的状态重新映射为扁平化的参数,而不是未扁平化的参数,并且限制为仅此排名部分的优化器状态。第一个参数应该是`full_optim_state_dict()`的返回值。 示例: ```py >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) >>> torch.save(full_osd, PATH) >>> # Define new model with possibly different world size >>> new_model, new_optim = ... >>> full_osd = torch.load(PATH) >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model) >>> new_optim.load_state_dict(sharded_osd) ``` 注意 `shard_full_optim_state_dict()`和`scatter_full_optim_state_dict()`都可用于获取分片优化器状态字典以加载。假设完整的优化器状态字典驻留在 CPU 内存中,前者要求每个排名在 CPU 内存中具有完整的字典,其中每个排名单独对字典进行分片而无需任何通信,而后者仅要求排名 0 在 CPU 内存中具有完整的字典,其中排名 0 将每个分片移动到 GPU 内存(用于 NCCL)并适当地将其通信给排名。因此,前者具有更高的总体 CPU 内存成本,而后者具有更高的通信成本。 参数 + **full_optim_state_dict**(*字典**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*,* *Any**]*) - 与未扁平化参数对应的优化器状态字典,保存完整的非分片优化器状态。 + **模型**(*torch.nn.Module*)- 根模块(可能是或不是`FullyShardedDataParallel`实例),其参数对应于`full_optim_state_dict`中的优化器状态。 + **optim_input**(*可选**[**Union**[**List**[**Dict**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*,* *Any**]**]**,* *Iterable**[**torch.nn.Parameter**]**]**]*) - 传递给优化器的输入,表示参数组的[`list`](https://docs.python.org/3/library/stdtypes.html#list "(in Python v3.12)")或参数的可迭代对象;如果为`None`,则此方法假定输入为`model.parameters()`。此参数已被弃用,不再需要传递它。(默认值:`None`) + **optim**(*可选***[*torch.optim.Optimizer**]*) - 将由此方法返回的状态字典加载的优化器。这是优选的参数,用于覆盖`optim_input`。(默认值:`None`) 返回 现在完整的优化器状态字典已重新映射为扁平化参数,而不是未扁平化参数,并且仅限于包括此排名部分的优化器状态。 返回类型 字典[[str](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)"), Any] ```py static sharded_optim_state_dict(model, optim, group=None) ``` 以其分片形式返回优化器状态字典。 API 类似于`full_optim_state_dict()`,但此 API 将所有非零维状态分块为`ShardedTensor`以节省内存。当使用上下文管理器`with state_dict_type(SHARDED_STATE_DICT):`派生模型`state_dict`时,应仅使用此 API。 有关详细用法,请参考`full_optim_state_dict()`。 警告 返回的状态字典包含`ShardedTensor`,不能直接被常规的`optim.load_state_dict`使用。 返回类型 [*字典*](https://docs.python.org/3/library/typing.html#typing.Dict "(in Python v3.12)")[[str](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)"), [*Any*](https://docs.python.org/3/library/typing.html#typing.Any "(in Python v3.12)")] ```py static state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None) ``` 设置目标模块的所有后代 FSDP 模块的`state_dict_type`。 此上下文管理器具有与`set_state_dict_type()`相同的功能。阅读`set_state_dict_type()`的文档以获取详细信息。 示例: ```py >>> model = DDP(FSDP(...)) >>> with FSDP.state_dict_type( >>> model, >>> StateDictType.SHARDED_STATE_DICT, >>> ): >>> checkpoint = model.state_dict() ``` 参数 + **module**(*torch.nn.Module**]*) - 目标`state_dict_type`的模型`state_dict`配置。 + **optim_state_dict_config**(*可选***[*OptimStateDictConfig**]*) - 目标`state_dict_type`的优化器`state_dict`配置。 返回类型 [*生成器*](https://docs.python.org/3/library/typing.html#typing.Generator "(在 Python v3.12 中)") ```py static summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False, with_grads=False) ``` 使用此上下文管理器为 FSDP 实例公开完整参数。 在模型进行前向/反向传播之后,可以用于获取参数以进行额外处理或检查。它可以接受一个非 FSDP 模块,并将召唤所有包含的 FSDP 模块以及它们的子模块的完整参数,取决于`recurse`参数。 注意 这可以在内部 FSDP 上使用。 注意 这不能在前向或反向传递中使用。也不能在此上下文中启动前向和反向传递。 注意 参数在上下文管理器退出后将恢复为其本地分片,存储行为与前向传播相同。 注意 完整参数可以被修改,但只有对应于本地参数分片的部分将在上下文管理器退出后保留(除非`writeback=False`,在这种情况下更改将被丢弃)。在 FSDP 不对参数进行分片的情况下,目前仅当`world_size == 1`或`NO_SHARD`配置时,修改将被持久化,无论`writeback`如何。 注意 此方法适用于本身不是 FSDP 但可能包含多个独立 FSDP 单元的模块。在这种情况下,给定的参数将适用于所有包含的 FSDP 单元。 警告 请注意,`rank0_only=True`与`writeback=True`结合使用目前不受支持,将引发错误。这是因为模型参数形状在上下文中的不同排名之间会有所不同,对其进行写入可能会导致在退出上下文时跨排名之间的不一致性。 警告 请注意,`offload_to_cpu`和`rank0_only=False`会导致完整参数被冗余地复制到 CPU 内存中,对于与 GPU 位于同一台机器上的情况,这可能会带来 CPU 内存溢出的风险。建议使用`offload_to_cpu`与`rank0_only=True`。 参数 + **recurse**([*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)")*,* *可选*) - 递归召唤所有嵌套 FSDP 实例的参数(默认值:True)。 + **writeback**([*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)")*,* *可选*) - 如果为`False`,则在上下文管理器退出后对参数的修改将被丢弃;禁用此选项可能会稍微更有效(默认值:True) + **rank0_only**([*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)")*,* *可选*) - 如果为`True`,则只有全局排名 0 上的完整参数会被实现。这意味着在上下文中,只有排名 0 会有完整参数,其他排名将具有分片参数。请注意,在`rank0_only=True`与`writeback=True`一起使用时不受支持,因为模型参数形状在上下文中的不同排名之间会有所不同,对其进行写入可能会导致在退出上下文时跨排名之间的不一致性。 + **offload_to_cpu**([*布尔*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 版本)")*,* *可选*) – 如果为`True`,完整的参数会被卸载到 CPU。请注意,只有在参数被分片时(对于 world_size = 1 或`NO_SHARD`配置除外),才会发生此卸载。建议使用`offload_to_cpu`与`rank0_only=True`一起使用,以避免将模型参数的冗余副本卸载到相同的 CPU 内存中。 + **with_grads**([*布尔*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 版本)")*,* *可选*) – 如果为`True`,梯度也会与参数一起取消分片。目前,只有在将`use_orig_params=True`传递给 FSDP 构造函数并将`offload_to_cpu=False`传递给此方法时才支持此功能。(默认值:`False`) 返回类型 [*生成器*](https://docs.python.org/3/library/typing.html#typing.Generator "(在 Python v3.12 版本)") ```py class torch.distributed.fsdp.BackwardPrefetch(value) ``` 这配置了显式的向后预取,通过在向后传递中启用通信和计算重叠来提高吞吐量,但会略微增加内存使用量。 + `BACKWARD_PRE`:这会增加最多的重叠,但也会增加最多的内存使用量。这会在当前一组参数的梯度计算*之前*预取下一组参数。这会重叠*下一个全局聚集*和*当前梯度计算*,在峰值时,它会在内存中保存当前一组参数、下一组参数和当前一组梯度。 + `BACKWARD_POST`:这会减少重叠,但需要更少的内存使用量。这会在当前一组参数的梯度计算*之后*预取下一组参数。这会重叠*当前的 reduce-scatter*和*下一个梯度计算*,并在为下一组参数分配内存之前释放当前一组参数,仅在内存中保留下一组参数和当前一组梯度。 + FSDP 的`backward_prefetch`参数接受`None`,这会完全禁用向后预取。这不会重叠,也不会增加内存使用量。总的来说,我们不建议使用这个设置,因为它可能会显著降低吞吐量。 更多技术背景:对于使用 NCCL 后端的单个进程组,任何集合,即使从不同流发出,也会争夺相同的每个设备 NCCL 流,这意味着发出集合的相对顺序对于重叠很重要。两个向后预取值对应不同的发出顺序。 ```py class torch.distributed.fsdp.ShardingStrategy(value) ``` 这指定了由`FullyShardedDataParallel`用于分布式训练的分片策略。 + `FULL_SHARD`:参数、梯度和优化器状态被分片。对于参数,此策略在前向传递之前取消分片(通过全局聚集),在前向传递后重新分片,在向后计算之前取消分片,并在向后计算后重新分片。对于梯度,它在向后计算后同步和分片它们(通过 reduce-scatter)。分片的优化器状态在每个秩上本地更新。 + `SHARD_GRAD_OP`:梯度和优化器状态在计算过程中被分片,此外,参数在计算之外被分片。对于参数,此策略在前向传递之前取消分片,在前向传递后不再分片,仅在向后计算后重新分片。分片的优化器状态在每个秩上本地更新。在`no_sync()`中,参数在向后计算后不再分片。 + `NO_SHARD`:参数、梯度和优化器状态不分片,而是在各个秩之间复制,类似于 PyTorch 的`DistributedDataParallel`API。对于梯度,此策略在向后计算后同步它们(通过全局归约)。未分片的优化器状态在每个秩上本地更新。 + `HYBRID_SHARD`:在节点内应用`FULL_SHARD`,并在节点之间复制参数。这会减少通信量,因为昂贵的所有收集和减少散射仅在节点内完成,对于中等大小的模型可能更高效。 + `_HYBRID_SHARD_ZERO2`:在节点内应用`SHARD_GRAD_OP`,并在节点之间复制参数。这类似于`HYBRID_SHARD`,不同之处在于在前向传递后不释放未分片参数,从而节省了前向传递中的所有收集操作。 ```py class torch.distributed.fsdp.MixedPrecision(param_dtype=None, reduce_dtype=None, buffer_dtype=None, keep_low_precision_grads=False, cast_forward_inputs=False, cast_root_forward_inputs=True, _module_classes_to_ignore=(, )) ``` 这配置了 FSDP 本机混合精度训练。 变量 + **param_dtype** (*Optional***[*torch.dtype**]*) – 这指定了模型参数在前向和反向期间的数据类型,因此也是前向和反向计算的数据类型。在前向和反向之外,*分片*参数保持全精度(例如,用于优化器步骤),并且对于模型检查点,参数始终以全精度保存。(默认值:`None`) + **reduce_dtype** (*Optional***[*torch.dtype**]*) – 这指定了梯度减少的数据类型(即 reduce-scatter 或 all-reduce)。如果这是`None`,但`param_dtype`不是`None`,那么它将采用`param_dtype`的值,仍然以低精度运行梯度减少。这允许与`param_dtype`不同,例如强制梯度减少以全精度运行。(默认值:`None`) + **buffer_dtype** (*Optional***[*torch.dtype**]*) – 这指定了缓冲区的数据类型。FSDP 不对缓冲区进行分片。相反,FSDP 在第一次前向传递中将它们转换为`buffer_dtype`,并在此后保持该数据类型。对于模型检查点,缓冲区以全精度保存,除了`LOCAL_STATE_DICT`。(默认值:`None`) + **keep_low_precision_grads** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")) – 如果为`False`,则 FSDP 在反向传递后将梯度向上转换为全精度,以准备进行优化器步骤。如果为`True`,则 FSDP 保持梯度在用于梯度减少的数据类型中,这可以节省内存,如果使用支持低精度运行的自定义优化器。(默认值:`False`) + **cast_forward_inputs** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")) – 如果为`True`,则此 FSDP 模块将其前向参数和 kwargs 转换为`param_dtype`。这是为了确保前向计算所需的参数和输入数据类型匹配,许多操作都需要。当仅对一些而不是所有 FSDP 模块应用混合精度时,可能需要将其设置为`True`,在这种情况下,混合精度 FSDP 子模块需要重新转换其输入。(默认值:`False`) + **cast_root_forward_inputs** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")) – 如果为`True`,则根 FSDP 模块将其前向参数和 kwargs 转换为`param_dtype`,覆盖`cast_forward_inputs`的值。对于非根 FSDP 模块,这不会做任何事情。(默认值:`True`) + **_module_classes_to_ignore** (*Sequence****Type**[*[*torch.nn.modules.module.Module**]**]*) – (Sequence[Type[nn.Module]]):这指定了在使用`auto_wrap_policy`时要忽略的模块类别:这些类别的模块将单独应用 FSDP,关闭混合精度(意味着最终的 FSDP 构造将偏离指定的策略)。如果未指定`auto_wrap_policy`,则这不会做任何事情。此 API 是实验性的,可能会更改。(默认值:`(_BatchNorm,)`) 注意 此 API 是实验性的,可能会更改。 注意 只有浮点张量会被转换为指定的数据类型。 注意 在 `summon_full_params` 中,参数被强制转换为全精度,但缓冲区不会。 注意 层归一化和批量归一化即使在其输入为低精度(如 `float16` 或 `bfloat16`)时也会累积为 `float32`。对于这些归一化模块禁用 FSDP 的混合精度只意味着仿射参数保留为 `float32`。然而,这会为这些归一化模块产生单独的全局收集和减少散播,这可能是低效的,因此如果工作负载允许,用户应该仍然将混合精度应用于这些模块。 注意 默认情况下,如果用户传递了一个带有任何 `_BatchNorm` 模块并指定了 `auto_wrap_policy` 的模型,那么批量归一化模块将单独应用 FSDP,但混合精度被禁用。请参阅 `_module_classes_to_ignore` 参数。 注意 `MixedPrecision` 默认情况下具有 `cast_root_forward_inputs=True` 和 `cast_forward_inputs=False`。对于根 FSDP 实例,其 `cast_root_forward_inputs` 优先于其 `cast_forward_inputs`。对于非根 FSDP 实例,它们的 `cast_root_forward_inputs` 值将被忽略。默认设置对于典型情况足够,其中每个 FSDP 实例具有相同的 `MixedPrecision` 配置,并且只需要在模型前向传递的开始时将输入转换为 `param_dtype`。 注意 对于具有不同 `MixedPrecision` 配置的嵌套 FSDP 实例,我们建议设置单独的 `cast_forward_inputs` 值来配置每个实例的前向传递之前是否转换输入。在这种情况下,由于转换发生在每个 FSDP 实例的前向传递之前,父 FSDP 实例应该在其 FSDP 子模块之前运行其非 FSDP 子模块,以避免由于不同的 `MixedPrecision` 配置而导致激活数据类型发生变化。 示例: ```py >>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) >>> model[1] = FSDP( >>> model[1], >>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True), >>> ) >>> model = FSDP( >>> model, >>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), >>> ) ``` 上面展示了一个工作示例。另一方面,如果 `model[1]` 被替换为 `model[0]`,意味着使用不同 `MixedPrecision` 的子模块先运行其前向传播,那么 `model[1]` 将错误地看到 `float16` 激活而不是 `bfloat16` 的激活。 ```py class torch.distributed.fsdp.CPUOffload(offload_params=False) ``` 这配置了 CPU 卸载。 变量 **offload_params** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")) – 这指定是否在计算中未涉及时将参数卸载到 CPU。如果为 `True`,则梯度也会被卸载到 CPU,这意味着优化器步骤在 CPU 上运行。 ```py class torch.distributed.fsdp.StateDictConfig(offload_to_cpu=False) ``` `StateDictConfig` 是所有 `state_dict` 配置类的基类。用户应该实例化一个子类(例如 `FullStateDictConfig`)以配置由 FSDP 支持的相应 `state_dict` 类型的设置。 变量 **offload_to_cpu** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")) – 如果为 `True`,则 FSDP 将状态字典值卸载到 CPU,如果为 `False`,则 FSDP 保留在 GPU 上。(默认值:`False`) ```py class torch.distributed.fsdp.FullStateDictConfig(offload_to_cpu=False, rank0_only=False) ``` `FullStateDictConfig` 是一个配置类,用于与 `StateDictType.FULL_STATE_DICT` 一起使用。我们建议在保存完整状态字典时分别启用 `offload_to_cpu=True` 和 `rank0_only=True`,以节省 GPU 内存和 CPU 内存。这个配置类应该通过 `state_dict_type()` 上下文管理器来使用,如下所示: ```py >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> fsdp = FSDP(model, auto_wrap_policy=...) >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): >>> state = fsdp.state_dict() >>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0. >>> # To reload checkpoint for inference, finetuning, transfer learning, etc: >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP >>> if dist.get_rank() == 0: >>> # Load checkpoint only on rank 0 to avoid memory redundancy >>> state_dict = torch.load("my_checkpoint.pt") >>> model.load_state_dict(state_dict) >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument >>> # communicates loaded checkpoint states from rank 0 to rest of the world. >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True) >>> # After this point, all ranks have FSDP model with loaded checkpoint. ``` 变量 **rank0_only** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")) – 如果为 `True`,则只有排名为 0 的进程保存完整状态字典,非零排名的进程保存一个空字典。如果为 `False`,则所有排名都保存完整状态字典。(默认值:`False`) ```py class torch.distributed.fsdp.ShardedStateDictConfig(offload_to_cpu=False, _use_dtensor=False) ``` `ShardedStateDictConfig` 是一个配置类,用于与 `StateDictType.SHARDED_STATE_DICT` 一起使用。 变量 **_use_dtensor** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")) – 如果为 `True`,则 FSDP 将状态字典值保存为 `DTensor`,如果为 `False`,则 FSDP 将它们保存为 `ShardedTensor`。(默认值:`False`) 警告 `_use_dtensor`是`ShardedStateDictConfig`的私有字段,由 FSDP 用于确定状态字典值的类型。用户不应手动修改`_use_dtensor`。 ```py class torch.distributed.fsdp.LocalStateDictConfig(offload_to_cpu: bool = False) ``` ```py class torch.distributed.fsdp.OptimStateDictConfig(offload_to_cpu=True) ``` `OptimStateDictConfig`是所有`optim_state_dict`配置类的基类。用户应该实例化一个子类(例如`FullOptimStateDictConfig`)以配置由 FSDP 支持的相应`optim_state_dict`类型的设置。 变量 **offload_to_cpu**([*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)")) - 如果为`True`,则 FSDP 将状态字典的张量值转移到 CPU,如果为`False`,则 FSDP 将其保留在原始设备上(即 GPU,除非启用了 CPU 卸载参数)。 (默认值:`True`) ```py class torch.distributed.fsdp.FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False) ``` 变量 **rank0_only**([*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)")) - 如果为`True`,则只有排名为 0 的进程保存完整的状态字典,非零排名的进程保存空字典。如果为`False`,则所有进程都保存完整的状态字典。(默认值:`False`) ```py class torch.distributed.fsdp.ShardedOptimStateDictConfig(offload_to_cpu=True, _use_dtensor=False) ``` `ShardedOptimStateDictConfig`是一个配置类,用于与`StateDictType.SHARDED_STATE_DICT`一起使用。 变量 **_use_dtensor**([*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)")) - 如果为`True`,则 FSDP 将状态字典的值保存为`DTensor`,如果为`False`,则 FSDP 将其保存为`ShardedTensor`。(默认值:`False`) 警告 `_use_dtensor`是`ShardedOptimStateDictConfig`的私有字段,由 FSDP 用于确定状态字典值的类型。用户不应手动修改`_use_dtensor`。 ```py class torch.distributed.fsdp.LocalOptimStateDictConfig(offload_to_cpu: bool = False) ``` ```py class torch.distributed.fsdp.StateDictSettings(state_dict_type: torch.distributed.fsdp.api.StateDictType, state_dict_config: torch.distributed.fsdp.api.StateDictConfig, optim_state_dict_config: torch.distributed.fsdp.api.OptimStateDictConfig) ```