From 65d3dfc729b820ec04c71084a7573f20698f2a14 Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Tue, 10 Nov 2020 17:18:50 +0800 Subject: [PATCH] =?UTF-8?q?rnn=E6=94=AF=E6=8C=81=E5=AF=BC=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppocr/modeling/necks/rnn.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ppocr/modeling/necks/rnn.py b/ppocr/modeling/necks/rnn.py index 810c2c8d..de87b3d9 100644 --- a/ppocr/modeling/necks/rnn.py +++ b/ppocr/modeling/necks/rnn.py @@ -28,8 +28,9 @@ class Im2Seq(nn.Layer): def forward(self, x): B, C, H, W = x.shape - x = x.reshape((B, -1, W)) - x = x.transpose((0, 2, 1)) # (NTC)(batch, width, channels) + assert H == 1 + x = x.squeeze(axis=2) + x = x.transpose([0, 2, 1]) # (NTC)(batch, width, channels) return x @@ -76,7 +77,8 @@ class SequenceEncoder(nn.Layer): 'fc': EncoderWithFC, 'rnn': EncoderWithRNN } - assert encoder_type in support_encoder_dict, '{} must in {}'.format(encoder_type, support_encoder_dict.keys()) + assert encoder_type in support_encoder_dict, '{} must in {}'.format( + encoder_type, support_encoder_dict.keys()) self.encoder = support_encoder_dict[encoder_type]( self.encoder_reshape.out_channels, hidden_size) -- GitLab