From dc7bfe8a8441167cf6303879800eee7e160e2679 Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Fri, 1 Jul 2022 09:42:27 +0000 Subject: [PATCH] fix --- ppstructure/infer.sh | 4 --- ppstructure/vqa/predict_vqa_token_ser.py | 42 ++++++++++++++---------- 2 files changed, 25 insertions(+), 21 deletions(-) delete mode 100644 ppstructure/infer.sh diff --git a/ppstructure/infer.sh b/ppstructure/infer.sh deleted file mode 100644 index a08cbadf..00000000 --- a/ppstructure/infer.sh +++ /dev/null @@ -1,4 +0,0 @@ -python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../models/ser_LayoutXLM_xfun_zh/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --image_dir=docs/vqa/input/zh_val_42.jpg - - -python3.7 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=models/re_LayoutXLM_xfun_zh/ Global.infer_img=ppstructure/docs/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=models/ser_LayoutXLM_xfun_zh/ \ No newline at end of file diff --git a/ppstructure/vqa/predict_vqa_token_ser.py b/ppstructure/vqa/predict_vqa_token_ser.py index f55c8757..de0bbfe7 100644 --- a/ppstructure/vqa/predict_vqa_token_ser.py +++ b/ppstructure/vqa/predict_vqa_token_ser.py @@ -16,7 +16,7 @@ import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) os.environ["FLAGS_allocator_strategy"] = 'auto_growth' @@ -50,18 +50,18 @@ class SerPredictor(object): 'ocr_engine': self.ocr_engine } }, { - 'VQATokenPad':{ - 'max_seq_len':512, + 'VQATokenPad': { + 'max_seq_len': 512, 'return_attention_mask': True } }, { - 'VQASerTokenChunk':{ - 'max_seq_len':512, + 'VQASerTokenChunk': { + 'max_seq_len': 512, 'return_attention_mask': True } }, { - 'Resize':{ - 'size' : [224, 224] + 'Resize': { + 'size': [224, 224] } }, { 'NormalizeImage': { @@ -75,8 +75,8 @@ class SerPredictor(object): }, { 'KeepKeys': { 'keep_keys': [ - 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels', - 'segment_offset_id', 'ocr_info', + 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', + 'image', 'labels', 'segment_offset_id', 'ocr_info', 'entities' ] } @@ -86,7 +86,8 @@ class SerPredictor(object): "class_path": args.ser_dict_path, } - self.preprocess_op = create_operators(pre_process_list, {'infer_mode':True}) + self.preprocess_op = create_operators(pre_process_list, + {'infer_mode': True}) self.postprocess_op = build_post_process(postprocess_params) self.predictor, self.input_tensor, self.output_tensors, self.config = \ utility.create_predictor(args, 'ser', logger) @@ -113,11 +114,9 @@ class SerPredictor(object): output = output_tensor.copy_to_cpu() outputs.append(output) preds = outputs[0] - + post_result = self.postprocess_op( - preds, - segment_offset_ids=[data[6]], - ocr_infos=[data[7]]) + preds, segment_offset_ids=[data[6]], ocr_infos=[data[7]]) elapse = time.time() - starttime return post_result, elapse @@ -136,17 +135,25 @@ def main(args): img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) - img = img[:,:,::-1] + img = img[:, :, ::-1] if img is None: logger.info("error in loading image:{}".format(image_file)) continue ser_res, elapse = ser_predictor(img) ser_res = ser_res[0] - res_str = '{}\t{}\n'.format(image_file,json.dumps({"ocr_info": ser_res,}, ensure_ascii=False)) + res_str = '{}\t{}\n'.format( + image_file, + json.dumps( + { + "ocr_info": ser_res, + }, ensure_ascii=False)) f_w.write(res_str) - img_res = draw_ser_results(image_file, ser_res, font_path="../doc/fonts/simfang.ttf",) + img_res = draw_ser_results( + image_file, + ser_res, + font_path="../doc/fonts/simfang.ttf", ) img_save_path = os.path.join(args.output, os.path.basename(image_file)) @@ -157,5 +164,6 @@ def main(args): count += 1 logger.info("Predict time of {}: {}".format(image_file, elapse)) + if __name__ == "__main__": main(parse_args()) -- GitLab