From d9c281280a26065e57fb389af78cabb65c6667c0 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Sat, 9 Oct 2021 18:03:52 +0800 Subject: [PATCH] fix multi-inputs --- ppocr/modeling/backbones/kie_unet_sdmgr.py | 16 ++++------------ tools/infer_kie.py | 6 ++++-- tools/program.py | 6 +++--- 3 files changed, 11 insertions(+), 17 deletions(-) diff --git a/ppocr/modeling/backbones/kie_unet_sdmgr.py b/ppocr/modeling/backbones/kie_unet_sdmgr.py index 62bae2ea..5003551d 100644 --- a/ppocr/modeling/backbones/kie_unet_sdmgr.py +++ b/ppocr/modeling/backbones/kie_unet_sdmgr.py @@ -167,20 +167,12 @@ class Kie_backbone(nn.Layer): gt_bboxes[i, :num, ...], dtype='float32')) return img, temp_relations, temp_texts, temp_gt_bboxes - def forward(self, inputs): - img, relations, texts, gt_bboxes, tag, img_size = inputs[0], inputs[ - 1], inputs[2], inputs[3], inputs[5], inputs[-1] + def forward(self, images, inputs): + img = images + relations, texts, gt_bboxes, tag, img_size = inputs[0], inputs[ + 1], inputs[2], inputs[4], inputs[-1] img, relations, texts, gt_bboxes = self.pre_process( img, relations, texts, gt_bboxes, tag, img_size) - # for i in range(4): - # img_t = (img[i].numpy().transpose([1, 2, 0]) * 255.0).astype('uint8') - # img_t = img_t.copy() - # gt_bboxes_t = gt_bboxes[i].cpu().numpy() - # box = gt_bboxes_t.astype(np.int32).reshape((-1, 1, 2)) - # cv2.polylines(img_t, [box], True, color=(255, 255, 0), thickness=1) - # cv2.imwrite("/Users/hongyongjie/project/PaddleOCR/output/{}.png".format(i), img_t) - # # cv2.imwrite("/Users/hongyongjie/project/PaddleOCR/output/{}.png".format(i), img_t * 255.0) - # exit() x = self.img_feat(img) boxes, rois_num = self.bbox2roi(gt_bboxes) feats = paddle.fluid.layers.roi_align( diff --git a/tools/infer_kie.py b/tools/infer_kie.py index 62ef6972..6be3ce14 100755 --- a/tools/infer_kie.py +++ b/tools/infer_kie.py @@ -80,7 +80,8 @@ def draw_kie_result(batch, node, idx_to_cls, count): vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255 vis_img[:, :w] = img vis_img[:, w:] = pred_img - save_kie_path = os.path.dirname(config['Global']['save_res_path']) + "/kie_results/" + save_kie_path = os.path.dirname(config['Global'][ + 'save_res_path']) + "/kie_results/" if not os.path.exists(save_kie_path): os.makedirs(save_kie_path) save_path = os.path.join(save_kie_path, str(count) + ".png") @@ -128,7 +129,8 @@ def main(): batch_pred[i] = paddle.to_tensor( np.expand_dims( batch[i], axis=0)) - node, edge = model(batch_pred) + + node, edge = model(batch[0], batch[1:]) node = F.softmax(node, -1) draw_kie_result(batch, node, idx_to_cls, index) logger.info("success!") diff --git a/tools/program.py b/tools/program.py index 3d0d3635..3aba5aaf 100755 --- a/tools/program.py +++ b/tools/program.py @@ -197,7 +197,7 @@ def train(config, use_srn = config['Architecture']['algorithm'] == "SRN" extra_input = config['Architecture'][ - 'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"] + 'algorithm'] in ["SRN", "NRTR", "SAR", "SEED", "SDMGR"] try: model_type = config['Architecture']['model_type'] except: @@ -230,7 +230,7 @@ def train(config, if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) else: - preds = model(batch) + preds = model(images) loss = loss_class(preds, batch) avg_loss = loss['loss'] avg_loss.backward() @@ -379,7 +379,7 @@ def eval(model, if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) else: - preds = model(batch) + preds = model(images) batch = [item.numpy() for item in batch] # Obtain usable results from post-processing methods total_time += time.time() - start -- GitLab