From ef0880b9ae498c5385073aec78330d0aa1c2bae5 Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Sat, 28 Nov 2020 22:56:10 +0800 Subject: [PATCH] delete save_inference_mode fun, because the dev paddle has support export crnn model --- tools/program.py | 16 ---------------- tools/train.py | 1 - 2 files changed, 17 deletions(-) diff --git a/tools/program.py b/tools/program.py index c9eadea0..8e84d30e 100755 --- a/tools/program.py +++ b/tools/program.py @@ -333,22 +333,6 @@ def eval(model, valid_dataloader, post_process_class, eval_class): return metirc -def save_inference_mode(model, config, logger): - if dist.get_rank() == 0: - model.eval() - print('infer') - save_path = '{}/infer/{}'.format(config['Global']['save_model_dir'], - config['Architecture']['model_type']) - if config['Architecture']['model_type'] == 'rec': - input_shape = [None, 3, 32, None] - jit_model = paddle.jit.to_static( - model, input_spec=[paddle.static.InputSpec(input_shape)]) - paddle.jit.save(jit_model, save_path) - logger.info('inference model save to {}'.format(save_path)) - - model.train() - - def preprocess(): FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) diff --git a/tools/train.py b/tools/train.py index 1cf644e6..6e44c598 100755 --- a/tools/train.py +++ b/tools/train.py @@ -89,7 +89,6 @@ def main(config, device, logger, vdl_writer): program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, eval_class, pre_best_model_dict, logger, vdl_writer) - program.save_inference_mode(model, config, logger) def test_reader(config, device, logger): -- GitLab