From ff2142f25811d2e83721033ce83a8347a799b709 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Thu, 27 Jul 2023 18:13:47 +0800 Subject: [PATCH] add int32/int64 for outer/matmul Kernel. (#55584) * add int32/int64 for outer/matmul Kernel. * fix by comment. * fix by comment --- paddle/phi/kernels/cpu/matmul_kernel.cc | 2 + paddle/phi/kernels/gpu/matmul_kernel.cu | 2 + .../kernels/impl/matmul_grad_kernel_impl.h | 15 +-- paddle/phi/kernels/impl/matmul_kernel_impl.h | 35 +++++- python/paddle/tensor/math.py | 5 +- test/legacy_test/test_matmul_v2_op.py | 104 ++++++++++++++++++ test/legacy_test/test_outer.py | 24 ++++ 7 files changed, 178 insertions(+), 9 deletions(-) diff --git a/paddle/phi/kernels/cpu/matmul_kernel.cc b/paddle/phi/kernels/cpu/matmul_kernel.cc index c75a50130db..af5d11839e0 100644 --- a/paddle/phi/kernels/cpu/matmul_kernel.cc +++ b/paddle/phi/kernels/cpu/matmul_kernel.cc @@ -25,6 +25,8 @@ PD_REGISTER_KERNEL(matmul, phi::MatmulKernel, float, double, + int32_t, + int64_t, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/matmul_kernel.cu b/paddle/phi/kernels/gpu/matmul_kernel.cu index c5271a4eeec..71095bf783b 100644 --- a/paddle/phi/kernels/gpu/matmul_kernel.cu +++ b/paddle/phi/kernels/gpu/matmul_kernel.cu @@ -25,6 +25,8 @@ PD_REGISTER_KERNEL(matmul, phi::MatmulKernel, float, double, + int32_t, + int64_t, phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, diff --git a/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h b/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h index 885827a36be..899ee5f3a49 100644 --- a/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h @@ -97,13 +97,14 @@ static DenseTensor FoldHeadAndLastDims(const Context& dev_ctx, } template -void MatMul(const Context& dev_ctx, - const DenseTensor& a, - bool trans_a, - const DenseTensor& b, - bool trans_b, - DenseTensor* out, - bool flag = false) { +typename std::enable_if::value>::type MatMul( + const Context& dev_ctx, + const DenseTensor& a, + bool trans_a, + const DenseTensor& b, + bool trans_b, + DenseTensor* out, + bool flag = false) { dev_ctx.template Alloc(out); auto blas = phi::funcs::GetBlas(dev_ctx); auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); diff --git a/paddle/phi/kernels/impl/matmul_kernel_impl.h b/paddle/phi/kernels/impl/matmul_kernel_impl.h index a77fbd96131..e680e164e62 100644 --- a/paddle/phi/kernels/impl/matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_kernel_impl.h @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/autotune/cache_base.h" +#include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" #include "paddle/phi/kernels/funcs/complex_functors.h" @@ -1078,6 +1079,38 @@ void MatMulInt8Function(const Context& ctx, #endif } +template +typename std::enable_if::value>::type +MatmulJudgeDtypeKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const std::vector& x_dims, + const std::vector& y_dims, + DenseTensor* out, + bool transpose_x, + bool transpose_y) { + auto x_tmp = phi::Cast(ctx, x, phi::DataType::FLOAT32); + auto y_tmp = phi::Cast(ctx, y, phi::DataType::FLOAT32); + DenseTensor out_tmp; + MatMulFunction( + ctx, x_tmp, y_tmp, x_dims, y_dims, &out_tmp, transpose_x, transpose_y); + phi::CastKernel(ctx, out_tmp, x.dtype(), out); +} + +template +typename std::enable_if::value>::type +MatmulJudgeDtypeKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const std::vector& x_dims, + const std::vector& y_dims, + DenseTensor* out, + bool transpose_x, + bool transpose_y) { + MatMulFunction( + ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); +} + template void MatmulKernel(const Context& ctx, const DenseTensor& x, @@ -1097,7 +1130,7 @@ void MatmulKernel(const Context& ctx, " but reviced dims size is 0. ")); const std::vector x_dims = vectorize(x.dims()); const std::vector y_dims = vectorize(y.dims()); - MatMulFunction( + MatmulJudgeDtypeKernel( ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); } diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 157349e8ff0..2dcfcb5b08f 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2315,7 +2315,10 @@ def outer(x, y, name=None): var_names = {'x': x, 'y': y} for name, val in var_names.items(): check_variable_and_dtype( - val, name, ['float16', 'float32', 'float64'], 'inner' + val, + name, + ['float16', 'float32', 'float64', 'int32', 'int64'], + 'outer', ) __check_input(nx, ny) diff --git a/test/legacy_test/test_matmul_v2_op.py b/test/legacy_test/test_matmul_v2_op.py index 7869042f507..fa939fd29d4 100644 --- a/test/legacy_test/test_matmul_v2_op.py +++ b/test/legacy_test/test_matmul_v2_op.py @@ -712,6 +712,110 @@ class TestMatMulTypePromotion(TestComplexMatMulOp): self.out = np.dot(self.x, self.y) +class TestInt32MatmulOp(OpTest): + def setUp(self): + self.op_type = "matmul_v2" + self.python_api = paddle.tensor.matmul + self.init_base_dtype() + self.init_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y), + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.int32 + + def init_input_output(self): + self.x = np.random.random((10, 10)).astype(self.dtype) + self.y = np.random.random((10, 10)).astype(self.dtype) + self.out = np.matmul(self.x, self.y) + + def test_check_output(self): + self.check_output(check_cinn=False) + + +class TestInt32MatMulOpBroadcast(OpTest): + def setUp(self): + self.op_type = "matmul_v2" + self.python_api = paddle.tensor.matmul + self.init_base_dtype() + self.init_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y), + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.int32 + + def init_input_output(self): + self.x = np.random.random((10, 2, 5)).astype(self.dtype) + self.y = np.random.random((5, 20)).astype(self.dtype) + self.out = np.matmul(self.x, self.y) + + def test_check_output(self): + self.check_output(check_cinn=False) + + +class TestInt64MatmulOp(OpTest): + def setUp(self): + self.op_type = "matmul_v2" + self.python_api = paddle.tensor.matmul + self.init_base_dtype() + self.init_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y), + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.int64 + + def init_input_output(self): + self.x = np.random.random((10, 10)).astype(self.dtype) + self.y = np.random.random((10, 10)).astype(self.dtype) + self.out = np.matmul(self.x, self.y) + + def test_check_output(self): + self.check_output(check_cinn=False) + + +class TestInt64MatMulOpBroadcast(OpTest): + def setUp(self): + self.op_type = "matmul_v2" + self.python_api = paddle.tensor.matmul + self.init_base_dtype() + self.init_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y), + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.int64 + + def init_input_output(self): + self.x = np.random.random((10, 2, 5)).astype(self.dtype) + self.y = np.random.random((5, 20)).astype(self.dtype) + self.out = np.matmul(self.x, self.y) + + def test_check_output(self): + self.check_output(check_cinn=False) + + class TestMatmulop(unittest.TestCase): def func_dygraph_matmul(self): paddle.disable_static() diff --git a/test/legacy_test/test_outer.py b/test/legacy_test/test_outer.py index 3bbe20b7b5b..5ce564509d4 100644 --- a/test/legacy_test/test_outer.py +++ b/test/legacy_test/test_outer.py @@ -74,6 +74,18 @@ class TestMultiplyApi(unittest.TestCase): res = self._run_static_graph_case(x_data, y_data) np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05) + # test static computation graph: 1-d int32 array + x_data = np.random.rand(50).astype(np.int32) + y_data = np.random.rand(50).astype(np.int32) + res = self._run_static_graph_case(x_data, y_data) + np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05) + + # test static computation graph: 1-d int64 array + x_data = np.random.rand(50).astype(np.int64) + y_data = np.random.rand(50).astype(np.int64) + res = self._run_static_graph_case(x_data, y_data) + np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05) + # test dynamic computation graph: 3-d array x_data = np.random.rand(5, 10, 10).astype(np.float64) y_data = np.random.rand(2, 10).astype(np.float64) @@ -112,6 +124,18 @@ class TestMultiplyApi(unittest.TestCase): res = self._run_dynamic_graph_case(x_data, y_data) np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05) + # test dynamic computation graph: 3-d int32 array + x_data = np.random.rand(5, 10, 10).astype(np.int32) + y_data = np.random.rand(2, 10).astype(np.int32) + res = self._run_dynamic_graph_case(x_data, y_data) + np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05) + + # test dynamic computation graph: 3-d int64 array + x_data = np.random.rand(5, 10, 10).astype(np.int64) + y_data = np.random.rand(2, 10).astype(np.int64) + res = self._run_dynamic_graph_case(x_data, y_data) + np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05) + class TestMultiplyError(unittest.TestCase): def test_errors(self): -- GitLab