未验证 提交 c1db7e32 编写于 作者: S ShenLiang 提交者: GitHub

[HybridParallel] Fix amp bug in ModelParallel (#32579)

* fix amp bug

* fix name of wordsize
上级 9930a582
...@@ -67,10 +67,11 @@ class HybridParallelGradScaler: ...@@ -67,10 +67,11 @@ class HybridParallelGradScaler:
# allreduce_max found_inf in check_group # allreduce_max found_inf in check_group
if self._is_mp: if self._is_mp:
self._found_inf = paddle.cast(self._found_inf, dtype="int32") self._found_inf = paddle.cast(self._found_inf, dtype="int32")
# TODO(shenliang03) Since the minimize call in the optimizer is
# after the gradscaler, check_finite needs to synchronize global
# information. In the future, we should use check_group
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
self._found_inf, self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None)
op=paddle.distributed.ReduceOp.MAX,
group=self._hcg.get_check_parallel_group())
self._found_inf = paddle.cast(self._found_inf, dtype="bool") self._found_inf = paddle.cast(self._found_inf, dtype="bool")
def __getattr__(self, item): def __getattr__(self, item):
......
...@@ -77,7 +77,7 @@ class PipelineLayer(Layer): ...@@ -77,7 +77,7 @@ class PipelineLayer(Layer):
self.layers = layers self.layers = layers
self._loss_fn = loss_fn self._loss_fn = loss_fn
self._topo = topology self._topo = topology
word_size = dist.get_world_size() world_size = dist.get_world_size()
self.global_rank = dist.get_rank() self.global_rank = dist.get_rank()
if self._topo: if self._topo:
...@@ -88,11 +88,11 @@ class PipelineLayer(Layer): ...@@ -88,11 +88,11 @@ class PipelineLayer(Layer):
self._num_stages) self._num_stages)
else: else:
# construct default topology # construct default topology
if word_size % num_stages != 0: if world_size % num_stages != 0:
raise ValueError("should provide correct num_stages({}) " raise ValueError("should provide correct num_stages({}) "
"which can be divided by word_size({})".format( "which can be divided by world_size({})".
num_stages, word_size)) format(num_stages, world_size))
dp_num = word_size // num_stages dp_num = world_size // num_stages
self._topo = fleet.CommunicateTopology(["data", "pipe", "model"], self._topo = fleet.CommunicateTopology(["data", "pipe", "model"],
[dp_num, num_stages, 1]) [dp_num, num_stages, 1])
self._stage_id = self._topo.get_coord(self.global_rank).pipe self._stage_id = self._topo.get_coord(self.global_rank).pipe
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册