提交 89e031f0 编写于 作者: W WenmuZhou

添加infer_det和infer_rec

上级 b28ea0a9
......@@ -29,12 +29,11 @@ import cv2
import json
import paddle
from ppocr.utils.logging import get_logger
from ppocr.data import create_operators, transform
from ppocr.modeling import build_model
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.utility import print_dict, get_image_file_list
from ppocr.utils.utility import get_image_file_list
import tools.program as program
......@@ -67,11 +66,11 @@ def main():
# create data ops
transforms = []
for op in config['EVAL']['dataset']['transforms']:
for op in config['Eval']['dataset']['transforms']:
op_name = list(op)[0]
if 'Label' in op_name:
continue
elif op_name == 'keepKeys':
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['image', 'shape']
transforms.append(op)
......@@ -92,8 +91,7 @@ def main():
images = np.expand_dims(batch[0], axis=0)
shape_list = np.expand_dims(batch[1], axis=0)
images = paddle.to_variable(images)
print(images.shape)
images = paddle.to_tensor(images)
preds = model(images)
post_result = post_process_class(preds, shape_list)
boxes = post_result[0]['points']
......@@ -109,14 +107,7 @@ def main():
draw_det_res(boxes, config, src_img, file)
logger.info("success!")
# save inference model
# paddle.jit.save(model, 'output/model')
if __name__ == '__main__':
place, config = program.preprocess()
paddle.disable_static(place)
logger = get_logger()
print_dict(config, logger)
config, device, logger, vdl_writer = program.preprocess()
main()
\ No newline at end of file
......@@ -27,12 +27,11 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
import paddle
from ppocr.utils.logging import get_logger
from ppocr.data import create_operators, transform
from ppocr.modeling import build_model
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.utility import print_dict, get_image_file_list
from ppocr.utils.utility import get_image_file_list
import tools.program as program
......@@ -54,13 +53,13 @@ def main():
# create data ops
transforms = []
for op in config['EVAL']['dataset']['transforms']:
for op in config['Eval']['dataset']['transforms']:
op_name = list(op)[0]
if 'Label' in op_name:
continue
elif op_name in ['RecResizeImg']:
op[op_name]['infer_mode'] = True
elif op_name == 'keepKeys':
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['image']
transforms.append(op)
global_config['infer_mode'] = True
......@@ -75,22 +74,14 @@ def main():
batch = transform(data, ops)
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_variable(images)
images = paddle.to_tensor(images)
preds = model(images)
post_result = post_process_class(preds)
for rec_reuslt in post_result:
logger.info('\t result: {}'.format(rec_reuslt))
logger.info("success!")
# save inference model
# currently, paddle.jit.to_static not support rnn
# paddle.jit.save(model, 'output/rec/model')
if __name__ == '__main__':
place, config = program.preprocess()
paddle.disable_static(place)
logger = get_logger()
print_dict(config, logger)
config, device, logger, vdl_writer = program.preprocess()
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册