diff --git a/ppocr/modeling/backbones/vqa_layoutlm.py b/ppocr/modeling/backbones/vqa_layoutlm.py index 1542519651a5199ab1d08cc950919cb84baebc0d..2fd1b1b2a78a98dba1930378f4a06783aadd8834 100644 --- a/ppocr/modeling/backbones/vqa_layoutlm.py +++ b/ppocr/modeling/backbones/vqa_layoutlm.py @@ -121,14 +121,14 @@ class LayoutXLMForSer(NLPBaseModel): def forward(self, x): x = self.model( - input_ids=x[0], - bbox=x[1], - attention_mask=x[2], - token_type_ids=x[3], - image=x[4], - position_ids=None, - head_mask=None, - labels=None) + input_ids=x[0], + bbox=x[1], + attention_mask=x[2], + token_type_ids=x[3], + image=x[4], + position_ids=None, + head_mask=None, + labels=None) if not self.training: return x return x[0] diff --git a/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py index 90bc52733df2d2b5020cf5756a96b84e903282e1..8a6669f71f5ae6a7a16931e565b43355de5928d9 100644 --- a/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py +++ b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py @@ -68,7 +68,8 @@ class VQASerTokenLayoutLMPostProcess(object): def _infer(self, preds, segment_offset_ids, ocr_infos): results = [] - for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids, ocr_infos): + for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids, + ocr_infos): pred = np.argmax(pred, axis=1) pred = [self.id2label_map[idx] for idx in pred] diff --git a/ppstructure/utility.py b/ppstructure/utility.py index 7ad3426740a14719083c23bf525591c176036da8..4ae56099b83a46c85ce2dc362c1c6417b324dbe1 100644 --- a/ppstructure/utility.py +++ b/ppstructure/utility.py @@ -40,7 +40,6 @@ def init_args(): type=ast.literal_eval, default=None, help='label map according to ppstructure/layout/README_ch.md') - # params for vqa parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM') parser.add_argument("--ser_model_dir", type=str) @@ -73,7 +72,7 @@ def init_args(): "--recovery", type=bool, default=False, - help='Whether to enable layout of recovery') + help='Whether to enable layout of recovery') return parser diff --git a/tools/export_model.py b/tools/export_model.py index 752732ff7eb2410f943212b3f8722c179df1e060..503951a8a69d4855842478bad8c35525f51b9185 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -97,8 +97,9 @@ def export_single_model(model, shape=[None, 1, 32, 100], dtype="float32"), ] model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]: - input_spec=[ + input_spec = [ paddle.static.InputSpec( shape=[None, 512], dtype="int64"), # input_ids paddle.static.InputSpec( diff --git a/tools/infer/utility.py b/tools/infer/utility.py index aa5d2371b6ea62c3150b7f336dfd01ae721d444b..7eb77dec74bf283936e1143edcb5b5dfc28365bd 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -318,7 +318,7 @@ def create_predictor(args, mode, logger): # create predictor predictor = inference.create_predictor(config) input_names = predictor.get_input_names() - if mode in ['ser','re']: + if mode in ['ser', 're']: input_tensor = [] for name in input_names: input_tensor.append(predictor.get_input_handle(name)) diff --git a/tools/infer_vqa_token_ser.py b/tools/infer_vqa_token_ser.py index 1e661ec2c72fa2c076310d6fc4d9d8f4c6a346b4..0173a554cace31e20ab47dbe36d132a4dbb2127b 100755 --- a/tools/infer_vqa_token_ser.py +++ b/tools/infer_vqa_token_ser.py @@ -44,7 +44,7 @@ def to_tensor(data): from collections import defaultdict data_dict = defaultdict(list) to_tensor_idxs = [] - + for idx, v in enumerate(data): if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): if idx not in to_tensor_idxs: @@ -72,7 +72,10 @@ class SerPredictor(object): from paddleocr import PaddleOCR - self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False, use_gpu=global_config['use_gpu']) + self.ocr_engine = PaddleOCR( + use_angle_cls=False, + show_log=False, + use_gpu=global_config['use_gpu']) # create data ops transforms = [] @@ -82,8 +85,8 @@ class SerPredictor(object): op[op_name]['ocr_engine'] = self.ocr_engine elif op_name == 'KeepKeys': op[op_name]['keep_keys'] = [ - 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels', - 'segment_offset_id', 'ocr_info', + 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', + 'image', 'labels', 'segment_offset_id', 'ocr_info', 'entities' ] @@ -103,11 +106,9 @@ class SerPredictor(object): preds = self.model(batch) if self.algorithm in ['LayoutLMv2', 'LayoutXLM']: preds = preds[0] - + post_result = self.post_process_class( - preds, - segment_offset_ids=batch[6], - ocr_infos=batch[7]) + preds, segment_offset_ids=batch[6], ocr_infos=batch[7]) return post_result, batch @@ -154,4 +155,3 @@ if __name__ == '__main__': logger.info("process: [{}/{}], save result to {}".format( idx, len(infer_imgs), save_img_path)) - diff --git a/tools/infer_vqa_token_ser_re.py b/tools/infer_vqa_token_ser_re.py index d5ae634ceabd089f88a7f4d8e109029267010374..20ab1fe176c3be75f7a7b01a8d77df6419c58c75 100755 --- a/tools/infer_vqa_token_ser_re.py +++ b/tools/infer_vqa_token_ser_re.py @@ -192,6 +192,6 @@ if __name__ == '__main__': }, ensure_ascii=False) + "\n") img_res = draw_re_results(img_path, result) cv2.imwrite(save_img_path, img_res) - + logger.info("process: [{}/{}], save result to {}".format( - idx, len(infer_imgs), save_img_path)) \ No newline at end of file + idx, len(infer_imgs), save_img_path))