未验证 提交 5c4adfae 编写于 作者: W wangxinxin08 提交者: GitHub

check weight shape of conv1d_transpose (#49417)

* check weight shape of conv1d_transpose

* add unittest case
上级 3e8cec85
......@@ -70,5 +70,17 @@ class TestFunctionalConv1DErrorCase1(TestFunctionalConv1DError):
self.data_format = "NCL"
class TestFunctionalConv1DErrorCase2(TestFunctionalConv1DError):
def setUp(self):
self.input = np.random.randn(1, 3, 3)
self.filter = np.random.randn(3)
self.bias = None
self.padding = 0
self.stride = 1
self.dilation = 1
self.groups = 1
self.data_format = "NCL"
if __name__ == "__main__":
unittest.main()
......@@ -984,6 +984,13 @@ def conv1d_transpose(
)
)
if len(weight.shape) != 3:
raise ValueError(
'Input weight should be 3D tensor, but received weight with the shape of {}'.format(
weight.shape
)
)
op_type = 'conv2d_transpose'
num_filters = weight.shape[1]
if (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册