diff --git a/tools/train.py b/tools/train.py index a5f765f066bfefcf419cb78518a4b58d870c326c..3456e2aec4dd6f50547bbda927971a6b309cc663 100644 --- a/tools/train.py +++ b/tools/train.py @@ -134,8 +134,7 @@ def main(args): model_path = os.path.join(config.model_save_dir, config.ARCHITECTURE["name"]) - save_model(train_prog, model_path, - "best_model_in_epoch_" + str(epoch_id)) + save_model(train_prog, model_path, "best_model") # 3. save the persistable model if epoch_id % config.save_interval == 0: