未验证 提交 2ff949da 编写于 作者: H houj04 提交者: GitHub

[XPU] speed up for special case of strided_slice op. (#55166)

上级 a0951187
......@@ -16,6 +16,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/xpu/stride_slice_util.h"
namespace phi {
......@@ -77,6 +78,71 @@ void StridedSliceRawGradKernel(const Context& dev_ctx,
strides_in[cur_axe] = strides_[i];
}
if (is_strided_slice_special_case(xshape, starts_in, ends_in, strides_in)) {
PADDLE_ENFORCE_EQ(
x.numel(),
x_grad->numel(),
errors::PreconditionNotMet(
"x.numel() should be equal to x_grad->numel() in special case."));
PADDLE_ENFORCE_EQ(
x.numel(),
out_grad.numel() * 2,
errors::PreconditionNotMet("x.numel() should be equal to "
"out_grad->numel() * 2 in special case."));
/*
* sample input: [1 2 3 4 5]
* starts = [0/1]
* strides = [2]
* sample output: [1 0 2 0 3 0 4 0 5 0] (last value in starts is 0)
* sample output: [0 1 0 2 0 3 0 4 0 5] (last value in starts is 1)
*/
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* x_transpose = RAII_GUARD.alloc_l3_or_gm<XPUType>(x.numel());
// step 1: set all value to 0
// int constant(Context* ctx, T* x, int len, T val)
int r = xpu::constant(
dev_ctx.x_context(), x_transpose, x.numel(), static_cast<XPUType>(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
/*
* step 2: copy dy to dx:
* if starts from 0: [1 2 3 4 5 0 0 0 0 0]
* if starts from 1: [0 0 0 0 0 1 2 3 4 5]
*/
int offset = 0;
if (starts_in.back() == 1) {
offset = x.numel() / 2;
}
// int copy(Context* ctx, const T* x, T* y, int64_t len)
r = xpu::copy<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
x_transpose + offset,
x.numel() / 2);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
/*
* step3: transpose, input shape is (2, x.numel/2):
* input:
* [1 2 3 4 5
* 0 0 0 0 0]
* after transpose:
* [1 0
* 2 0
* 3 0
* 4 0
* 5 0]
*/
r = xpu::transpose<XPUType>(dev_ctx.x_context(),
x_transpose,
reinterpret_cast<XPUType*>(x_grad->data<T>()),
{2, x.numel() / 2},
{1, 0});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
return;
}
int r = xpu::strided_slice_grad(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
......
......@@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"
#include "paddle/phi/kernels/xpu/stride_slice_util.h"
namespace phi {
......@@ -99,6 +100,57 @@ void StridedSliceRawKernel(const Context& dev_ctx,
strides_in[cur_axe] = strides_[i];
}
if (is_strided_slice_special_case(xshape, starts_in, ends_in, strides_in)) {
PADDLE_ENFORCE_EQ(
x.numel(),
out->numel() * 2,
errors::PreconditionNotMet(
"x.numel() should be equal to out->numel() * 2 in special case."));
/*
* sample input: [1 2 3 4 5 6 7 8 9 10]
* starts = [0/1]
* strides = [2]
* sample output: [1 3 5 7 9] (last value in starts is 0)
* sample output: [2 4 6 8 10] (last value in starts is 1)
*/
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* x_transpose = RAII_GUARD.alloc_l3_or_gm<XPUType>(x.numel());
/*
* step 1: transpose, input shape is (x.numel/2, 2):
* input:
* [1 2
* 3 4
* 5 6
* 7 8
* 9 10]
* after transpose:
* [1 3 5 7 9
* 2 4 6 8 10]
*/
// int transpose(Context* ctx, const T* x, T* y, const std::vector<int>&
// xshape, const std::vector<int>& permute)
int r =
xpu::transpose<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
x_transpose,
{x.numel() / 2, 2},
{1, 0});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
// step 2: if starts from 0, use "first half" data as result, otherwise use
// "second half".
int offset = 0;
if (starts_in.back() == 1) {
offset = x.numel() / 2;
}
// int copy(Context* ctx, const T* x, T* y, int64_t len)
r = xpu::copy<XPUType>(dev_ctx.x_context(),
x_transpose + offset,
reinterpret_cast<XPUType*>(out->data<T>()),
x.numel() / 2);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
return;
}
int r = xpu::strided_slice(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
namespace phi {
inline bool is_strided_slice_special_case(const std::vector<int>& xshape,
const std::vector<int>& starts,
const std::vector<int>& ends,
const std::vector<int>& strides) {
// starts match {0, 0, ..., 0, 0} or {0, 0, ..., 0, 1}
for (size_t i = 0; i < starts.size() - 1; i++) {
if (starts[i] != 0) {
return false;
}
}
if (starts.back() != 0 && starts.back() != 1) {
return false;
}
// xshape match ends
if (xshape != ends) {
return false;
}
// strides match {1, 1, ..., 1, 2}
for (size_t i = 0; i < strides.size() - 1; i++) {
if (strides[i] != 1) {
return false;
}
}
if (strides.back() != 2) {
return false;
}
// last dim of xshape is even number
if (xshape.back() % 2 != 0) {
return false;
}
return true;
}
} // namespace phi
......@@ -174,6 +174,24 @@ class XPUTestStrideSliceOp(XPUOpTestWrapper):
self.strides = [1, 1, 1, 1, 1, 2]
self.infer_flags = [1, 1, 1, 1, 1]
class XPUTestStrideSliceOp_eb_1(XPUTestStrideSliceOp):
def initTestCase(self):
self.inshape = (1, 4, 4096, 128)
self.axes = [0, 1, 2, 3]
self.starts = [0, 0, 0, 0]
self.ends = [1, 4, 4096, 128]
self.strides = [1, 1, 1, 2]
self.infer_flags = [1, 1, 1, 1]
class XPUTestStrideSliceOp_eb_2(XPUTestStrideSliceOp):
def initTestCase(self):
self.inshape = (1, 4, 4096, 128)
self.axes = [0, 1, 2, 3]
self.starts = [0, 0, 0, 1]
self.ends = [1, 4, 4096, 128]
self.strides = [1, 1, 1, 2]
self.infer_flags = [1, 1, 1, 1]
support_types = get_xpu_op_support_types('strided_slice')
for stype in support_types:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册