From 8aec0580e648c20f1178f7266fa5ba14a0380dd5 Mon Sep 17 00:00:00 2001 From: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> Date: Mon, 27 Feb 2023 14:34:00 +0800 Subject: [PATCH] Reduce redundant cpu computation in slice compute (#50348) * conflict * add UpdateSliceAttrs --- paddle/phi/kernels/funcs/slice_utils.h | 43 +++++++++++++++++++++ paddle/phi/kernels/impl/slice_kernel_impl.h | 2 +- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/funcs/slice_utils.h b/paddle/phi/kernels/funcs/slice_utils.h index ed403c75dbd..a56a5e16f65 100644 --- a/paddle/phi/kernels/funcs/slice_utils.h +++ b/paddle/phi/kernels/funcs/slice_utils.h @@ -101,6 +101,49 @@ inline void CheckAndUpdateSliceAttrs(const DDim in_dims, } } +template +inline void UpdateSliceAttrs(const DDim in_dims, + const std::vector& axes, + std::vector* starts, + std::vector* ends, + std::vector* steps = nullptr, + std::vector* infer_flags = nullptr) { + for (size_t i = 0; i < axes.size(); ++i) { + T axis = axes[i]; + if (infer_flags != nullptr && (*infer_flags)[i] == -1) { + continue; + } + T dim_value = in_dims[axis]; + if (dim_value > 0) { + T step = steps == nullptr ? 1 : (*steps)[i]; + T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i]; + start = std::max(start, static_cast(0)); + T end = + 0 < step && (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i]; + end = std::min(end, dim_value); + + if (step > 0) { + start = std::min(start, dim_value); + end = std::max(end, static_cast(0)); + } else { + // NOTE: When step < 0, start should less and equal to + // dim_value-1 + // "end is -1" means contain the 0-th element of this axis. + start = std::min(start, dim_value - 1); + if (end < -1) { + end += dim_value; + } + end = std::max(end, static_cast(-1)); + } + (*starts)[i] = start; + (*ends)[i] = end; + } else if (dim_value == 0) { + (*starts)[i] = 0; + (*ends)[i] = 0; + } + } +} + template inline phi::DDim GetSliceDims(const phi::DDim in_dims, const std::vector& axes, diff --git a/paddle/phi/kernels/impl/slice_kernel_impl.h b/paddle/phi/kernels/impl/slice_kernel_impl.h index 2b1c2d87dbd..8a8925e3e0e 100644 --- a/paddle/phi/kernels/impl/slice_kernel_impl.h +++ b/paddle/phi/kernels/impl/slice_kernel_impl.h @@ -53,7 +53,7 @@ void SliceCompute(const Context& ctx, } } - funcs::CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends); + funcs::UpdateSliceAttrs(in_dims, axes, &starts, &ends); slice_dims = funcs::GetSliceDims( in_dims, axes, starts, ends, nullptr, nullptr); out_dims = funcs::GetDecreasedDims(slice_dims, decrease_axis); -- GitLab