提交 bd1820b7 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix infer rec

上级 115955f7
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
infer_mode: false infer_mode: false
use_space_char: false use_space_char: false
distributed: true distributed: true
save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt
Optimizer: Optimizer:
...@@ -98,7 +99,7 @@ Loss: ...@@ -98,7 +99,7 @@ Loss:
PostProcess: PostProcess:
name: DistillationCTCLabelDecode name: DistillationCTCLabelDecode
model_name: ["Student"] model_name: ["Student", "Teacher"]
key: head_out key: head_out
Metric: Metric:
......
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
import os import os
import sys import sys
import json
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
...@@ -113,11 +114,23 @@ def main(): ...@@ -113,11 +114,23 @@ def main():
else: else:
preds = model(images) preds = model(images)
post_result = post_process_class(preds) post_result = post_process_class(preds)
for rec_reuslt in post_result: info = None
logger.info('\t result: {}'.format(rec_reuslt)) if isinstance(post_result, dict):
if len(rec_reuslt) >= 2: rec_info = dict()
fout.write(file + "\t" + rec_reuslt[0] + "\t" + str( for key in post_result:
rec_reuslt[1]) + "\n") if len(post_result[key][0]) >= 2:
rec_info[key] = {
"label": post_result[key][0][0],
"score": post_result[key][0][1],
}
info = json.dumps(rec_info)
else:
if len(post_result[0]) >= 2:
info = post_result[0][0] + "\t" + str(post_result[0][1])
if info is not None:
logger.info("\t result: {}".format(info))
fout.write(file + "\t" + info)
logger.info("success!") logger.info("success!")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册