diff --git a/ppocr/modeling/necks/rnn.py b/ppocr/modeling/necks/rnn.py index 810c2c8d3b28fd19c551fdba4efc335637e57617..de87b3d9895168657f8c9722177c026b992c2966 100644 --- a/ppocr/modeling/necks/rnn.py +++ b/ppocr/modeling/necks/rnn.py @@ -28,8 +28,9 @@ class Im2Seq(nn.Layer): def forward(self, x): B, C, H, W = x.shape - x = x.reshape((B, -1, W)) - x = x.transpose((0, 2, 1)) # (NTC)(batch, width, channels) + assert H == 1 + x = x.squeeze(axis=2) + x = x.transpose([0, 2, 1]) # (NTC)(batch, width, channels) return x @@ -76,7 +77,8 @@ class SequenceEncoder(nn.Layer): 'fc': EncoderWithFC, 'rnn': EncoderWithRNN } - assert encoder_type in support_encoder_dict, '{} must in {}'.format(encoder_type, support_encoder_dict.keys()) + assert encoder_type in support_encoder_dict, '{} must in {}'.format( + encoder_type, support_encoder_dict.keys()) self.encoder = support_encoder_dict[encoder_type]( self.encoder_reshape.out_channels, hidden_size)