Use group_sharded_parallel can perform group shared configuration on the model, optimizer and GradScaler. Level has three string options, 'os', 'os_g' and 'p_g_os' corresponds to three different usage scenarios: optimizer state segmentation, optimizer state + gradient segmentation, and parameter + gradient + optimizer state segmentation.
Use group_sharded_parallel can perform group shared configuration on the model, optimizer and GradScaler. Level has three string options, 'os', 'os_g' and 'p_g_os' corresponds to three different usage scenarios: optimizer state segmentation, optimizer state + gradient segmentation, and parameter + gradient + optimizer state segmentation.
...
@@ -63,6 +64,7 @@ def group_sharded_parallel(
...
@@ -63,6 +64,7 @@ def group_sharded_parallel(
segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20, indicating that the dimension of the minimum segmented parameter is 2**20.
segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20, indicating that the dimension of the minimum segmented parameter is 2**20.
sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False, indicating that asynchronous communication is used.
sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False, indicating that asynchronous communication is used.
dp_group(Group, optional): dp communication group, support to combine stage2 or stage3 with dp hybrid communication.
dp_group(Group, optional): dp communication group, support to combine stage2 or stage3 with dp hybrid communication.
exclude_layer(list, optional): exclude some layers for slicing for sharding stage3, for example, exclude_layer=["GroupNorm", id(model.gpt.linear)], exclude_layer must contain the layers' name or one layer's id.
Returns:
Returns:
model: A wrapper for group sharded given model.
model: A wrapper for group sharded given model.
...
@@ -159,6 +161,7 @@ def group_sharded_parallel(
...
@@ -159,6 +161,7 @@ def group_sharded_parallel(
sync_comm=sync_comm,
sync_comm=sync_comm,
dp_group=dp_group,
dp_group=dp_group,
device=device,
device=device,
exclude_layer=exclude_layer,
)
)
else:
else:
raiseValueError("Please enter the correct level.")
raiseValueError("Please enter the correct level.")