From 2ff949da9d894a51b679753b03e74d77f4fc5a98 Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Thu, 6 Jul 2023 11:41:37 +0800 Subject: [PATCH] [XPU] speed up for special case of strided_slice op. (#55166) --- .../kernels/xpu/stride_slice_grad_kernel.cc | 66 +++++++++++++++++++ paddle/phi/kernels/xpu/stride_slice_kernel.cc | 52 +++++++++++++++ paddle/phi/kernels/xpu/stride_slice_util.h | 54 +++++++++++++++ test/xpu/test_strided_slice_op_xpu.py | 18 +++++ 4 files changed, 190 insertions(+) create mode 100644 paddle/phi/kernels/xpu/stride_slice_util.h diff --git a/paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc b/paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc index 70bd235688c..fbc7a0bf6ab 100644 --- a/paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc @@ -16,6 +16,7 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/xpu/stride_slice_util.h" namespace phi { @@ -77,6 +78,71 @@ void StridedSliceRawGradKernel(const Context& dev_ctx, strides_in[cur_axe] = strides_[i]; } + if (is_strided_slice_special_case(xshape, starts_in, ends_in, strides_in)) { + PADDLE_ENFORCE_EQ( + x.numel(), + x_grad->numel(), + errors::PreconditionNotMet( + "x.numel() should be equal to x_grad->numel() in special case.")); + PADDLE_ENFORCE_EQ( + x.numel(), + out_grad.numel() * 2, + errors::PreconditionNotMet("x.numel() should be equal to " + "out_grad->numel() * 2 in special case.")); + + /* + * sample input: [1 2 3 4 5] + * starts = [0/1] + * strides = [2] + * sample output: [1 0 2 0 3 0 4 0 5 0] (last value in starts is 0) + * sample output: [0 1 0 2 0 3 0 4 0 5] (last value in starts is 1) + */ + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + XPUType* x_transpose = RAII_GUARD.alloc_l3_or_gm(x.numel()); + + // step 1: set all value to 0 + + // int constant(Context* ctx, T* x, int len, T val) + int r = xpu::constant( + dev_ctx.x_context(), x_transpose, x.numel(), static_cast(0)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + + /* + * step 2: copy dy to dx: + * if starts from 0: [1 2 3 4 5 0 0 0 0 0] + * if starts from 1: [0 0 0 0 0 1 2 3 4 5] + */ + int offset = 0; + if (starts_in.back() == 1) { + offset = x.numel() / 2; + } + // int copy(Context* ctx, const T* x, T* y, int64_t len) + r = xpu::copy(dev_ctx.x_context(), + reinterpret_cast(out_grad.data()), + x_transpose + offset, + x.numel() / 2); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); + /* + * step3: transpose, input shape is (2, x.numel/2): + * input: + * [1 2 3 4 5 + * 0 0 0 0 0] + * after transpose: + * [1 0 + * 2 0 + * 3 0 + * 4 0 + * 5 0] + */ + r = xpu::transpose(dev_ctx.x_context(), + x_transpose, + reinterpret_cast(x_grad->data()), + {2, x.numel() / 2}, + {1, 0}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + return; + } + int r = xpu::strided_slice_grad( dev_ctx.x_context(), reinterpret_cast(out_grad.data()), diff --git a/paddle/phi/kernels/xpu/stride_slice_kernel.cc b/paddle/phi/kernels/xpu/stride_slice_kernel.cc index da181376997..a2de8c2c8ff 100644 --- a/paddle/phi/kernels/xpu/stride_slice_kernel.cc +++ b/paddle/phi/kernels/xpu/stride_slice_kernel.cc @@ -17,6 +17,7 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/strided_slice.h" +#include "paddle/phi/kernels/xpu/stride_slice_util.h" namespace phi { @@ -99,6 +100,57 @@ void StridedSliceRawKernel(const Context& dev_ctx, strides_in[cur_axe] = strides_[i]; } + if (is_strided_slice_special_case(xshape, starts_in, ends_in, strides_in)) { + PADDLE_ENFORCE_EQ( + x.numel(), + out->numel() * 2, + errors::PreconditionNotMet( + "x.numel() should be equal to out->numel() * 2 in special case.")); + /* + * sample input: [1 2 3 4 5 6 7 8 9 10] + * starts = [0/1] + * strides = [2] + * sample output: [1 3 5 7 9] (last value in starts is 0) + * sample output: [2 4 6 8 10] (last value in starts is 1) + */ + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + XPUType* x_transpose = RAII_GUARD.alloc_l3_or_gm(x.numel()); + /* + * step 1: transpose, input shape is (x.numel/2, 2): + * input: + * [1 2 + * 3 4 + * 5 6 + * 7 8 + * 9 10] + * after transpose: + * [1 3 5 7 9 + * 2 4 6 8 10] + */ + // int transpose(Context* ctx, const T* x, T* y, const std::vector& + // xshape, const std::vector& permute) + int r = + xpu::transpose(dev_ctx.x_context(), + reinterpret_cast(x.data()), + x_transpose, + {x.numel() / 2, 2}, + {1, 0}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + // step 2: if starts from 0, use "first half" data as result, otherwise use + // "second half". + int offset = 0; + if (starts_in.back() == 1) { + offset = x.numel() / 2; + } + // int copy(Context* ctx, const T* x, T* y, int64_t len) + r = xpu::copy(dev_ctx.x_context(), + x_transpose + offset, + reinterpret_cast(out->data()), + x.numel() / 2); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); + return; + } + int r = xpu::strided_slice(dev_ctx.x_context(), reinterpret_cast(x.data()), reinterpret_cast(out->data()), diff --git a/paddle/phi/kernels/xpu/stride_slice_util.h b/paddle/phi/kernels/xpu/stride_slice_util.h new file mode 100644 index 00000000000..70e36456351 --- /dev/null +++ b/paddle/phi/kernels/xpu/stride_slice_util.h @@ -0,0 +1,54 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace phi { + +inline bool is_strided_slice_special_case(const std::vector& xshape, + const std::vector& starts, + const std::vector& ends, + const std::vector& strides) { + // starts match {0, 0, ..., 0, 0} or {0, 0, ..., 0, 1} + for (size_t i = 0; i < starts.size() - 1; i++) { + if (starts[i] != 0) { + return false; + } + } + if (starts.back() != 0 && starts.back() != 1) { + return false; + } + // xshape match ends + if (xshape != ends) { + return false; + } + // strides match {1, 1, ..., 1, 2} + for (size_t i = 0; i < strides.size() - 1; i++) { + if (strides[i] != 1) { + return false; + } + } + if (strides.back() != 2) { + return false; + } + // last dim of xshape is even number + if (xshape.back() % 2 != 0) { + return false; + } + return true; +} + +} // namespace phi diff --git a/test/xpu/test_strided_slice_op_xpu.py b/test/xpu/test_strided_slice_op_xpu.py index 63954dfd785..e86bc8606f0 100644 --- a/test/xpu/test_strided_slice_op_xpu.py +++ b/test/xpu/test_strided_slice_op_xpu.py @@ -174,6 +174,24 @@ class XPUTestStrideSliceOp(XPUOpTestWrapper): self.strides = [1, 1, 1, 1, 1, 2] self.infer_flags = [1, 1, 1, 1, 1] + class XPUTestStrideSliceOp_eb_1(XPUTestStrideSliceOp): + def initTestCase(self): + self.inshape = (1, 4, 4096, 128) + self.axes = [0, 1, 2, 3] + self.starts = [0, 0, 0, 0] + self.ends = [1, 4, 4096, 128] + self.strides = [1, 1, 1, 2] + self.infer_flags = [1, 1, 1, 1] + + class XPUTestStrideSliceOp_eb_2(XPUTestStrideSliceOp): + def initTestCase(self): + self.inshape = (1, 4, 4096, 128) + self.axes = [0, 1, 2, 3] + self.starts = [0, 0, 0, 1] + self.ends = [1, 4, 4096, 128] + self.strides = [1, 1, 1, 2] + self.infer_flags = [1, 1, 1, 1] + support_types = get_xpu_op_support_types('strided_slice') for stype in support_types: -- GitLab