未验证 提交 ff2142f2 编写于 作者: Z zxcd 提交者: GitHub

add int32/int64 for outer/matmul Kernel. (#55584)

* add int32/int64 for outer/matmul Kernel.

* fix by comment.

* fix by comment
上级 bd73a57d
......@@ -25,6 +25,8 @@ PD_REGISTER_KERNEL(matmul,
phi::MatmulKernel,
float,
double,
int32_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......
......@@ -25,6 +25,8 @@ PD_REGISTER_KERNEL(matmul,
phi::MatmulKernel,
float,
double,
int32_t,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
......
......@@ -97,7 +97,8 @@ static DenseTensor FoldHeadAndLastDims(const Context& dev_ctx,
}
template <typename Context, typename T>
void MatMul(const Context& dev_ctx,
typename std::enable_if<!std::is_integral<T>::value>::type MatMul(
const Context& dev_ctx,
const DenseTensor& a,
bool trans_a,
const DenseTensor& b,
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/autotune/cache_base.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
......@@ -1078,6 +1079,38 @@ void MatMulInt8Function(const Context& ctx,
#endif
}
template <typename Context, typename T>
typename std::enable_if<std::is_integral<T>::value>::type
MatmulJudgeDtypeKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const std::vector<std::int64_t>& x_dims,
const std::vector<std::int64_t>& y_dims,
DenseTensor* out,
bool transpose_x,
bool transpose_y) {
auto x_tmp = phi::Cast<T, Context>(ctx, x, phi::DataType::FLOAT32);
auto y_tmp = phi::Cast<T, Context>(ctx, y, phi::DataType::FLOAT32);
DenseTensor out_tmp;
MatMulFunction<Context, float>(
ctx, x_tmp, y_tmp, x_dims, y_dims, &out_tmp, transpose_x, transpose_y);
phi::CastKernel<float>(ctx, out_tmp, x.dtype(), out);
}
template <typename Context, typename T>
typename std::enable_if<!std::is_integral<T>::value>::type
MatmulJudgeDtypeKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const std::vector<std::int64_t>& x_dims,
const std::vector<std::int64_t>& y_dims,
DenseTensor* out,
bool transpose_x,
bool transpose_y) {
MatMulFunction<Context, T>(
ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y);
}
template <typename T, typename Context>
void MatmulKernel(const Context& ctx,
const DenseTensor& x,
......@@ -1097,7 +1130,7 @@ void MatmulKernel(const Context& ctx,
" but reviced dims size is 0. "));
const std::vector<std::int64_t> x_dims = vectorize(x.dims());
const std::vector<std::int64_t> y_dims = vectorize(y.dims());
MatMulFunction<Context, T>(
MatmulJudgeDtypeKernel<Context, T>(
ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y);
}
......
......@@ -2315,7 +2315,10 @@ def outer(x, y, name=None):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(
val, name, ['float16', 'float32', 'float64'], 'inner'
val,
name,
['float16', 'float32', 'float64', 'int32', 'int64'],
'outer',
)
__check_input(nx, ny)
......
......@@ -712,6 +712,110 @@ class TestMatMulTypePromotion(TestComplexMatMulOp):
self.out = np.dot(self.x, self.y)
class TestInt32MatmulOp(OpTest):
def setUp(self):
self.op_type = "matmul_v2"
self.python_api = paddle.tensor.matmul
self.init_base_dtype()
self.init_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.int32
def init_input_output(self):
self.x = np.random.random((10, 10)).astype(self.dtype)
self.y = np.random.random((10, 10)).astype(self.dtype)
self.out = np.matmul(self.x, self.y)
def test_check_output(self):
self.check_output(check_cinn=False)
class TestInt32MatMulOpBroadcast(OpTest):
def setUp(self):
self.op_type = "matmul_v2"
self.python_api = paddle.tensor.matmul
self.init_base_dtype()
self.init_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.int32
def init_input_output(self):
self.x = np.random.random((10, 2, 5)).astype(self.dtype)
self.y = np.random.random((5, 20)).astype(self.dtype)
self.out = np.matmul(self.x, self.y)
def test_check_output(self):
self.check_output(check_cinn=False)
class TestInt64MatmulOp(OpTest):
def setUp(self):
self.op_type = "matmul_v2"
self.python_api = paddle.tensor.matmul
self.init_base_dtype()
self.init_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.int64
def init_input_output(self):
self.x = np.random.random((10, 10)).astype(self.dtype)
self.y = np.random.random((10, 10)).astype(self.dtype)
self.out = np.matmul(self.x, self.y)
def test_check_output(self):
self.check_output(check_cinn=False)
class TestInt64MatMulOpBroadcast(OpTest):
def setUp(self):
self.op_type = "matmul_v2"
self.python_api = paddle.tensor.matmul
self.init_base_dtype()
self.init_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.int64
def init_input_output(self):
self.x = np.random.random((10, 2, 5)).astype(self.dtype)
self.y = np.random.random((5, 20)).astype(self.dtype)
self.out = np.matmul(self.x, self.y)
def test_check_output(self):
self.check_output(check_cinn=False)
class TestMatmulop(unittest.TestCase):
def func_dygraph_matmul(self):
paddle.disable_static()
......
......@@ -74,6 +74,18 @@ class TestMultiplyApi(unittest.TestCase):
res = self._run_static_graph_case(x_data, y_data)
np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05)
# test static computation graph: 1-d int32 array
x_data = np.random.rand(50).astype(np.int32)
y_data = np.random.rand(50).astype(np.int32)
res = self._run_static_graph_case(x_data, y_data)
np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05)
# test static computation graph: 1-d int64 array
x_data = np.random.rand(50).astype(np.int64)
y_data = np.random.rand(50).astype(np.int64)
res = self._run_static_graph_case(x_data, y_data)
np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05)
# test dynamic computation graph: 3-d array
x_data = np.random.rand(5, 10, 10).astype(np.float64)
y_data = np.random.rand(2, 10).astype(np.float64)
......@@ -112,6 +124,18 @@ class TestMultiplyApi(unittest.TestCase):
res = self._run_dynamic_graph_case(x_data, y_data)
np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05)
# test dynamic computation graph: 3-d int32 array
x_data = np.random.rand(5, 10, 10).astype(np.int32)
y_data = np.random.rand(2, 10).astype(np.int32)
res = self._run_dynamic_graph_case(x_data, y_data)
np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05)
# test dynamic computation graph: 3-d int64 array
x_data = np.random.rand(5, 10, 10).astype(np.int64)
y_data = np.random.rand(2, 10).astype(np.int64)
res = self._run_dynamic_graph_case(x_data, y_data)
np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05)
class TestMultiplyError(unittest.TestCase):
def test_errors(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册