未验证 提交 82367481 编写于 作者: X xiemoyuan 提交者: GitHub

Fix DGU bug (#5272)

上级 09490523
......@@ -128,22 +128,13 @@ def train(args, model, train_data_loader, dev_data_loader, metric, rank):
max_train_steps=max_train_steps)
lr_scheduler = LambdaDecay(args.learning_rate, factor_fn)
optimizer = AdamW(
learning_rate=lr_scheduler,
parameters=model.parameters(),
weight_decay=args.weight_decay,
apply_decay_param_fun=lambda x: x in [
params.name for params in model.parameters()
if not any(nd in params.name for nd in ['bias', 'norm'])],
grad_clip=nn.ClipGradByGlobalNorm(args.max_grad_norm)
)
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler,
parameters=model.parameters(),
weight_decay=args.weight_decay,
apply_decay_param_fun=lambda x: x in [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
])
if not any(nd in n for nd in ["bias", "norm"])],
grad_clip=nn.ClipGradByGlobalNorm(args.max_grad_norm))
loss_fn = DGULossFunction(args.task_name)
load_ckpt(args, model, optimizer)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册