提交 66029dd8 编写于 作者: L LDOUBLEV

fix kie infer and eval bug

上级 30e8dd8e
......@@ -167,10 +167,10 @@ class Kie_backbone(nn.Layer):
gt_bboxes[i, :num, ...], dtype='float32'))
return img, temp_relations, temp_texts, temp_gt_bboxes
def forward(self, images, inputs):
img = images
relations, texts, gt_bboxes, tag, img_size = inputs[0], inputs[
1], inputs[2], inputs[4], inputs[-1]
def forward(self, inputs):
img = inputs[0]
relations, texts, gt_bboxes, tag, img_size = inputs[1], inputs[
2], inputs[3], inputs[5], inputs[-1]
img, relations, texts, gt_bboxes = self.pre_process(
img, relations, texts, gt_bboxes, tag, img_size)
x = self.img_feat(img)
......
......@@ -49,7 +49,7 @@ class SDMGRHead(nn.Layer):
self.node_cls = nn.Linear(node_embed, num_classes)
self.edge_cls = nn.Linear(edge_embed, 2)
def forward(self, input):
def forward(self, input, targets):
relations, texts, x = input
node_nums, char_nums = [], []
for text in texts:
......
......@@ -54,7 +54,7 @@ def main():
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
extra_input = config['Architecture']['algorithm'] in ["SRN", "SAR"]
extra_input = config['Architecture']['algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
if "model_type" in config['Architecture'].keys():
model_type = config['Architecture']['model_type']
else:
......@@ -68,7 +68,7 @@ def main():
# build metric
eval_class = build_metric(config['Metric'])
logger.info(f"extra_inputs: {extra_input}")
# start eval
metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, model_type, extra_input)
......
......@@ -80,8 +80,7 @@ def draw_kie_result(batch, node, idx_to_cls, count):
vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255
vis_img[:, :w] = img
vis_img[:, w:] = pred_img
save_kie_path = os.path.dirname(config['Global'][
'save_res_path']) + "/kie_results/"
save_kie_path = os.path.dirname(config['Global']['save_res_path']) + "/kie_results/"
if not os.path.exists(save_kie_path):
os.makedirs(save_kie_path)
save_path = os.path.join(save_kie_path, str(count) + ".png")
......@@ -129,8 +128,7 @@ def main():
batch_pred[i] = paddle.to_tensor(
np.expand_dims(
batch[i], axis=0))
node, edge = model(batch[0], batch[1:])
node, edge = model(batch_pred)
node = F.softmax(node, -1)
draw_kie_result(batch, node, idx_to_cls, index)
logger.info("success!")
......
......@@ -196,7 +196,7 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input = config['Architecture'][
'algorithm'] in ["SRN", "NRTR", "SAR", "SEED", "SDMGR"]
'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
try:
model_type = config['Architecture']['model_type']
except:
......@@ -228,6 +228,8 @@ def train(config,
model_average = True
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
if model_type == "kie":
preds = model(batch)
else:
preds = model(images)
loss = loss_class(preds, batch)
......@@ -249,7 +251,7 @@ def train(config,
if cal_metric_during_train: # only rec and cls need
batch = [item.numpy() for item in batch]
if model_type == 'table':
if model_type in ['table', 'kie']:
eval_class(preds, batch)
else:
post_result = post_process_class(preds, batch[1])
......@@ -377,13 +379,15 @@ def eval(model,
start = time.time()
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
if model_type == "kie":
preds = model(batch)
else:
preds = model(images)
batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods
total_time += time.time() - start
# Evaluate the results of the current batch
if model_type == 'table':
if model_type in ['table', 'kie']:
eval_class(preds, batch)
else:
post_result = post_process_class(preds, batch[1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册