diff --git a/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml b/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml index 8e8acd8b107384dd79ed28b8d3588ff0e76b3679..016788ea72be3d9b4536c0c410354449c7de84ae 100644 --- a/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml +++ b/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml @@ -19,6 +19,7 @@ Global: infer_mode: false use_space_char: false distributed: true + save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt Optimizer: @@ -98,7 +99,7 @@ Loss: PostProcess: name: DistillationCTCLabelDecode - model_name: ["Student"] + model_name: ["Student", "Teacher"] key: head_out Metric: diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 6bd8f1429072e533e8c449c8c8a439ed51f521a3..6894207d4bb7eaa2aa84f4a0a30ee878d389b5cc 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -20,6 +20,7 @@ import numpy as np import os import sys +import json __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) @@ -113,11 +114,23 @@ def main(): else: preds = model(images) post_result = post_process_class(preds) - for rec_reuslt in post_result: - logger.info('\t result: {}'.format(rec_reuslt)) - if len(rec_reuslt) >= 2: - fout.write(file + "\t" + rec_reuslt[0] + "\t" + str( - rec_reuslt[1]) + "\n") + info = None + if isinstance(post_result, dict): + rec_info = dict() + for key in post_result: + if len(post_result[key][0]) >= 2: + rec_info[key] = { + "label": post_result[key][0][0], + "score": post_result[key][0][1], + } + info = json.dumps(rec_info) + else: + if len(post_result[0]) >= 2: + info = post_result[0][0] + "\t" + str(post_result[0][1]) + + if info is not None: + logger.info("\t result: {}".format(info)) + fout.write(file + "\t" + info) logger.info("success!")