diff --git a/paddle/cinn/hlir/pe/transform.cc b/paddle/cinn/hlir/pe/transform.cc index 38e3ba3a39541e915270ea761d23a68126bbe79e..5c02b4a8493135b59390ffe3dddca4378db852c4 100644 --- a/paddle/cinn/hlir/pe/transform.cc +++ b/paddle/cinn/hlir/pe/transform.cc @@ -90,7 +90,8 @@ std::vector> GetMatmulNewShapes( : std::vector{1, x_shape[0]}; new_y_shape = trans_y ? std::vector{1, y_shape[0]} : std::vector{y_shape[0], 1}; - out_shape = {1}; + // [m] * [m] -> [], which aligns with Paddle's matmul + out_shape = {}; } else if (x_dim == 1) { // vector * matrix int y_K = trans_y ? y_shape[max_dim - 1] : y_shape[max_dim - 2]; diff --git a/test/cinn/ops/test_zero_dim_tensor.py b/test/cinn/ops/test_zero_dim_tensor.py index 4c36905b6b245ac970c32d7dcf831407c95d01d1..3ba7ac3bc7591b06414f653b05924161e3b15084 100644 --- a/test/cinn/ops/test_zero_dim_tensor.py +++ b/test/cinn/ops/test_zero_dim_tensor.py @@ -824,5 +824,50 @@ class TestSqueezeOp2D(TestSqueezeOp): self.target_shape = () +@OpTestTool.skip_if( + not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." +) +class TestMatmulOp(OpTest): + def setUp(self): + np.random.seed(2023) + self.dtype = "float32" + self.init_input() + + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, [10]).astype(self.dtype), + "y": np.random.randint(-10, 10, [10]).astype(self.dtype), + } + self.target_shape = () + + def build_paddle_program(self, target): + x = paddle.to_tensor(self.inputs["x"], stop_gradient=False) + y = paddle.to_tensor(self.inputs["y"], stop_gradient=False) + out = paddle.matmul(x, y) + + self.paddle_outputs = [out] + + def build_cinn_program(self, target): + builder = NetBuilder("matmul_op") + x = builder.create_input( + cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x" + ) + y = builder.create_input( + cinn_dtype_convert(self.dtype), self.inputs["y"].shape, "y" + ) + out = builder.matmul(x, y) + + prog = builder.build() + res = self.get_cinn_output( + prog, target, [x, y], [self.inputs["x"], self.inputs["y"]], [out] + ) + + self.cinn_outputs = res + self.assertEqual(res[0].shape, self.target_shape) + + def test_check_results(self): + self.check_outputs_and_grads() + + if __name__ == "__main__": unittest.main()