未验证 提交 8e19d1ba 编写于 作者: W wawltor 提交者: GitHub

add the shape check for the matmul (#35791)

* add the shape check for the matmul

* remove the test case for the linear
上级 4e7bd9c3
......@@ -380,6 +380,14 @@ class MatMulV2Kernel : public framework::OpKernel<T> {
auto* Out = ctx.Output<Tensor>("Out");
bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y");
PADDLE_ENFORCE_NE(framework::product(X->dims()), 0,
platform::errors::InvalidArgument(
"The Input(X) dims size must not be equal 0,"
" but reviced dims size is 0. "));
PADDLE_ENFORCE_NE(framework::product(Y->dims()), 0,
platform::errors::InvalidArgument(
"The Input(Y) dims size must not be equal 0,"
" but reviced dims size is 0. "));
MatMulFunction<DeviceContext, T>(X, Y, Out, trans_x, trans_y, ctx);
}
};
......
......@@ -73,15 +73,6 @@ class LinearTestCase(unittest.TestCase):
np.testing.assert_array_almost_equal(res_f, res_nn)
np.testing.assert_array_almost_equal(res_nn, res_np)
def test_error_dummy_input(self, place=paddle.CPUPlace()):
with self.assertRaises(RuntimeError):
x_arr = np.array([], dtype=np.float32)
x = paddle.to_tensor(
np.reshape(x_arr, (0, 4, 4, 4)), dtype='float32')
weight = paddle.zeros([4, 4, 4], dtype='float32')
bias = paddle.to_tensor([], dtype='float32')
paddle.nn.functional.linear(x, weight, bias=bias)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册