diff --git a/tools/program.py b/tools/program.py index 8bae0fd5d16f4b17520c1162f6cd9bd54f032a73..c2b9306c305c5dcc817acb6829fd562be10b8d51 100755 --- a/tools/program.py +++ b/tools/program.py @@ -323,6 +323,20 @@ def eval(model, valid_dataloader, post_process_class, eval_class): return metirc +def save_inference_mode(model, config, logger): + model.eval() + save_path = '{}/infer/{}'.format(config['Global']['save_model_dir'], + config['Architecture']['model_type']) + if config['Architecture']['model_type'] == 'rec': + input_shape = [None, 3, 32, None] + jit_model = paddle.jit.to_static( + model, input_spec=[paddle.static.InputSpec(input_shape)]) + paddle.jit.save(jit_model, save_path) + logger.info('inference model save to {}'.format(save_path)) + + model.train() + + def preprocess(): FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) @@ -334,7 +348,7 @@ def preprocess(): alg = config['Architecture']['algorithm'] assert alg in [ - 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN' + 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS' ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' diff --git a/tools/train.py b/tools/train.py index c16223796f8785d1928ba1df57602967e0518e31..1cf644e6fd4b61d7925c6d9dda79855c7a72e886 100755 --- a/tools/train.py +++ b/tools/train.py @@ -89,6 +89,7 @@ def main(config, device, logger, vdl_writer): program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, eval_class, pre_best_model_dict, logger, vdl_writer) + program.save_inference_mode(model, config, logger) def test_reader(config, device, logger): @@ -102,8 +103,8 @@ def test_reader(config, device, logger): if count % 1 == 0: batch_time = time.time() - starttime starttime = time.time() - logger.info("reader: {}, {}, {}".format(count, - len(data), batch_time)) + logger.info("reader: {}, {}, {}".format( + count, len(data[0]), batch_time)) except Exception as e: logger.info(e) logger.info("finish reader: {}, Success!".format(count)) @@ -112,4 +113,4 @@ def test_reader(config, device, logger): if __name__ == '__main__': config, device, logger, vdl_writer = program.preprocess() main(config, device, logger, vdl_writer) -# test_reader(config, device, logger) + # test_reader(config, device, logger)