未验证 提交 7a0fdeb9 编写于 作者: 张春乔 提交者: GitHub

fix div 0 error in conv1/2/3 (#49999)

上级 36e5de81
...@@ -467,6 +467,13 @@ void ConvInferMeta(const MetaTensor& input, ...@@ -467,6 +467,13 @@ void ConvInferMeta(const MetaTensor& input,
const bool channel_last = (config.is_run_mkldnn_kernel == false) && const bool channel_last = (config.is_run_mkldnn_kernel == false) &&
(data_format == "NHWC" || data_format == "NDHWC"); (data_format == "NHWC" || data_format == "NDHWC");
for (int i = 0; i < 2; ++i) {
PADDLE_ENFORCE_NE(in_dims[i],
0,
phi::errors::InvalidArgument(
"The size of Op(Conv) inputs should not be 0."));
}
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_dims.size() == 4 || in_dims.size() == 5, in_dims.size() == 4 || in_dims.size() == 5,
true, true,
......
...@@ -70,5 +70,17 @@ class TestFunctionalConv1DErrorCase1(TestFunctionalConv1DError): ...@@ -70,5 +70,17 @@ class TestFunctionalConv1DErrorCase1(TestFunctionalConv1DError):
self.data_format = "NCL" self.data_format = "NCL"
class TestFunctionalConv1DErrorCase2(TestFunctionalConv1DError):
def setUp(self):
self.input = np.random.randn(0, 0, 0)
self.filter = np.random.randn(1, 0, 0)
self.bias = None
self.padding = 0
self.stride = 1
self.dilation = 1
self.groups = 1
self.data_format = "NCL"
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -569,6 +569,20 @@ class TestFunctionalConv2DErrorCase13(TestFunctionalConv2DErrorCase12): ...@@ -569,6 +569,20 @@ class TestFunctionalConv2DErrorCase13(TestFunctionalConv2DErrorCase12):
self.data_format = "NCHW" self.data_format = "NCHW"
class TestFunctionalConv2DErrorCase14(TestFunctionalConv2DErrorCase12):
def setUp(self):
self.input = np.random.randn(0, 0, 0, 0)
self.filter = np.random.randn(1, 0, 0, 0)
self.num_filters = 0
self.filter_size = 0
self.bias = None
self.padding = 0
self.stride = 1
self.dilation = 1
self.groups = 1
self.data_format = "NCHW"
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
...@@ -544,6 +544,20 @@ class TestFunctionalConv3DErrorCase12(TestFunctionalConv3DErrorCase11): ...@@ -544,6 +544,20 @@ class TestFunctionalConv3DErrorCase12(TestFunctionalConv3DErrorCase11):
self.data_format = "NCDHW" self.data_format = "NCDHW"
class TestFunctionalConv3DErrorCase13(TestFunctionalConv3DErrorCase11):
def setUp(self):
self.input = np.random.randn(0, 0, 0, 0, 0)
self.filter = np.random.randn(1, 0, 0, 0, 0)
self.num_filters = 1
self.filter_size = 1
self.bias = None
self.padding = 0
self.stride = 1
self.dilation = 1
self.groups = 1
self.data_format = "NCDHW"
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册