diff --git a/tools/program.py b/tools/program.py index 5963016b66d4d254e180b0bc9ad098e49518969f..6456aad5dcda764816e5af7b5becf30cc7192af4 100755 --- a/tools/program.py +++ b/tools/program.py @@ -159,7 +159,8 @@ def train(config, eval_class, pre_best_model_dict, logger, - vdl_writer=None): + vdl_writer=None, + scaler=None): cal_metric_during_train = config['Global'].get('cal_metric_during_train', False) log_smooth_window = config['Global']['log_smooth_window'] diff --git a/tools/train.py b/tools/train.py index b34ac9790e4ff776a79e5e9d556d2dd0e020911d..49e44112c2939f4cf0fde9aca773d508d2f95736 100755 --- a/tools/train.py +++ b/tools/train.py @@ -122,7 +122,7 @@ def main(config, device, logger, vdl_writer): # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, - eval_class, pre_best_model_dict, logger, vdl_writer) + eval_class, pre_best_model_dict, logger, vdl_writer, scaler) def test_reader(config, device, logger):