From 8e19d1ba8ccca926e7bf18055ef3f1ee2b62ff13 Mon Sep 17 00:00:00 2001 From: wawltor Date: Fri, 24 Sep 2021 10:06:17 +0800 Subject: [PATCH] add the shape check for the matmul (#35791) * add the shape check for the matmul * remove the test case for the linear --- paddle/fluid/operators/matmul_v2_op.h | 8 ++++++++ python/paddle/fluid/tests/unittests/test_linear.py | 9 --------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index 58e57c3914..dd9940db29 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -380,6 +380,14 @@ class MatMulV2Kernel : public framework::OpKernel { auto* Out = ctx.Output("Out"); bool trans_x = ctx.Attr("trans_x"); bool trans_y = ctx.Attr("trans_y"); + PADDLE_ENFORCE_NE(framework::product(X->dims()), 0, + platform::errors::InvalidArgument( + "The Input(X) dims size must not be equal 0," + " but reviced dims size is 0. ")); + PADDLE_ENFORCE_NE(framework::product(Y->dims()), 0, + platform::errors::InvalidArgument( + "The Input(Y) dims size must not be equal 0," + " but reviced dims size is 0. ")); MatMulFunction(X, Y, Out, trans_x, trans_y, ctx); } }; diff --git a/python/paddle/fluid/tests/unittests/test_linear.py b/python/paddle/fluid/tests/unittests/test_linear.py index 59f38d7cad..9d07a80da1 100644 --- a/python/paddle/fluid/tests/unittests/test_linear.py +++ b/python/paddle/fluid/tests/unittests/test_linear.py @@ -73,15 +73,6 @@ class LinearTestCase(unittest.TestCase): np.testing.assert_array_almost_equal(res_f, res_nn) np.testing.assert_array_almost_equal(res_nn, res_np) - def test_error_dummy_input(self, place=paddle.CPUPlace()): - with self.assertRaises(RuntimeError): - x_arr = np.array([], dtype=np.float32) - x = paddle.to_tensor( - np.reshape(x_arr, (0, 4, 4, 4)), dtype='float32') - weight = paddle.zeros([4, 4, 4], dtype='float32') - bias = paddle.to_tensor([], dtype='float32') - paddle.nn.functional.linear(x, weight, bias=bias) - if __name__ == "__main__": unittest.main() -- GitLab