From 8bb9fb7e3d18a23f2e86373c58406a777b04c821 Mon Sep 17 00:00:00 2001 From: stephon Date: Fri, 15 Oct 2021 08:34:27 +0000 Subject: [PATCH] fix some error --- tools/program.py | 3 ++- tools/train.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tools/program.py b/tools/program.py index 5963016b..6456aad5 100755 --- a/tools/program.py +++ b/tools/program.py @@ -159,7 +159,8 @@ def train(config, eval_class, pre_best_model_dict, logger, - vdl_writer=None): + vdl_writer=None, + scaler=None): cal_metric_during_train = config['Global'].get('cal_metric_during_train', False) log_smooth_window = config['Global']['log_smooth_window'] diff --git a/tools/train.py b/tools/train.py index b34ac979..49e44112 100755 --- a/tools/train.py +++ b/tools/train.py @@ -122,7 +122,7 @@ def main(config, device, logger, vdl_writer): # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, - eval_class, pre_best_model_dict, logger, vdl_writer) + eval_class, pre_best_model_dict, logger, vdl_writer, scaler) def test_reader(config, device, logger): -- GitLab