提交 ca9ea622 编写于 作者: W WenmuZhou

添加im2seq实现

上级 bdad0cef
...@@ -21,18 +21,6 @@ from paddle import nn ...@@ -21,18 +21,6 @@ from paddle import nn
from ppocr.modeling.heads.rec_ctc_head import get_para_bias_attr from ppocr.modeling.heads.rec_ctc_head import get_para_bias_attr
class EncoderWithReshape(nn.Layer):
def __init__(self, in_channels, **kwargs):
super().__init__()
self.out_channels = in_channels
def forward(self, x):
B, C, H, W = x.shape
x = x.reshape((B, C, -1))
x = x.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
return x
class Im2Seq(nn.Layer): class Im2Seq(nn.Layer):
def __init__(self, in_channels, **kwargs): def __init__(self, in_channels, **kwargs):
super().__init__() super().__init__()
...@@ -40,9 +28,8 @@ class Im2Seq(nn.Layer): ...@@ -40,9 +28,8 @@ class Im2Seq(nn.Layer):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
assert H == 1 x = x.reshape((B, -1, W))
x = x.transpose((0, 2, 3, 1)) x = x.transpose((0, 2, 1)) # (NTC)(batch, width, channels)
x = x.reshape((-1, C))
return x return x
...@@ -50,49 +37,10 @@ class EncoderWithRNN(nn.Layer): ...@@ -50,49 +37,10 @@ class EncoderWithRNN(nn.Layer):
def __init__(self, in_channels, hidden_size): def __init__(self, in_channels, hidden_size):
super(EncoderWithRNN, self).__init__() super(EncoderWithRNN, self).__init__()
self.out_channels = hidden_size * 2 self.out_channels = hidden_size * 2
# self.lstm1_fw = nn.LSTMCell(
# in_channels,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st1_fc1_w'),
# bias_ih_attr=ParamAttr(name='lstm_st1_fc1_b'),
# weight_hh_attr=ParamAttr(name='lstm_st1_out1_w'),
# bias_hh_attr=ParamAttr(name='lstm_st1_out1_b'),
# )
# self.lstm1_bw = nn.LSTMCell(
# in_channels,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st1_fc2_w'),
# bias_ih_attr=ParamAttr(name='lstm_st1_fc2_b'),
# weight_hh_attr=ParamAttr(name='lstm_st1_out2_w'),
# bias_hh_attr=ParamAttr(name='lstm_st1_out2_b'),
# )
# self.lstm2_fw = nn.LSTMCell(
# hidden_size,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st2_fc1_w'),
# bias_ih_attr=ParamAttr(name='lstm_st2_fc1_b'),
# weight_hh_attr=ParamAttr(name='lstm_st2_out1_w'),
# bias_hh_attr=ParamAttr(name='lstm_st2_out1_b'),
# )
# self.lstm2_bw = nn.LSTMCell(
# hidden_size,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st2_fc2_w'),
# bias_ih_attr=ParamAttr(name='lstm_st2_fc2_b'),
# weight_hh_attr=ParamAttr(name='lstm_st2_out2_w'),
# bias_hh_attr=ParamAttr(name='lstm_st2_out2_b'),
# )
self.lstm = nn.LSTM( self.lstm = nn.LSTM(
in_channels, hidden_size, direction='bidirectional', num_layers=2) in_channels, hidden_size, direction='bidirectional', num_layers=2)
def forward(self, x): def forward(self, x):
# fw_x, _ = self.lstm1_fw(x)
# fw_x, _ = self.lstm2_fw(fw_x)
#
# # bw
# bw_x, _ = self.lstm1_bw(x)
# bw_x, _ = self.lstm2_bw(bw_x)
# x = paddle.concat([fw_x, bw_x], axis=2)
x, _ = self.lstm(x) x, _ = self.lstm(x)
return x return x
...@@ -118,13 +66,13 @@ class EncoderWithFC(nn.Layer): ...@@ -118,13 +66,13 @@ class EncoderWithFC(nn.Layer):
class SequenceEncoder(nn.Layer): class SequenceEncoder(nn.Layer):
def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs): def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs):
super(SequenceEncoder, self).__init__() super(SequenceEncoder, self).__init__()
self.encoder_reshape = EncoderWithReshape(in_channels) self.encoder_reshape = Im2Seq(in_channels)
self.out_channels = self.encoder_reshape.out_channels self.out_channels = self.encoder_reshape.out_channels
if encoder_type == 'reshape': if encoder_type == 'reshape':
self.only_reshape = True self.only_reshape = True
else: else:
support_encoder_dict = { support_encoder_dict = {
'reshape': EncoderWithReshape, 'reshape': Im2Seq,
'fc': EncoderWithFC, 'fc': EncoderWithFC,
'rnn': EncoderWithRNN 'rnn': EncoderWithRNN
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册