From 4d44b23043063202852d157deb4206c2f91dcf29 Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Thu, 12 Nov 2020 12:06:46 +0800 Subject: [PATCH] =?UTF-8?q?=E8=AF=86=E5=88=AB=E6=A8=A1=E5=9E=8B=E5=AF=BC?= =?UTF-8?q?=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/program.py | 16 +++++++++++++++- tools/train.py | 7 ++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/tools/program.py b/tools/program.py index 8bae0fd5..c2b9306c 100755 --- a/tools/program.py +++ b/tools/program.py @@ -323,6 +323,20 @@ def eval(model, valid_dataloader, post_process_class, eval_class): return metirc +def save_inference_mode(model, config, logger): + model.eval() + save_path = '{}/infer/{}'.format(config['Global']['save_model_dir'], + config['Architecture']['model_type']) + if config['Architecture']['model_type'] == 'rec': + input_shape = [None, 3, 32, None] + jit_model = paddle.jit.to_static( + model, input_spec=[paddle.static.InputSpec(input_shape)]) + paddle.jit.save(jit_model, save_path) + logger.info('inference model save to {}'.format(save_path)) + + model.train() + + def preprocess(): FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) @@ -334,7 +348,7 @@ def preprocess(): alg = config['Architecture']['algorithm'] assert alg in [ - 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN' + 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS' ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' diff --git a/tools/train.py b/tools/train.py index c1622379..1cf644e6 100755 --- a/tools/train.py +++ b/tools/train.py @@ -89,6 +89,7 @@ def main(config, device, logger, vdl_writer): program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, eval_class, pre_best_model_dict, logger, vdl_writer) + program.save_inference_mode(model, config, logger) def test_reader(config, device, logger): @@ -102,8 +103,8 @@ def test_reader(config, device, logger): if count % 1 == 0: batch_time = time.time() - starttime starttime = time.time() - logger.info("reader: {}, {}, {}".format(count, - len(data), batch_time)) + logger.info("reader: {}, {}, {}".format( + count, len(data[0]), batch_time)) except Exception as e: logger.info(e) logger.info("finish reader: {}, Success!".format(count)) @@ -112,4 +113,4 @@ def test_reader(config, device, logger): if __name__ == '__main__': config, device, logger, vdl_writer = program.preprocess() main(config, device, logger, vdl_writer) -# test_reader(config, device, logger) + # test_reader(config, device, logger) -- GitLab