提交 162f013e 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: minimize() dont support parameter_list of type dict

there are diffs that step()+update() and minimize().
this will be fixed in https://github.com/PaddlePaddle/Paddle/pull/53773.
上级 8b218b01
......@@ -61,7 +61,8 @@ def train_epoch(engine, epoch_id, print_batch_step):
if (iter_id + 1) % engine.update_freq == 0:
for i in range(len(engine.optimizer)):
# optimizer.step() with auto amp
engine.scaler.minimize(engine.optimizer[i], scaled)
engine.scaler.step(engine.optimizer[i])
engine.scaler.update()
if (iter_id + 1) % engine.update_freq == 0:
# clear grad
......
......@@ -75,7 +75,9 @@ def train_epoch_fixmatch(engine, epoch_id, print_batch_step):
scaled.backward()
for i in range(len(engine.optimizer)):
engine.scaler.minimize(engine.optimizer[i], scaled)
# optimizer.step() with auto amp
engine.scaler.step(engine.optimizer[i])
engine.scaler.update()
# step lr(by step)
for i in range(len(engine.lr_sch)):
......
......@@ -201,7 +201,11 @@ def backward(engine, loss, optimizer):
optimizer.clear_grad()
scaled = engine.scaler.scale(loss)
scaled.backward()
engine.scaler.minimize(optimizer, scaled)
# optimizer.step() with auto amp
engine.scaler.step(optimizer)
engine.scaler.update()
for name, layer in engine.model.backbone.named_sublayers():
if "gate" == name.split('.')[-1]:
layer.clip_gate()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册