From 807dd106361ce97c0f0de59fda6ede5d15413c9d Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Fri, 1 Jul 2022 09:59:11 +0000 Subject: [PATCH] pre-commit --- ppocr/modeling/backbones/vqa_layoutlm.py | 16 ++++++++-------- .../vqa_token_ser_layoutlm_postprocess.py | 3 ++- ppstructure/utility.py | 3 +-- tools/export_model.py | 3 ++- tools/infer/utility.py | 2 +- tools/infer_vqa_token_ser.py | 18 +++++++++--------- tools/infer_vqa_token_ser_re.py | 4 ++-- 7 files changed, 25 insertions(+), 24 deletions(-) diff --git a/ppocr/modeling/backbones/vqa_layoutlm.py b/ppocr/modeling/backbones/vqa_layoutlm.py index 15425196..2fd1b1b2 100644 --- a/ppocr/modeling/backbones/vqa_layoutlm.py +++ b/ppocr/modeling/backbones/vqa_layoutlm.py @@ -121,14 +121,14 @@ class LayoutXLMForSer(NLPBaseModel): def forward(self, x): x = self.model( - input_ids=x[0], - bbox=x[1], - attention_mask=x[2], - token_type_ids=x[3], - image=x[4], - position_ids=None, - head_mask=None, - labels=None) + input_ids=x[0], + bbox=x[1], + attention_mask=x[2], + token_type_ids=x[3], + image=x[4], + position_ids=None, + head_mask=None, + labels=None) if not self.training: return x return x[0] diff --git a/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py index 90bc5273..8a6669f7 100644 --- a/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py +++ b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py @@ -68,7 +68,8 @@ class VQASerTokenLayoutLMPostProcess(object): def _infer(self, preds, segment_offset_ids, ocr_infos): results = [] - for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids, ocr_infos): + for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids, + ocr_infos): pred = np.argmax(pred, axis=1) pred = [self.id2label_map[idx] for idx in pred] diff --git a/ppstructure/utility.py b/ppstructure/utility.py index 7ad34267..4ae56099 100644 --- a/ppstructure/utility.py +++ b/ppstructure/utility.py @@ -40,7 +40,6 @@ def init_args(): type=ast.literal_eval, default=None, help='label map according to ppstructure/layout/README_ch.md') - # params for vqa parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM') parser.add_argument("--ser_model_dir", type=str) @@ -73,7 +72,7 @@ def init_args(): "--recovery", type=bool, default=False, - help='Whether to enable layout of recovery') + help='Whether to enable layout of recovery') return parser diff --git a/tools/export_model.py b/tools/export_model.py index 752732ff..503951a8 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -97,8 +97,9 @@ def export_single_model(model, shape=[None, 1, 32, 100], dtype="float32"), ] model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]: - input_spec=[ + input_spec = [ paddle.static.InputSpec( shape=[None, 512], dtype="int64"), # input_ids paddle.static.InputSpec( diff --git a/tools/infer/utility.py b/tools/infer/utility.py index aa5d2371..7eb77dec 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -318,7 +318,7 @@ def create_predictor(args, mode, logger): # create predictor predictor = inference.create_predictor(config) input_names = predictor.get_input_names() - if mode in ['ser','re']: + if mode in ['ser', 're']: input_tensor = [] for name in input_names: input_tensor.append(predictor.get_input_handle(name)) diff --git a/tools/infer_vqa_token_ser.py b/tools/infer_vqa_token_ser.py index 1e661ec2..0173a554 100755 --- a/tools/infer_vqa_token_ser.py +++ b/tools/infer_vqa_token_ser.py @@ -44,7 +44,7 @@ def to_tensor(data): from collections import defaultdict data_dict = defaultdict(list) to_tensor_idxs = [] - + for idx, v in enumerate(data): if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): if idx not in to_tensor_idxs: @@ -72,7 +72,10 @@ class SerPredictor(object): from paddleocr import PaddleOCR - self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False, use_gpu=global_config['use_gpu']) + self.ocr_engine = PaddleOCR( + use_angle_cls=False, + show_log=False, + use_gpu=global_config['use_gpu']) # create data ops transforms = [] @@ -82,8 +85,8 @@ class SerPredictor(object): op[op_name]['ocr_engine'] = self.ocr_engine elif op_name == 'KeepKeys': op[op_name]['keep_keys'] = [ - 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels', - 'segment_offset_id', 'ocr_info', + 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', + 'image', 'labels', 'segment_offset_id', 'ocr_info', 'entities' ] @@ -103,11 +106,9 @@ class SerPredictor(object): preds = self.model(batch) if self.algorithm in ['LayoutLMv2', 'LayoutXLM']: preds = preds[0] - + post_result = self.post_process_class( - preds, - segment_offset_ids=batch[6], - ocr_infos=batch[7]) + preds, segment_offset_ids=batch[6], ocr_infos=batch[7]) return post_result, batch @@ -154,4 +155,3 @@ if __name__ == '__main__': logger.info("process: [{}/{}], save result to {}".format( idx, len(infer_imgs), save_img_path)) - diff --git a/tools/infer_vqa_token_ser_re.py b/tools/infer_vqa_token_ser_re.py index d5ae634c..20ab1fe1 100755 --- a/tools/infer_vqa_token_ser_re.py +++ b/tools/infer_vqa_token_ser_re.py @@ -192,6 +192,6 @@ if __name__ == '__main__': }, ensure_ascii=False) + "\n") img_res = draw_re_results(img_path, result) cv2.imwrite(save_img_path, img_res) - + logger.info("process: [{}/{}], save result to {}".format( - idx, len(infer_imgs), save_img_path)) \ No newline at end of file + idx, len(infer_imgs), save_img_path)) -- GitLab