提交 65d3dfc7 编写于 作者: W WenmuZhou

rnn支持导出

上级 2f9f258f
...@@ -28,8 +28,9 @@ class Im2Seq(nn.Layer): ...@@ -28,8 +28,9 @@ class Im2Seq(nn.Layer):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
x = x.reshape((B, -1, W)) assert H == 1
x = x.transpose((0, 2, 1)) # (NTC)(batch, width, channels) x = x.squeeze(axis=2)
x = x.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
return x return x
...@@ -76,7 +77,8 @@ class SequenceEncoder(nn.Layer): ...@@ -76,7 +77,8 @@ class SequenceEncoder(nn.Layer):
'fc': EncoderWithFC, 'fc': EncoderWithFC,
'rnn': EncoderWithRNN 'rnn': EncoderWithRNN
} }
assert encoder_type in support_encoder_dict, '{} must in {}'.format(encoder_type, support_encoder_dict.keys()) assert encoder_type in support_encoder_dict, '{} must in {}'.format(
encoder_type, support_encoder_dict.keys())
self.encoder = support_encoder_dict[encoder_type]( self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels, hidden_size) self.encoder_reshape.out_channels, hidden_size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册