提交 a1e5f270 编写于 作者: Z zxcd

mv scaler.unscale_ blow grad_clip.

上级 7399d560
...@@ -82,7 +82,6 @@ class U2Trainer(Trainer): ...@@ -82,7 +82,6 @@ class U2Trainer(Trainer):
with context(): with context():
if scaler: if scaler:
scaler.scale(loss).backward() scaler.scale(loss).backward()
scaler.unscale_(self.optimizer)
else: else:
loss.backward() loss.backward()
layer_tools.print_grads(self.model, print_func=None) layer_tools.print_grads(self.model, print_func=None)
...@@ -91,6 +90,8 @@ class U2Trainer(Trainer): ...@@ -91,6 +90,8 @@ class U2Trainer(Trainer):
if (batch_index + 1) % train_conf.accum_grad == 0: if (batch_index + 1) % train_conf.accum_grad == 0:
# do global grad clip # do global grad clip
if train_conf.global_grad_clip != 0: if train_conf.global_grad_clip != 0:
if scaler:
scaler.unscale_(self.optimizer)
# need paddlepaddle==develop or paddlepaddle>=2.5 # need paddlepaddle==develop or paddlepaddle>=2.5
clip_grad_norm_(self.model.parameters(), clip_grad_norm_(self.model.parameters(),
train_conf.global_grad_clip) train_conf.global_grad_clip)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册