From b3c28179aabf2d00ce8ec5d4f321e66da1e6a11b Mon Sep 17 00:00:00 2001 From: Meiyim Date: Wed, 13 Jan 2021 19:49:59 +0800 Subject: [PATCH] [fix] adamw : `exclude from weight_decay` (#614) * [fix] adamw : `exclude from weight_decay` * [fix] fix demo `finetune_classifier.py` Co-authored-by: chenxuyi --- demo/finetune_classifier.py | 9 +++++---- demo/finetune_classifier_distributed.py | 2 +- demo/finetune_ner.py | 2 +- demo/finetune_sentiment_analysis.py | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/demo/finetune_classifier.py b/demo/finetune_classifier.py index 4e9540d..ce14d3e 100644 --- a/demo/finetune_classifier.py +++ b/demo/finetune_classifier.py @@ -177,15 +177,15 @@ if args.use_lr_decay: lr_scheduler, parameters=model.parameters(), weight_decay=args.wd, - apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n), + apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n), grad_clip=g_clip) else: lr_scheduler = None - opt = P.optimizer.Adam( + opt = P.optimizer.AdamW( args.lr, parameters=model.parameters(), weight_decay=args.wd, - apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n), + apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n), grad_clip=g_clip) scaler = P.amp.GradScaler(enable=args.use_amp) @@ -209,7 +209,8 @@ with LogWriter( lr_scheduler and lr_scheduler.step() if step % 10 == 0: - _lr = lr_scheduler.get_lr() + _lr = lr_scheduler.get_lr( + ) if args.use_lr_decay else args.lr if args.use_amp: _l = (loss / scaler._scale).numpy() msg = '[step-%d] train loss %.5f lr %.3e scaling %.3e' % ( diff --git a/demo/finetune_classifier_distributed.py b/demo/finetune_classifier_distributed.py index d1df867..d4b1195 100644 --- a/demo/finetune_classifier_distributed.py +++ b/demo/finetune_classifier_distributed.py @@ -144,7 +144,7 @@ lr_scheduler = P.optimizer.lr.LambdaDecay( opt = P.optimizer.AdamW( learning_rate=lr_scheduler, parameters=model.parameters(), - apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n), + apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n), weight_decay=args.wd, grad_clip=g_clip) scaler = P.amp.GradScaler(enable=args.use_amp) diff --git a/demo/finetune_ner.py b/demo/finetune_ner.py index 7489f16..6929afd 100644 --- a/demo/finetune_ner.py +++ b/demo/finetune_ner.py @@ -210,7 +210,7 @@ opt = P.optimizer.AdamW( lr_scheduler, parameters=model.parameters(), weight_decay=args.wd, - apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n), + apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n), grad_clip=g_clip) scaler = P.amp.GradScaler(enable=args.use_amp) diff --git a/demo/finetune_sentiment_analysis.py b/demo/finetune_sentiment_analysis.py index 015d29d..16087fa 100644 --- a/demo/finetune_sentiment_analysis.py +++ b/demo/finetune_sentiment_analysis.py @@ -126,7 +126,7 @@ if not args.eval: lr_scheduler, parameters=model.parameters(), weight_decay=args.wd, - apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n), + apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n), grad_clip=g_clip) scaler = P.amp.GradScaler(enable=args.use_amp) with LogWriter(logdir=str(create_if_not_exists(args.save_dir / -- GitLab