未验证 提交 a312647b 编写于 作者: L littletomatodonkey 提交者: GitHub

add save rec res (#2632)

* add save rec res

* fix hub hub

* fix yml
上级 438060d3
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: True
save_res_path: ./output/rec/predicts_chinese_common_v2.0.txt
Optimizer:
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: True
save_res_path: ./output/rec/predicts_chinese_lite_v2.0.txt
Optimizer:
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_ic15.txt
Optimizer:
name: Adam
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_mv3_none_bilstm_ctc.txt
Optimizer:
name: Adam
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_mv3_none_none_ctc.txt
Optimizer:
name: Adam
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_mv3_tps_bilstm_att.txt
Optimizer:
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_mv3_tps_bilstm_ctc.txt
Optimizer:
name: Adam
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_r34_vd_none_bilstm_ctc.txt
Optimizer:
name: Adam
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_r34_vd_none_none_ctc.txt
Optimizer:
name: Adam
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_b3_rare_r34_none_gru.txt
Optimizer:
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_r34_vd_tps_bilstm_ctc.txt
Optimizer:
name: Adam
......
......@@ -20,6 +20,7 @@ Global:
num_heads: 8
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_srn.txt
Optimizer:
......
......@@ -55,7 +55,6 @@ class OCRCls(hub.Module):
sys.argv = sys.argv[:1]
cfg = parse_args()
from ocr_det.params import read_params
update_cfg_map = vars(read_params())
for key in update_cfg_map:
......
......@@ -73,35 +73,45 @@ def main():
global_config['infer_mode'] = True
ops = create_operators(transforms, global_config)
save_res_path = config['Global'].get('save_res_path',
"./output/rec/predicts_rec.txt")
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
model.eval()
for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file))
with open(file, 'rb') as f:
img = f.read()
data = {'image': img}
batch = transform(data, ops)
if config['Architecture']['algorithm'] == "SRN":
encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
others = [
paddle.to_tensor(encoder_word_pos_list),
paddle.to_tensor(gsrm_word_pos_list),
paddle.to_tensor(gsrm_slf_attn_bias1_list),
paddle.to_tensor(gsrm_slf_attn_bias2_list)
]
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN":
preds = model(images, others)
else:
preds = model(images)
post_result = post_process_class(preds)
for rec_reuslt in post_result:
logger.info('\t result: {}'.format(rec_reuslt))
with open(save_res_path, "w") as fout:
for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file))
with open(file, 'rb') as f:
img = f.read()
data = {'image': img}
batch = transform(data, ops)
if config['Architecture']['algorithm'] == "SRN":
encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
others = [
paddle.to_tensor(encoder_word_pos_list),
paddle.to_tensor(gsrm_word_pos_list),
paddle.to_tensor(gsrm_slf_attn_bias1_list),
paddle.to_tensor(gsrm_slf_attn_bias2_list)
]
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN":
preds = model(images, others)
else:
preds = model(images)
post_result = post_process_class(preds)
for rec_reuslt in post_result:
logger.info('\t result: {}'.format(rec_reuslt))
if len(rec_reuslt) >= 2:
fout.write(file + "\t" + rec_reuslt[0] + "\t" + str(
rec_reuslt[1]) + "\n")
logger.info("success!")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册