提交 f6698a32 编写于 作者: 文幕地方's avatar 文幕地方

support onnx infer of SLANet

上级 e0194680
...@@ -166,6 +166,7 @@ class SLAHead(nn.Layer): ...@@ -166,6 +166,7 @@ class SLAHead(nn.Layer):
self.max_text_length = max_text_length self.max_text_length = max_text_length
self.emb = self._char_to_onehot self.emb = self._char_to_onehot
self.num_embeddings = out_channels self.num_embeddings = out_channels
self.loc_reg_num = loc_reg_num
# structure # structure
self.structure_attention_cell = AttentionGRUCell( self.structure_attention_cell = AttentionGRUCell(
...@@ -213,16 +214,15 @@ class SLAHead(nn.Layer): ...@@ -213,16 +214,15 @@ class SLAHead(nn.Layer):
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
hidden = paddle.zeros((batch_size, self.hidden_size)) hidden = paddle.zeros((batch_size, self.hidden_size))
structure_preds = [] structure_preds = paddle.zeros((batch_size, self.max_text_length + 1, self.num_embeddings))
loc_preds = [] loc_preds = paddle.zeros((batch_size, self.max_text_length + 1, self.loc_reg_num))
if self.training and targets is not None: if self.training and targets is not None:
structure = targets[0] structure = targets[0]
for i in range(self.max_text_length + 1): for i in range(self.max_text_length + 1):
hidden, structure_step, loc_step = self._decode(structure[:, i], hidden, structure_step, loc_step = self._decode(structure[:, i],
fea, hidden) fea, hidden)
structure_preds.append(structure_step) structure_preds[:, i, :] = structure_step
loc_preds.append(loc_step) loc_preds[:, i, :] = loc_step
else:
pre_chars = paddle.zeros(shape=[batch_size], dtype="int32") pre_chars = paddle.zeros(shape=[batch_size], dtype="int32")
max_text_length = paddle.to_tensor(self.max_text_length) max_text_length = paddle.to_tensor(self.max_text_length)
# for export # for export
...@@ -231,10 +231,8 @@ class SLAHead(nn.Layer): ...@@ -231,10 +231,8 @@ class SLAHead(nn.Layer):
hidden, structure_step, loc_step = self._decode(pre_chars, fea, hidden, structure_step, loc_step = self._decode(pre_chars, fea,
hidden) hidden)
pre_chars = structure_step.argmax(axis=1, dtype="int32") pre_chars = structure_step.argmax(axis=1, dtype="int32")
structure_preds.append(structure_step) structure_preds[:, i, :] = structure_step
loc_preds.append(loc_step) loc_preds[:, i, :] = loc_step
structure_preds = paddle.stack(structure_preds, axis=1)
loc_preds = paddle.stack(loc_preds, axis=1)
if not self.training: if not self.training:
structure_preds = F.softmax(structure_preds) structure_preds = F.softmax(structure_preds)
return {'structure_probs': structure_preds, 'loc_preds': loc_preds} return {'structure_probs': structure_preds, 'loc_preds': loc_preds}
......
...@@ -68,6 +68,7 @@ def build_pre_process_list(args): ...@@ -68,6 +68,7 @@ def build_pre_process_list(args):
class TableStructurer(object): class TableStructurer(object):
def __init__(self, args): def __init__(self, args):
self.use_onnx = args.use_onnx
pre_process_list = build_pre_process_list(args) pre_process_list = build_pre_process_list(args)
if args.table_algorithm not in ['TableMaster']: if args.table_algorithm not in ['TableMaster']:
postprocess_params = { postprocess_params = {
...@@ -98,13 +99,17 @@ class TableStructurer(object): ...@@ -98,13 +99,17 @@ class TableStructurer(object):
return None, 0 return None, 0
img = np.expand_dims(img, axis=0) img = np.expand_dims(img, axis=0)
img = img.copy() img = img.copy()
if self.use_onnx:
self.input_tensor.copy_from_cpu(img) input_dict = {}
self.predictor.run() input_dict[self.input_tensor.name] = img
outputs = [] outputs = self.predictor.run(self.output_tensors, input_dict)
for output_tensor in self.output_tensors: else:
output = output_tensor.copy_to_cpu() self.input_tensor.copy_from_cpu(img)
outputs.append(output) self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = {} preds = {}
preds['structure_probs'] = outputs[1] preds['structure_probs'] = outputs[1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册