未验证 提交 1755a154 编写于 作者: 张春乔 提交者: GitHub

fix div 0 error in conv1_transpose (#50000)

上级 52672ea5
...@@ -37,6 +37,11 @@ struct ConcatFunctor<phi::CPUContext, T> { ...@@ -37,6 +37,11 @@ struct ConcatFunctor<phi::CPUContext, T> {
} }
int64_t out_rows = rows, out_cols = 0; int64_t out_rows = rows, out_cols = 0;
PADDLE_ENFORCE_NE(
rows,
0,
phi::errors::InvalidArgument("The input size should not be 0."));
std::vector<int64_t> input_cols(input.size()); std::vector<int64_t> input_cols(input.size());
for (size_t i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
int64_t t_cols = input[i].numel() / rows; int64_t t_cols = input[i].numel() / rows;
......
...@@ -82,5 +82,17 @@ class TestFunctionalConv1DErrorCase2(TestFunctionalConv1DError): ...@@ -82,5 +82,17 @@ class TestFunctionalConv1DErrorCase2(TestFunctionalConv1DError):
self.data_format = "NCL" self.data_format = "NCL"
class TestFunctionalConv1DErrorCase3(TestFunctionalConv1DError):
def setUp(self):
self.input = np.random.randn(6, 0, 6)
self.filter = np.random.randn(6, 0, 0)
self.bias = None
self.padding = 0
self.stride = 1
self.dilation = 1
self.groups = 1
self.data_format = "NCL"
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册