提交 bd1820b7 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix infer rec

上级 115955f7
......@@ -19,6 +19,7 @@ Global:
infer_mode: false
use_space_char: false
distributed: true
save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt
Optimizer:
......@@ -98,7 +99,7 @@ Loss:
PostProcess:
name: DistillationCTCLabelDecode
model_name: ["Student"]
model_name: ["Student", "Teacher"]
key: head_out
Metric:
......
......@@ -20,6 +20,7 @@ import numpy as np
import os
import sys
import json
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
......@@ -113,11 +114,23 @@ def main():
else:
preds = model(images)
post_result = post_process_class(preds)
for rec_reuslt in post_result:
logger.info('\t result: {}'.format(rec_reuslt))
if len(rec_reuslt) >= 2:
fout.write(file + "\t" + rec_reuslt[0] + "\t" + str(
rec_reuslt[1]) + "\n")
info = None
if isinstance(post_result, dict):
rec_info = dict()
for key in post_result:
if len(post_result[key][0]) >= 2:
rec_info[key] = {
"label": post_result[key][0][0],
"score": post_result[key][0][1],
}
info = json.dumps(rec_info)
else:
if len(post_result[0]) >= 2:
info = post_result[0][0] + "\t" + str(post_result[0][1])
if info is not None:
logger.info("\t result: {}".format(info))
fout.write(file + "\t" + info)
logger.info("success!")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册