提交 f313a6d8 编写于 作者: T tianyi1997 提交者: HydrogenSulfate

Support training without amp

上级 4553d22c
......@@ -98,7 +98,6 @@ def train_epoch_metabin(engine, epoch_id, print_batch_step):
for key, value in mtest_loss_dict.items()}
}
# step lr (by iter)
# the last lr_sch is cyclic_lr
for i in range(len(engine.lr_sch)):
if not getattr(engine.lr_sch[i], "by_epoch", False):
engine.lr_sch[i].step()
......@@ -117,7 +116,6 @@ def train_epoch_metabin(engine, epoch_id, print_batch_step):
tic = time.time()
# step lr(by epoch)
# the last lr_sch is cyclic_lr
for i in range(len(engine.lr_sch)):
if getattr(engine.lr_sch[i], "by_epoch", False) and \
type_name(engine.lr_sch[i]) != "ReduceOnPlateau":
......@@ -191,10 +189,16 @@ def get_meta_data(meta_dataloader_iter, num_domain):
def forward(engine, batch, loss_func):
batch_info = defaultdict()
batch_info = {"label": batch[1], "domain": batch[2]}
amp_level = engine.config["AMP"].get("level", "O1").upper()
with paddle.amp.auto_cast(
custom_black_list={"flatten_contiguous_range", "greater_than"},
level=amp_level):
if engine.amp:
amp_level = engine.config["AMP"].get("level", "O1").upper()
with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=amp_level):
out = engine.model(batch[0], batch[1])
loss_dict = loss_func(out, batch_info)
else:
out = engine.model(batch[0], batch[1])
loss_dict = loss_func(out, batch_info)
return out, loss_dict
......@@ -202,9 +206,13 @@ def forward(engine, batch, loss_func):
def backward(engine, loss, optimizer):
optimizer.clear_grad()
scaled = engine.scaler.scale(loss)
scaled.backward()
engine.scaler.minimize(optimizer, scaled)
if engine.amp:
scaled = engine.scaler.scale(loss)
scaled.backward()
engine.scaler.minimize(optimizer, scaled)
else:
loss.backward()
optimizer.step()
for name, layer in engine.model.backbone.named_sublayers():
if "gate" == name.split('.')[-1]:
layer.clip_gate()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册