未验证 提交 7bf7e6e0 编写于 作者: Z zhangyikun02 提交者: GitHub

optimize for argsort with xpu, test=kunlun (#48440)

上级 7d6263e6
...@@ -20,6 +20,149 @@ ...@@ -20,6 +20,149 @@
namespace phi { namespace phi {
template <typename T, typename TID>
static inline void xpu_argsort(xpu::Context* ctx,
const T* input_data,
T* output_data,
TID* indices_data,
int m,
int n,
bool descending) {
int ret =
xpu::sort(ctx, input_data, output_data, indices_data, m, n, descending);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "sort");
}
template <typename T>
static inline void xpu_transpose(xpu::Context* ctx,
const T* x,
T* y,
const std::vector<int>& xshape,
const std::vector<int>& permute) {
int ret = xpu::transpose(ctx, x, y, xshape, permute);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "transpose");
}
template <typename TX, typename TY>
static inline void xpu_cast(xpu::Context* ctx, const TX* x, TY* y, int len) {
int ret = xpu::cast(ctx, x, y, len);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "cast");
}
template <typename T,
bool VALUE_NEED_CAST = false,
bool INDEX_NEED_CAST = false>
struct XPUArgsort {
void operator()(xpu::Context* ctx,
const T* input_data,
T* output_data,
int64_t* indices_data,
const std::vector<int>& data_shape,
const std::vector<int>& permute,
bool descending) {
xpu::ctx_guard RAII_GUARD(ctx);
int m = data_shape[0] * data_shape[2];
int n = data_shape[1];
int len = data_shape[0] * data_shape[1] * data_shape[2];
std::vector<int> trans_data_shape{
data_shape[0], data_shape[2], data_shape[1]};
T* input_data_trans = RAII_GUARD.alloc_l3_or_gm<T>(len);
T* output_data_trans = RAII_GUARD.alloc_l3_or_gm<T>(len);
int64_t* indices_data_trans = RAII_GUARD.alloc_l3_or_gm<int64_t>(len);
xpu_transpose(ctx, input_data, input_data_trans, data_shape, permute);
xpu_argsort(ctx,
input_data_trans,
output_data_trans,
indices_data_trans,
m,
n,
descending);
xpu_transpose(
ctx, output_data_trans, output_data, trans_data_shape, permute);
xpu_transpose(
ctx, indices_data_trans, indices_data, trans_data_shape, permute);
}
};
template <typename T>
struct XPUArgsort<T, false, true> {
void operator()(xpu::Context* ctx,
const T* input_data,
T* output_data,
int64_t* indices_data,
const std::vector<int>& data_shape,
const std::vector<int>& permute,
bool descending) {
xpu::ctx_guard RAII_GUARD(ctx);
int m = data_shape[0] * data_shape[2];
int n = data_shape[1];
int len = data_shape[0] * data_shape[1] * data_shape[2];
std::vector<int> trans_data_shape{
data_shape[0], data_shape[2], data_shape[1]};
T* input_data_trans = RAII_GUARD.alloc_l3_or_gm<T>(len);
T* output_data_trans = RAII_GUARD.alloc_l3_or_gm<T>(len);
int* indices_data_trans = RAII_GUARD.alloc_l3_or_gm<int>(len);
int64_t* cast_data_int64 = RAII_GUARD.alloc_l3_or_gm<int64_t>(len);
xpu_transpose(ctx, input_data, input_data_trans, data_shape, permute);
xpu_argsort(ctx,
input_data_trans,
output_data_trans,
indices_data_trans,
m,
n,
descending);
xpu_transpose(
ctx, output_data_trans, output_data, trans_data_shape, permute);
xpu_cast(ctx, indices_data_trans, cast_data_int64, len);
xpu_transpose(
ctx, cast_data_int64, indices_data, trans_data_shape, permute);
}
};
template <>
struct XPUArgsort<int64_t, true, true> {
void operator()(xpu::Context* ctx,
const int64_t* input_data,
int64_t* output_data,
int64_t* indices_data,
const std::vector<int>& data_shape,
const std::vector<int>& permute,
bool descending) {
xpu::ctx_guard RAII_GUARD(ctx);
int m = data_shape[0] * data_shape[2];
int n = data_shape[1];
int len = data_shape[0] * data_shape[1] * data_shape[2];
std::vector<int> trans_data_shape{
data_shape[0], data_shape[2], data_shape[1]};
int* input_data_trans = RAII_GUARD.alloc_l3_or_gm<int>(len);
int* output_data_trans = RAII_GUARD.alloc_l3_or_gm<int>(len);
int* indices_data_trans = RAII_GUARD.alloc_l3_or_gm<int>(len);
int* cast_data_int = RAII_GUARD.alloc_l3_or_gm<int>(len);
int64_t* cast_data_int64 = RAII_GUARD.alloc_l3_or_gm<int64_t>(len);
xpu_cast(ctx, input_data, cast_data_int, len);
xpu_transpose(ctx, cast_data_int, input_data_trans, data_shape, permute);
xpu_argsort(ctx,
input_data_trans,
output_data_trans,
indices_data_trans,
m,
n,
descending);
xpu_cast(ctx, output_data_trans, cast_data_int64, len);
xpu_transpose(ctx, cast_data_int64, output_data, trans_data_shape, permute);
xpu_cast(ctx, indices_data_trans, cast_data_int64, len);
xpu_transpose(
ctx, cast_data_int64, indices_data, trans_data_shape, permute);
}
};
template <typename T, typename Context> template <typename T, typename Context>
void ArgsortKernel(const Context& dev_ctx, void ArgsortKernel(const Context& dev_ctx,
const DenseTensor& input, const DenseTensor& input,
...@@ -35,63 +178,51 @@ void ArgsortKernel(const Context& dev_ctx, ...@@ -35,63 +178,51 @@ void ArgsortKernel(const Context& dev_ctx,
auto output_data = dev_ctx.template Alloc<T>(output); auto output_data = dev_ctx.template Alloc<T>(output);
auto indices_data = dev_ctx.template Alloc<int64_t>(indices); auto indices_data = dev_ctx.template Alloc<int64_t>(indices);
bool is_need_transpose = true;
if (axis == -1 || axis + 1 == in_dims.size()) {
is_need_transpose = false;
}
int len_before = phi::product(phi::slice_ddim(in_dims, 0, axis)); int len_before = phi::product(phi::slice_ddim(in_dims, 0, axis));
int len_after = int len_after =
phi::product(phi::slice_ddim(in_dims, axis + 1, in_dims.size())); phi::product(phi::slice_ddim(in_dims, axis + 1, in_dims.size()));
int m = len_before * len_after;
int len = m * n;
std::vector<int> permute_vec{0, 2, 1}; std::vector<int> permute_vec{0, 2, 1};
std::vector<int> data_shape{len_before, n, len_after}; std::vector<int> data_shape{len_before, n, len_after};
std::vector<int> data_shape_trans{len_before, len_after, n};
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); bool int64_need_cast = false;
if (is_need_transpose) { bool index_need_cast = false;
T* input_data_trans = RAII_GUARD.alloc_l3_or_gm<T>(len); if (std::is_same<T, int64_t>::value) {
PADDLE_ENFORCE_XDNN_NOT_NULL(input_data_trans); if ((n > 10240) && (n <= 16384)) {
T* output_data_trans = RAII_GUARD.alloc_l3_or_gm<T>(len); int64_need_cast = true;
PADDLE_ENFORCE_XDNN_NOT_NULL(output_data_trans); }
int64_t* indices_data_trans = RAII_GUARD.alloc_l3_or_gm<int64_t>(len); if ((n > 8192) && (n <= 10240)) {
PADDLE_ENFORCE_XDNN_NOT_NULL(indices_data_trans); index_need_cast = true;
}
int r = xpu::transpose<T>(dev_ctx.x_context(), } else {
input_data, if ((n > 10240) && (n <= 16384)) {
input_data_trans, index_need_cast = true;
data_shape, }
permute_vec);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
input_data = input_data_trans;
output_data = output_data_trans;
indices_data = indices_data_trans;
} }
int ret = xpu::sort<T, int64_t>(dev_ctx.x_context(), if (int64_need_cast) {
XPUArgsort<T, true, true>()(dev_ctx.x_context(),
input_data,
output_data,
indices_data,
data_shape,
permute_vec,
descending);
} else if (index_need_cast) {
XPUArgsort<T, false, true>()(dev_ctx.x_context(),
input_data,
output_data,
indices_data,
data_shape,
permute_vec,
descending);
} else {
XPUArgsort<T, false, false>()(dev_ctx.x_context(),
input_data, input_data,
output_data, output_data,
indices_data, indices_data,
m, data_shape,
n, permute_vec,
descending); descending);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "sort");
if (is_need_transpose) {
int r = xpu::transpose<T>(dev_ctx.x_context(),
output_data,
output->data<T>(),
data_shape_trans,
permute_vec);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
r = xpu::transpose<int64_t>(dev_ctx.x_context(),
indices_data,
indices->data<int64_t>(),
data_shape_trans,
permute_vec);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
} }
} }
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -100,9 +100,92 @@ class XPUTestArgsortOp(XPUOpTestWrapper): ...@@ -100,9 +100,92 @@ class XPUTestArgsortOp(XPUOpTestWrapper):
self.check_grad_with_place(self.place, {'X'}, 'Out') self.check_grad_with_place(self.place, {'X'}, 'Out')
class XPUTestArgsortOp_LargeN(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'argsort'
self.use_dynamic_create_class = False
class TestArgsortOpCase1(XPUOpTest):
def setUp(self):
self.set_xpu()
self.op_type = "argsort"
self.place = paddle.XPUPlace(0)
self.dtype = self.in_type
self.axis = -1 if not hasattr(self, 'init_axis') else self.init_axis
self.init_test_case()
self.descending = (
False
if not hasattr(self, 'init_descending')
else self.init_descending
)
np.random.seed(100)
if self.dtype == np.float32:
self.x = np.random.random(self.input_shape).astype(self.dtype)
else:
self.x = np.random.choice(
1000000, self.input_shape, replace=False
).astype(self.dtype)
self.inputs = {"X": self.x}
self.attrs = {"axis": self.axis, "descending": self.descending}
self.get_output()
self.outputs = {"Out": self.sorted_x, "Indices": self.indices}
def get_output(self):
if self.descending:
self.indices = np.flip(
np.argsort(self.x, kind='heapsort', axis=self.axis),
self.axis,
)
self.sorted_x = np.flip(
np.sort(self.x, kind='heapsort', axis=self.axis), self.axis
)
else:
self.indices = np.argsort(
self.x, kind='heapsort', axis=self.axis
)
self.sorted_x = np.sort(self.x, kind='heapsort', axis=self.axis)
def set_xpu(self):
self.__class__.use_xpu = True
def init_test_case(self):
self.input_shape = [2, 8732] # test for 8192 < n <= 10240
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, {'X'}, 'Out')
class TestArgsortOpCase2(TestArgsortOpCase1):
def init_test_case(self):
self.input_shape = [2, 10241] # test for 10240 < n <= 16384
class TestArgsortOpCase3(TestArgsortOpCase1):
def init_test_case(self):
self.input_shape = [
2,
8732,
1,
] # test for 8192 < n <= 10240 + nees_transpose
self.axis = 1
class TestArgsortOpCase4(TestArgsortOpCase1):
def init_test_case(self):
self.input_shape = [
2,
10241,
1,
] # test for 10240 < n <= 16384 + nees_transpose
self.axis = 1
support_types = get_xpu_op_support_types('argsort') support_types = get_xpu_op_support_types('argsort')
for stype in support_types: for stype in support_types:
create_test_class(globals(), XPUTestArgsortOp, stype) create_test_class(globals(), XPUTestArgsortOp, stype)
create_test_class(globals(), XPUTestArgsortOp_LargeN, stype)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -457,6 +457,170 @@ class XPUTestPad3dOp(XPUOpTestWrapper): ...@@ -457,6 +457,170 @@ class XPUTestPad3dOp(XPUOpTestWrapper):
np.testing.assert_allclose(y2.numpy(), np_out2, rtol=1e-05) np.testing.assert_allclose(y2.numpy(), np_out2, rtol=1e-05)
np.testing.assert_allclose(y3.numpy(), np_out3, rtol=1e-05) np.testing.assert_allclose(y3.numpy(), np_out3, rtol=1e-05)
class TestPad1dAPI(unittest.TestCase):
def _get_numpy_out(
self, input_data, pad, mode, value=0.0, data_format="NCL"
):
if data_format == "NCL":
pad = [
(0, 0),
(0, 0),
(pad[0], pad[1]),
]
else:
pad = [
(0, 0),
(pad[0], pad[1]),
(0, 0),
]
if mode == "constant":
out = np.pad(input_data, pad, mode=mode, constant_values=value)
elif mode == "reflect":
out = np.pad(input_data, pad, mode=mode)
elif mode == "replicate":
out = np.pad(input_data, pad, mode="edge")
elif mode == "circular":
out = np.pad(input_data, pad, mode="wrap")
return out
def setUp(self):
self.places = [paddle.XPUPlace(0)]
self.dtype = self.in_type
def test_class(self):
paddle.disable_static()
for place in self.places:
input_shape = (3, 4, 5)
pad = [1, 2]
pad_int = 1
value = 100
input_data = np.random.rand(*input_shape).astype(self.dtype)
pad_reflection = nn.Pad1D(padding=pad, mode="reflect")
pad_replication = nn.Pad1D(padding=pad, mode="replicate")
pad_constant = nn.Pad1D(
padding=pad, mode="constant", value=value
)
pad_constant_int = nn.Pad1D(
padding=pad_int, mode="constant", value=value
)
pad_circular = nn.Pad1D(padding=pad, mode="circular")
data = paddle.to_tensor(input_data)
output = pad_reflection(data)
np_out = self._get_numpy_out(
input_data, pad, "reflect", data_format="NCL"
)
np.testing.assert_allclose(output.numpy(), np_out, rtol=1e-05)
output = pad_replication(data)
np_out = self._get_numpy_out(
input_data, pad, "replicate", data_format="NCL"
)
np.testing.assert_allclose(output.numpy(), np_out, rtol=1e-05)
output = pad_constant(data)
np_out = self._get_numpy_out(
input_data, pad, "constant", value=value, data_format="NCL"
)
np.testing.assert_allclose(output.numpy(), np_out, rtol=1e-05)
output = pad_constant_int(data)
np_out = self._get_numpy_out(
input_data,
[pad_int] * 2,
"constant",
value=value,
data_format="NCL",
)
np.testing.assert_allclose(output.numpy(), np_out, rtol=1e-05)
class TestPad2dAPI(unittest.TestCase):
def _get_numpy_out(
self, input_data, pad, mode, value=0.0, data_format="NCHW"
):
if data_format == "NCHW":
pad = [
(0, 0),
(0, 0),
(pad[2], pad[3]),
(pad[0], pad[1]),
]
else:
pad = [
(0, 0),
(pad[2], pad[3]),
(pad[0], pad[1]),
(0, 0),
]
if mode == "constant":
out = np.pad(input_data, pad, mode=mode, constant_values=value)
elif mode == "reflect":
out = np.pad(input_data, pad, mode=mode)
elif mode == "replicate":
out = np.pad(input_data, pad, mode="edge")
elif mode == "circular":
out = np.pad(input_data, pad, mode="wrap")
return out
def setUp(self):
self.places = [paddle.XPUPlace(0)]
self.dtype = self.in_type
def test_class(self):
paddle.disable_static()
for place in self.places:
input_shape = (3, 4, 5, 6)
pad = [1, 2, 2, 1]
pad_int = 1
value = 100
input_data = np.random.rand(*input_shape).astype(self.dtype)
pad_reflection = nn.Pad2D(padding=pad, mode="reflect")
pad_replication = nn.Pad2D(padding=pad, mode="replicate")
pad_constant = nn.Pad2D(
padding=pad, mode="constant", value=value
)
pad_constant_int = nn.Pad2D(
padding=pad_int, mode="constant", value=value
)
pad_circular = nn.Pad2D(padding=pad, mode="circular")
data = paddle.to_tensor(input_data)
output = pad_reflection(data)
np_out = self._get_numpy_out(
input_data, pad, "reflect", data_format="NCHW"
)
np.testing.assert_allclose(output.numpy(), np_out, rtol=1e-05)
output = pad_replication(data)
np_out = self._get_numpy_out(
input_data, pad, "replicate", data_format="NCHW"
)
np.testing.assert_allclose(output.numpy(), np_out, rtol=1e-05)
output = pad_constant(data)
np_out = self._get_numpy_out(
input_data, pad, "constant", value=value, data_format="NCHW"
)
np.testing.assert_allclose(output.numpy(), np_out, rtol=1e-05)
output = pad_constant_int(data)
np_out = self._get_numpy_out(
input_data,
[pad_int] * 4,
"constant",
value=value,
data_format="NCHW",
)
np.testing.assert_allclose(output.numpy(), np_out, rtol=1e-05)
class TestPad3dAPI(unittest.TestCase): class TestPad3dAPI(unittest.TestCase):
def _get_numpy_out( def _get_numpy_out(
self, input_data, pad, mode, value=0.0, data_format="NCDHW" self, input_data, pad, mode, value=0.0, data_format="NCDHW"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册