提交 8bb9fb7e 编写于 作者: S stephon

fix some error

上级 2005cc3e
......@@ -159,7 +159,8 @@ def train(config,
eval_class,
pre_best_model_dict,
logger,
vdl_writer=None):
vdl_writer=None,
scaler=None):
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False)
log_smooth_window = config['Global']['log_smooth_window']
......
......@@ -122,7 +122,7 @@ def main(config, device, logger, vdl_writer):
# start train
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer)
eval_class, pre_best_model_dict, logger, vdl_writer, scaler)
def test_reader(config, device, logger):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册