From c342b7a0135f838f6ee5b3727f37418d7535a52b Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Sat, 9 Oct 2021 15:48:16 +0800 Subject: [PATCH] fix infer --- tools/infer_kie.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/tools/infer_kie.py b/tools/infer_kie.py index 216df723..62ef6972 100755 --- a/tools/infer_kie.py +++ b/tools/infer_kie.py @@ -47,9 +47,9 @@ def read_class_list(filepath): return dict -def draw_kie_result(batch, node, idx_to_cls): - img = batch[-2] - boxes = batch[-1] +def draw_kie_result(batch, node, idx_to_cls, count): + img = batch[6].copy() + boxes = batch[7] h, w = img.shape[:2] pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255 max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1) @@ -77,11 +77,15 @@ def draw_kie_result(batch, node, idx_to_cls): text = pred_label + '(' + pred_score + ')' cv2.putText(pred_img, text, (x_min * 2, y_min), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1) - vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255 vis_img[:, :w] = img vis_img[:, w:] = pred_img - return vis_img + save_kie_path = os.path.dirname(config['Global']['save_res_path']) + "/kie_results/" + if not os.path.exists(save_kie_path): + os.makedirs(save_kie_path) + save_path = os.path.join(save_kie_path, str(count) + ".png") + cv2.imwrite(save_path, vis_img) + logger.info("The Kie Image saved in {}".format(save_path)) def main(): @@ -89,7 +93,6 @@ def main(): # build model model = build_model(config['Architecture']) - init_model(config, model, logger) # create data ops @@ -97,6 +100,8 @@ def main(): for op in config['Eval']['dataset']['transforms']: transforms.append(op) + data_dir = config['Eval']['dataset']['data_dir'] + ops = create_operators(transforms, global_config) save_res_path = config['Global']['save_res_path'] @@ -109,11 +114,10 @@ def main(): with open(save_res_path, "wb") as fout: with open(config['Global']['infer_img'], "rb") as f: lines = f.readlines() - for data_line in lines: + for index, data_line in enumerate(lines): data_line = data_line.decode('utf-8') substr = data_line.strip("\n").split("\t") - img_path, label = "/Users/hongyongjie/project/PaddleOCR/train_data/wildreceipt/" + substr[ - 0], substr[1] + img_path, label = data_dir + "/" + substr[0], substr[1] data = {'img_path': img_path, 'label': label} with open(data['img_path'], 'rb') as f: img = f.read() @@ -126,9 +130,7 @@ def main(): batch[i], axis=0)) node, edge = model(batch_pred) node = F.softmax(node, -1) - img = draw_kie_result(batch, node, idx_to_cls) - cv2.imwrite('1.png', img) - exit() + draw_kie_result(batch, node, idx_to_cls, index) logger.info("success!") -- GitLab