diff --git a/ppocr/modeling/necks/rnn.py b/ppocr/modeling/necks/rnn.py index 004ec5641c2c4c33037e5a9af38ad255331a67f5..582be7c43683cc00af825994d3c05b3ad79d0882 100644 --- a/ppocr/modeling/necks/rnn.py +++ b/ppocr/modeling/necks/rnn.py @@ -21,18 +21,6 @@ from paddle import nn from ppocr.modeling.heads.rec_ctc_head import get_para_bias_attr -class EncoderWithReshape(nn.Layer): - def __init__(self, in_channels, **kwargs): - super().__init__() - self.out_channels = in_channels - - def forward(self, x): - B, C, H, W = x.shape - x = x.reshape((B, C, -1)) - x = x.transpose([0, 2, 1]) # (NTC)(batch, width, channels) - return x - - class Im2Seq(nn.Layer): def __init__(self, in_channels, **kwargs): super().__init__() @@ -40,9 +28,8 @@ class Im2Seq(nn.Layer): def forward(self, x): B, C, H, W = x.shape - assert H == 1 - x = x.transpose((0, 2, 3, 1)) - x = x.reshape((-1, C)) + x = x.reshape((B, -1, W)) + x = x.transpose((0, 2, 1)) # (NTC)(batch, width, channels) return x @@ -50,49 +37,10 @@ class EncoderWithRNN(nn.Layer): def __init__(self, in_channels, hidden_size): super(EncoderWithRNN, self).__init__() self.out_channels = hidden_size * 2 - # self.lstm1_fw = nn.LSTMCell( - # in_channels, - # hidden_size, - # weight_ih_attr=ParamAttr(name='lstm_st1_fc1_w'), - # bias_ih_attr=ParamAttr(name='lstm_st1_fc1_b'), - # weight_hh_attr=ParamAttr(name='lstm_st1_out1_w'), - # bias_hh_attr=ParamAttr(name='lstm_st1_out1_b'), - # ) - # self.lstm1_bw = nn.LSTMCell( - # in_channels, - # hidden_size, - # weight_ih_attr=ParamAttr(name='lstm_st1_fc2_w'), - # bias_ih_attr=ParamAttr(name='lstm_st1_fc2_b'), - # weight_hh_attr=ParamAttr(name='lstm_st1_out2_w'), - # bias_hh_attr=ParamAttr(name='lstm_st1_out2_b'), - # ) - # self.lstm2_fw = nn.LSTMCell( - # hidden_size, - # hidden_size, - # weight_ih_attr=ParamAttr(name='lstm_st2_fc1_w'), - # bias_ih_attr=ParamAttr(name='lstm_st2_fc1_b'), - # weight_hh_attr=ParamAttr(name='lstm_st2_out1_w'), - # bias_hh_attr=ParamAttr(name='lstm_st2_out1_b'), - # ) - # self.lstm2_bw = nn.LSTMCell( - # hidden_size, - # hidden_size, - # weight_ih_attr=ParamAttr(name='lstm_st2_fc2_w'), - # bias_ih_attr=ParamAttr(name='lstm_st2_fc2_b'), - # weight_hh_attr=ParamAttr(name='lstm_st2_out2_w'), - # bias_hh_attr=ParamAttr(name='lstm_st2_out2_b'), - # ) self.lstm = nn.LSTM( in_channels, hidden_size, direction='bidirectional', num_layers=2) def forward(self, x): - # fw_x, _ = self.lstm1_fw(x) - # fw_x, _ = self.lstm2_fw(fw_x) - # - # # bw - # bw_x, _ = self.lstm1_bw(x) - # bw_x, _ = self.lstm2_bw(bw_x) - # x = paddle.concat([fw_x, bw_x], axis=2) x, _ = self.lstm(x) return x @@ -118,13 +66,13 @@ class EncoderWithFC(nn.Layer): class SequenceEncoder(nn.Layer): def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs): super(SequenceEncoder, self).__init__() - self.encoder_reshape = EncoderWithReshape(in_channels) + self.encoder_reshape = Im2Seq(in_channels) self.out_channels = self.encoder_reshape.out_channels if encoder_type == 'reshape': self.only_reshape = True else: support_encoder_dict = { - 'reshape': EncoderWithReshape, + 'reshape': Im2Seq, 'fc': EncoderWithFC, 'rnn': EncoderWithRNN }