未验证 提交 e13ec733 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

unify kie and ser for vqa data format (#6704)

* unify kie and ser for vqa data format

* fix config and label ops

* fix doc

* add distort bbox
上级 466214f9
...@@ -17,7 +17,7 @@ Global: ...@@ -17,7 +17,7 @@ Global:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
class_path: ./train_data/wildreceipt/class_list.txt class_path: &class_path ./train_data/wildreceipt/class_list.txt
infer_img: ./train_data/wildreceipt/1.txt infer_img: ./train_data/wildreceipt/1.txt
save_res_path: ./output/sdmgr_kie/predicts_kie.txt save_res_path: ./output/sdmgr_kie/predicts_kie.txt
img_scale: [ 1024, 512 ] img_scale: [ 1024, 512 ]
...@@ -72,6 +72,7 @@ Train: ...@@ -72,6 +72,7 @@ Train:
order: 'hwc' order: 'hwc'
- KieLabelEncode: # Class handling label - KieLabelEncode: # Class handling label
character_dict_path: ./train_data/wildreceipt/dict.txt character_dict_path: ./train_data/wildreceipt/dict.txt
class_path: *class_path
- KieResize: - KieResize:
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
...@@ -88,7 +89,6 @@ Eval: ...@@ -88,7 +89,6 @@ Eval:
data_dir: ./train_data/wildreceipt data_dir: ./train_data/wildreceipt
label_file_list: label_file_list:
- ./train_data/wildreceipt/wildreceipt_test.txt - ./train_data/wildreceipt/wildreceipt_test.txt
# - /paddle/data/PaddleOCR/train_data/wildreceipt/1.txt
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
......
...@@ -43,7 +43,7 @@ Optimizer: ...@@ -43,7 +43,7 @@ Optimizer:
PostProcess: PostProcess:
name: VQASerTokenLayoutLMPostProcess name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt class_path: &class_path train_data/XFUND/class_list_xfun.txt
Metric: Metric:
name: VQASerTokenMetric name: VQASerTokenMetric
...@@ -54,7 +54,7 @@ Train: ...@@ -54,7 +54,7 @@ Train:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image data_dir: train_data/XFUND/zh_train/image
label_file_list: label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json - train_data/XFUND/zh_train/train.json
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
...@@ -89,7 +89,7 @@ Eval: ...@@ -89,7 +89,7 @@ Eval:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image data_dir: train_data/XFUND/zh_val/image
label_file_list: label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json - train_data/XFUND/zh_val/val.json
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
......
...@@ -44,7 +44,7 @@ Optimizer: ...@@ -44,7 +44,7 @@ Optimizer:
PostProcess: PostProcess:
name: VQASerTokenLayoutLMPostProcess name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt class_path: &class_path train_data/XFUND/class_list_xfun.txt
Metric: Metric:
name: VQASerTokenMetric name: VQASerTokenMetric
...@@ -55,7 +55,7 @@ Train: ...@@ -55,7 +55,7 @@ Train:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image data_dir: train_data/XFUND/zh_train/image
label_file_list: label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json - train_data/XFUND/zh_train/train.json
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
...@@ -90,7 +90,7 @@ Eval: ...@@ -90,7 +90,7 @@ Eval:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image data_dir: train_data/XFUND/zh_val/image
label_file_list: label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json - train_data/XFUND/zh_val/val.json
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
......
...@@ -11,7 +11,7 @@ Global: ...@@ -11,7 +11,7 @@ Global:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2022 seed: 2022
infer_img: doc/vqa/input/zh_val_42.jpg infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser save_res_path: ./output/ser
Architecture: Architecture:
...@@ -54,7 +54,7 @@ Train: ...@@ -54,7 +54,7 @@ Train:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image data_dir: train_data/XFUND/zh_train/image
label_file_list: label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json - train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ] ratio_list: [ 1.0 ]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
...@@ -90,7 +90,7 @@ Eval: ...@@ -90,7 +90,7 @@ Eval:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image data_dir: train_data/XFUND/zh_val/image
label_file_list: label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json - train_data/XFUND/zh_val/val.json
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
......
...@@ -259,15 +259,26 @@ class E2ELabelEncodeTrain(object): ...@@ -259,15 +259,26 @@ class E2ELabelEncodeTrain(object):
class KieLabelEncode(object): class KieLabelEncode(object):
def __init__(self, character_dict_path, norm=10, directed=False, **kwargs): def __init__(self,
character_dict_path,
class_path,
norm=10,
directed=False,
**kwargs):
super(KieLabelEncode, self).__init__() super(KieLabelEncode, self).__init__()
self.dict = dict({'': 0}) self.dict = dict({'': 0})
self.label2classid_map = dict()
with open(character_dict_path, 'r', encoding='utf-8') as fr: with open(character_dict_path, 'r', encoding='utf-8') as fr:
idx = 1 idx = 1
for line in fr: for line in fr:
char = line.strip() char = line.strip()
self.dict[char] = idx self.dict[char] = idx
idx += 1 idx += 1
with open(class_path, "r") as fin:
lines = fin.readlines()
for idx, line in enumerate(lines):
line = line.strip("\n")
self.label2classid_map[line] = idx
self.norm = norm self.norm = norm
self.directed = directed self.directed = directed
...@@ -408,7 +419,7 @@ class KieLabelEncode(object): ...@@ -408,7 +419,7 @@ class KieLabelEncode(object):
text_ind = [self.dict[c] for c in text if c in self.dict] text_ind = [self.dict[c] for c in text if c in self.dict]
text_inds.append(text_ind) text_inds.append(text_ind)
if 'label' in ann.keys(): if 'label' in ann.keys():
labels.append(ann['label']) labels.append(self.label2classid_map[ann['label']])
elif 'key_cls' in ann.keys(): elif 'key_cls' in ann.keys():
labels.append(ann['key_cls']) labels.append(ann['key_cls'])
else: else:
...@@ -876,15 +887,16 @@ class VQATokenLabelEncode(object): ...@@ -876,15 +887,16 @@ class VQATokenLabelEncode(object):
for info in ocr_info: for info in ocr_info:
if train_re: if train_re:
# for re # for re
if len(info["text"]) == 0: if len(info["transcription"]) == 0:
empty_entity.add(info["id"]) empty_entity.add(info["id"])
continue continue
id2label[info["id"]] = info["label"] id2label[info["id"]] = info["label"]
relations.extend([tuple(sorted(l)) for l in info["linking"]]) relations.extend([tuple(sorted(l)) for l in info["linking"]])
# smooth_box # smooth_box
info["bbox"] = self.trans_poly_to_bbox(info["points"])
bbox = self._smooth_box(info["bbox"], height, width) bbox = self._smooth_box(info["bbox"], height, width)
text = info["text"] text = info["transcription"]
encode_res = self.tokenizer.encode( encode_res = self.tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True) text, pad_to_max_seq_len=False, return_attention_mask=True)
...@@ -900,7 +912,7 @@ class VQATokenLabelEncode(object): ...@@ -900,7 +912,7 @@ class VQATokenLabelEncode(object):
label = info['label'] label = info['label']
gt_label = self._parse_label(label, encode_res) gt_label = self._parse_label(label, encode_res)
# construct entities for re # construct entities for re
if train_re: if train_re:
if gt_label[0] != self.label2id_map["O"]: if gt_label[0] != self.label2id_map["O"]:
entity_id_to_index_map[info["id"]] = len(entities) entity_id_to_index_map[info["id"]] = len(entities)
...@@ -944,29 +956,29 @@ class VQATokenLabelEncode(object): ...@@ -944,29 +956,29 @@ class VQATokenLabelEncode(object):
data['entity_id_to_index_map'] = entity_id_to_index_map data['entity_id_to_index_map'] = entity_id_to_index_map
return data return data
def _load_ocr_info(self, data): def trans_poly_to_bbox(self, poly):
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly]) x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly]) x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly]) y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly]) y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2] return [x1, y1, x2, y2]
def _load_ocr_info(self, data):
if self.infer_mode: if self.infer_mode:
ocr_result = self.ocr_engine.ocr(data['image'], cls=False) ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
ocr_info = [] ocr_info = []
for res in ocr_result: for res in ocr_result:
ocr_info.append({ ocr_info.append({
"text": res[1][0], "transcription": res[1][0],
"bbox": trans_poly_to_bbox(res[0]), "bbox": self.trans_poly_to_bbox(res[0]),
"poly": res[0], "points": res[0],
}) })
return ocr_info return ocr_info
else: else:
info = data['label'] info = data['label']
# read text info # read text info
info_dict = json.loads(info) info_dict = json.loads(info)
return info_dict["ocr_info"] return info_dict
def _smooth_box(self, bbox, height, width): def _smooth_box(self, bbox, height, width):
bbox[0] = int(bbox[0] * 1000.0 / width) bbox[0] = int(bbox[0] * 1000.0 / width)
...@@ -977,7 +989,7 @@ class VQATokenLabelEncode(object): ...@@ -977,7 +989,7 @@ class VQATokenLabelEncode(object):
def _parse_label(self, label, encode_res): def _parse_label(self, label, encode_res):
gt_label = [] gt_label = []
if label.lower() == "other": if label.lower() in ["other", "others", "ignore"]:
gt_label.extend([0] * len(encode_res["input_ids"])) gt_label.extend([0] * len(encode_res["input_ids"]))
else: else:
gt_label.append(self.label2id_map[("b-" + label).upper()]) gt_label.append(self.label2id_map[("b-" + label).upper()])
......
...@@ -13,7 +13,12 @@ ...@@ -13,7 +13,12 @@
# limitations under the License. # limitations under the License.
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
from .augment import DistortBBox
__all__ = [ __all__ = [
'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation' 'VQATokenPad',
'VQASerTokenChunk',
'VQAReTokenChunk',
'VQAReTokenRelation',
'DistortBBox',
] ]
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import random
class DistortBBox:
def __init__(self, prob=0.5, max_scale=1, **kwargs):
"""Random distort bbox
"""
self.prob = prob
self.max_scale = max_scale
def __call__(self, data):
if random.random() > self.prob:
return data
bbox = np.array(data['bbox'])
rnd_scale = (np.random.rand(*bbox.shape) - 0.5) * 2 * self.max_scale
bbox = np.round(bbox + rnd_scale).astype(bbox.dtype)
data['bbox'] = np.clip(data['bbox'], 0, 1000)
data['bbox'] = bbox.tolist()
sys.stdout.flush()
return data
...@@ -91,18 +91,19 @@ def check_and_read_gif(img_path): ...@@ -91,18 +91,19 @@ def check_and_read_gif(img_path):
def load_vqa_bio_label_maps(label_map_path): def load_vqa_bio_label_maps(label_map_path):
with open(label_map_path, "r", encoding='utf-8') as fin: with open(label_map_path, "r", encoding='utf-8') as fin:
lines = fin.readlines() lines = fin.readlines()
lines = [line.strip() for line in lines] old_lines = [line.strip() for line in lines]
if "O" not in lines: lines = ["O"]
lines.insert(0, "O") for line in old_lines:
labels = [] # "O" has already been in lines
for line in lines: if line.upper() in ["OTHER", "OTHERS", "IGNORE"]:
if line == "O": continue
labels.append("O") lines.append(line)
else: labels = ["O"]
for line in lines[1:]:
labels.append("B-" + line) labels.append("B-" + line)
labels.append("I-" + line) labels.append("I-" + line)
label2id_map = {label: idx for idx, label in enumerate(labels)} label2id_map = {label.upper(): idx for idx, label in enumerate(labels)}
id2label_map = {idx: label for idx, label in enumerate(labels)} id2label_map = {idx: label.upper() for idx, label in enumerate(labels)}
return label2id_map, id2label_map return label2id_map, id2label_map
......
...@@ -19,7 +19,7 @@ from PIL import Image, ImageDraw, ImageFont ...@@ -19,7 +19,7 @@ from PIL import Image, ImageDraw, ImageFont
def draw_ser_results(image, def draw_ser_results(image,
ocr_results, ocr_results,
font_path="doc/fonts/simfang.ttf", font_path="doc/fonts/simfang.ttf",
font_size=18): font_size=14):
np.random.seed(2021) np.random.seed(2021)
color = (np.random.permutation(range(255)), color = (np.random.permutation(range(255)),
np.random.permutation(range(255)), np.random.permutation(range(255)),
...@@ -40,9 +40,15 @@ def draw_ser_results(image, ...@@ -40,9 +40,15 @@ def draw_ser_results(image,
if ocr_info["pred_id"] not in color_map: if ocr_info["pred_id"] not in color_map:
continue continue
color = color_map[ocr_info["pred_id"]] color = color_map[ocr_info["pred_id"]]
text = "{}: {}".format(ocr_info["pred"], ocr_info["text"]) text = "{}: {}".format(ocr_info["pred"], ocr_info["transcription"])
draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color) if "bbox" in ocr_info:
# draw with ocr engine
bbox = ocr_info["bbox"]
else:
# draw with ocr groundtruth
bbox = trans_poly_to_bbox(ocr_info["points"])
draw_box_txt(bbox, text, draw, font, font_size, color)
img_new = Image.blend(image, img_new, 0.5) img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new) return np.array(img_new)
...@@ -62,6 +68,14 @@ def draw_box_txt(bbox, text, draw, font, font_size, color): ...@@ -62,6 +68,14 @@ def draw_box_txt(bbox, text, draw, font, font_size, color):
draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font) draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]
def draw_re_results(image, def draw_re_results(image,
result, result,
font_path="doc/fonts/simfang.ttf", font_path="doc/fonts/simfang.ttf",
...@@ -80,10 +94,10 @@ def draw_re_results(image, ...@@ -80,10 +94,10 @@ def draw_re_results(image,
color_line = (0, 255, 0) color_line = (0, 255, 0)
for ocr_info_head, ocr_info_tail in result: for ocr_info_head, ocr_info_tail in result:
draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font, draw_box_txt(ocr_info_head["bbox"], ocr_info_head["transcription"],
font_size, color_head) draw, font, font_size, color_head)
draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font, draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["transcription"],
font_size, color_tail) draw, font, font_size, color_tail)
center_head = ( center_head = (
(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2, (ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
......
...@@ -16,7 +16,7 @@ SDMGR是一个关键信息提取算法,将每个检测到的文本区域分类 ...@@ -16,7 +16,7 @@ SDMGR是一个关键信息提取算法,将每个检测到的文本区域分类
训练和测试的数据采用wildreceipt数据集,通过如下指令下载数据集: 训练和测试的数据采用wildreceipt数据集,通过如下指令下载数据集:
``` ```
wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/wildreceipt.tar && tar xf wildreceipt.tar wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/wildreceipt.tar && tar xf wildreceipt.tar
``` ```
执行预测: 执行预测:
......
...@@ -15,7 +15,7 @@ This section provides a tutorial example on how to quickly use, train, and evalu ...@@ -15,7 +15,7 @@ This section provides a tutorial example on how to quickly use, train, and evalu
[Wildreceipt dataset](https://paperswithcode.com/dataset/wildreceipt) is used for this tutorial. It contains 1765 photos, with 25 classes, and 50000 text boxes, which can be downloaded by wget: [Wildreceipt dataset](https://paperswithcode.com/dataset/wildreceipt) is used for this tutorial. It contains 1765 photos, with 25 classes, and 50000 text boxes, which can be downloaded by wget:
```shell ```shell
wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/wildreceipt.tar && tar xf wildreceipt.tar wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/wildreceipt.tar && tar xf wildreceipt.tar
``` ```
Download the pretrained model and predict the result: Download the pretrained model and predict the result:
......
...@@ -125,13 +125,13 @@ If you want to experience the prediction process directly, you can download the ...@@ -125,13 +125,13 @@ If you want to experience the prediction process directly, you can download the
* Download the processed dataset * Download the processed dataset
The download address of the processed XFUND Chinese dataset: [https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar). The download address of the processed XFUND Chinese dataset: [link](https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar).
Download and unzip the dataset, and place the dataset in the current directory after unzipping. Download and unzip the dataset, and place the dataset in the current directory after unzipping.
```shell ```shell
wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar
```` ````
* Convert the dataset * Convert the dataset
......
...@@ -122,13 +122,13 @@ python3 -m pip install -r ppstructure/vqa/requirements.txt ...@@ -122,13 +122,13 @@ python3 -m pip install -r ppstructure/vqa/requirements.txt
* 下载处理好的数据集 * 下载处理好的数据集
处理好的XFUND中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar) 处理好的XFUND中文数据集下载地址:[链接](https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar)
下载并解压该数据集,解压后将数据集放置在当前目录下。 下载并解压该数据集,解压后将数据集放置在当前目录下。
```shell ```shell
wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar
``` ```
* 转换数据集 * 转换数据集
......
...@@ -21,26 +21,22 @@ def transfer_xfun_data(json_path=None, output_file=None): ...@@ -21,26 +21,22 @@ def transfer_xfun_data(json_path=None, output_file=None):
json_info = json.loads(lines[0]) json_info = json.loads(lines[0])
documents = json_info["documents"] documents = json_info["documents"]
label_info = {}
with open(output_file, "w", encoding='utf-8') as fout: with open(output_file, "w", encoding='utf-8') as fout:
for idx, document in enumerate(documents): for idx, document in enumerate(documents):
label_info = []
img_info = document["img"] img_info = document["img"]
document = document["document"] document = document["document"]
image_path = img_info["fname"] image_path = img_info["fname"]
label_info["height"] = img_info["height"]
label_info["width"] = img_info["width"]
label_info["ocr_info"] = []
for doc in document: for doc in document:
label_info["ocr_info"].append({ x1, y1, x2, y2 = doc["box"]
"text": doc["text"], points = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
label_info.append({
"transcription": doc["text"],
"label": doc["label"], "label": doc["label"],
"bbox": doc["box"], "points": points,
"id": doc["id"], "id": doc["id"],
"linking": doc["linking"], "linking": doc["linking"]
"words": doc["words"]
}) })
fout.write(image_path + "\t" + json.dumps( fout.write(image_path + "\t" + json.dumps(
......
...@@ -39,13 +39,12 @@ import time ...@@ -39,13 +39,12 @@ import time
def read_class_list(filepath): def read_class_list(filepath):
dict = {} ret = {}
with open(filepath, "r") as f: with open(filepath, "r") as f:
lines = f.readlines() lines = f.readlines()
for line in lines: for idx, line in enumerate(lines):
key, value = line.split(" ") ret[idx] = line.strip("\n")
dict[key] = value.rstrip() return ret
return dict
def draw_kie_result(batch, node, idx_to_cls, count): def draw_kie_result(batch, node, idx_to_cls, count):
...@@ -71,7 +70,7 @@ def draw_kie_result(batch, node, idx_to_cls, count): ...@@ -71,7 +70,7 @@ def draw_kie_result(batch, node, idx_to_cls, count):
x_min = int(min([point[0] for point in new_box])) x_min = int(min([point[0] for point in new_box]))
y_min = int(min([point[1] for point in new_box])) y_min = int(min([point[1] for point in new_box]))
pred_label = str(node_pred_label[i]) pred_label = node_pred_label[i]
if pred_label in idx_to_cls: if pred_label in idx_to_cls:
pred_label = idx_to_cls[pred_label] pred_label = idx_to_cls[pred_label]
pred_score = '{:.2f}'.format(node_pred_score[i]) pred_score = '{:.2f}'.format(node_pred_score[i])
...@@ -109,8 +108,7 @@ def main(): ...@@ -109,8 +108,7 @@ def main():
save_res_path = config['Global']['save_res_path'] save_res_path = config['Global']['save_res_path']
class_path = config['Global']['class_path'] class_path = config['Global']['class_path']
idx_to_cls = read_class_list(class_path) idx_to_cls = read_class_list(class_path)
if not os.path.exists(os.path.dirname(save_res_path)): os.makedirs(os.path.dirname(save_res_path), exist_ok=True)
os.makedirs(os.path.dirname(save_res_path))
model.eval() model.eval()
......
...@@ -86,15 +86,16 @@ class SerPredictor(object): ...@@ -86,15 +86,16 @@ class SerPredictor(object):
] ]
transforms.append(op) transforms.append(op)
if config["Global"].get("infer_mode", None) is None:
global_config['infer_mode'] = True global_config['infer_mode'] = True
self.ops = create_operators(config['Eval']['dataset']['transforms'], self.ops = create_operators(config['Eval']['dataset']['transforms'],
global_config) global_config)
self.model.eval() self.model.eval()
def __call__(self, img_path): def __call__(self, data):
with open(img_path, 'rb') as f: with open(data["img_path"], 'rb') as f:
img = f.read() img = f.read()
data = {'image': img} data["image"] = img
batch = transform(data, self.ops) batch = transform(data, self.ops)
batch = to_tensor(batch) batch = to_tensor(batch)
preds = self.model(batch) preds = self.model(batch)
...@@ -112,20 +113,35 @@ if __name__ == '__main__': ...@@ -112,20 +113,35 @@ if __name__ == '__main__':
ser_engine = SerPredictor(config) ser_engine = SerPredictor(config)
if config["Global"].get("infer_mode", None) is False:
data_dir = config['Eval']['dataset']['data_dir']
with open(config['Global']['infer_img'], "rb") as f:
infer_imgs = f.readlines()
else:
infer_imgs = get_image_file_list(config['Global']['infer_img']) infer_imgs = get_image_file_list(config['Global']['infer_img'])
with open( with open(
os.path.join(config['Global']['save_res_path'], os.path.join(config['Global']['save_res_path'],
"infer_results.txt"), "infer_results.txt"),
"w", "w",
encoding='utf-8') as fout: encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs): for idx, info in enumerate(infer_imgs):
if config["Global"].get("infer_mode", None) is False:
data_line = info.decode('utf-8')
substr = data_line.strip("\n").split("\t")
img_path = os.path.join(data_dir, substr[0])
data = {'img_path': img_path, 'label': substr[1]}
else:
img_path = info
data = {'img_path': img_path}
save_img_path = os.path.join( save_img_path = os.path.join(
config['Global']['save_res_path'], config['Global']['save_res_path'],
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg") os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
logger.info("process: [{}/{}], save result to {}".format( logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path)) idx, len(infer_imgs), save_img_path))
result, _ = ser_engine(img_path) result, _ = ser_engine(data)
result = result[0] result = result[0]
fout.write(img_path + "\t" + json.dumps( fout.write(img_path + "\t" + json.dumps(
{ {
......
...@@ -576,8 +576,8 @@ def preprocess(is_train=False): ...@@ -576,8 +576,8 @@ def preprocess(is_train=False):
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'ViTSTR', 'ABINet' 'SVTR', 'ViTSTR', 'ABINet'
] ]
if use_xpu: if use_xpu:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册