diff --git a/tools/export_model.py b/tools/export_model.py index f587b2bb363e01ab4c0b2429fc95f243085649d1..37e7ba61314e8e74cc27f33b3e8349bbc8531b9a 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -47,23 +47,25 @@ def main(): char_num = len(getattr(post_process_class, 'character')) config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) - init_model(config, model, logger) + # init_model(config, model, logger) model.eval() save_path = '{}/inference'.format(config['Global']['save_inference_dir']) if config['Architecture']['algorithm'] == "SRN": + max_text_length = config['Architecture']['Head']['max_text_length'] other_shape = [ paddle.static.InputSpec( shape=[None, 1, 64, 256], dtype='float32'), [ paddle.static.InputSpec( shape=[None, 256, 1], dtype="int64"), paddle.static.InputSpec( - shape=[None, 25, 1], - dtype="int64"), paddle.static.InputSpec( - shape=[None, 8, 25, 25], dtype="int64"), + shape=[None, max_text_length, 1], dtype="int64"), paddle.static.InputSpec( - shape=[None, 8, 25, 25], dtype="int64") + shape=[None, 8, max_text_length, max_text_length], + dtype="int64"), paddle.static.InputSpec( + shape=[None, 8, max_text_length, max_text_length], + dtype="int64") ] ] model = to_static(model, input_spec=other_shape)