diff --git a/tools/train.py b/tools/train.py index 2091ff48b4b83c1e3955d0b9600c60815d4d99ec..05d295aa99718c25b94a123c23d08c2904fe8c6a 100755 --- a/tools/train.py +++ b/tools/train.py @@ -97,8 +97,7 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) # load pretrain model - #pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer) - pre_best_model_dict = {} + pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer) logger.info('train dataloader has {} iters'.format(len(train_dataloader))) if valid_dataloader is not None: logger.info('valid dataloader has {} iters'.format(