提交 66029dd8 编写于 作者: L LDOUBLEV

fix kie infer and eval bug

上级 30e8dd8e
...@@ -167,10 +167,10 @@ class Kie_backbone(nn.Layer): ...@@ -167,10 +167,10 @@ class Kie_backbone(nn.Layer):
gt_bboxes[i, :num, ...], dtype='float32')) gt_bboxes[i, :num, ...], dtype='float32'))
return img, temp_relations, temp_texts, temp_gt_bboxes return img, temp_relations, temp_texts, temp_gt_bboxes
def forward(self, images, inputs): def forward(self, inputs):
img = images img = inputs[0]
relations, texts, gt_bboxes, tag, img_size = inputs[0], inputs[ relations, texts, gt_bboxes, tag, img_size = inputs[1], inputs[
1], inputs[2], inputs[4], inputs[-1] 2], inputs[3], inputs[5], inputs[-1]
img, relations, texts, gt_bboxes = self.pre_process( img, relations, texts, gt_bboxes = self.pre_process(
img, relations, texts, gt_bboxes, tag, img_size) img, relations, texts, gt_bboxes, tag, img_size)
x = self.img_feat(img) x = self.img_feat(img)
......
...@@ -49,7 +49,7 @@ class SDMGRHead(nn.Layer): ...@@ -49,7 +49,7 @@ class SDMGRHead(nn.Layer):
self.node_cls = nn.Linear(node_embed, num_classes) self.node_cls = nn.Linear(node_embed, num_classes)
self.edge_cls = nn.Linear(edge_embed, 2) self.edge_cls = nn.Linear(edge_embed, 2)
def forward(self, input): def forward(self, input, targets):
relations, texts, x = input relations, texts, x = input
node_nums, char_nums = [], [] node_nums, char_nums = [], []
for text in texts: for text in texts:
......
...@@ -54,7 +54,7 @@ def main(): ...@@ -54,7 +54,7 @@ def main():
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
extra_input = config['Architecture']['algorithm'] in ["SRN", "SAR"] extra_input = config['Architecture']['algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
if "model_type" in config['Architecture'].keys(): if "model_type" in config['Architecture'].keys():
model_type = config['Architecture']['model_type'] model_type = config['Architecture']['model_type']
else: else:
...@@ -68,7 +68,7 @@ def main(): ...@@ -68,7 +68,7 @@ def main():
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
logger.info(f"extra_inputs: {extra_input}")
# start eval # start eval
metric = program.eval(model, valid_dataloader, post_process_class, metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, model_type, extra_input) eval_class, model_type, extra_input)
......
...@@ -80,8 +80,7 @@ def draw_kie_result(batch, node, idx_to_cls, count): ...@@ -80,8 +80,7 @@ def draw_kie_result(batch, node, idx_to_cls, count):
vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255 vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255
vis_img[:, :w] = img vis_img[:, :w] = img
vis_img[:, w:] = pred_img vis_img[:, w:] = pred_img
save_kie_path = os.path.dirname(config['Global'][ save_kie_path = os.path.dirname(config['Global']['save_res_path']) + "/kie_results/"
'save_res_path']) + "/kie_results/"
if not os.path.exists(save_kie_path): if not os.path.exists(save_kie_path):
os.makedirs(save_kie_path) os.makedirs(save_kie_path)
save_path = os.path.join(save_kie_path, str(count) + ".png") save_path = os.path.join(save_kie_path, str(count) + ".png")
...@@ -129,8 +128,7 @@ def main(): ...@@ -129,8 +128,7 @@ def main():
batch_pred[i] = paddle.to_tensor( batch_pred[i] = paddle.to_tensor(
np.expand_dims( np.expand_dims(
batch[i], axis=0)) batch[i], axis=0))
node, edge = model(batch_pred)
node, edge = model(batch[0], batch[1:])
node = F.softmax(node, -1) node = F.softmax(node, -1)
draw_kie_result(batch, node, idx_to_cls, index) draw_kie_result(batch, node, idx_to_cls, index)
logger.info("success!") logger.info("success!")
......
...@@ -196,7 +196,7 @@ def train(config, ...@@ -196,7 +196,7 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input = config['Architecture'][ extra_input = config['Architecture'][
'algorithm'] in ["SRN", "NRTR", "SAR", "SEED", "SDMGR"] 'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
try: try:
model_type = config['Architecture']['model_type'] model_type = config['Architecture']['model_type']
except: except:
...@@ -228,6 +228,8 @@ def train(config, ...@@ -228,6 +228,8 @@ def train(config,
model_average = True model_average = True
if model_type == 'table' or extra_input: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
if model_type == "kie":
preds = model(batch)
else: else:
preds = model(images) preds = model(images)
loss = loss_class(preds, batch) loss = loss_class(preds, batch)
...@@ -249,7 +251,7 @@ def train(config, ...@@ -249,7 +251,7 @@ def train(config,
if cal_metric_during_train: # only rec and cls need if cal_metric_during_train: # only rec and cls need
batch = [item.numpy() for item in batch] batch = [item.numpy() for item in batch]
if model_type == 'table': if model_type in ['table', 'kie']:
eval_class(preds, batch) eval_class(preds, batch)
else: else:
post_result = post_process_class(preds, batch[1]) post_result = post_process_class(preds, batch[1])
...@@ -377,13 +379,15 @@ def eval(model, ...@@ -377,13 +379,15 @@ def eval(model,
start = time.time() start = time.time()
if model_type == 'table' or extra_input: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
if model_type == "kie":
preds = model(batch)
else: else:
preds = model(images) preds = model(images)
batch = [item.numpy() for item in batch] batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods # Obtain usable results from post-processing methods
total_time += time.time() - start total_time += time.time() - start
# Evaluate the results of the current batch # Evaluate the results of the current batch
if model_type == 'table': if model_type in ['table', 'kie']:
eval_class(preds, batch) eval_class(preds, batch)
else: else:
post_result = post_process_class(preds, batch[1]) post_result = post_process_class(preds, batch[1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册