diff --git a/ppocr/modeling/backbones/kie_unet_sdmgr.py b/ppocr/modeling/backbones/kie_unet_sdmgr.py index 5003551de87d916119a7222b964af9de21392843..545e4e7511e58c3d8220e9ec0be35474deba8806 100644 --- a/ppocr/modeling/backbones/kie_unet_sdmgr.py +++ b/ppocr/modeling/backbones/kie_unet_sdmgr.py @@ -167,10 +167,10 @@ class Kie_backbone(nn.Layer): gt_bboxes[i, :num, ...], dtype='float32')) return img, temp_relations, temp_texts, temp_gt_bboxes - def forward(self, images, inputs): - img = images - relations, texts, gt_bboxes, tag, img_size = inputs[0], inputs[ - 1], inputs[2], inputs[4], inputs[-1] + def forward(self, inputs): + img = inputs[0] + relations, texts, gt_bboxes, tag, img_size = inputs[1], inputs[ + 2], inputs[3], inputs[5], inputs[-1] img, relations, texts, gt_bboxes = self.pre_process( img, relations, texts, gt_bboxes, tag, img_size) x = self.img_feat(img) diff --git a/ppocr/modeling/heads/kie_sdmgr_head.py b/ppocr/modeling/heads/kie_sdmgr_head.py index b1352d980ea479f6f23f17d6aaebae5c264bdffd..46ac0ed8dcaccb7628ef87fbe851a2b6acd60d55 100644 --- a/ppocr/modeling/heads/kie_sdmgr_head.py +++ b/ppocr/modeling/heads/kie_sdmgr_head.py @@ -49,7 +49,7 @@ class SDMGRHead(nn.Layer): self.node_cls = nn.Linear(node_embed, num_classes) self.edge_cls = nn.Linear(edge_embed, 2) - def forward(self, input): + def forward(self, input, targets): relations, texts, x = input node_nums, char_nums = [], [] for text in texts: diff --git a/tools/eval.py b/tools/eval.py index 28247bc57450aaf067fcb405674098eacb990166..8eccfd6541aafc1ec0e976a3784b41b90cead1e3 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -54,7 +54,7 @@ def main(): config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) - extra_input = config['Architecture']['algorithm'] in ["SRN", "SAR"] + extra_input = config['Architecture']['algorithm'] in ["SRN", "NRTR", "SAR", "SEED"] if "model_type" in config['Architecture'].keys(): model_type = config['Architecture']['model_type'] else: @@ -68,7 +68,7 @@ def main(): # build metric eval_class = build_metric(config['Metric']) - + logger.info(f"extra_inputs: {extra_input}") # start eval metric = program.eval(model, valid_dataloader, post_process_class, eval_class, model_type, extra_input) diff --git a/tools/infer_kie.py b/tools/infer_kie.py index 6be3ce14f0d5df09bdc072a2f93ce128ddc3119f..62ef697240ffe89fcb858c5308bd010105dde2ab 100755 --- a/tools/infer_kie.py +++ b/tools/infer_kie.py @@ -80,8 +80,7 @@ def draw_kie_result(batch, node, idx_to_cls, count): vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255 vis_img[:, :w] = img vis_img[:, w:] = pred_img - save_kie_path = os.path.dirname(config['Global'][ - 'save_res_path']) + "/kie_results/" + save_kie_path = os.path.dirname(config['Global']['save_res_path']) + "/kie_results/" if not os.path.exists(save_kie_path): os.makedirs(save_kie_path) save_path = os.path.join(save_kie_path, str(count) + ".png") @@ -129,8 +128,7 @@ def main(): batch_pred[i] = paddle.to_tensor( np.expand_dims( batch[i], axis=0)) - - node, edge = model(batch[0], batch[1:]) + node, edge = model(batch_pred) node = F.softmax(node, -1) draw_kie_result(batch, node, idx_to_cls, index) logger.info("success!") diff --git a/tools/program.py b/tools/program.py index b49df092cb95ef556cdfcda0439f30cda56e4f88..ddf39e65c34012ae36efd2752946f737f365b1c1 100755 --- a/tools/program.py +++ b/tools/program.py @@ -196,7 +196,7 @@ def train(config, use_srn = config['Architecture']['algorithm'] == "SRN" extra_input = config['Architecture'][ - 'algorithm'] in ["SRN", "NRTR", "SAR", "SEED", "SDMGR"] + 'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"] try: model_type = config['Architecture']['model_type'] except: @@ -228,6 +228,8 @@ def train(config, model_average = True if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) + if model_type == "kie": + preds = model(batch) else: preds = model(images) loss = loss_class(preds, batch) @@ -249,7 +251,7 @@ def train(config, if cal_metric_during_train: # only rec and cls need batch = [item.numpy() for item in batch] - if model_type == 'table': + if model_type in ['table', 'kie']: eval_class(preds, batch) else: post_result = post_process_class(preds, batch[1]) @@ -377,13 +379,15 @@ def eval(model, start = time.time() if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) + if model_type == "kie": + preds = model(batch) else: preds = model(images) batch = [item.numpy() for item in batch] # Obtain usable results from post-processing methods total_time += time.time() - start # Evaluate the results of the current batch - if model_type == 'table': + if model_type in ['table', 'kie']: eval_class(preds, batch) else: post_result = post_process_class(preds, batch[1])