未验证 提交 b3c28179 编写于 作者: M Meiyim 提交者: GitHub

[fix] adamw : `exclude from weight_decay` (#614)

* [fix] adamw : `exclude from weight_decay`

* [fix] fix demo `finetune_classifier.py`
Co-authored-by: Nchenxuyi <work@yq01-qianmo-com-255-129-11.yq01.baidu.com>
上级 738e3688
......@@ -177,15 +177,15 @@ if args.use_lr_decay:
lr_scheduler,
parameters=model.parameters(),
weight_decay=args.wd,
apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n),
apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n),
grad_clip=g_clip)
else:
lr_scheduler = None
opt = P.optimizer.Adam(
opt = P.optimizer.AdamW(
args.lr,
parameters=model.parameters(),
weight_decay=args.wd,
apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n),
apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n),
grad_clip=g_clip)
scaler = P.amp.GradScaler(enable=args.use_amp)
......@@ -209,7 +209,8 @@ with LogWriter(
lr_scheduler and lr_scheduler.step()
if step % 10 == 0:
_lr = lr_scheduler.get_lr()
_lr = lr_scheduler.get_lr(
) if args.use_lr_decay else args.lr
if args.use_amp:
_l = (loss / scaler._scale).numpy()
msg = '[step-%d] train loss %.5f lr %.3e scaling %.3e' % (
......
......@@ -144,7 +144,7 @@ lr_scheduler = P.optimizer.lr.LambdaDecay(
opt = P.optimizer.AdamW(
learning_rate=lr_scheduler,
parameters=model.parameters(),
apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n),
apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n),
weight_decay=args.wd,
grad_clip=g_clip)
scaler = P.amp.GradScaler(enable=args.use_amp)
......
......@@ -210,7 +210,7 @@ opt = P.optimizer.AdamW(
lr_scheduler,
parameters=model.parameters(),
weight_decay=args.wd,
apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n),
apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n),
grad_clip=g_clip)
scaler = P.amp.GradScaler(enable=args.use_amp)
......
......@@ -126,7 +126,7 @@ if not args.eval:
lr_scheduler,
parameters=model.parameters(),
weight_decay=args.wd,
apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n),
apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n),
grad_clip=g_clip)
scaler = P.amp.GradScaler(enable=args.use_amp)
with LogWriter(logdir=str(create_if_not_exists(args.save_dir /
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册