From 89e031f0e7491f66e7fa783fab14ba9e302e6ca2 Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Mon, 9 Nov 2020 16:40:24 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0infer=5Fdet=E5=92=8Cinfer=5Fr?= =?UTF-8?q?ec?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/infer_det.py | 23 +++++++---------------- tools/infer_rec.py | 21 ++++++--------------- 2 files changed, 13 insertions(+), 31 deletions(-) diff --git a/tools/infer_det.py b/tools/infer_det.py index 8e6b6b21..d1b1b752 100755 --- a/tools/infer_det.py +++ b/tools/infer_det.py @@ -29,12 +29,11 @@ import cv2 import json import paddle -from ppocr.utils.logging import get_logger from ppocr.data import create_operators, transform -from ppocr.modeling import build_model +from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process from ppocr.utils.save_load import init_model -from ppocr.utils.utility import print_dict, get_image_file_list +from ppocr.utils.utility import get_image_file_list import tools.program as program @@ -67,11 +66,11 @@ def main(): # create data ops transforms = [] - for op in config['EVAL']['dataset']['transforms']: + for op in config['Eval']['dataset']['transforms']: op_name = list(op)[0] if 'Label' in op_name: continue - elif op_name == 'keepKeys': + elif op_name == 'KeepKeys': op[op_name]['keep_keys'] = ['image', 'shape'] transforms.append(op) @@ -92,8 +91,7 @@ def main(): images = np.expand_dims(batch[0], axis=0) shape_list = np.expand_dims(batch[1], axis=0) - images = paddle.to_variable(images) - print(images.shape) + images = paddle.to_tensor(images) preds = model(images) post_result = post_process_class(preds, shape_list) boxes = post_result[0]['points'] @@ -109,14 +107,7 @@ def main(): draw_det_res(boxes, config, src_img, file) logger.info("success!") - # save inference model - # paddle.jit.save(model, 'output/model') - if __name__ == '__main__': - place, config = program.preprocess() - paddle.disable_static(place) - - logger = get_logger() - print_dict(config, logger) - main() + config, device, logger, vdl_writer = program.preprocess() + main() \ No newline at end of file diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 239d2dcb..e3e85b5d 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -27,12 +27,11 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) import paddle -from ppocr.utils.logging import get_logger from ppocr.data import create_operators, transform -from ppocr.modeling import build_model +from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process from ppocr.utils.save_load import init_model -from ppocr.utils.utility import print_dict, get_image_file_list +from ppocr.utils.utility import get_image_file_list import tools.program as program @@ -54,13 +53,13 @@ def main(): # create data ops transforms = [] - for op in config['EVAL']['dataset']['transforms']: + for op in config['Eval']['dataset']['transforms']: op_name = list(op)[0] if 'Label' in op_name: continue elif op_name in ['RecResizeImg']: op[op_name]['infer_mode'] = True - elif op_name == 'keepKeys': + elif op_name == 'KeepKeys': op[op_name]['keep_keys'] = ['image'] transforms.append(op) global_config['infer_mode'] = True @@ -75,22 +74,14 @@ def main(): batch = transform(data, ops) images = np.expand_dims(batch[0], axis=0) - images = paddle.to_variable(images) + images = paddle.to_tensor(images) preds = model(images) post_result = post_process_class(preds) for rec_reuslt in post_result: logger.info('\t result: {}'.format(rec_reuslt)) logger.info("success!") - # save inference model - # currently, paddle.jit.to_static not support rnn - # paddle.jit.save(model, 'output/rec/model') - if __name__ == '__main__': - place, config = program.preprocess() - paddle.disable_static(place) - - logger = get_logger() - print_dict(config, logger) + config, device, logger, vdl_writer = program.preprocess() main() -- GitLab