未验证 提交 ff2142f2 编写于 作者: Z zxcd 提交者: GitHub

add int32/int64 for outer/matmul Kernel. (#55584)

* add int32/int64 for outer/matmul Kernel.

* fix by comment.

* fix by comment
上级 bd73a57d
...@@ -25,6 +25,8 @@ PD_REGISTER_KERNEL(matmul, ...@@ -25,6 +25,8 @@ PD_REGISTER_KERNEL(matmul,
phi::MatmulKernel, phi::MatmulKernel,
float, float,
double, double,
int32_t,
int64_t,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
......
...@@ -25,6 +25,8 @@ PD_REGISTER_KERNEL(matmul, ...@@ -25,6 +25,8 @@ PD_REGISTER_KERNEL(matmul,
phi::MatmulKernel, phi::MatmulKernel,
float, float,
double, double,
int32_t,
int64_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
......
...@@ -97,13 +97,14 @@ static DenseTensor FoldHeadAndLastDims(const Context& dev_ctx, ...@@ -97,13 +97,14 @@ static DenseTensor FoldHeadAndLastDims(const Context& dev_ctx,
} }
template <typename Context, typename T> template <typename Context, typename T>
void MatMul(const Context& dev_ctx, typename std::enable_if<!std::is_integral<T>::value>::type MatMul(
const DenseTensor& a, const Context& dev_ctx,
bool trans_a, const DenseTensor& a,
const DenseTensor& b, bool trans_a,
bool trans_b, const DenseTensor& b,
DenseTensor* out, bool trans_b,
bool flag = false) { DenseTensor* out,
bool flag = false) {
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx); auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a);
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/autotune/cache_base.h" #include "paddle/phi/kernels/autotune/cache_base.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" #include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/complex_functors.h"
...@@ -1078,6 +1079,38 @@ void MatMulInt8Function(const Context& ctx, ...@@ -1078,6 +1079,38 @@ void MatMulInt8Function(const Context& ctx,
#endif #endif
} }
template <typename Context, typename T>
typename std::enable_if<std::is_integral<T>::value>::type
MatmulJudgeDtypeKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const std::vector<std::int64_t>& x_dims,
const std::vector<std::int64_t>& y_dims,
DenseTensor* out,
bool transpose_x,
bool transpose_y) {
auto x_tmp = phi::Cast<T, Context>(ctx, x, phi::DataType::FLOAT32);
auto y_tmp = phi::Cast<T, Context>(ctx, y, phi::DataType::FLOAT32);
DenseTensor out_tmp;
MatMulFunction<Context, float>(
ctx, x_tmp, y_tmp, x_dims, y_dims, &out_tmp, transpose_x, transpose_y);
phi::CastKernel<float>(ctx, out_tmp, x.dtype(), out);
}
template <typename Context, typename T>
typename std::enable_if<!std::is_integral<T>::value>::type
MatmulJudgeDtypeKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const std::vector<std::int64_t>& x_dims,
const std::vector<std::int64_t>& y_dims,
DenseTensor* out,
bool transpose_x,
bool transpose_y) {
MatMulFunction<Context, T>(
ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y);
}
template <typename T, typename Context> template <typename T, typename Context>
void MatmulKernel(const Context& ctx, void MatmulKernel(const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -1097,7 +1130,7 @@ void MatmulKernel(const Context& ctx, ...@@ -1097,7 +1130,7 @@ void MatmulKernel(const Context& ctx,
" but reviced dims size is 0. ")); " but reviced dims size is 0. "));
const std::vector<std::int64_t> x_dims = vectorize(x.dims()); const std::vector<std::int64_t> x_dims = vectorize(x.dims());
const std::vector<std::int64_t> y_dims = vectorize(y.dims()); const std::vector<std::int64_t> y_dims = vectorize(y.dims());
MatMulFunction<Context, T>( MatmulJudgeDtypeKernel<Context, T>(
ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y);
} }
......
...@@ -2315,7 +2315,10 @@ def outer(x, y, name=None): ...@@ -2315,7 +2315,10 @@ def outer(x, y, name=None):
var_names = {'x': x, 'y': y} var_names = {'x': x, 'y': y}
for name, val in var_names.items(): for name, val in var_names.items():
check_variable_and_dtype( check_variable_and_dtype(
val, name, ['float16', 'float32', 'float64'], 'inner' val,
name,
['float16', 'float32', 'float64', 'int32', 'int64'],
'outer',
) )
__check_input(nx, ny) __check_input(nx, ny)
......
...@@ -712,6 +712,110 @@ class TestMatMulTypePromotion(TestComplexMatMulOp): ...@@ -712,6 +712,110 @@ class TestMatMulTypePromotion(TestComplexMatMulOp):
self.out = np.dot(self.x, self.y) self.out = np.dot(self.x, self.y)
class TestInt32MatmulOp(OpTest):
def setUp(self):
self.op_type = "matmul_v2"
self.python_api = paddle.tensor.matmul
self.init_base_dtype()
self.init_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.int32
def init_input_output(self):
self.x = np.random.random((10, 10)).astype(self.dtype)
self.y = np.random.random((10, 10)).astype(self.dtype)
self.out = np.matmul(self.x, self.y)
def test_check_output(self):
self.check_output(check_cinn=False)
class TestInt32MatMulOpBroadcast(OpTest):
def setUp(self):
self.op_type = "matmul_v2"
self.python_api = paddle.tensor.matmul
self.init_base_dtype()
self.init_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.int32
def init_input_output(self):
self.x = np.random.random((10, 2, 5)).astype(self.dtype)
self.y = np.random.random((5, 20)).astype(self.dtype)
self.out = np.matmul(self.x, self.y)
def test_check_output(self):
self.check_output(check_cinn=False)
class TestInt64MatmulOp(OpTest):
def setUp(self):
self.op_type = "matmul_v2"
self.python_api = paddle.tensor.matmul
self.init_base_dtype()
self.init_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.int64
def init_input_output(self):
self.x = np.random.random((10, 10)).astype(self.dtype)
self.y = np.random.random((10, 10)).astype(self.dtype)
self.out = np.matmul(self.x, self.y)
def test_check_output(self):
self.check_output(check_cinn=False)
class TestInt64MatMulOpBroadcast(OpTest):
def setUp(self):
self.op_type = "matmul_v2"
self.python_api = paddle.tensor.matmul
self.init_base_dtype()
self.init_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.int64
def init_input_output(self):
self.x = np.random.random((10, 2, 5)).astype(self.dtype)
self.y = np.random.random((5, 20)).astype(self.dtype)
self.out = np.matmul(self.x, self.y)
def test_check_output(self):
self.check_output(check_cinn=False)
class TestMatmulop(unittest.TestCase): class TestMatmulop(unittest.TestCase):
def func_dygraph_matmul(self): def func_dygraph_matmul(self):
paddle.disable_static() paddle.disable_static()
......
...@@ -74,6 +74,18 @@ class TestMultiplyApi(unittest.TestCase): ...@@ -74,6 +74,18 @@ class TestMultiplyApi(unittest.TestCase):
res = self._run_static_graph_case(x_data, y_data) res = self._run_static_graph_case(x_data, y_data)
np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05) np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05)
# test static computation graph: 1-d int32 array
x_data = np.random.rand(50).astype(np.int32)
y_data = np.random.rand(50).astype(np.int32)
res = self._run_static_graph_case(x_data, y_data)
np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05)
# test static computation graph: 1-d int64 array
x_data = np.random.rand(50).astype(np.int64)
y_data = np.random.rand(50).astype(np.int64)
res = self._run_static_graph_case(x_data, y_data)
np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05)
# test dynamic computation graph: 3-d array # test dynamic computation graph: 3-d array
x_data = np.random.rand(5, 10, 10).astype(np.float64) x_data = np.random.rand(5, 10, 10).astype(np.float64)
y_data = np.random.rand(2, 10).astype(np.float64) y_data = np.random.rand(2, 10).astype(np.float64)
...@@ -112,6 +124,18 @@ class TestMultiplyApi(unittest.TestCase): ...@@ -112,6 +124,18 @@ class TestMultiplyApi(unittest.TestCase):
res = self._run_dynamic_graph_case(x_data, y_data) res = self._run_dynamic_graph_case(x_data, y_data)
np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05) np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05)
# test dynamic computation graph: 3-d int32 array
x_data = np.random.rand(5, 10, 10).astype(np.int32)
y_data = np.random.rand(2, 10).astype(np.int32)
res = self._run_dynamic_graph_case(x_data, y_data)
np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05)
# test dynamic computation graph: 3-d int64 array
x_data = np.random.rand(5, 10, 10).astype(np.int64)
y_data = np.random.rand(2, 10).astype(np.int64)
res = self._run_dynamic_graph_case(x_data, y_data)
np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05)
class TestMultiplyError(unittest.TestCase): class TestMultiplyError(unittest.TestCase):
def test_errors(self): def test_errors(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册