提交 dc7bfe8a 编写于 作者: 文幕地方's avatar 文幕地方

fix

上级 ce21ad83
python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../models/ser_LayoutXLM_xfun_zh/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --image_dir=docs/vqa/input/zh_val_42.jpg
python3.7 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=models/re_LayoutXLM_xfun_zh/ Global.infer_img=ppstructure/docs/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=models/ser_LayoutXLM_xfun_zh/
\ No newline at end of file
......@@ -16,7 +16,7 @@ import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
......@@ -50,18 +50,18 @@ class SerPredictor(object):
'ocr_engine': self.ocr_engine
}
}, {
'VQATokenPad':{
'max_seq_len':512,
'VQATokenPad': {
'max_seq_len': 512,
'return_attention_mask': True
}
}, {
'VQASerTokenChunk':{
'max_seq_len':512,
'VQASerTokenChunk': {
'max_seq_len': 512,
'return_attention_mask': True
}
}, {
'Resize':{
'size' : [224, 224]
'Resize': {
'size': [224, 224]
}
}, {
'NormalizeImage': {
......@@ -75,8 +75,8 @@ class SerPredictor(object):
}, {
'KeepKeys': {
'keep_keys': [
'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels',
'segment_offset_id', 'ocr_info',
'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
'image', 'labels', 'segment_offset_id', 'ocr_info',
'entities'
]
}
......@@ -86,7 +86,8 @@ class SerPredictor(object):
"class_path": args.ser_dict_path,
}
self.preprocess_op = create_operators(pre_process_list, {'infer_mode':True})
self.preprocess_op = create_operators(pre_process_list,
{'infer_mode': True})
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'ser', logger)
......@@ -115,9 +116,7 @@ class SerPredictor(object):
preds = outputs[0]
post_result = self.postprocess_op(
preds,
segment_offset_ids=[data[6]],
ocr_infos=[data[7]])
preds, segment_offset_ids=[data[6]], ocr_infos=[data[7]])
elapse = time.time() - starttime
return post_result, elapse
......@@ -136,17 +135,25 @@ def main(args):
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
img = img[:,:,::-1]
img = img[:, :, ::-1]
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
ser_res, elapse = ser_predictor(img)
ser_res = ser_res[0]
res_str = '{}\t{}\n'.format(image_file,json.dumps({"ocr_info": ser_res,}, ensure_ascii=False))
res_str = '{}\t{}\n'.format(
image_file,
json.dumps(
{
"ocr_info": ser_res,
}, ensure_ascii=False))
f_w.write(res_str)
img_res = draw_ser_results(image_file, ser_res, font_path="../doc/fonts/simfang.ttf",)
img_res = draw_ser_results(
image_file,
ser_res,
font_path="../doc/fonts/simfang.ttf", )
img_save_path = os.path.join(args.output,
os.path.basename(image_file))
......@@ -157,5 +164,6 @@ def main(args):
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))
if __name__ == "__main__":
main(parse_args())
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册