提交 8bb9fb7e 编写于 作者: S stephon

fix some error

上级 2005cc3e
...@@ -159,7 +159,8 @@ def train(config, ...@@ -159,7 +159,8 @@ def train(config,
eval_class, eval_class,
pre_best_model_dict, pre_best_model_dict,
logger, logger,
vdl_writer=None): vdl_writer=None,
scaler=None):
cal_metric_during_train = config['Global'].get('cal_metric_during_train', cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False) False)
log_smooth_window = config['Global']['log_smooth_window'] log_smooth_window = config['Global']['log_smooth_window']
......
...@@ -122,7 +122,7 @@ def main(config, device, logger, vdl_writer): ...@@ -122,7 +122,7 @@ def main(config, device, logger, vdl_writer):
# start train # start train
program.train(config, train_dataloader, valid_dataloader, device, model, program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class, loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer) eval_class, pre_best_model_dict, logger, vdl_writer, scaler)
def test_reader(config, device, logger): def test_reader(config, device, logger):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册