未验证 提交 1691dc7a 编写于 作者: S ShenLiang 提交者: GitHub

add update (#36017)

上级 82f255d0
......@@ -329,6 +329,7 @@ class PipelineParallel(MetaParallelBase):
def _optimizer_step(self):
if self.scaler:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
......
......@@ -48,6 +48,7 @@ class TestMPClipGrad(TestDistMPTraning):
scaled.backward() # do backward
scaler.step(optimizer) # update parameters
scaler.update()
optimizer.clear_grad()
return scaled
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册