提交 807dd106 编写于 作者: 文幕地方's avatar 文幕地方

pre-commit

上级 dc7bfe8a
......@@ -121,14 +121,14 @@ class LayoutXLMForSer(NLPBaseModel):
def forward(self, x):
x = self.model(
input_ids=x[0],
bbox=x[1],
attention_mask=x[2],
token_type_ids=x[3],
image=x[4],
position_ids=None,
head_mask=None,
labels=None)
input_ids=x[0],
bbox=x[1],
attention_mask=x[2],
token_type_ids=x[3],
image=x[4],
position_ids=None,
head_mask=None,
labels=None)
if not self.training:
return x
return x[0]
......
......@@ -68,7 +68,8 @@ class VQASerTokenLayoutLMPostProcess(object):
def _infer(self, preds, segment_offset_ids, ocr_infos):
results = []
for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids, ocr_infos):
for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids,
ocr_infos):
pred = np.argmax(pred, axis=1)
pred = [self.id2label_map[idx] for idx in pred]
......
......@@ -40,7 +40,6 @@ def init_args():
type=ast.literal_eval,
default=None,
help='label map according to ppstructure/layout/README_ch.md')
# params for vqa
parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM')
parser.add_argument("--ser_model_dir", type=str)
......@@ -73,7 +72,7 @@ def init_args():
"--recovery",
type=bool,
default=False,
help='Whether to enable layout of recovery')
help='Whether to enable layout of recovery')
return parser
......
......@@ -97,8 +97,9 @@ def export_single_model(model,
shape=[None, 1, 32, 100], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec=[
input_spec = [
paddle.static.InputSpec(
shape=[None, 512], dtype="int64"), # input_ids
paddle.static.InputSpec(
......
......@@ -318,7 +318,7 @@ def create_predictor(args, mode, logger):
# create predictor
predictor = inference.create_predictor(config)
input_names = predictor.get_input_names()
if mode in ['ser','re']:
if mode in ['ser', 're']:
input_tensor = []
for name in input_names:
input_tensor.append(predictor.get_input_handle(name))
......
......@@ -44,7 +44,7 @@ def to_tensor(data):
from collections import defaultdict
data_dict = defaultdict(list)
to_tensor_idxs = []
for idx, v in enumerate(data):
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if idx not in to_tensor_idxs:
......@@ -72,7 +72,10 @@ class SerPredictor(object):
from paddleocr import PaddleOCR
self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False, use_gpu=global_config['use_gpu'])
self.ocr_engine = PaddleOCR(
use_angle_cls=False,
show_log=False,
use_gpu=global_config['use_gpu'])
# create data ops
transforms = []
......@@ -82,8 +85,8 @@ class SerPredictor(object):
op[op_name]['ocr_engine'] = self.ocr_engine
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = [
'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels',
'segment_offset_id', 'ocr_info',
'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
'image', 'labels', 'segment_offset_id', 'ocr_info',
'entities'
]
......@@ -103,11 +106,9 @@ class SerPredictor(object):
preds = self.model(batch)
if self.algorithm in ['LayoutLMv2', 'LayoutXLM']:
preds = preds[0]
post_result = self.post_process_class(
preds,
segment_offset_ids=batch[6],
ocr_infos=batch[7])
preds, segment_offset_ids=batch[6], ocr_infos=batch[7])
return post_result, batch
......@@ -154,4 +155,3 @@ if __name__ == '__main__':
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
......@@ -192,6 +192,6 @@ if __name__ == '__main__':
}, ensure_ascii=False) + "\n")
img_res = draw_re_results(img_path, result)
cv2.imwrite(save_img_path, img_res)
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
\ No newline at end of file
idx, len(infer_imgs), save_img_path))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册