提交 807dd106 编写于 作者: 文幕地方's avatar 文幕地方

pre-commit

上级 dc7bfe8a
...@@ -68,7 +68,8 @@ class VQASerTokenLayoutLMPostProcess(object): ...@@ -68,7 +68,8 @@ class VQASerTokenLayoutLMPostProcess(object):
def _infer(self, preds, segment_offset_ids, ocr_infos): def _infer(self, preds, segment_offset_ids, ocr_infos):
results = [] results = []
for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids, ocr_infos): for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids,
ocr_infos):
pred = np.argmax(pred, axis=1) pred = np.argmax(pred, axis=1)
pred = [self.id2label_map[idx] for idx in pred] pred = [self.id2label_map[idx] for idx in pred]
......
...@@ -40,7 +40,6 @@ def init_args(): ...@@ -40,7 +40,6 @@ def init_args():
type=ast.literal_eval, type=ast.literal_eval,
default=None, default=None,
help='label map according to ppstructure/layout/README_ch.md') help='label map according to ppstructure/layout/README_ch.md')
# params for vqa # params for vqa
parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM') parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM')
parser.add_argument("--ser_model_dir", type=str) parser.add_argument("--ser_model_dir", type=str)
......
...@@ -97,8 +97,9 @@ def export_single_model(model, ...@@ -97,8 +97,9 @@ def export_single_model(model,
shape=[None, 1, 32, 100], dtype="float32"), shape=[None, 1, 32, 100], dtype="float32"),
] ]
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]: elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec=[ input_spec = [
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 512], dtype="int64"), # input_ids shape=[None, 512], dtype="int64"), # input_ids
paddle.static.InputSpec( paddle.static.InputSpec(
......
...@@ -318,7 +318,7 @@ def create_predictor(args, mode, logger): ...@@ -318,7 +318,7 @@ def create_predictor(args, mode, logger):
# create predictor # create predictor
predictor = inference.create_predictor(config) predictor = inference.create_predictor(config)
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
if mode in ['ser','re']: if mode in ['ser', 're']:
input_tensor = [] input_tensor = []
for name in input_names: for name in input_names:
input_tensor.append(predictor.get_input_handle(name)) input_tensor.append(predictor.get_input_handle(name))
......
...@@ -72,7 +72,10 @@ class SerPredictor(object): ...@@ -72,7 +72,10 @@ class SerPredictor(object):
from paddleocr import PaddleOCR from paddleocr import PaddleOCR
self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False, use_gpu=global_config['use_gpu']) self.ocr_engine = PaddleOCR(
use_angle_cls=False,
show_log=False,
use_gpu=global_config['use_gpu'])
# create data ops # create data ops
transforms = [] transforms = []
...@@ -82,8 +85,8 @@ class SerPredictor(object): ...@@ -82,8 +85,8 @@ class SerPredictor(object):
op[op_name]['ocr_engine'] = self.ocr_engine op[op_name]['ocr_engine'] = self.ocr_engine
elif op_name == 'KeepKeys': elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = [ op[op_name]['keep_keys'] = [
'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels', 'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
'segment_offset_id', 'ocr_info', 'image', 'labels', 'segment_offset_id', 'ocr_info',
'entities' 'entities'
] ]
...@@ -105,9 +108,7 @@ class SerPredictor(object): ...@@ -105,9 +108,7 @@ class SerPredictor(object):
preds = preds[0] preds = preds[0]
post_result = self.post_process_class( post_result = self.post_process_class(
preds, preds, segment_offset_ids=batch[6], ocr_infos=batch[7])
segment_offset_ids=batch[6],
ocr_infos=batch[7])
return post_result, batch return post_result, batch
...@@ -154,4 +155,3 @@ if __name__ == '__main__': ...@@ -154,4 +155,3 @@ if __name__ == '__main__':
logger.info("process: [{}/{}], save result to {}".format( logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path)) idx, len(infer_imgs), save_img_path))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册