未验证 提交 1691dc7a 编写于 作者: S ShenLiang 提交者: GitHub

add update (#36017)

上级 82f255d0
...@@ -329,6 +329,7 @@ class PipelineParallel(MetaParallelBase): ...@@ -329,6 +329,7 @@ class PipelineParallel(MetaParallelBase):
def _optimizer_step(self): def _optimizer_step(self):
if self.scaler: if self.scaler:
self.scaler.step(self.optimizer) self.scaler.step(self.optimizer)
self.scaler.update()
else: else:
self.optimizer.step() self.optimizer.step()
......
...@@ -48,6 +48,7 @@ class TestMPClipGrad(TestDistMPTraning): ...@@ -48,6 +48,7 @@ class TestMPClipGrad(TestDistMPTraning):
scaled.backward() # do backward scaled.backward() # do backward
scaler.step(optimizer) # update parameters scaler.step(optimizer) # update parameters
scaler.update()
optimizer.clear_grad() optimizer.clear_grad()
return scaled return scaled
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册