diff --git a/tools/program.py b/tools/program.py index c9eadea0b48643c89fbe612610a3cf2bd15c9eb5..8e84d30e64fa19a99fea205bca2d08c490b6fd7e 100755 --- a/tools/program.py +++ b/tools/program.py @@ -333,22 +333,6 @@ def eval(model, valid_dataloader, post_process_class, eval_class): return metirc -def save_inference_mode(model, config, logger): - if dist.get_rank() == 0: - model.eval() - print('infer') - save_path = '{}/infer/{}'.format(config['Global']['save_model_dir'], - config['Architecture']['model_type']) - if config['Architecture']['model_type'] == 'rec': - input_shape = [None, 3, 32, None] - jit_model = paddle.jit.to_static( - model, input_spec=[paddle.static.InputSpec(input_shape)]) - paddle.jit.save(jit_model, save_path) - logger.info('inference model save to {}'.format(save_path)) - - model.train() - - def preprocess(): FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) diff --git a/tools/train.py b/tools/train.py index 1cf644e6fd4b61d7925c6d9dda79855c7a72e886..6e44c5982ec5595c9202d83b14c058a7579c6a27 100755 --- a/tools/train.py +++ b/tools/train.py @@ -89,7 +89,6 @@ def main(config, device, logger, vdl_writer): program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, eval_class, pre_best_model_dict, logger, vdl_writer) - program.save_inference_mode(model, config, logger) def test_reader(config, device, logger):