提交 807dd106 编写于 作者: 文幕地方's avatar 文幕地方

pre-commit

上级 dc7bfe8a
......@@ -68,7 +68,8 @@ class VQASerTokenLayoutLMPostProcess(object):
def _infer(self, preds, segment_offset_ids, ocr_infos):
results = []
for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids, ocr_infos):
for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids,
ocr_infos):
pred = np.argmax(pred, axis=1)
pred = [self.id2label_map[idx] for idx in pred]
......
......@@ -40,7 +40,6 @@ def init_args():
type=ast.literal_eval,
default=None,
help='label map according to ppstructure/layout/README_ch.md')
# params for vqa
parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM')
parser.add_argument("--ser_model_dir", type=str)
......
......@@ -97,8 +97,9 @@ def export_single_model(model,
shape=[None, 1, 32, 100], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec=[
input_spec = [
paddle.static.InputSpec(
shape=[None, 512], dtype="int64"), # input_ids
paddle.static.InputSpec(
......
......@@ -318,7 +318,7 @@ def create_predictor(args, mode, logger):
# create predictor
predictor = inference.create_predictor(config)
input_names = predictor.get_input_names()
if mode in ['ser','re']:
if mode in ['ser', 're']:
input_tensor = []
for name in input_names:
input_tensor.append(predictor.get_input_handle(name))
......
......@@ -72,7 +72,10 @@ class SerPredictor(object):
from paddleocr import PaddleOCR
self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False, use_gpu=global_config['use_gpu'])
self.ocr_engine = PaddleOCR(
use_angle_cls=False,
show_log=False,
use_gpu=global_config['use_gpu'])
# create data ops
transforms = []
......@@ -82,8 +85,8 @@ class SerPredictor(object):
op[op_name]['ocr_engine'] = self.ocr_engine
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = [
'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels',
'segment_offset_id', 'ocr_info',
'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
'image', 'labels', 'segment_offset_id', 'ocr_info',
'entities'
]
......@@ -105,9 +108,7 @@ class SerPredictor(object):
preds = preds[0]
post_result = self.post_process_class(
preds,
segment_offset_ids=batch[6],
ocr_infos=batch[7])
preds, segment_offset_ids=batch[6], ocr_infos=batch[7])
return post_result, batch
......@@ -154,4 +155,3 @@ if __name__ == '__main__':
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册