未验证 提交 2c628439 编写于 作者: A Aswin John Mathews 提交者: GitHub

Added 4-byte alignment on NCCL/RCCL (#1328)

* Added 4-byte alignment on NCCL/RCCL

* pre-commit formatting fixes

* Fix for checkpoint loading with optimizer partitioning

* Better assert print

* Added unit tests for nccl/rccl 4-byte alignment

* bug

* Updated alignment to implicit
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 3e7d06ad
......@@ -230,6 +230,11 @@ class FP16_DeepSpeedZeroOptimizer(object):
# number of elements per partition in each group
self.partition_size = []
#align nccl all-gather send buffers to 4-bye boundary
self.nccl_start_alignment_factor = 2 # 4-byte alignment/sizeof(fp16) = 2
assert (allgather_bucket_size % self.nccl_start_alignment_factor == 0), f"allgather_bucket_size must be a multiple of nccl_start_alignment_factor, {self.nccl_start_alignment_factor} "
self.all_reduce_print = False
self.dtype = self.optimizer.param_groups[0]['params'][0].dtype
......@@ -283,6 +288,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.fp16_groups_flat.append(
self.flatten_dense_tensors_aligned(
self.round_robin_fp16_groups[i],
self.nccl_start_alignment_factor *
dist.get_world_size(group=self.real_dp_process_group[i])).cuda(
torch.cuda.current_device()))
see_memory_usage(f"After flattening and moving param group {i} to GPU",
......@@ -303,6 +309,11 @@ class FP16_DeepSpeedZeroOptimizer(object):
i)
self.parallel_partitioned_fp16_groups.append(data_parallel_partitions)
# verify that data partition start locations are 4-byte aligned
for partitioned_data in data_parallel_partitions:
assert (partitioned_data.data_ptr() %
(2 * self.nccl_start_alignment_factor) == 0)
# a partition of the fp32 master weights that will be updated by this process
if not fp16_master_weights_and_gradients:
self.single_partition_of_fp32_groups.append(
......@@ -1993,6 +2004,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
]
flat_merged_partitions = self.flatten_dense_tensors_aligned(
merged_partitions,
self.nccl_start_alignment_factor *
dist.get_world_size(group=self.real_dp_process_group[i]))
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions, i)
merged_single_partition_of_fp32_groups.append(dp_partitions[partition_id])
......
......@@ -231,3 +231,91 @@ def test_zero_to_fp32(tmpdir, zero_stage):
fp32_state_dict[name].float())
_test_zero_to_fp32()
@pytest.mark.parametrize('zero_stage, allgather_bucket_size', [(2, 1000), (2, 1001)])
def test_incorrect_allgather_bucket_size(tmpdir, zero_stage, allgather_bucket_size):
config_dict = {
"train_micro_batch_size_per_gpu": 2,
"gradient_accumulation_steps": 2,
"steps_per_print": 1,
"zero_optimization": {
"stage": zero_stage,
"allgather_bucket_size": allgather_bucket_size
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3
}
},
"fp16": {
"enabled": True,
"initial_scale_power": 8
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 4
model = SimpleModel(hidden_dim=hidden_dim)
@distributed_test(world_size=[1])
def _test_incorrect_allgather_bucket_size(args, model, hidden_dim):
if allgather_bucket_size % 2 == 0:
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
else:
with pytest.raises(AssertionError) as assertinfo:
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
assert "allgather_bucket_size must be a multiple of nccl_start_alignment_factor" in str(
assertinfo)
_test_incorrect_allgather_bucket_size(args=args, model=model, hidden_dim=hidden_dim)
@pytest.mark.parametrize('zero_stage, world_size', [(2, 2), (2, 3), (2, 4)])
def test_partition_nccl_alignment(tmpdir, zero_stage, world_size):
config_dict = {
"train_micro_batch_size_per_gpu": 2,
"gradient_accumulation_steps": 2,
"steps_per_print": 1,
"zero_optimization": {
"stage": zero_stage
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3
}
},
"fp16": {
"enabled": True,
"initial_scale_power": 8
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 4
model = SimpleModel(hidden_dim=hidden_dim)
@distributed_test(world_size=world_size)
def _test_partition_nccl_alignment(args, model, hidden_dim):
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
# get nccl all-gather send buffers alignment factor
nccl_start_alignment_factor = model.optimizer.nccl_start_alignment_factor
for data_parallel_partitions in model.optimizer.parallel_partitioned_fp16_groups:
for partition_id, partitioned_data in enumerate(data_parallel_partitions):
# verify that data partition start locations are 4-byte aligned
assert (partitioned_data.data_ptr() %
(2 * nccl_start_alignment_factor) == 0)
_test_partition_nccl_alignment(args=args, model=model, hidden_dim=hidden_dim)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册