From 8cecc9bd3571a812440d2dc82f85269a2841d0d1 Mon Sep 17 00:00:00 2001 From: phlrain Date: Thu, 17 Feb 2022 07:50:20 +0000 Subject: [PATCH] merge develop; test=develop --- .../kernels/funcs}/slice_utils.h | 59 +++++++++++-------- .../kernels/impl/slice_grad_kernel_impl.h | 2 +- paddle/pten/kernels/impl/slice_kernel_impl.h | 11 ++-- 3 files changed, 40 insertions(+), 32 deletions(-) rename paddle/{fluid/operators => pten/kernels/funcs}/slice_utils.h (79%) diff --git a/paddle/fluid/operators/slice_utils.h b/paddle/pten/kernels/funcs/slice_utils.h similarity index 79% rename from paddle/fluid/operators/slice_utils.h rename to paddle/pten/kernels/funcs/slice_utils.h index c02e54a8a2c..ddeadaf274e 100644 --- a/paddle/fluid/operators/slice_utils.h +++ b/paddle/pten/kernels/funcs/slice_utils.h @@ -13,12 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include +#include #include #include -namespace paddle { -namespace operators { +namespace pten { template inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, @@ -30,11 +29,14 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, for (size_t i = 0; i < axes.size(); ++i) { T axis = axes[i]; PADDLE_ENFORCE_LT( - axis, in_dims.size(), - platform::errors::InvalidArgument( + axis, + in_dims.size(), + pten::errors::InvalidArgument( "The axis value should be less than the rank of input, " "but received axes[%d] = %d, rank of input is %d.", - i, axis, in_dims.size())); + i, + axis, + in_dims.size())); if (infer_flags != nullptr && (*infer_flags)[i] == -1) { continue; @@ -45,8 +47,10 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, if (dim_value > 0) { T step = steps == nullptr ? 1 : (*steps)[i]; PADDLE_ENFORCE_NE( - step, 0, platform::errors::InvalidArgument( - "Step should not be 0, but received step = %d.", step)); + step, + 0, + pten::errors::InvalidArgument( + "Step should not be 0, but received step = %d.", step)); T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i]; start = std::max(start, static_cast(0)); @@ -59,11 +63,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, start = std::min(start, dim_value); end = std::max(end, static_cast(0)); PADDLE_ENFORCE_GE( - end, start, - platform::errors::InvalidArgument( + end, + start, + pten::errors::InvalidArgument( "When step > 0, end should be greater than start, but " "received end = %d, start = %d.", - end, start)); + end, + start)); } else { // NOTE(liym27): When step < 0, start should less and equal to // dim_value-1 @@ -71,11 +77,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, start = std::min(start, dim_value - 1); end = std::max(end, static_cast(-1)); PADDLE_ENFORCE_GE( - start, end, - platform::errors::InvalidArgument( + start, + end, + pten::errors::InvalidArgument( "When step < 0, start should be greater than end, but " "received start = %d, end = %d.", - start, end)); + start, + end)); } (*starts)[i] = start; @@ -88,13 +96,14 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, } template -inline framework::DDim GetSliceDims(const framework::DDim in_dims, - const std::vector& axes, - const std::vector& starts, - const std::vector& ends, - std::vector* steps = nullptr, - std::vector* infer_flags = nullptr) { - framework::DDim slice_dims(in_dims); +inline pten::framework::DDim GetSliceDims( + const pten::framework::DDim in_dims, + const std::vector& axes, + const std::vector& starts, + const std::vector& ends, + std::vector* steps = nullptr, + std::vector* infer_flags = nullptr) { + pten::framework::DDim slice_dims(in_dims); for (size_t i = 0; i < axes.size(); ++i) { T axis = axes[i]; @@ -127,8 +136,9 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims, T axis = decrease_axes[i]; decrease_flag[axis] = 1; if (infer_flags && (*infer_flags)[i] != -1) { - PADDLE_ENFORCE_EQ(decreased_dims[axis], 1, - platform::errors::InvalidArgument( + PADDLE_ENFORCE_EQ(decreased_dims[axis], + 1, + pten::errors::InvalidArgument( "Decrease dim should be 1, but now received %d", decreased_dims[axis])); } @@ -152,5 +162,4 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims, return decreased_dims; } -} // namespace operators -} // namespace paddle +} // namespace pten diff --git a/paddle/pten/kernels/impl/slice_grad_kernel_impl.h b/paddle/pten/kernels/impl/slice_grad_kernel_impl.h index 2f442bdaf8e..e188522fc7b 100644 --- a/paddle/pten/kernels/impl/slice_grad_kernel_impl.h +++ b/paddle/pten/kernels/impl/slice_grad_kernel_impl.h @@ -14,9 +14,9 @@ #pragma once -#include "paddle/fluid/operators/slice_utils.h" #include "paddle/pten/kernels/funcs/eigen/common.h" #include "paddle/pten/kernels/funcs/eigen/eigen_function.h" +#include "paddle/pten/kernels/funcs/slice_utils.h" #include "paddle/pten/kernels/slice_grad_kernel.h" namespace pten { diff --git a/paddle/pten/kernels/impl/slice_kernel_impl.h b/paddle/pten/kernels/impl/slice_kernel_impl.h index b3c4f65de4b..b9075e12384 100644 --- a/paddle/pten/kernels/impl/slice_kernel_impl.h +++ b/paddle/pten/kernels/impl/slice_kernel_impl.h @@ -14,9 +14,9 @@ #pragma once -#include "paddle/fluid/operators/slice_utils.h" #include "paddle/pten/kernels/funcs/eigen/common.h" #include "paddle/pten/kernels/funcs/eigen/eigen_function.h" +#include "paddle/pten/kernels/funcs/slice_utils.h" namespace pten { @@ -60,11 +60,10 @@ void SliceCompute(const Context& ctx, } } - paddle::operators::CheckAndUpdateSliceAttrs( - in_dims, axes, &starts, &ends); - slice_dims = paddle::operators::GetSliceDims( - in_dims, axes, starts, ends, nullptr, nullptr); - out_dims = paddle::operators::GetDecreasedDims(slice_dims, decrease_axis); + CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends); + slice_dims = + GetSliceDims(in_dims, axes, starts, ends, nullptr, nullptr); + out_dims = GetDecreasedDims(slice_dims, decrease_axis); // 2.2 Get output auto offsets = Eigen::DSizes(); -- GitLab