未验证 提交 6319dd83 编写于 作者: H Haohongxiang 提交者: GitHub

fix bugs (#43115)

上级 a4bb38cb
......@@ -140,17 +140,12 @@ def broadcast_dp_parameters(model, hcg):
def fused_allreduce_gradients(parameter_list, hcg):
if _in_legacy_dygraph():
data_parallel_group = None if hcg is None else hcg.get_data_parallel_group(
)
logger.debug("dp start fuse allreduce gradients")
with framework.no_grad():
_apply_collective_grads(parameter_list, data_parallel_group)
elif in_dygraph_mode():
assert hcg is None, "It's not support to use hcg in EagerDygraph now."
data_parallel_group = paddle.distributed.collective._get_default_group()
with framework.no_grad():
_apply_collective_grads_eager(parameter_list, data_parallel_group)
data_parallel_group = None if hcg is None else hcg.get_data_parallel_group()
logger.debug("dp start fuse allreduce gradients")
apply_func = _apply_collective_grads_eager if in_dygraph_mode(
) else _apply_collective_grads
with framework.no_grad():
apply_func(parameter_list, data_parallel_group)
def sharding_reduce_gradients(parameter_list, hcg):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册