未验证 提交 b3c28179 编写于 作者: M Meiyim 提交者: GitHub

[fix] adamw : `exclude from weight_decay` (#614)

* [fix] adamw : `exclude from weight_decay`

* [fix] fix demo `finetune_classifier.py`
Co-authored-by: Nchenxuyi <work@yq01-qianmo-com-255-129-11.yq01.baidu.com>
上级 738e3688
...@@ -177,15 +177,15 @@ if args.use_lr_decay: ...@@ -177,15 +177,15 @@ if args.use_lr_decay:
lr_scheduler, lr_scheduler,
parameters=model.parameters(), parameters=model.parameters(),
weight_decay=args.wd, weight_decay=args.wd,
apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n), apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n),
grad_clip=g_clip) grad_clip=g_clip)
else: else:
lr_scheduler = None lr_scheduler = None
opt = P.optimizer.Adam( opt = P.optimizer.AdamW(
args.lr, args.lr,
parameters=model.parameters(), parameters=model.parameters(),
weight_decay=args.wd, weight_decay=args.wd,
apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n), apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n),
grad_clip=g_clip) grad_clip=g_clip)
scaler = P.amp.GradScaler(enable=args.use_amp) scaler = P.amp.GradScaler(enable=args.use_amp)
...@@ -209,7 +209,8 @@ with LogWriter( ...@@ -209,7 +209,8 @@ with LogWriter(
lr_scheduler and lr_scheduler.step() lr_scheduler and lr_scheduler.step()
if step % 10 == 0: if step % 10 == 0:
_lr = lr_scheduler.get_lr() _lr = lr_scheduler.get_lr(
) if args.use_lr_decay else args.lr
if args.use_amp: if args.use_amp:
_l = (loss / scaler._scale).numpy() _l = (loss / scaler._scale).numpy()
msg = '[step-%d] train loss %.5f lr %.3e scaling %.3e' % ( msg = '[step-%d] train loss %.5f lr %.3e scaling %.3e' % (
......
...@@ -144,7 +144,7 @@ lr_scheduler = P.optimizer.lr.LambdaDecay( ...@@ -144,7 +144,7 @@ lr_scheduler = P.optimizer.lr.LambdaDecay(
opt = P.optimizer.AdamW( opt = P.optimizer.AdamW(
learning_rate=lr_scheduler, learning_rate=lr_scheduler,
parameters=model.parameters(), parameters=model.parameters(),
apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n), apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n),
weight_decay=args.wd, weight_decay=args.wd,
grad_clip=g_clip) grad_clip=g_clip)
scaler = P.amp.GradScaler(enable=args.use_amp) scaler = P.amp.GradScaler(enable=args.use_amp)
......
...@@ -210,7 +210,7 @@ opt = P.optimizer.AdamW( ...@@ -210,7 +210,7 @@ opt = P.optimizer.AdamW(
lr_scheduler, lr_scheduler,
parameters=model.parameters(), parameters=model.parameters(),
weight_decay=args.wd, weight_decay=args.wd,
apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n), apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n),
grad_clip=g_clip) grad_clip=g_clip)
scaler = P.amp.GradScaler(enable=args.use_amp) scaler = P.amp.GradScaler(enable=args.use_amp)
......
...@@ -126,7 +126,7 @@ if not args.eval: ...@@ -126,7 +126,7 @@ if not args.eval:
lr_scheduler, lr_scheduler,
parameters=model.parameters(), parameters=model.parameters(),
weight_decay=args.wd, weight_decay=args.wd,
apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n), apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n),
grad_clip=g_clip) grad_clip=g_clip)
scaler = P.amp.GradScaler(enable=args.use_amp) scaler = P.amp.GradScaler(enable=args.use_amp)
with LogWriter(logdir=str(create_if_not_exists(args.save_dir / with LogWriter(logdir=str(create_if_not_exists(args.save_dir /
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册