From 66029dd8c660594bd0031872dfa3c39321ff284c Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Mon, 11 Oct 2021 02:35:26 +0000 Subject: [PATCH] fix kie infer and eval bug --- ppocr/modeling/backbones/kie_unet_sdmgr.py | 8 ++++---- ppocr/modeling/heads/kie_sdmgr_head.py | 2 +- tools/eval.py | 4 ++-- tools/infer_kie.py | 6 ++---- tools/program.py | 10 +++++++--- 5 files changed, 16 insertions(+), 14 deletions(-) diff --git a/ppocr/modeling/backbones/kie_unet_sdmgr.py b/ppocr/modeling/backbones/kie_unet_sdmgr.py index 5003551d..545e4e75 100644 --- a/ppocr/modeling/backbones/kie_unet_sdmgr.py +++ b/ppocr/modeling/backbones/kie_unet_sdmgr.py @@ -167,10 +167,10 @@ class Kie_backbone(nn.Layer): gt_bboxes[i, :num, ...], dtype='float32')) return img, temp_relations, temp_texts, temp_gt_bboxes - def forward(self, images, inputs): - img = images - relations, texts, gt_bboxes, tag, img_size = inputs[0], inputs[ - 1], inputs[2], inputs[4], inputs[-1] + def forward(self, inputs): + img = inputs[0] + relations, texts, gt_bboxes, tag, img_size = inputs[1], inputs[ + 2], inputs[3], inputs[5], inputs[-1] img, relations, texts, gt_bboxes = self.pre_process( img, relations, texts, gt_bboxes, tag, img_size) x = self.img_feat(img) diff --git a/ppocr/modeling/heads/kie_sdmgr_head.py b/ppocr/modeling/heads/kie_sdmgr_head.py index b1352d98..46ac0ed8 100644 --- a/ppocr/modeling/heads/kie_sdmgr_head.py +++ b/ppocr/modeling/heads/kie_sdmgr_head.py @@ -49,7 +49,7 @@ class SDMGRHead(nn.Layer): self.node_cls = nn.Linear(node_embed, num_classes) self.edge_cls = nn.Linear(edge_embed, 2) - def forward(self, input): + def forward(self, input, targets): relations, texts, x = input node_nums, char_nums = [], [] for text in texts: diff --git a/tools/eval.py b/tools/eval.py index 28247bc5..8eccfd65 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -54,7 +54,7 @@ def main(): config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) - extra_input = config['Architecture']['algorithm'] in ["SRN", "SAR"] + extra_input = config['Architecture']['algorithm'] in ["SRN", "NRTR", "SAR", "SEED"] if "model_type" in config['Architecture'].keys(): model_type = config['Architecture']['model_type'] else: @@ -68,7 +68,7 @@ def main(): # build metric eval_class = build_metric(config['Metric']) - + logger.info(f"extra_inputs: {extra_input}") # start eval metric = program.eval(model, valid_dataloader, post_process_class, eval_class, model_type, extra_input) diff --git a/tools/infer_kie.py b/tools/infer_kie.py index 6be3ce14..62ef6972 100755 --- a/tools/infer_kie.py +++ b/tools/infer_kie.py @@ -80,8 +80,7 @@ def draw_kie_result(batch, node, idx_to_cls, count): vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255 vis_img[:, :w] = img vis_img[:, w:] = pred_img - save_kie_path = os.path.dirname(config['Global'][ - 'save_res_path']) + "/kie_results/" + save_kie_path = os.path.dirname(config['Global']['save_res_path']) + "/kie_results/" if not os.path.exists(save_kie_path): os.makedirs(save_kie_path) save_path = os.path.join(save_kie_path, str(count) + ".png") @@ -129,8 +128,7 @@ def main(): batch_pred[i] = paddle.to_tensor( np.expand_dims( batch[i], axis=0)) - - node, edge = model(batch[0], batch[1:]) + node, edge = model(batch_pred) node = F.softmax(node, -1) draw_kie_result(batch, node, idx_to_cls, index) logger.info("success!") diff --git a/tools/program.py b/tools/program.py index b49df092..ddf39e65 100755 --- a/tools/program.py +++ b/tools/program.py @@ -196,7 +196,7 @@ def train(config, use_srn = config['Architecture']['algorithm'] == "SRN" extra_input = config['Architecture'][ - 'algorithm'] in ["SRN", "NRTR", "SAR", "SEED", "SDMGR"] + 'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"] try: model_type = config['Architecture']['model_type'] except: @@ -228,6 +228,8 @@ def train(config, model_average = True if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) + if model_type == "kie": + preds = model(batch) else: preds = model(images) loss = loss_class(preds, batch) @@ -249,7 +251,7 @@ def train(config, if cal_metric_during_train: # only rec and cls need batch = [item.numpy() for item in batch] - if model_type == 'table': + if model_type in ['table', 'kie']: eval_class(preds, batch) else: post_result = post_process_class(preds, batch[1]) @@ -377,13 +379,15 @@ def eval(model, start = time.time() if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) + if model_type == "kie": + preds = model(batch) else: preds = model(images) batch = [item.numpy() for item in batch] # Obtain usable results from post-processing methods total_time += time.time() - start # Evaluate the results of the current batch - if model_type == 'table': + if model_type in ['table', 'kie']: eval_class(preds, batch) else: post_result = post_process_class(preds, batch[1]) -- GitLab