未验证 提交 a4b9b0c6 编写于 作者: W whs 提交者: GitHub

Fix get_attention_feeder_for_infer (#2067)

上级 dd6e2970
......@@ -31,7 +31,7 @@ def inference(args):
"""OCR inference"""
if args.model == "crnn_ctc":
infer = ctc_infer
get_feeder_data = get_ctc_feeder_data
get_feeder_data = get_ctc_feeder_for_infer
else:
infer = attention_infer
get_feeder_data = get_attention_feeder_for_infer
......@@ -78,7 +78,7 @@ def inference(args):
batch_times = []
iters = 0
for data in infer_reader():
feed_dict = get_feeder_data(data, place, need_label=False)
feed_dict = get_feeder_data(data, place)
if args.iterations > 0 and iters == args.iterations + args.skip_batch_num:
break
if iters < args.skip_batch_num:
......
......@@ -83,7 +83,8 @@ def get_ctc_feeder_data(data, place, need_label=True):
pixel_tensor = core.LoDTensor()
pixel_data = None
pixel_data = np.concatenate(
list(map(lambda x: x[0][np.newaxis, :], data)), axis=0).astype("float32")
list(map(lambda x: x[0][np.newaxis, :], data)),
axis=0).astype("float32")
pixel_tensor.set(pixel_data, place)
label_tensor = to_lodtensor(list(map(lambda x: x[1], data)), place)
if need_label:
......@@ -92,11 +93,16 @@ def get_ctc_feeder_data(data, place, need_label=True):
return {"pixel": pixel_tensor}
def get_ctc_feeder_for_infer(data, place):
return get_ctc_feeder_data(data, place, need_label=False)
def get_attention_feeder_data(data, place, need_label=True):
pixel_tensor = core.LoDTensor()
pixel_data = None
pixel_data = np.concatenate(
list(map(lambda x: x[0][np.newaxis, :], data)), axis=0).astype("float32")
list(map(lambda x: x[0][np.newaxis, :], data)),
axis=0).astype("float32")
pixel_tensor.set(pixel_data, place)
label_in_tensor = to_lodtensor(list(map(lambda x: x[1], data)), place)
label_out_tensor = to_lodtensor(list(map(lambda x: x[2], data)), place)
......@@ -127,7 +133,8 @@ def get_attention_feeder_for_infer(data, place):
pixel_tensor = core.LoDTensor()
pixel_data = None
pixel_data = np.concatenate(
list(map(lambda x: x[0][np.newaxis, :], data)), axis=0).astype("float32")
list(map(lambda x: x[0][np.newaxis, :], data)),
axis=0).astype("float32")
pixel_tensor.set(pixel_data, place)
return {
"pixel": pixel_tensor,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册