提交 fb2b7f54 编写于 作者: 文幕地方's avatar 文幕地方

add paddle2onnx support for tablerec-rare

上级 321b0a77
...@@ -43,7 +43,6 @@ Architecture: ...@@ -43,7 +43,6 @@ Architecture:
Head: Head:
name: TableAttentionHead name: TableAttentionHead
hidden_size: 256 hidden_size: 256
loc_type: 2
max_text_length: *max_text_length max_text_length: *max_text_length
loc_reg_num: &loc_reg_num 4 loc_reg_num: &loc_reg_num 4
......
...@@ -16,6 +16,7 @@ from __future__ import absolute_import ...@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import math
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from paddle import ParamAttr from paddle import ParamAttr
...@@ -42,7 +43,6 @@ class TableAttentionHead(nn.Layer): ...@@ -42,7 +43,6 @@ class TableAttentionHead(nn.Layer):
def __init__(self, def __init__(self,
in_channels, in_channels,
hidden_size, hidden_size,
loc_type,
in_max_len=488, in_max_len=488,
max_text_length=800, max_text_length=800,
out_channels=30, out_channels=30,
...@@ -57,20 +57,16 @@ class TableAttentionHead(nn.Layer): ...@@ -57,20 +57,16 @@ class TableAttentionHead(nn.Layer):
self.structure_attention_cell = AttentionGRUCell( self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.out_channels, use_gru=False) self.input_size, hidden_size, self.out_channels, use_gru=False)
self.structure_generator = nn.Linear(hidden_size, self.out_channels) self.structure_generator = nn.Linear(hidden_size, self.out_channels)
self.loc_type = loc_type
self.in_max_len = in_max_len self.in_max_len = in_max_len
if self.loc_type == 1: if self.in_max_len == 640:
self.loc_generator = nn.Linear(hidden_size, 4) self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1)
elif self.in_max_len == 800:
self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1)
else: else:
if self.in_max_len == 640: self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1) self.loc_generator = nn.Linear(self.input_size + hidden_size,
elif self.in_max_len == 800: loc_reg_num)
self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1)
else:
self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
self.loc_generator = nn.Linear(self.input_size + hidden_size,
loc_reg_num)
def _char_to_onehot(self, input_char, onehot_dim): def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim) input_ont_hot = F.one_hot(input_char, onehot_dim)
...@@ -80,16 +76,13 @@ class TableAttentionHead(nn.Layer): ...@@ -80,16 +76,13 @@ class TableAttentionHead(nn.Layer):
# if and else branch are both needed when you want to assign a variable # if and else branch are both needed when you want to assign a variable
# if you modify the var in just one branch, then the modification will not work. # if you modify the var in just one branch, then the modification will not work.
fea = inputs[-1] fea = inputs[-1]
if len(fea.shape) == 3: last_shape = int(np.prod(fea.shape[2:])) # gry added
pass fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
else: fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
last_shape = int(np.prod(fea.shape[2:])) # gry added
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
batch_size = fea.shape[0] batch_size = fea.shape[0]
hidden = paddle.zeros((batch_size, self.hidden_size)) hidden = paddle.zeros((batch_size, self.hidden_size))
output_hiddens = [] output_hiddens = paddle.zeros((batch_size, self.max_text_length + 1, self.hidden_size))
if self.training and targets is not None: if self.training and targets is not None:
structure = targets[0] structure = targets[0]
for i in range(self.max_text_length + 1): for i in range(self.max_text_length + 1):
...@@ -97,7 +90,8 @@ class TableAttentionHead(nn.Layer): ...@@ -97,7 +90,8 @@ class TableAttentionHead(nn.Layer):
structure[:, i], onehot_dim=self.out_channels) structure[:, i], onehot_dim=self.out_channels)
(outputs, hidden), alpha = self.structure_attention_cell( (outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots) hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) output_hiddens[:, i, :] = outputs
# output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
output = paddle.concat(output_hiddens, axis=1) output = paddle.concat(output_hiddens, axis=1)
structure_probs = self.structure_generator(output) structure_probs = self.structure_generator(output)
if self.loc_type == 1: if self.loc_type == 1:
...@@ -118,30 +112,25 @@ class TableAttentionHead(nn.Layer): ...@@ -118,30 +112,25 @@ class TableAttentionHead(nn.Layer):
outputs = None outputs = None
alpha = None alpha = None
max_text_length = paddle.to_tensor(self.max_text_length) max_text_length = paddle.to_tensor(self.max_text_length)
i = 0 for i in range(max_text_length + 1):
while i < max_text_length + 1:
elem_onehots = self._char_to_onehot( elem_onehots = self._char_to_onehot(
temp_elem, onehot_dim=self.out_channels) temp_elem, onehot_dim=self.out_channels)
(outputs, hidden), alpha = self.structure_attention_cell( (outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots) hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) output_hiddens[:, i, :] = outputs
# output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
structure_probs_step = self.structure_generator(outputs) structure_probs_step = self.structure_generator(outputs)
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32") temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
i += 1
output = paddle.concat(output_hiddens, axis=1) output = output_hiddens
structure_probs = self.structure_generator(output) structure_probs = self.structure_generator(output)
structure_probs = F.softmax(structure_probs) structure_probs = F.softmax(structure_probs)
if self.loc_type == 1: loc_fea = fea.transpose([0, 2, 1])
loc_preds = self.loc_generator(output) loc_fea = self.loc_fea_trans(loc_fea)
loc_preds = F.sigmoid(loc_preds) loc_fea = loc_fea.transpose([0, 2, 1])
else: loc_concat = paddle.concat([output, loc_fea], axis=2)
loc_fea = fea.transpose([0, 2, 1]) loc_preds = self.loc_generator(loc_concat)
loc_fea = self.loc_fea_trans(loc_fea) loc_preds = F.sigmoid(loc_preds)
loc_fea = loc_fea.transpose([0, 2, 1])
loc_concat = paddle.concat([output, loc_fea], axis=2)
loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds)
return {'structure_probs': structure_probs, 'loc_preds': loc_preds} return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册