未验证 提交 9f4a8763 编写于 作者: Y Yizhou Wang 提交者: GitHub

Fix broadcast error on multi-node training with ZeroStage3 and TensorParallel=2 (#2999)

* * try to fix broadcast error on multi-node training with ZeroStage3 and TensorParallel=2

* * fix format error

* * fix format issue

* * add TODO for integrated testing of TP and ZeRO 1/2/3

* fix default pg error

---------
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 5c6da1f0
......@@ -813,7 +813,11 @@ class Init(InsertPostInitMethodToModuleSubClasses):
f"Partitioning param {debug_param2name_id_shape(param)} module={debug_module2name(module)}")
if get_accelerator().on_accelerator(param):
dist.broadcast(param, 0, self.get_dp_process_group())
if dist.get_world_group() == self.get_dp_process_group():
dist.broadcast(param, 0, self.get_dp_process_group())
else:
dist.broadcast(param, dist.get_global_rank(self.get_dp_process_group(), 0),
self.get_dp_process_group())
else:
if dist.get_rank() == 0:
logger.warn(f"param `{name}` in {module.__class__.__name__} "
......
......@@ -21,6 +21,7 @@ pytestmark = pytest.mark.skipif(not required_maximum_torch_version(major_version
reason='Megatron-LM package requires Pytorch version 1.13 or below')
# TODO: integrated testing of TP and ZeRO 1/2/3
def get_deepspeed_model(model):
ds_config_dict = {
"train_micro_batch_size_per_gpu": 1,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册