diff --git a/paddle/phi/kernels/funcs/slice_utils.h b/paddle/phi/kernels/funcs/slice_utils.h index ed403c75dbdc8a08d6f1f7a7aeb0875dfbd9aed2..a56a5e16f6503d79ad99ae11d8579f2bf67aef54 100644 --- a/paddle/phi/kernels/funcs/slice_utils.h +++ b/paddle/phi/kernels/funcs/slice_utils.h @@ -101,6 +101,49 @@ inline void CheckAndUpdateSliceAttrs(const DDim in_dims, } } +template +inline void UpdateSliceAttrs(const DDim in_dims, + const std::vector& axes, + std::vector* starts, + std::vector* ends, + std::vector* steps = nullptr, + std::vector* infer_flags = nullptr) { + for (size_t i = 0; i < axes.size(); ++i) { + T axis = axes[i]; + if (infer_flags != nullptr && (*infer_flags)[i] == -1) { + continue; + } + T dim_value = in_dims[axis]; + if (dim_value > 0) { + T step = steps == nullptr ? 1 : (*steps)[i]; + T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i]; + start = std::max(start, static_cast(0)); + T end = + 0 < step && (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i]; + end = std::min(end, dim_value); + + if (step > 0) { + start = std::min(start, dim_value); + end = std::max(end, static_cast(0)); + } else { + // NOTE: When step < 0, start should less and equal to + // dim_value-1 + // "end is -1" means contain the 0-th element of this axis. + start = std::min(start, dim_value - 1); + if (end < -1) { + end += dim_value; + } + end = std::max(end, static_cast(-1)); + } + (*starts)[i] = start; + (*ends)[i] = end; + } else if (dim_value == 0) { + (*starts)[i] = 0; + (*ends)[i] = 0; + } + } +} + template inline phi::DDim GetSliceDims(const phi::DDim in_dims, const std::vector& axes, diff --git a/paddle/phi/kernels/impl/slice_kernel_impl.h b/paddle/phi/kernels/impl/slice_kernel_impl.h index 2b1c2d87dbd61deffb6ec7b60b7cb5a0889b24cd..8a8925e3e0e228c2e17a9d9cbec49849778c5946 100644 --- a/paddle/phi/kernels/impl/slice_kernel_impl.h +++ b/paddle/phi/kernels/impl/slice_kernel_impl.h @@ -53,7 +53,7 @@ void SliceCompute(const Context& ctx, } } - funcs::CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends); + funcs::UpdateSliceAttrs(in_dims, axes, &starts, &ends); slice_dims = funcs::GetSliceDims( in_dims, axes, starts, ends, nullptr, nullptr); out_dims = funcs::GetDecreasedDims(slice_dims, decrease_axis);