From 7bf7e6e0f97b40e739858b10e353a3a9998458d8 Mon Sep 17 00:00:00 2001 From: zhangyikun02 <48021248+zhangyk0314@users.noreply.github.com> Date: Wed, 30 Nov 2022 10:29:37 +0800 Subject: [PATCH] optimize for argsort with xpu, test=kunlun (#48440) --- paddle/phi/kernels/xpu/argsort_kernel.cc | 221 ++++++++++++++---- .../unittests/xpu/test_argsort_op_xpu.py | 85 ++++++- .../tests/unittests/xpu/test_pad3d_op_xpu.py | 164 +++++++++++++ 3 files changed, 424 insertions(+), 46 deletions(-) diff --git a/paddle/phi/kernels/xpu/argsort_kernel.cc b/paddle/phi/kernels/xpu/argsort_kernel.cc index 9a1cdd763b9..0a71ec71463 100644 --- a/paddle/phi/kernels/xpu/argsort_kernel.cc +++ b/paddle/phi/kernels/xpu/argsort_kernel.cc @@ -20,6 +20,149 @@ namespace phi { +template +static inline void xpu_argsort(xpu::Context* ctx, + const T* input_data, + T* output_data, + TID* indices_data, + int m, + int n, + bool descending) { + int ret = + xpu::sort(ctx, input_data, output_data, indices_data, m, n, descending); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "sort"); +} + +template +static inline void xpu_transpose(xpu::Context* ctx, + const T* x, + T* y, + const std::vector& xshape, + const std::vector& permute) { + int ret = xpu::transpose(ctx, x, y, xshape, permute); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "transpose"); +} + +template +static inline void xpu_cast(xpu::Context* ctx, const TX* x, TY* y, int len) { + int ret = xpu::cast(ctx, x, y, len); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "cast"); +} + +template +struct XPUArgsort { + void operator()(xpu::Context* ctx, + const T* input_data, + T* output_data, + int64_t* indices_data, + const std::vector& data_shape, + const std::vector& permute, + bool descending) { + xpu::ctx_guard RAII_GUARD(ctx); + int m = data_shape[0] * data_shape[2]; + int n = data_shape[1]; + int len = data_shape[0] * data_shape[1] * data_shape[2]; + std::vector trans_data_shape{ + data_shape[0], data_shape[2], data_shape[1]}; + + T* input_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + T* output_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + int64_t* indices_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + + xpu_transpose(ctx, input_data, input_data_trans, data_shape, permute); + xpu_argsort(ctx, + input_data_trans, + output_data_trans, + indices_data_trans, + m, + n, + descending); + xpu_transpose( + ctx, output_data_trans, output_data, trans_data_shape, permute); + xpu_transpose( + ctx, indices_data_trans, indices_data, trans_data_shape, permute); + } +}; + +template +struct XPUArgsort { + void operator()(xpu::Context* ctx, + const T* input_data, + T* output_data, + int64_t* indices_data, + const std::vector& data_shape, + const std::vector& permute, + bool descending) { + xpu::ctx_guard RAII_GUARD(ctx); + int m = data_shape[0] * data_shape[2]; + int n = data_shape[1]; + int len = data_shape[0] * data_shape[1] * data_shape[2]; + std::vector trans_data_shape{ + data_shape[0], data_shape[2], data_shape[1]}; + + T* input_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + T* output_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + int* indices_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + int64_t* cast_data_int64 = RAII_GUARD.alloc_l3_or_gm(len); + + xpu_transpose(ctx, input_data, input_data_trans, data_shape, permute); + xpu_argsort(ctx, + input_data_trans, + output_data_trans, + indices_data_trans, + m, + n, + descending); + xpu_transpose( + ctx, output_data_trans, output_data, trans_data_shape, permute); + xpu_cast(ctx, indices_data_trans, cast_data_int64, len); + xpu_transpose( + ctx, cast_data_int64, indices_data, trans_data_shape, permute); + } +}; + +template <> +struct XPUArgsort { + void operator()(xpu::Context* ctx, + const int64_t* input_data, + int64_t* output_data, + int64_t* indices_data, + const std::vector& data_shape, + const std::vector& permute, + bool descending) { + xpu::ctx_guard RAII_GUARD(ctx); + int m = data_shape[0] * data_shape[2]; + int n = data_shape[1]; + int len = data_shape[0] * data_shape[1] * data_shape[2]; + std::vector trans_data_shape{ + data_shape[0], data_shape[2], data_shape[1]}; + + int* input_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + int* output_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + int* indices_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + int* cast_data_int = RAII_GUARD.alloc_l3_or_gm(len); + int64_t* cast_data_int64 = RAII_GUARD.alloc_l3_or_gm(len); + + xpu_cast(ctx, input_data, cast_data_int, len); + xpu_transpose(ctx, cast_data_int, input_data_trans, data_shape, permute); + xpu_argsort(ctx, + input_data_trans, + output_data_trans, + indices_data_trans, + m, + n, + descending); + + xpu_cast(ctx, output_data_trans, cast_data_int64, len); + xpu_transpose(ctx, cast_data_int64, output_data, trans_data_shape, permute); + xpu_cast(ctx, indices_data_trans, cast_data_int64, len); + xpu_transpose( + ctx, cast_data_int64, indices_data, trans_data_shape, permute); + } +}; + template void ArgsortKernel(const Context& dev_ctx, const DenseTensor& input, @@ -35,63 +178,51 @@ void ArgsortKernel(const Context& dev_ctx, auto output_data = dev_ctx.template Alloc(output); auto indices_data = dev_ctx.template Alloc(indices); - bool is_need_transpose = true; - if (axis == -1 || axis + 1 == in_dims.size()) { - is_need_transpose = false; - } int len_before = phi::product(phi::slice_ddim(in_dims, 0, axis)); int len_after = phi::product(phi::slice_ddim(in_dims, axis + 1, in_dims.size())); - int m = len_before * len_after; - int len = m * n; std::vector permute_vec{0, 2, 1}; std::vector data_shape{len_before, n, len_after}; - std::vector data_shape_trans{len_before, len_after, n}; - xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - if (is_need_transpose) { - T* input_data_trans = RAII_GUARD.alloc_l3_or_gm(len); - PADDLE_ENFORCE_XDNN_NOT_NULL(input_data_trans); - T* output_data_trans = RAII_GUARD.alloc_l3_or_gm(len); - PADDLE_ENFORCE_XDNN_NOT_NULL(output_data_trans); - int64_t* indices_data_trans = RAII_GUARD.alloc_l3_or_gm(len); - PADDLE_ENFORCE_XDNN_NOT_NULL(indices_data_trans); - - int r = xpu::transpose(dev_ctx.x_context(), - input_data, - input_data_trans, - data_shape, - permute_vec); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); - - input_data = input_data_trans; - output_data = output_data_trans; - indices_data = indices_data_trans; + bool int64_need_cast = false; + bool index_need_cast = false; + if (std::is_same::value) { + if ((n > 10240) && (n <= 16384)) { + int64_need_cast = true; + } + if ((n > 8192) && (n <= 10240)) { + index_need_cast = true; + } + } else { + if ((n > 10240) && (n <= 16384)) { + index_need_cast = true; + } } - int ret = xpu::sort(dev_ctx.x_context(), + if (int64_need_cast) { + XPUArgsort()(dev_ctx.x_context(), + input_data, + output_data, + indices_data, + data_shape, + permute_vec, + descending); + } else if (index_need_cast) { + XPUArgsort()(dev_ctx.x_context(), + input_data, + output_data, + indices_data, + data_shape, + permute_vec, + descending); + } else { + XPUArgsort()(dev_ctx.x_context(), input_data, output_data, indices_data, - m, - n, + data_shape, + permute_vec, descending); - PADDLE_ENFORCE_XDNN_SUCCESS(ret, "sort"); - - if (is_need_transpose) { - int r = xpu::transpose(dev_ctx.x_context(), - output_data, - output->data(), - data_shape_trans, - permute_vec); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); - - r = xpu::transpose(dev_ctx.x_context(), - indices_data, - indices->data(), - data_shape_trans, - permute_vec); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); } } diff --git a/python/paddle/fluid/tests/unittests/xpu/test_argsort_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_argsort_op_xpu.py index 12227622e65..70b988dcd1b 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_argsort_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_argsort_op_xpu.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -100,9 +100,92 @@ class XPUTestArgsortOp(XPUOpTestWrapper): self.check_grad_with_place(self.place, {'X'}, 'Out') +class XPUTestArgsortOp_LargeN(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'argsort' + self.use_dynamic_create_class = False + + class TestArgsortOpCase1(XPUOpTest): + def setUp(self): + self.set_xpu() + self.op_type = "argsort" + self.place = paddle.XPUPlace(0) + self.dtype = self.in_type + self.axis = -1 if not hasattr(self, 'init_axis') else self.init_axis + self.init_test_case() + self.descending = ( + False + if not hasattr(self, 'init_descending') + else self.init_descending + ) + + np.random.seed(100) + if self.dtype == np.float32: + self.x = np.random.random(self.input_shape).astype(self.dtype) + else: + self.x = np.random.choice( + 1000000, self.input_shape, replace=False + ).astype(self.dtype) + + self.inputs = {"X": self.x} + self.attrs = {"axis": self.axis, "descending": self.descending} + self.get_output() + self.outputs = {"Out": self.sorted_x, "Indices": self.indices} + + def get_output(self): + if self.descending: + self.indices = np.flip( + np.argsort(self.x, kind='heapsort', axis=self.axis), + self.axis, + ) + self.sorted_x = np.flip( + np.sort(self.x, kind='heapsort', axis=self.axis), self.axis + ) + else: + self.indices = np.argsort( + self.x, kind='heapsort', axis=self.axis + ) + self.sorted_x = np.sort(self.x, kind='heapsort', axis=self.axis) + + def set_xpu(self): + self.__class__.use_xpu = True + + def init_test_case(self): + self.input_shape = [2, 8732] # test for 8192 < n <= 10240 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, {'X'}, 'Out') + + class TestArgsortOpCase2(TestArgsortOpCase1): + def init_test_case(self): + self.input_shape = [2, 10241] # test for 10240 < n <= 16384 + + class TestArgsortOpCase3(TestArgsortOpCase1): + def init_test_case(self): + self.input_shape = [ + 2, + 8732, + 1, + ] # test for 8192 < n <= 10240 + nees_transpose + self.axis = 1 + + class TestArgsortOpCase4(TestArgsortOpCase1): + def init_test_case(self): + self.input_shape = [ + 2, + 10241, + 1, + ] # test for 10240 < n <= 16384 + nees_transpose + self.axis = 1 + + support_types = get_xpu_op_support_types('argsort') for stype in support_types: create_test_class(globals(), XPUTestArgsortOp, stype) + create_test_class(globals(), XPUTestArgsortOp_LargeN, stype) if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_pad3d_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_pad3d_op_xpu.py index 2522fa9f6ce..4ecb8878ba9 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_pad3d_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_pad3d_op_xpu.py @@ -457,6 +457,170 @@ class XPUTestPad3dOp(XPUOpTestWrapper): np.testing.assert_allclose(y2.numpy(), np_out2, rtol=1e-05) np.testing.assert_allclose(y3.numpy(), np_out3, rtol=1e-05) + class TestPad1dAPI(unittest.TestCase): + def _get_numpy_out( + self, input_data, pad, mode, value=0.0, data_format="NCL" + ): + if data_format == "NCL": + pad = [ + (0, 0), + (0, 0), + (pad[0], pad[1]), + ] + else: + pad = [ + (0, 0), + (pad[0], pad[1]), + (0, 0), + ] + + if mode == "constant": + out = np.pad(input_data, pad, mode=mode, constant_values=value) + elif mode == "reflect": + out = np.pad(input_data, pad, mode=mode) + elif mode == "replicate": + out = np.pad(input_data, pad, mode="edge") + elif mode == "circular": + out = np.pad(input_data, pad, mode="wrap") + + return out + + def setUp(self): + self.places = [paddle.XPUPlace(0)] + self.dtype = self.in_type + + def test_class(self): + paddle.disable_static() + for place in self.places: + input_shape = (3, 4, 5) + pad = [1, 2] + pad_int = 1 + value = 100 + input_data = np.random.rand(*input_shape).astype(self.dtype) + + pad_reflection = nn.Pad1D(padding=pad, mode="reflect") + pad_replication = nn.Pad1D(padding=pad, mode="replicate") + pad_constant = nn.Pad1D( + padding=pad, mode="constant", value=value + ) + pad_constant_int = nn.Pad1D( + padding=pad_int, mode="constant", value=value + ) + pad_circular = nn.Pad1D(padding=pad, mode="circular") + + data = paddle.to_tensor(input_data) + + output = pad_reflection(data) + np_out = self._get_numpy_out( + input_data, pad, "reflect", data_format="NCL" + ) + np.testing.assert_allclose(output.numpy(), np_out, rtol=1e-05) + + output = pad_replication(data) + np_out = self._get_numpy_out( + input_data, pad, "replicate", data_format="NCL" + ) + np.testing.assert_allclose(output.numpy(), np_out, rtol=1e-05) + + output = pad_constant(data) + np_out = self._get_numpy_out( + input_data, pad, "constant", value=value, data_format="NCL" + ) + np.testing.assert_allclose(output.numpy(), np_out, rtol=1e-05) + + output = pad_constant_int(data) + np_out = self._get_numpy_out( + input_data, + [pad_int] * 2, + "constant", + value=value, + data_format="NCL", + ) + np.testing.assert_allclose(output.numpy(), np_out, rtol=1e-05) + + class TestPad2dAPI(unittest.TestCase): + def _get_numpy_out( + self, input_data, pad, mode, value=0.0, data_format="NCHW" + ): + if data_format == "NCHW": + pad = [ + (0, 0), + (0, 0), + (pad[2], pad[3]), + (pad[0], pad[1]), + ] + else: + pad = [ + (0, 0), + (pad[2], pad[3]), + (pad[0], pad[1]), + (0, 0), + ] + + if mode == "constant": + out = np.pad(input_data, pad, mode=mode, constant_values=value) + elif mode == "reflect": + out = np.pad(input_data, pad, mode=mode) + elif mode == "replicate": + out = np.pad(input_data, pad, mode="edge") + elif mode == "circular": + out = np.pad(input_data, pad, mode="wrap") + + return out + + def setUp(self): + self.places = [paddle.XPUPlace(0)] + self.dtype = self.in_type + + def test_class(self): + paddle.disable_static() + for place in self.places: + input_shape = (3, 4, 5, 6) + pad = [1, 2, 2, 1] + pad_int = 1 + value = 100 + input_data = np.random.rand(*input_shape).astype(self.dtype) + + pad_reflection = nn.Pad2D(padding=pad, mode="reflect") + pad_replication = nn.Pad2D(padding=pad, mode="replicate") + pad_constant = nn.Pad2D( + padding=pad, mode="constant", value=value + ) + pad_constant_int = nn.Pad2D( + padding=pad_int, mode="constant", value=value + ) + pad_circular = nn.Pad2D(padding=pad, mode="circular") + + data = paddle.to_tensor(input_data) + + output = pad_reflection(data) + np_out = self._get_numpy_out( + input_data, pad, "reflect", data_format="NCHW" + ) + np.testing.assert_allclose(output.numpy(), np_out, rtol=1e-05) + + output = pad_replication(data) + np_out = self._get_numpy_out( + input_data, pad, "replicate", data_format="NCHW" + ) + np.testing.assert_allclose(output.numpy(), np_out, rtol=1e-05) + + output = pad_constant(data) + np_out = self._get_numpy_out( + input_data, pad, "constant", value=value, data_format="NCHW" + ) + np.testing.assert_allclose(output.numpy(), np_out, rtol=1e-05) + + output = pad_constant_int(data) + np_out = self._get_numpy_out( + input_data, + [pad_int] * 4, + "constant", + value=value, + data_format="NCHW", + ) + np.testing.assert_allclose(output.numpy(), np_out, rtol=1e-05) + class TestPad3dAPI(unittest.TestCase): def _get_numpy_out( self, input_data, pad, mode, value=0.0, data_format="NCDHW" -- GitLab