提交 b8997058 编写于 作者: H Hui Zhang

size to shape

上级 2ba520d9
...@@ -128,7 +128,7 @@ class Conv2dSubsampling4(BaseSubsampling): ...@@ -128,7 +128,7 @@ class Conv2dSubsampling4(BaseSubsampling):
""" """
x = x.unsqueeze(1) # (b, c=1, t, f) x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x) x = self.conv(x)
b, c, t, f = paddle.shape(x) b, c, t, f = x.shape
x = self.out(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f])) x = self.out(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f]))
x, pos_emb = self.pos_enc(x, offset) x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]
...@@ -181,7 +181,7 @@ class Conv2dSubsampling6(BaseSubsampling): ...@@ -181,7 +181,7 @@ class Conv2dSubsampling6(BaseSubsampling):
""" """
x = x.unsqueeze(1) # (b, c, t, f) x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x) x = self.conv(x)
b, c, t, f = paddle.shape(x) b, c, t, f = x.shape
x = self.linear(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f])) x = self.linear(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f]))
x, pos_emb = self.pos_enc(x, offset) x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3] return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册