From e13ec733a628006cf7295652600bbf8c1604755f Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Thu, 30 Jun 2022 15:23:31 +0800 Subject: [PATCH] unify kie and ser for vqa data format (#6704) * unify kie and ser for vqa data format * fix config and label ops * fix doc * add distort bbox --- configs/kie/kie_unet_sdmgr.yml | 4 +-- configs/vqa/ser/layoutlm.yml | 6 ++-- configs/vqa/ser/layoutlmv2.yml | 6 ++-- configs/vqa/ser/layoutxlm.yml | 6 ++-- ppocr/data/imaug/label_ops.py | 46 +++++++++++++++--------- ppocr/data/imaug/vqa/__init__.py | 7 +++- ppocr/data/imaug/vqa/augment.py | 37 +++++++++++++++++++ ppocr/utils/utility.py | 25 ++++++------- ppocr/utils/visual.py | 28 +++++++++++---- ppstructure/docs/kie.md | 2 +- ppstructure/docs/kie_en.md | 2 +- ppstructure/vqa/README.md | 4 +-- ppstructure/vqa/README_ch.md | 4 +-- ppstructure/vqa/labels/labels_ser.txt | 3 -- ppstructure/vqa/tools/trans_xfun_data.py | 18 ++++------ tools/infer_kie.py | 14 ++++---- tools/infer_vqa_token_ser.py | 30 ++++++++++++---- tools/program.py | 4 +-- 18 files changed, 161 insertions(+), 85 deletions(-) create mode 100644 ppocr/data/imaug/vqa/augment.py delete mode 100644 ppstructure/vqa/labels/labels_ser.txt diff --git a/configs/kie/kie_unet_sdmgr.yml b/configs/kie/kie_unet_sdmgr.yml index a6968aaa..da2e4fda 100644 --- a/configs/kie/kie_unet_sdmgr.yml +++ b/configs/kie/kie_unet_sdmgr.yml @@ -17,7 +17,7 @@ Global: checkpoints: save_inference_dir: use_visualdl: False - class_path: ./train_data/wildreceipt/class_list.txt + class_path: &class_path ./train_data/wildreceipt/class_list.txt infer_img: ./train_data/wildreceipt/1.txt save_res_path: ./output/sdmgr_kie/predicts_kie.txt img_scale: [ 1024, 512 ] @@ -72,6 +72,7 @@ Train: order: 'hwc' - KieLabelEncode: # Class handling label character_dict_path: ./train_data/wildreceipt/dict.txt + class_path: *class_path - KieResize: - ToCHWImage: - KeepKeys: @@ -88,7 +89,6 @@ Eval: data_dir: ./train_data/wildreceipt label_file_list: - ./train_data/wildreceipt/wildreceipt_test.txt - # - /paddle/data/PaddleOCR/train_data/wildreceipt/1.txt transforms: - DecodeImage: # load image img_mode: RGB diff --git a/configs/vqa/ser/layoutlm.yml b/configs/vqa/ser/layoutlm.yml index 87131170..47ab093e 100644 --- a/configs/vqa/ser/layoutlm.yml +++ b/configs/vqa/ser/layoutlm.yml @@ -43,7 +43,7 @@ Optimizer: PostProcess: name: VQASerTokenLayoutLMPostProcess - class_path: &class_path ppstructure/vqa/labels/labels_ser.txt + class_path: &class_path train_data/XFUND/class_list_xfun.txt Metric: name: VQASerTokenMetric @@ -54,7 +54,7 @@ Train: name: SimpleDataSet data_dir: train_data/XFUND/zh_train/image label_file_list: - - train_data/XFUND/zh_train/xfun_normalize_train.json + - train_data/XFUND/zh_train/train.json transforms: - DecodeImage: # load image img_mode: RGB @@ -89,7 +89,7 @@ Eval: name: SimpleDataSet data_dir: train_data/XFUND/zh_val/image label_file_list: - - train_data/XFUND/zh_val/xfun_normalize_val.json + - train_data/XFUND/zh_val/val.json transforms: - DecodeImage: # load image img_mode: RGB diff --git a/configs/vqa/ser/layoutlmv2.yml b/configs/vqa/ser/layoutlmv2.yml index 33406252..d6a9c03e 100644 --- a/configs/vqa/ser/layoutlmv2.yml +++ b/configs/vqa/ser/layoutlmv2.yml @@ -44,7 +44,7 @@ Optimizer: PostProcess: name: VQASerTokenLayoutLMPostProcess - class_path: &class_path ppstructure/vqa/labels/labels_ser.txt + class_path: &class_path train_data/XFUND/class_list_xfun.txt Metric: name: VQASerTokenMetric @@ -55,7 +55,7 @@ Train: name: SimpleDataSet data_dir: train_data/XFUND/zh_train/image label_file_list: - - train_data/XFUND/zh_train/xfun_normalize_train.json + - train_data/XFUND/zh_train/train.json transforms: - DecodeImage: # load image img_mode: RGB @@ -90,7 +90,7 @@ Eval: name: SimpleDataSet data_dir: train_data/XFUND/zh_val/image label_file_list: - - train_data/XFUND/zh_val/xfun_normalize_val.json + - train_data/XFUND/zh_val/val.json transforms: - DecodeImage: # load image img_mode: RGB diff --git a/configs/vqa/ser/layoutxlm.yml b/configs/vqa/ser/layoutxlm.yml index eb1cca5a..3686989c 100644 --- a/configs/vqa/ser/layoutxlm.yml +++ b/configs/vqa/ser/layoutxlm.yml @@ -11,7 +11,7 @@ Global: save_inference_dir: use_visualdl: False seed: 2022 - infer_img: doc/vqa/input/zh_val_42.jpg + infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg save_res_path: ./output/ser Architecture: @@ -54,7 +54,7 @@ Train: name: SimpleDataSet data_dir: train_data/XFUND/zh_train/image label_file_list: - - train_data/XFUND/zh_train/xfun_normalize_train.json + - train_data/XFUND/zh_train/train.json ratio_list: [ 1.0 ] transforms: - DecodeImage: # load image @@ -90,7 +90,7 @@ Eval: name: SimpleDataSet data_dir: train_data/XFUND/zh_val/image label_file_list: - - train_data/XFUND/zh_val/xfun_normalize_val.json + - train_data/XFUND/zh_val/val.json transforms: - DecodeImage: # load image img_mode: RGB diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 312d6dc9..c95b3262 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -259,15 +259,26 @@ class E2ELabelEncodeTrain(object): class KieLabelEncode(object): - def __init__(self, character_dict_path, norm=10, directed=False, **kwargs): + def __init__(self, + character_dict_path, + class_path, + norm=10, + directed=False, + **kwargs): super(KieLabelEncode, self).__init__() self.dict = dict({'': 0}) + self.label2classid_map = dict() with open(character_dict_path, 'r', encoding='utf-8') as fr: idx = 1 for line in fr: char = line.strip() self.dict[char] = idx idx += 1 + with open(class_path, "r") as fin: + lines = fin.readlines() + for idx, line in enumerate(lines): + line = line.strip("\n") + self.label2classid_map[line] = idx self.norm = norm self.directed = directed @@ -408,7 +419,7 @@ class KieLabelEncode(object): text_ind = [self.dict[c] for c in text if c in self.dict] text_inds.append(text_ind) if 'label' in ann.keys(): - labels.append(ann['label']) + labels.append(self.label2classid_map[ann['label']]) elif 'key_cls' in ann.keys(): labels.append(ann['key_cls']) else: @@ -876,15 +887,16 @@ class VQATokenLabelEncode(object): for info in ocr_info: if train_re: # for re - if len(info["text"]) == 0: + if len(info["transcription"]) == 0: empty_entity.add(info["id"]) continue id2label[info["id"]] = info["label"] relations.extend([tuple(sorted(l)) for l in info["linking"]]) # smooth_box + info["bbox"] = self.trans_poly_to_bbox(info["points"]) bbox = self._smooth_box(info["bbox"], height, width) - text = info["text"] + text = info["transcription"] encode_res = self.tokenizer.encode( text, pad_to_max_seq_len=False, return_attention_mask=True) @@ -900,7 +912,7 @@ class VQATokenLabelEncode(object): label = info['label'] gt_label = self._parse_label(label, encode_res) - # construct entities for re +# construct entities for re if train_re: if gt_label[0] != self.label2id_map["O"]: entity_id_to_index_map[info["id"]] = len(entities) @@ -944,29 +956,29 @@ class VQATokenLabelEncode(object): data['entity_id_to_index_map'] = entity_id_to_index_map return data - def _load_ocr_info(self, data): - def trans_poly_to_bbox(poly): - x1 = np.min([p[0] for p in poly]) - x2 = np.max([p[0] for p in poly]) - y1 = np.min([p[1] for p in poly]) - y2 = np.max([p[1] for p in poly]) - return [x1, y1, x2, y2] + def trans_poly_to_bbox(self, poly): + x1 = np.min([p[0] for p in poly]) + x2 = np.max([p[0] for p in poly]) + y1 = np.min([p[1] for p in poly]) + y2 = np.max([p[1] for p in poly]) + return [x1, y1, x2, y2] + def _load_ocr_info(self, data): if self.infer_mode: ocr_result = self.ocr_engine.ocr(data['image'], cls=False) ocr_info = [] for res in ocr_result: ocr_info.append({ - "text": res[1][0], - "bbox": trans_poly_to_bbox(res[0]), - "poly": res[0], + "transcription": res[1][0], + "bbox": self.trans_poly_to_bbox(res[0]), + "points": res[0], }) return ocr_info else: info = data['label'] # read text info info_dict = json.loads(info) - return info_dict["ocr_info"] + return info_dict def _smooth_box(self, bbox, height, width): bbox[0] = int(bbox[0] * 1000.0 / width) @@ -977,7 +989,7 @@ class VQATokenLabelEncode(object): def _parse_label(self, label, encode_res): gt_label = [] - if label.lower() == "other": + if label.lower() in ["other", "others", "ignore"]: gt_label.extend([0] * len(encode_res["input_ids"])) else: gt_label.append(self.label2id_map[("b-" + label).upper()]) diff --git a/ppocr/data/imaug/vqa/__init__.py b/ppocr/data/imaug/vqa/__init__.py index a5025e79..bde17511 100644 --- a/ppocr/data/imaug/vqa/__init__.py +++ b/ppocr/data/imaug/vqa/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation +from .augment import DistortBBox __all__ = [ - 'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation' + 'VQATokenPad', + 'VQASerTokenChunk', + 'VQAReTokenChunk', + 'VQAReTokenRelation', + 'DistortBBox', ] diff --git a/ppocr/data/imaug/vqa/augment.py b/ppocr/data/imaug/vqa/augment.py new file mode 100644 index 00000000..fcdc9685 --- /dev/null +++ b/ppocr/data/imaug/vqa/augment.py @@ -0,0 +1,37 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import numpy as np +import random + + +class DistortBBox: + def __init__(self, prob=0.5, max_scale=1, **kwargs): + """Random distort bbox + """ + self.prob = prob + self.max_scale = max_scale + + def __call__(self, data): + if random.random() > self.prob: + return data + bbox = np.array(data['bbox']) + rnd_scale = (np.random.rand(*bbox.shape) - 0.5) * 2 * self.max_scale + bbox = np.round(bbox + rnd_scale).astype(bbox.dtype) + data['bbox'] = np.clip(data['bbox'], 0, 1000) + data['bbox'] = bbox.tolist() + sys.stdout.flush() + return data diff --git a/ppocr/utils/utility.py b/ppocr/utils/utility.py index 4a25ff8b..b881fcab 100755 --- a/ppocr/utils/utility.py +++ b/ppocr/utils/utility.py @@ -91,18 +91,19 @@ def check_and_read_gif(img_path): def load_vqa_bio_label_maps(label_map_path): with open(label_map_path, "r", encoding='utf-8') as fin: lines = fin.readlines() - lines = [line.strip() for line in lines] - if "O" not in lines: - lines.insert(0, "O") - labels = [] - for line in lines: - if line == "O": - labels.append("O") - else: - labels.append("B-" + line) - labels.append("I-" + line) - label2id_map = {label: idx for idx, label in enumerate(labels)} - id2label_map = {idx: label for idx, label in enumerate(labels)} + old_lines = [line.strip() for line in lines] + lines = ["O"] + for line in old_lines: + # "O" has already been in lines + if line.upper() in ["OTHER", "OTHERS", "IGNORE"]: + continue + lines.append(line) + labels = ["O"] + for line in lines[1:]: + labels.append("B-" + line) + labels.append("I-" + line) + label2id_map = {label.upper(): idx for idx, label in enumerate(labels)} + id2label_map = {idx: label.upper() for idx, label in enumerate(labels)} return label2id_map, id2label_map diff --git a/ppocr/utils/visual.py b/ppocr/utils/visual.py index 7a8c1674..235eb572 100644 --- a/ppocr/utils/visual.py +++ b/ppocr/utils/visual.py @@ -19,7 +19,7 @@ from PIL import Image, ImageDraw, ImageFont def draw_ser_results(image, ocr_results, font_path="doc/fonts/simfang.ttf", - font_size=18): + font_size=14): np.random.seed(2021) color = (np.random.permutation(range(255)), np.random.permutation(range(255)), @@ -40,9 +40,15 @@ def draw_ser_results(image, if ocr_info["pred_id"] not in color_map: continue color = color_map[ocr_info["pred_id"]] - text = "{}: {}".format(ocr_info["pred"], ocr_info["text"]) + text = "{}: {}".format(ocr_info["pred"], ocr_info["transcription"]) - draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color) + if "bbox" in ocr_info: + # draw with ocr engine + bbox = ocr_info["bbox"] + else: + # draw with ocr groundtruth + bbox = trans_poly_to_bbox(ocr_info["points"]) + draw_box_txt(bbox, text, draw, font, font_size, color) img_new = Image.blend(image, img_new, 0.5) return np.array(img_new) @@ -62,6 +68,14 @@ def draw_box_txt(bbox, text, draw, font, font_size, color): draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font) +def trans_poly_to_bbox(poly): + x1 = np.min([p[0] for p in poly]) + x2 = np.max([p[0] for p in poly]) + y1 = np.min([p[1] for p in poly]) + y2 = np.max([p[1] for p in poly]) + return [x1, y1, x2, y2] + + def draw_re_results(image, result, font_path="doc/fonts/simfang.ttf", @@ -80,10 +94,10 @@ def draw_re_results(image, color_line = (0, 255, 0) for ocr_info_head, ocr_info_tail in result: - draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font, - font_size, color_head) - draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font, - font_size, color_tail) + draw_box_txt(ocr_info_head["bbox"], ocr_info_head["transcription"], + draw, font, font_size, color_head) + draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["transcription"], + draw, font, font_size, color_tail) center_head = ( (ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2, diff --git a/ppstructure/docs/kie.md b/ppstructure/docs/kie.md index 35498b33..315dd9f7 100644 --- a/ppstructure/docs/kie.md +++ b/ppstructure/docs/kie.md @@ -16,7 +16,7 @@ SDMGR是一个关键信息提取算法,将每个检测到的文本区域分类 训练和测试的数据采用wildreceipt数据集,通过如下指令下载数据集: ``` -wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/wildreceipt.tar && tar xf wildreceipt.tar +wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/wildreceipt.tar && tar xf wildreceipt.tar ``` 执行预测: diff --git a/ppstructure/docs/kie_en.md b/ppstructure/docs/kie_en.md index 1fe38b0b..7b375222 100644 --- a/ppstructure/docs/kie_en.md +++ b/ppstructure/docs/kie_en.md @@ -15,7 +15,7 @@ This section provides a tutorial example on how to quickly use, train, and evalu [Wildreceipt dataset](https://paperswithcode.com/dataset/wildreceipt) is used for this tutorial. It contains 1765 photos, with 25 classes, and 50000 text boxes, which can be downloaded by wget: ```shell -wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/wildreceipt.tar && tar xf wildreceipt.tar +wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/wildreceipt.tar && tar xf wildreceipt.tar ``` Download the pretrained model and predict the result: diff --git a/ppstructure/vqa/README.md b/ppstructure/vqa/README.md index e3a10671..711ffa31 100644 --- a/ppstructure/vqa/README.md +++ b/ppstructure/vqa/README.md @@ -125,13 +125,13 @@ If you want to experience the prediction process directly, you can download the * Download the processed dataset -The download address of the processed XFUND Chinese dataset: [https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar). +The download address of the processed XFUND Chinese dataset: [link](https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar). Download and unzip the dataset, and place the dataset in the current directory after unzipping. ```shell -wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar +wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar ```` * Convert the dataset diff --git a/ppstructure/vqa/README_ch.md b/ppstructure/vqa/README_ch.md index b677dc07..297ba64f 100644 --- a/ppstructure/vqa/README_ch.md +++ b/ppstructure/vqa/README_ch.md @@ -122,13 +122,13 @@ python3 -m pip install -r ppstructure/vqa/requirements.txt * 下载处理好的数据集 -处理好的XFUND中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)。 +处理好的XFUND中文数据集下载地址:[链接](https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar)。 下载并解压该数据集,解压后将数据集放置在当前目录下。 ```shell -wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar +wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar ``` * 转换数据集 diff --git a/ppstructure/vqa/labels/labels_ser.txt b/ppstructure/vqa/labels/labels_ser.txt deleted file mode 100644 index 508e4811..00000000 --- a/ppstructure/vqa/labels/labels_ser.txt +++ /dev/null @@ -1,3 +0,0 @@ -QUESTION -ANSWER -HEADER diff --git a/ppstructure/vqa/tools/trans_xfun_data.py b/ppstructure/vqa/tools/trans_xfun_data.py index 93ec9816..11d221be 100644 --- a/ppstructure/vqa/tools/trans_xfun_data.py +++ b/ppstructure/vqa/tools/trans_xfun_data.py @@ -21,26 +21,22 @@ def transfer_xfun_data(json_path=None, output_file=None): json_info = json.loads(lines[0]) documents = json_info["documents"] - label_info = {} with open(output_file, "w", encoding='utf-8') as fout: for idx, document in enumerate(documents): + label_info = [] img_info = document["img"] document = document["document"] image_path = img_info["fname"] - label_info["height"] = img_info["height"] - label_info["width"] = img_info["width"] - - label_info["ocr_info"] = [] - for doc in document: - label_info["ocr_info"].append({ - "text": doc["text"], + x1, y1, x2, y2 = doc["box"] + points = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] + label_info.append({ + "transcription": doc["text"], "label": doc["label"], - "bbox": doc["box"], + "points": points, "id": doc["id"], - "linking": doc["linking"], - "words": doc["words"] + "linking": doc["linking"] }) fout.write(image_path + "\t" + json.dumps( diff --git a/tools/infer_kie.py b/tools/infer_kie.py index 0cb0b870..346e2e0a 100755 --- a/tools/infer_kie.py +++ b/tools/infer_kie.py @@ -39,13 +39,12 @@ import time def read_class_list(filepath): - dict = {} + ret = {} with open(filepath, "r") as f: lines = f.readlines() - for line in lines: - key, value = line.split(" ") - dict[key] = value.rstrip() - return dict + for idx, line in enumerate(lines): + ret[idx] = line.strip("\n") + return ret def draw_kie_result(batch, node, idx_to_cls, count): @@ -71,7 +70,7 @@ def draw_kie_result(batch, node, idx_to_cls, count): x_min = int(min([point[0] for point in new_box])) y_min = int(min([point[1] for point in new_box])) - pred_label = str(node_pred_label[i]) + pred_label = node_pred_label[i] if pred_label in idx_to_cls: pred_label = idx_to_cls[pred_label] pred_score = '{:.2f}'.format(node_pred_score[i]) @@ -109,8 +108,7 @@ def main(): save_res_path = config['Global']['save_res_path'] class_path = config['Global']['class_path'] idx_to_cls = read_class_list(class_path) - if not os.path.exists(os.path.dirname(save_res_path)): - os.makedirs(os.path.dirname(save_res_path)) + os.makedirs(os.path.dirname(save_res_path), exist_ok=True) model.eval() diff --git a/tools/infer_vqa_token_ser.py b/tools/infer_vqa_token_ser.py index 83ed72b3..39ada64a 100755 --- a/tools/infer_vqa_token_ser.py +++ b/tools/infer_vqa_token_ser.py @@ -86,15 +86,16 @@ class SerPredictor(object): ] transforms.append(op) - global_config['infer_mode'] = True + if config["Global"].get("infer_mode", None) is None: + global_config['infer_mode'] = True self.ops = create_operators(config['Eval']['dataset']['transforms'], global_config) self.model.eval() - def __call__(self, img_path): - with open(img_path, 'rb') as f: + def __call__(self, data): + with open(data["img_path"], 'rb') as f: img = f.read() - data = {'image': img} + data["image"] = img batch = transform(data, self.ops) batch = to_tensor(batch) preds = self.model(batch) @@ -112,20 +113,35 @@ if __name__ == '__main__': ser_engine = SerPredictor(config) - infer_imgs = get_image_file_list(config['Global']['infer_img']) + if config["Global"].get("infer_mode", None) is False: + data_dir = config['Eval']['dataset']['data_dir'] + with open(config['Global']['infer_img'], "rb") as f: + infer_imgs = f.readlines() + else: + infer_imgs = get_image_file_list(config['Global']['infer_img']) + with open( os.path.join(config['Global']['save_res_path'], "infer_results.txt"), "w", encoding='utf-8') as fout: - for idx, img_path in enumerate(infer_imgs): + for idx, info in enumerate(infer_imgs): + if config["Global"].get("infer_mode", None) is False: + data_line = info.decode('utf-8') + substr = data_line.strip("\n").split("\t") + img_path = os.path.join(data_dir, substr[0]) + data = {'img_path': img_path, 'label': substr[1]} + else: + img_path = info + data = {'img_path': img_path} + save_img_path = os.path.join( config['Global']['save_res_path'], os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg") logger.info("process: [{}/{}], save result to {}".format( idx, len(infer_imgs), save_img_path)) - result, _ = ser_engine(img_path) + result, _ = ser_engine(data) result = result[0] fout.write(img_path + "\t" + json.dumps( { diff --git a/tools/program.py b/tools/program.py index aa3ba82c..f598feb7 100755 --- a/tools/program.py +++ b/tools/program.py @@ -576,8 +576,8 @@ def preprocess(is_train=False): assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', - 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR', - 'ViTSTR', 'ABINet' + 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', + 'SVTR', 'ViTSTR', 'ABINet' ] if use_xpu: -- GitLab