提交 5d4144de 编写于 作者: G gong chen

bugfix(side effect): fix adding wrong control depend between AllReduce and GetStatus.

上级 c9fba7f0
......@@ -370,7 +370,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
self.grad_reducer = F.identity
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
degree = get_group_size()
......@@ -428,9 +428,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
mstype.float32))
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
# apply grad reducer on grads
grads = self.grad_reducer(grads)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
if self.is_distributed:
......
......@@ -220,7 +220,7 @@ class TrainOneStepWithLossScaleCell(Cell):
self.depend_parameter_use = ControlDepend(depend_mode=1)
self.allreduce = P.AllReduce()
self.parallel_mode = _get_parallel_mode()
self.grad_reducer = None
self.grad_reducer = F.identity
self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]
if self.reducer_flag:
mean = _get_mirror_mean()
......@@ -250,9 +250,8 @@ class TrainOneStepWithLossScaleCell(Cell):
scaling_sens = sens
grads = self.grad(self.network, weights)(data, label, F.cast(scaling_sens, F.dtype(loss)))
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
# apply grad reducer on grads
grads = self.grad_reducer(grads)
# get the overflow buffer
if not self.gpu_target:
self.get_status(init)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册