提交 807dd106 编写于 作者: 文幕地方's avatar 文幕地方

pre-commit

上级 dc7bfe8a
...@@ -121,14 +121,14 @@ class LayoutXLMForSer(NLPBaseModel): ...@@ -121,14 +121,14 @@ class LayoutXLMForSer(NLPBaseModel):
def forward(self, x): def forward(self, x):
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[1], bbox=x[1],
attention_mask=x[2], attention_mask=x[2],
token_type_ids=x[3], token_type_ids=x[3],
image=x[4], image=x[4],
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
labels=None) labels=None)
if not self.training: if not self.training:
return x return x
return x[0] return x[0]
......
...@@ -68,7 +68,8 @@ class VQASerTokenLayoutLMPostProcess(object): ...@@ -68,7 +68,8 @@ class VQASerTokenLayoutLMPostProcess(object):
def _infer(self, preds, segment_offset_ids, ocr_infos): def _infer(self, preds, segment_offset_ids, ocr_infos):
results = [] results = []
for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids, ocr_infos): for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids,
ocr_infos):
pred = np.argmax(pred, axis=1) pred = np.argmax(pred, axis=1)
pred = [self.id2label_map[idx] for idx in pred] pred = [self.id2label_map[idx] for idx in pred]
......
...@@ -40,7 +40,6 @@ def init_args(): ...@@ -40,7 +40,6 @@ def init_args():
type=ast.literal_eval, type=ast.literal_eval,
default=None, default=None,
help='label map according to ppstructure/layout/README_ch.md') help='label map according to ppstructure/layout/README_ch.md')
# params for vqa # params for vqa
parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM') parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM')
parser.add_argument("--ser_model_dir", type=str) parser.add_argument("--ser_model_dir", type=str)
...@@ -73,7 +72,7 @@ def init_args(): ...@@ -73,7 +72,7 @@ def init_args():
"--recovery", "--recovery",
type=bool, type=bool,
default=False, default=False,
help='Whether to enable layout of recovery') help='Whether to enable layout of recovery')
return parser return parser
......
...@@ -97,8 +97,9 @@ def export_single_model(model, ...@@ -97,8 +97,9 @@ def export_single_model(model,
shape=[None, 1, 32, 100], dtype="float32"), shape=[None, 1, 32, 100], dtype="float32"),
] ]
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]: elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec=[ input_spec = [
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 512], dtype="int64"), # input_ids shape=[None, 512], dtype="int64"), # input_ids
paddle.static.InputSpec( paddle.static.InputSpec(
......
...@@ -318,7 +318,7 @@ def create_predictor(args, mode, logger): ...@@ -318,7 +318,7 @@ def create_predictor(args, mode, logger):
# create predictor # create predictor
predictor = inference.create_predictor(config) predictor = inference.create_predictor(config)
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
if mode in ['ser','re']: if mode in ['ser', 're']:
input_tensor = [] input_tensor = []
for name in input_names: for name in input_names:
input_tensor.append(predictor.get_input_handle(name)) input_tensor.append(predictor.get_input_handle(name))
......
...@@ -44,7 +44,7 @@ def to_tensor(data): ...@@ -44,7 +44,7 @@ def to_tensor(data):
from collections import defaultdict from collections import defaultdict
data_dict = defaultdict(list) data_dict = defaultdict(list)
to_tensor_idxs = [] to_tensor_idxs = []
for idx, v in enumerate(data): for idx, v in enumerate(data):
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if idx not in to_tensor_idxs: if idx not in to_tensor_idxs:
...@@ -72,7 +72,10 @@ class SerPredictor(object): ...@@ -72,7 +72,10 @@ class SerPredictor(object):
from paddleocr import PaddleOCR from paddleocr import PaddleOCR
self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False, use_gpu=global_config['use_gpu']) self.ocr_engine = PaddleOCR(
use_angle_cls=False,
show_log=False,
use_gpu=global_config['use_gpu'])
# create data ops # create data ops
transforms = [] transforms = []
...@@ -82,8 +85,8 @@ class SerPredictor(object): ...@@ -82,8 +85,8 @@ class SerPredictor(object):
op[op_name]['ocr_engine'] = self.ocr_engine op[op_name]['ocr_engine'] = self.ocr_engine
elif op_name == 'KeepKeys': elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = [ op[op_name]['keep_keys'] = [
'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels', 'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
'segment_offset_id', 'ocr_info', 'image', 'labels', 'segment_offset_id', 'ocr_info',
'entities' 'entities'
] ]
...@@ -103,11 +106,9 @@ class SerPredictor(object): ...@@ -103,11 +106,9 @@ class SerPredictor(object):
preds = self.model(batch) preds = self.model(batch)
if self.algorithm in ['LayoutLMv2', 'LayoutXLM']: if self.algorithm in ['LayoutLMv2', 'LayoutXLM']:
preds = preds[0] preds = preds[0]
post_result = self.post_process_class( post_result = self.post_process_class(
preds, preds, segment_offset_ids=batch[6], ocr_infos=batch[7])
segment_offset_ids=batch[6],
ocr_infos=batch[7])
return post_result, batch return post_result, batch
...@@ -154,4 +155,3 @@ if __name__ == '__main__': ...@@ -154,4 +155,3 @@ if __name__ == '__main__':
logger.info("process: [{}/{}], save result to {}".format( logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path)) idx, len(infer_imgs), save_img_path))
...@@ -192,6 +192,6 @@ if __name__ == '__main__': ...@@ -192,6 +192,6 @@ if __name__ == '__main__':
}, ensure_ascii=False) + "\n") }, ensure_ascii=False) + "\n")
img_res = draw_re_results(img_path, result) img_res = draw_re_results(img_path, result)
cv2.imwrite(save_img_path, img_res) cv2.imwrite(save_img_path, img_res)
logger.info("process: [{}/{}], save result to {}".format( logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path)) idx, len(infer_imgs), save_img_path))
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册