提交 dc7bfe8a 编写于 作者: 文幕地方's avatar 文幕地方

fix

上级 ce21ad83
python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../models/ser_LayoutXLM_xfun_zh/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --image_dir=docs/vqa/input/zh_val_42.jpg
python3.7 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=models/re_LayoutXLM_xfun_zh/ Global.infer_img=ppstructure/docs/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=models/ser_LayoutXLM_xfun_zh/
\ No newline at end of file
...@@ -16,7 +16,7 @@ import sys ...@@ -16,7 +16,7 @@ import sys
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth' os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
...@@ -50,18 +50,18 @@ class SerPredictor(object): ...@@ -50,18 +50,18 @@ class SerPredictor(object):
'ocr_engine': self.ocr_engine 'ocr_engine': self.ocr_engine
} }
}, { }, {
'VQATokenPad':{ 'VQATokenPad': {
'max_seq_len':512, 'max_seq_len': 512,
'return_attention_mask': True 'return_attention_mask': True
} }
}, { }, {
'VQASerTokenChunk':{ 'VQASerTokenChunk': {
'max_seq_len':512, 'max_seq_len': 512,
'return_attention_mask': True 'return_attention_mask': True
} }
}, { }, {
'Resize':{ 'Resize': {
'size' : [224, 224] 'size': [224, 224]
} }
}, { }, {
'NormalizeImage': { 'NormalizeImage': {
...@@ -75,8 +75,8 @@ class SerPredictor(object): ...@@ -75,8 +75,8 @@ class SerPredictor(object):
}, { }, {
'KeepKeys': { 'KeepKeys': {
'keep_keys': [ 'keep_keys': [
'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels', 'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
'segment_offset_id', 'ocr_info', 'image', 'labels', 'segment_offset_id', 'ocr_info',
'entities' 'entities'
] ]
} }
...@@ -86,7 +86,8 @@ class SerPredictor(object): ...@@ -86,7 +86,8 @@ class SerPredictor(object):
"class_path": args.ser_dict_path, "class_path": args.ser_dict_path,
} }
self.preprocess_op = create_operators(pre_process_list, {'infer_mode':True}) self.preprocess_op = create_operators(pre_process_list,
{'infer_mode': True})
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \ self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'ser', logger) utility.create_predictor(args, 'ser', logger)
...@@ -113,11 +114,9 @@ class SerPredictor(object): ...@@ -113,11 +114,9 @@ class SerPredictor(object):
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
outputs.append(output) outputs.append(output)
preds = outputs[0] preds = outputs[0]
post_result = self.postprocess_op( post_result = self.postprocess_op(
preds, preds, segment_offset_ids=[data[6]], ocr_infos=[data[7]])
segment_offset_ids=[data[6]],
ocr_infos=[data[7]])
elapse = time.time() - starttime elapse = time.time() - starttime
return post_result, elapse return post_result, elapse
...@@ -136,17 +135,25 @@ def main(args): ...@@ -136,17 +135,25 @@ def main(args):
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
if not flag: if not flag:
img = cv2.imread(image_file) img = cv2.imread(image_file)
img = img[:,:,::-1] img = img[:, :, ::-1]
if img is None: if img is None:
logger.info("error in loading image:{}".format(image_file)) logger.info("error in loading image:{}".format(image_file))
continue continue
ser_res, elapse = ser_predictor(img) ser_res, elapse = ser_predictor(img)
ser_res = ser_res[0] ser_res = ser_res[0]
res_str = '{}\t{}\n'.format(image_file,json.dumps({"ocr_info": ser_res,}, ensure_ascii=False)) res_str = '{}\t{}\n'.format(
image_file,
json.dumps(
{
"ocr_info": ser_res,
}, ensure_ascii=False))
f_w.write(res_str) f_w.write(res_str)
img_res = draw_ser_results(image_file, ser_res, font_path="../doc/fonts/simfang.ttf",) img_res = draw_ser_results(
image_file,
ser_res,
font_path="../doc/fonts/simfang.ttf", )
img_save_path = os.path.join(args.output, img_save_path = os.path.join(args.output,
os.path.basename(image_file)) os.path.basename(image_file))
...@@ -157,5 +164,6 @@ def main(args): ...@@ -157,5 +164,6 @@ def main(args):
count += 1 count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse)) logger.info("Predict time of {}: {}".format(image_file, elapse))
if __name__ == "__main__": if __name__ == "__main__":
main(parse_args()) main(parse_args())
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册