From a4b9b0c609328b5d0779befa11315dc68c85eafa Mon Sep 17 00:00:00 2001 From: whs Date: Mon, 22 Apr 2019 15:22:42 +0800 Subject: [PATCH] Fix get_attention_feeder_for_infer (#2067) --- PaddleCV/ocr_recognition/infer.py | 4 ++-- PaddleCV/ocr_recognition/utility.py | 13 ++++++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/PaddleCV/ocr_recognition/infer.py b/PaddleCV/ocr_recognition/infer.py index 5c3e1f24..a3adbb0c 100755 --- a/PaddleCV/ocr_recognition/infer.py +++ b/PaddleCV/ocr_recognition/infer.py @@ -31,7 +31,7 @@ def inference(args): """OCR inference""" if args.model == "crnn_ctc": infer = ctc_infer - get_feeder_data = get_ctc_feeder_data + get_feeder_data = get_ctc_feeder_for_infer else: infer = attention_infer get_feeder_data = get_attention_feeder_for_infer @@ -78,7 +78,7 @@ def inference(args): batch_times = [] iters = 0 for data in infer_reader(): - feed_dict = get_feeder_data(data, place, need_label=False) + feed_dict = get_feeder_data(data, place) if args.iterations > 0 and iters == args.iterations + args.skip_batch_num: break if iters < args.skip_batch_num: diff --git a/PaddleCV/ocr_recognition/utility.py b/PaddleCV/ocr_recognition/utility.py index fb8d066c..3d0adb77 100755 --- a/PaddleCV/ocr_recognition/utility.py +++ b/PaddleCV/ocr_recognition/utility.py @@ -83,7 +83,8 @@ def get_ctc_feeder_data(data, place, need_label=True): pixel_tensor = core.LoDTensor() pixel_data = None pixel_data = np.concatenate( - list(map(lambda x: x[0][np.newaxis, :], data)), axis=0).astype("float32") + list(map(lambda x: x[0][np.newaxis, :], data)), + axis=0).astype("float32") pixel_tensor.set(pixel_data, place) label_tensor = to_lodtensor(list(map(lambda x: x[1], data)), place) if need_label: @@ -92,11 +93,16 @@ def get_ctc_feeder_data(data, place, need_label=True): return {"pixel": pixel_tensor} +def get_ctc_feeder_for_infer(data, place): + return get_ctc_feeder_data(data, place, need_label=False) + + def get_attention_feeder_data(data, place, need_label=True): pixel_tensor = core.LoDTensor() pixel_data = None pixel_data = np.concatenate( - list(map(lambda x: x[0][np.newaxis, :], data)), axis=0).astype("float32") + list(map(lambda x: x[0][np.newaxis, :], data)), + axis=0).astype("float32") pixel_tensor.set(pixel_data, place) label_in_tensor = to_lodtensor(list(map(lambda x: x[1], data)), place) label_out_tensor = to_lodtensor(list(map(lambda x: x[2], data)), place) @@ -127,7 +133,8 @@ def get_attention_feeder_for_infer(data, place): pixel_tensor = core.LoDTensor() pixel_data = None pixel_data = np.concatenate( - list(map(lambda x: x[0][np.newaxis, :], data)), axis=0).astype("float32") + list(map(lambda x: x[0][np.newaxis, :], data)), + axis=0).astype("float32") pixel_tensor.set(pixel_data, place) return { "pixel": pixel_tensor, -- GitLab