提交 d9c28128 编写于 作者: L LDOUBLEV

fix multi-inputs

上级 4e0fcd6e
......@@ -167,20 +167,12 @@ class Kie_backbone(nn.Layer):
gt_bboxes[i, :num, ...], dtype='float32'))
return img, temp_relations, temp_texts, temp_gt_bboxes
def forward(self, inputs):
img, relations, texts, gt_bboxes, tag, img_size = inputs[0], inputs[
1], inputs[2], inputs[3], inputs[5], inputs[-1]
def forward(self, images, inputs):
img = images
relations, texts, gt_bboxes, tag, img_size = inputs[0], inputs[
1], inputs[2], inputs[4], inputs[-1]
img, relations, texts, gt_bboxes = self.pre_process(
img, relations, texts, gt_bboxes, tag, img_size)
# for i in range(4):
# img_t = (img[i].numpy().transpose([1, 2, 0]) * 255.0).astype('uint8')
# img_t = img_t.copy()
# gt_bboxes_t = gt_bboxes[i].cpu().numpy()
# box = gt_bboxes_t.astype(np.int32).reshape((-1, 1, 2))
# cv2.polylines(img_t, [box], True, color=(255, 255, 0), thickness=1)
# cv2.imwrite("/Users/hongyongjie/project/PaddleOCR/output/{}.png".format(i), img_t)
# # cv2.imwrite("/Users/hongyongjie/project/PaddleOCR/output/{}.png".format(i), img_t * 255.0)
# exit()
x = self.img_feat(img)
boxes, rois_num = self.bbox2roi(gt_bboxes)
feats = paddle.fluid.layers.roi_align(
......
......@@ -80,7 +80,8 @@ def draw_kie_result(batch, node, idx_to_cls, count):
vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255
vis_img[:, :w] = img
vis_img[:, w:] = pred_img
save_kie_path = os.path.dirname(config['Global']['save_res_path']) + "/kie_results/"
save_kie_path = os.path.dirname(config['Global'][
'save_res_path']) + "/kie_results/"
if not os.path.exists(save_kie_path):
os.makedirs(save_kie_path)
save_path = os.path.join(save_kie_path, str(count) + ".png")
......@@ -128,7 +129,8 @@ def main():
batch_pred[i] = paddle.to_tensor(
np.expand_dims(
batch[i], axis=0))
node, edge = model(batch_pred)
node, edge = model(batch[0], batch[1:])
node = F.softmax(node, -1)
draw_kie_result(batch, node, idx_to_cls, index)
logger.info("success!")
......
......@@ -197,7 +197,7 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input = config['Architecture'][
'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
'algorithm'] in ["SRN", "NRTR", "SAR", "SEED", "SDMGR"]
try:
model_type = config['Architecture']['model_type']
except:
......@@ -230,7 +230,7 @@ def train(config,
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
else:
preds = model(batch)
preds = model(images)
loss = loss_class(preds, batch)
avg_loss = loss['loss']
avg_loss.backward()
......@@ -379,7 +379,7 @@ def eval(model,
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
else:
preds = model(batch)
preds = model(images)
batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods
total_time += time.time() - start
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册