From 4cf01462014f2e18e4af5003e98f0f934b878327 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Wed, 9 Jun 2021 15:12:02 +0800 Subject: [PATCH] Polish code for slice and set_value op (#32947) --- paddle/fluid/operators/set_value_op.h | 105 +----- paddle/fluid/operators/slice_op.cc | 101 ++---- paddle/fluid/operators/slice_op.h | 505 +++++++++++--------------- paddle/fluid/operators/slice_utils.h | 143 ++++++++ 4 files changed, 382 insertions(+), 472 deletions(-) create mode 100644 paddle/fluid/operators/slice_utils.h diff --git a/paddle/fluid/operators/set_value_op.h b/paddle/fluid/operators/set_value_op.h index eca51147f8..c7b61333cd 100644 --- a/paddle/fluid/operators/set_value_op.h +++ b/paddle/fluid/operators/set_value_op.h @@ -23,6 +23,7 @@ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/assign_value_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/operators/slice_utils.h" #include "paddle/fluid/operators/utils.h" #include "paddle/fluid/platform/enforce.h" @@ -59,106 +60,6 @@ inline std::string GetValueName(framework::proto::VarType::Type data_type) { return value_name; } -inline void CheckAndUpdateSlice(const framework::DDim in_dims, - const std::vector axes, - std::vector* starts, - std::vector* ends, - std::vector* steps) { - for (size_t i = 0; i < axes.size(); ++i) { - int64_t axis = axes[i]; - int64_t dim_value = in_dims[axis]; - - int64_t start = - (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i]; - int64_t end = (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i]; - start = std::max(start, static_cast(0)); - end = std::min(end, dim_value); - - int64_t step = (*steps)[i]; - PADDLE_ENFORCE_NE( - step, 0, platform::errors::InvalidArgument( - "Step should not be 0, but received step = %d.", step)); - if (step > 0) { - start = std::min(start, dim_value); - end = std::max(end, static_cast(0)); - PADDLE_ENFORCE_GT( - end, start, - platform::errors::InvalidArgument( - "When step > 0, end should be greater than start, but " - "received end = %d, start = %d.", - end, start)); - } else { - // NOTE(liym27): When step < 0, start should less and equal to dim_value-1 - // "end is -1" means contain the 0-th element of this axis. - start = std::min(start, dim_value - 1); - end = std::max(end, static_cast(-1)); - PADDLE_ENFORCE_GT( - start, end, - platform::errors::InvalidArgument( - "When step < 0, start should be greater than end, but " - "received start = %d, end = %d.", - start, end)); - } - - (*starts)[i] = start; - (*ends)[i] = end; - } -} - -inline framework::DDim GetSliceDims(const framework::DDim in_dims, - const std::vector& axes, - const std::vector& starts, - const std::vector& ends, - const std::vector& steps) { - framework::DDim slice_dims(in_dims); - - for (size_t i = 0; i < axes.size(); ++i) { - int64_t axis = axes[i]; - int64_t start = starts[i]; - int64_t end = ends[i]; - int64_t step = steps[i]; - - if (step > 0) { - slice_dims[axis] = (end - start + step - 1) / step; - } else { - slice_dims[axis] = (end - start + step + 1) / step; - } - } - return slice_dims; -} - -inline framework::DDim GetDecreasedDims( - const framework::DDim slice_dims, - const std::vector& decrease_axes) { - // Get dims after decreasing axes. - framework::DDim decreased_dims(slice_dims); - if (decrease_axes.size() > 0) { - for (size_t i = 0; i < decrease_axes.size(); ++i) { - int64_t axis = decrease_axes[i]; - PADDLE_ENFORCE_EQ( - decreased_dims[axis], 1, - platform::errors::InvalidArgument("decrease dim should be 1")); - decreased_dims[axis] = 0; - } - - std::vector new_shape; - for (int i = 0; i < decreased_dims.size(); ++i) { - if (decreased_dims[i] != 0) { - new_shape.push_back(decreased_dims[i]); - } - } - - // NOTE(liym27): Paddle does not support that the rank of Tensor is 0, and - // uses [1] instead. - if (new_shape.size() == 0) { - new_shape.push_back(1); - } - - decreased_dims = framework::make_ddim(new_shape); - } - return decreased_dims; -} - template class SetValueKernel : public framework::OpKernel { public: @@ -225,8 +126,8 @@ class SetValueKernel : public framework::OpKernel { } auto in_dims = in->dims(); - CheckAndUpdateSlice(in_dims, axes, &starts, &ends, &steps); - auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, steps); + CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends, &steps); + auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, &steps); auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes); auto place = ctx.GetPlace(); diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index b529897972..01daba7c07 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -28,13 +28,10 @@ class SliceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, - platform::errors::InvalidArgument( - "Input (Input) of slice op should not be null.")); + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "slice"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "slice"); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - platform::errors::InvalidArgument( - "Output (Out) of slice op should not be null.")); + // Case 1: Special treatment when input is a tensor array. auto x_var_type = ctx->GetInputsVarType("Input")[0]; auto axes = ctx->Attrs().Get>("axes"); if (x_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) { @@ -57,6 +54,8 @@ class SliceOp : public framework::OperatorWithKernel { return; } } + + // Case 2: input is a tensor. auto in_dims = ctx->GetInputDim("Input"); PADDLE_ENFORCE_LT(in_dims.size(), 7, platform::errors::InvalidArgument( @@ -65,101 +64,54 @@ class SliceOp : public framework::OperatorWithKernel { auto starts = ctx->Attrs().Get>("starts"); auto ends = ctx->Attrs().Get>("ends"); - auto infer_flags = ctx->Attrs().Get>("infer_flags"); auto decrease_axis = ctx->Attrs().Get>("decrease_axis"); - - auto starts_size = starts.size(); - auto ends_size = ends.size(); + auto infer_flags = ctx->Attrs().Get>("infer_flags"); if (infer_flags.empty()) { // Initialize infer_flags with 1. // To be compatible with other op tests in which infer_flags is not set. infer_flags = std::vector(axes.size(), 1); } + // 2.1 Check attrs. + auto starts_size = starts.size(); + auto ends_size = ends.size(); + if (ctx->HasInputs("StartsTensorList")) { - auto StartsTensorList = ctx->Inputs("StartsTensorList"); - PADDLE_ENFORCE_GT(StartsTensorList.size(), 0, + starts_size = ctx->Inputs("StartsTensorList").size(); + PADDLE_ENFORCE_GT(starts_size, 0, platform::errors::InvalidArgument( "StartsTensorList size can't be zero")); - starts_size = StartsTensorList.size(); } if (ctx->HasInputs("EndsTensorList")) { - auto EndsTensorList = ctx->Inputs("EndsTensorList"); - PADDLE_ENFORCE_GT(EndsTensorList.size(), 0, - platform::errors::InvalidArgument( - "EndsTensorList size can't be zero")); - ends_size = EndsTensorList.size(); + ends_size = ctx->Inputs("EndsTensorList").size(); + PADDLE_ENFORCE_GT(ends_size, 0, platform::errors::InvalidArgument( + "EndsTensorList size can't be zero")); } - if (ctx->HasInput("StartsTensor") == false) { + if (!ctx->HasInput("StartsTensor")) { PADDLE_ENFORCE_EQ( starts_size, axes.size(), platform::errors::InvalidArgument( "The size of starts must be equal to the size of axes.")); } - if (ctx->HasInput("EndsTensor") == false) { + if (!ctx->HasInput("EndsTensor")) { PADDLE_ENFORCE_EQ( ends_size, axes.size(), platform::errors::InvalidArgument( "The size of ends must be equal to the size of axes.")); } - int dim_value, start, end; - for (size_t i = 0; i < axes.size(); ++i) { - PADDLE_ENFORCE_LT(static_cast(axes[i]), in_dims.size(), - platform::errors::InvalidArgument( - "The index of dimension in axes must be less " - "than the size of input shape.")); - if (infer_flags[i] == -1) { - out_dims[axes[i]] = -1; - } else { - // infer out_dim shape - dim_value = out_dims[axes[i]]; - if (dim_value > 0) { - start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i]; - end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i]; - start = std::max(start, 0); - end = std::max(end, 0); - end = std::min(end, dim_value); - - PADDLE_ENFORCE_LE(start, dim_value, - platform::errors::InvalidArgument( - "start should be less than or equal to the " - "dimension value, but received " - "start = %d, shape[%d] = %d.", - starts[i], axes[i], out_dims[axes[i]])); - PADDLE_ENFORCE_GT(end, start, - platform::errors::InvalidArgument( - "end should greater than start, but received " - "end = %d, start = %d.", - ends[i], starts[i])); - out_dims[axes[i]] = end - start; - } - } - } - // generate new shape - if (decrease_axis.size() > 0) { - std::vector new_out_shape; - for (size_t i = 0; i < decrease_axis.size(); ++i) { - if (ctx->IsRuntime() && infer_flags[i] != -1) { - PADDLE_ENFORCE_EQ( - out_dims[decrease_axis[i]], 1, - platform::errors::InvalidArgument("decrease dim should be 1")); - } - out_dims[decrease_axis[i]] = 0; - } + CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends, nullptr, + &infer_flags); - for (int i = 0; i < out_dims.size(); ++i) { - if (out_dims[i] != 0) { - new_out_shape.push_back(out_dims[i]); - } - } - if (new_out_shape.size() == 0) { - new_out_shape.push_back(1); - } - - out_dims = framework::make_ddim(new_out_shape); + auto slice_dims = + GetSliceDims(in_dims, axes, starts, ends, nullptr, &infer_flags); + if (ctx->IsRuntime()) { + out_dims = GetDecreasedDims(slice_dims, decrease_axis, &infer_flags); + } else { + out_dims = GetDecreasedDims(slice_dims, decrease_axis, nullptr); } + ctx->SetOutputDim("Out", out_dims); if (axes[0] != 0) { ctx->ShareLoD("Input", /*->*/ "Out"); @@ -185,6 +137,7 @@ class SliceOp : public framework::OperatorWithKernel { return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace()); } + framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const Tensor &tensor, const framework::OpKernelType &expected_kernel_type) const override { diff --git a/paddle/fluid/operators/slice_op.h b/paddle/fluid/operators/slice_op.h index 3d294ae238..96b8ea11d6 100644 --- a/paddle/fluid/operators/slice_op.h +++ b/paddle/fluid/operators/slice_op.h @@ -19,21 +19,67 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/slice_utils.h" #include "paddle/fluid/operators/utils.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; +using Variable = framework::Variable; +using LoDTensorArray = framework::LoDTensorArray; +using DDim = framework::DDim; + +inline void DealTensorArray(const framework::ExecutionContext& ctx, + const std::vector& starts, + const std::vector& ends, + bool out_is_array) { + auto in_array = ctx.Input("Input"); + // If the input is LoDTensorArray, the rank of input is 1. + int64_t in_size = in_array->size(); + int64_t start = starts[0] < 0 ? (starts[0] + in_size) : starts[0]; + int64_t end = ends[0] < 0 ? (ends[0] + in_size) : ends[0]; + + start = std::max(start, static_cast(0)); + end = std::max(end, static_cast(0)); + end = std::min(end, in_size); + + PADDLE_ENFORCE_GT(end, start, + platform::errors::InvalidArgument( + "Attr(ends) should be greater than attr(starts) in " + "slice op. But received end = %d, start = %d.", + ends[0], starts[0])); + int64_t out_size = end - start; + + if (out_is_array) { + auto out_array = ctx.Output("Out"); + out_array->resize(out_size); + + for (int i = 0; i < out_size; ++i) { + auto* out_tensor = &out_array->at(i); + auto in_tensor = in_array->at(i + start); + out_tensor->set_lod(in_tensor.lod()); + if (in_tensor.memory_size() > 0) { + TensorCopy(in_tensor, ctx.GetPlace(), out_tensor); + } else { + VLOG(10) << "WARNING: The input tensor 'x_tensor' holds no memory, so " + "nothing has been written to output array[" + << i << "]."; + } + } + } else { + auto out = ctx.Output("Out"); + auto in_tensor = in_array->at(start); + TensorCopy(in_tensor, ctx.GetPlace(), out); + } +} template class SliceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - const framework::Variable* input_var = ctx.InputVar("Input"); - bool is_tensor_array = input_var->IsType(); - int rank = is_tensor_array - ? 1 - : ctx.Input("Input")->dims().size(); + const Variable* input_var = ctx.InputVar("Input"); + bool is_tensor_array = input_var->IsType(); + int rank = is_tensor_array ? 1 : ctx.Input("Input")->dims().size(); switch (rank) { case 1: @@ -54,53 +100,45 @@ class SliceKernel : public framework::OpKernel { case 6: SliceCompute<6>(ctx); break; + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "The rank of input should be less than 7, but received %d.", rank)); } } private: template - void SliceCompute(const framework::ExecutionContext& context) const { - auto& place = - *context.template device_context().eigen_device(); - const framework::Variable* input_var = context.InputVar("Input"); - framework::Variable* out_var = context.OutputVar("Out"); - bool input_is_tensor_array = input_var->IsType(); - bool out_is_tensor_array = out_var->IsType(); - - auto axes = context.Attr>("axes"); - - auto starts_int = context.Attr>("starts"); + void SliceCompute(const framework::ExecutionContext& ctx) const { + const Variable* input_var = ctx.InputVar("Input"); + Variable* out_var = ctx.OutputVar("Out"); + bool input_is_array = input_var->IsType(); + bool out_is_array = out_var->IsType(); + + auto axes_int = ctx.Attr>("axes"); + auto starts_int = ctx.Attr>("starts"); + auto ends_int = ctx.Attr>("ends"); + std::vector axes(axes_int.begin(), axes_int.end()); std::vector starts(starts_int.begin(), starts_int.end()); - auto ends_int = context.Attr>("ends"); std::vector ends(ends_int.begin(), ends_int.end()); - auto decrease_axis = context.Attr>("decrease_axis"); - auto infer_flags = context.Attr>("infer_flags"); - auto list_new_ends_tensor = - context.MultiInput("EndsTensorList"); - auto list_new_starts_tensor = - context.MultiInput("StartsTensorList"); - - bool need_infer = false; - if (context.HasInput("StartsTensor") || context.HasInput("EndsTensor")) { - need_infer = true; - } - if (list_new_starts_tensor.size() > 0 || list_new_ends_tensor.size() > 0) { - need_infer = true; + + auto decrease_axis = ctx.Attr>("decrease_axis"); + auto infer_flags = ctx.Attr>("infer_flags"); + + // Step 1: Get the accurate attribute value of starts and ends + auto starts_tensor_list = ctx.MultiInput("StartsTensorList"); + if (ctx.HasInput("StartsTensor")) { + starts = GetDataFromTensor(ctx.Input("StartsTensor")); + } else if (starts_tensor_list.size() > 0) { + starts = GetDataFromTensorList(starts_tensor_list); } - if (need_infer) { - if (context.HasInput("StartsTensor")) { - auto* starts_tensor = context.Input("StartsTensor"); - starts = GetDataFromTensor(starts_tensor); - } else if (list_new_starts_tensor.size() > 0) { - starts = GetDataFromTensorList(list_new_starts_tensor); - } - if (context.HasInput("EndsTensor")) { - auto* ends_tensor = context.Input("EndsTensor"); - ends = GetDataFromTensor(ends_tensor); - } else if (list_new_ends_tensor.size() > 0) { - ends = GetDataFromTensorList(list_new_ends_tensor); - } + + auto ends_tensor_list = ctx.MultiInput("EndsTensorList"); + if (ctx.HasInput("EndsTensor")) { + ends = GetDataFromTensor(ctx.Input("EndsTensor")); + } else if (ends_tensor_list.size() > 0) { + ends = GetDataFromTensorList(ends_tensor_list); } + PADDLE_ENFORCE_EQ( starts.size(), axes.size(), platform::errors::InvalidArgument( @@ -109,175 +147,74 @@ class SliceKernel : public framework::OpKernel { ends.size(), axes.size(), platform::errors::InvalidArgument( "The size of ends must be equal to the size of axes.")); - if (input_is_tensor_array) { - auto in_array = context.Input("Input"); - // If the input is LoDTensorArray, the rank of input is 1. - int64_t in_size = in_array->size(); - int64_t start = starts[0] < 0 ? (starts[0] + in_size) : starts[0]; - int64_t end = ends[0] < 0 ? (ends[0] + in_size) : ends[0]; - - start = std::max(start, static_cast(0)); - end = std::max(end, static_cast(0)); - end = std::min(end, in_size); - - PADDLE_ENFORCE_GT(end, start, - platform::errors::InvalidArgument( - "Attr(ends) should be greater than attr(starts) in " - "slice op. But received end = %d, start = %d.", - ends[0], starts[0])); - int64_t out_size = end - start; - - if (out_is_tensor_array) { - auto out_array = context.Output("Out"); - out_array->resize(out_size); - - for (int i = 0; i < out_size; ++i) { - auto* out_tensor = &out_array->at(i); - auto in_tensor = in_array->at(i + start); - out_tensor->set_lod(in_tensor.lod()); - if (in_tensor.memory_size() > 0) { - TensorCopy(in_tensor, context.GetPlace(), out_tensor); - } else { - VLOG(10) - << "WARNING: The input tensor 'x_tensor' holds no memory, so " - "nothing has been written to output array[" - << i << "]."; - } - } - } else { - auto out = context.Output("Out"); - auto in_tensor = in_array->at(start); - TensorCopy(in_tensor, context.GetPlace(), out); - } + // Step 2: Compute output + if (input_is_array) { + DealTensorArray(ctx, starts, ends, out_is_array); return; - } + } else { + auto in = ctx.Input("Input"); + auto out = ctx.Output("Out"); - auto in = context.Input("Input"); - auto out = context.Output("Out"); + auto in_dims = in->dims(); + auto out_dims = out->dims(); + auto slice_dims = out_dims; - auto out_dims = out->dims(); - auto in_dims = in->dims(); - if (need_infer) { - out_dims = in_dims; - int64_t dim_value, start, end; + // 2.1 Infer output dims for (size_t i = 0; i < axes.size(); ++i) { - dim_value = out_dims[axes[i]]; - if (dim_value > 0) { - // when end = start+1 and start == -1 - if (starts[i] == -1 && ends[i] == 0 && infer_flags[i] == -1) { - auto ret = - std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]); - if (ret != decrease_axis.end()) { - ends[i] = 10000000; - } - } - - start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i]; - end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i]; - start = std::max(start, static_cast(0)); - end = std::max(end, static_cast(0)); - end = std::min(end, dim_value); - PADDLE_ENFORCE_GT( - end, start, - platform::errors::InvalidArgument( - "Attr(ends) should be greater than attr(starts) in " - "slice op. But received end = %d, start = %d.", - ends[i], starts[i])); - out_dims[axes[i]] = end - start; - } - } - out->Resize(out_dims); - // generate new shape - if (decrease_axis.size() > 0) { - std::vector new_out_shape; - for (size_t i = 0; i < decrease_axis.size(); ++i) { - PADDLE_ENFORCE_EQ( - out_dims[decrease_axis[i]], 1, - platform::errors::InvalidArgument("decrease dim should be 1")); - out_dims[decrease_axis[i]] = 0; - } - - for (int i = 0; i < out_dims.size(); ++i) { - if (out_dims[i] != 0) { - new_out_shape.push_back(out_dims[i]); + // when start == -1 && end == start+1 + if (starts[i] == -1 && ends[i] == 0 && infer_flags[i] == -1) { + auto ret = + std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]); + if (ret != decrease_axis.end()) { + ends[i] = in_dims[axes[i]]; } } - if (new_out_shape.size() == 0) { - new_out_shape.push_back(1); - } - - out_dims = framework::make_ddim(new_out_shape); } - } - // resize out_dims - if (decrease_axis.size() > 0) { - if (decrease_axis.size() == (size_t)in_dims.size()) { - std::vector vec_origin_out_shape(decrease_axis.size(), 1); - out->Resize(framework::make_ddim(vec_origin_out_shape)); - } else { - std::vector vec_origin_out_shape( - out_dims.size() + decrease_axis.size(), -1); + CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends); + slice_dims = + GetSliceDims(in_dims, axes, starts, ends, nullptr, nullptr); + out_dims = GetDecreasedDims(slice_dims, decrease_axis); - for (size_t i = 0; i < decrease_axis.size(); ++i) { - vec_origin_out_shape[decrease_axis[i]] = 1; - } + // 2.2 Get output + auto offsets = Eigen::DSizes(); + auto extents = Eigen::DSizes(); - int index = 0; - for (size_t i = 0; i < vec_origin_out_shape.size(); ++i) { - if (vec_origin_out_shape[i] == -1) { - vec_origin_out_shape[i] = out_dims[index]; - ++index; - } - } - - out->Resize(framework::make_ddim(vec_origin_out_shape)); + for (size_t i = 0; i < D; ++i) { + offsets[i] = 0; + extents[i] = slice_dims[i]; } - } - - out->mutable_data(context.GetPlace()); - - auto new_out_dims = out->dims(); - auto offsets = Eigen::DSizes(); - auto extents = Eigen::DSizes(); - for (size_t i = 0; i < D; ++i) { - offsets[i] = 0; - extents[i] = new_out_dims[i]; - } - int64_t start; - for (size_t i = 0; i < axes.size(); ++i) { - start = starts[i]; - if (start < 0) { - start = (start + in_dims[axes[i]]); + for (size_t i = 0; i < axes.size(); ++i) { + offsets[axes[i]] = starts[i]; } - start = std::max(start, static_cast(0)); - offsets[axes[i]] = start; - } - auto in_t = - framework::EigenTensor::From( - *in); - auto out_t = - framework::EigenTensor::From( - *out, new_out_dims); - if (in->numel() <= Eigen::NumTraits::highest()) { - // similar to tf.slice: - // if element number less than INT_MAX, change the type of index to int - Eigen::DSizes offsets_32bit, extents_32bit; - for (size_t i = 0; i < D; i++) { - offsets_32bit[i] = offsets[i]; - extents_32bit[i] = extents[i]; + out->Resize(slice_dims); + out->mutable_data(ctx.GetPlace()); + + auto in_t = framework::EigenTensor::From(*in, in_dims); + auto out_t = framework::EigenTensor::From(*out, slice_dims); + auto& eigen_place = + *ctx.template device_context().eigen_device(); + + if (in->numel() <= Eigen::NumTraits::highest()) { + // similar to tf.slice: + // if element number less than INT_MAX, change the type of index to int + Eigen::DSizes offsets_32bit, extents_32bit; + for (size_t i = 0; i < D; i++) { + offsets_32bit[i] = offsets[i]; + extents_32bit[i] = extents[i]; + } + EigenSlice, T, D>::Eval( + eigen_place, framework::To32BitIndex(out_t), + framework::To32BitIndex(in_t), offsets_32bit, extents_32bit); + } else { + EigenSlice, T, D>::Eval( + eigen_place, out_t, in_t, offsets, extents); } - EigenSlice, T, D>::Eval( - place, framework::To32BitIndex(out_t), framework::To32BitIndex(in_t), - offsets_32bit, extents_32bit); - } else { - EigenSlice, T, D>::Eval(place, out_t, in_t, - offsets, extents); - } - out->Resize(out_dims); + out->Resize(out_dims); + } } }; @@ -285,11 +222,9 @@ template class SliceGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - const framework::Variable* input_var = ctx.InputVar("Input"); - bool is_tensor_array = input_var->IsType(); - size_t rank = is_tensor_array - ? 1 - : ctx.Input("Input")->dims().size(); + const Variable* input_var = ctx.InputVar("Input"); + bool is_array = input_var->IsType(); + size_t rank = is_array ? 1 : ctx.Input("Input")->dims().size(); switch (rank) { case 1: @@ -310,53 +245,48 @@ class SliceGradKernel : public framework::OpKernel { case 6: SliceCompute<6>(ctx); break; + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "The rank of input should be less than 7, but received %d.", rank)); } } private: template - void SliceCompute(const framework::ExecutionContext& context) const { - auto axes = context.Attr>("axes"); - - auto starts_int = context.Attr>("starts"); + void SliceCompute(const framework::ExecutionContext& ctx) const { + auto axes = ctx.Attr>("axes"); + auto starts_int = ctx.Attr>("starts"); + auto ends_int = ctx.Attr>("ends"); std::vector starts(starts_int.begin(), starts_int.end()); - - auto ends_int = context.Attr>("ends"); std::vector ends(ends_int.begin(), ends_int.end()); - auto list_new_ends_tensor = - context.MultiInput("EndsTensorList"); - auto list_new_starts_tensor = - context.MultiInput("StartsTensorList"); - - if (list_new_starts_tensor.size() > 0) { - starts = GetDataFromTensorList(list_new_starts_tensor); - } else if (context.HasInput("StartsTensor")) { - auto* starts_tensor = context.Input("StartsTensor"); - starts = GetDataFromTensor(starts_tensor); + // Get the accurate attribute value of starts and ends + auto starts_tensor_list = ctx.MultiInput("StartsTensorList"); + if (ctx.HasInput("StartsTensor")) { + starts = GetDataFromTensor(ctx.Input("StartsTensor")); + } else if (starts_tensor_list.size() > 0) { + starts = GetDataFromTensorList(starts_tensor_list); } - if (list_new_ends_tensor.size() > 0) { - ends = GetDataFromTensorList(list_new_ends_tensor); - } else if (context.HasInput("EndsTensor")) { - auto* ends_tensor = context.Input("EndsTensor"); - ends = GetDataFromTensor(ends_tensor); + auto ends_tensor_list = ctx.MultiInput("EndsTensorList"); + if (ctx.HasInput("EndsTensor")) { + ends = GetDataFromTensor(ctx.Input("EndsTensor")); + } else if (ends_tensor_list.size() > 0) { + ends = GetDataFromTensorList(ends_tensor_list); } - framework::Variable* d_input_var = - context.OutputVar(framework::GradVarName("Input")); - const framework::Variable* d_out_var = - context.InputVar(framework::GradVarName("Out")); - bool d_input_is_tensor_array = - d_input_var->IsType(); - bool d_out_is_tensor_array = d_out_var->IsType(); - - if (d_input_is_tensor_array) { - auto* input_array = context.Input("Input"); - auto* d_input_array = context.Output( - framework::GradVarName("Input")); + + Variable* d_input_var = ctx.OutputVar(framework::GradVarName("Input")); + const Variable* d_out_var = ctx.InputVar(framework::GradVarName("Out")); + bool d_input_is_array = d_input_var->IsType(); + bool d_out_is_array = d_out_var->IsType(); + + if (d_input_is_array) { + auto* input_array = ctx.Input("Input"); + auto* d_in_arr = + ctx.Output(framework::GradVarName("Input")); int64_t d_in_size = input_array->size(); - d_input_array->resize(d_in_size); + d_in_arr->resize(d_in_size); // If the input is LoDTensorArray, the rank of input is 1. // So only use the 0th element of starts. int64_t start = starts[0] < 0 ? (starts[0] + d_in_size) : starts[0]; @@ -364,68 +294,60 @@ class SliceGradKernel : public framework::OpKernel { // set zero platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto& dev_ctx = *pool.Get(context.GetPlace()); - T value = T(0); + auto& dev_ctx = *pool.Get(ctx.GetPlace()); math::SetConstant functor; for (int i = 0; i < d_in_size; ++i) { auto dim = input_array->at(i).dims(); - d_input_array->at(i).Resize(dim); - d_input_array->at(i).mutable_data(context.GetPlace()); + d_in_arr->at(i).Resize(dim); + d_in_arr->at(i).mutable_data(ctx.GetPlace()); functor(reinterpret_cast(dev_ctx), - &d_input_array->at(i), static_cast(value)); + &d_in_arr->at(i), static_cast(0)); } - if (d_out_is_tensor_array) { - auto* d_out_array = context.Input( - framework::GradVarName("Out")); - int d_out_size = d_out_array->size(); + if (d_out_is_array) { + auto* d_out_arr = + ctx.Input(framework::GradVarName("Out")); + int d_out_size = d_out_arr->size(); for (int i = 0; i < d_out_size; ++i) { - TensorCopy(d_out_array->at(i), context.GetPlace(), - &(d_input_array->at(start + i))); + TensorCopy(d_out_arr->at(i), ctx.GetPlace(), + &(d_in_arr->at(start + i))); } - } else { - auto* d_out = - context.Input(framework::GradVarName("Out")); - TensorCopy(*d_out, context.GetPlace(), &(d_input_array->at(start))); + auto* d_out = ctx.Input(framework::GradVarName("Out")); + TensorCopy(*d_out, ctx.GetPlace(), &(d_in_arr->at(start))); } return; } - auto* d_out = - context.Input(framework::GradVarName("Out")); - - auto* d_input = - context.Output(framework::GradVarName("Input")); - - d_input->mutable_data(context.GetPlace()); + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto* d_input = ctx.Output(framework::GradVarName("Input")); + d_input->mutable_data(ctx.GetPlace()); auto out_dims = d_out->dims(); auto in_dims = d_input->dims(); - auto decrease_axis = context.Attr>("decrease_axis"); - if (decrease_axis.size() > 0) { - if (decrease_axis.size() == (size_t)in_dims.size()) { + auto decrease_axis = ctx.Attr>("decrease_axis"); + auto decrease_size = decrease_axis.size(); + if (decrease_size > 0) { + if (decrease_size == (size_t)in_dims.size()) { // all dims decrease - std::vector vec_origin_out_shape(decrease_axis.size(), 1); - out_dims = framework::make_ddim(vec_origin_out_shape); + std::vector origin_out_shape(decrease_size, 1); + out_dims = framework::make_ddim(std::vector(decrease_size, 1)); } else { - std::vector vec_origin_out_shape( - out_dims.size() + decrease_axis.size(), -1); - - for (size_t i = 0; i < decrease_axis.size(); ++i) { - vec_origin_out_shape[decrease_axis[i]] = 1; + std::vector origin_out_shape(out_dims.size() + decrease_size, -1); + for (size_t i = 0; i < decrease_size; ++i) { + origin_out_shape[decrease_axis[i]] = 1; } int index = 0; - for (size_t i = 0; i < vec_origin_out_shape.size(); ++i) { - if (vec_origin_out_shape[i] == -1) { - vec_origin_out_shape[i] = out_dims[index]; + for (size_t i = 0; i < origin_out_shape.size(); ++i) { + if (origin_out_shape[i] == -1) { + origin_out_shape[i] = out_dims[index]; ++index; } } - out_dims = framework::make_ddim(vec_origin_out_shape); + out_dims = framework::make_ddim(origin_out_shape); } } @@ -435,28 +357,26 @@ class SliceGradKernel : public framework::OpKernel { offsets[i] = 0; extents[i] = out_dims[i]; } - int64_t start; + for (size_t i = 0; i < axes.size(); ++i) { - start = starts[i]; - if (start < 0) { - start = (start + in_dims[axes[i]]); - } + int axis = axes[i]; + int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i]; start = std::max(start, static_cast(0)); - offsets[axes[i]] = start; + offsets[axis] = start; } + Eigen::array, D> paddings; for (size_t i = 0; i < paddings.size(); ++i) { paddings[i].first = offsets[i]; paddings[i].second = (in_dims[i] - out_dims[i]) - offsets[i]; } - EigenPaddingCompute(context, d_input, in_dims, d_out, out_dims, paddings); + EigenPaddingCompute(ctx, d_input, in_dims, d_out, out_dims, paddings); } template void EigenPaddingCompute( - const framework::ExecutionContext& context, framework::Tensor* d_input, - const framework::DDim& in_dims, const framework::Tensor* d_out, - const framework::DDim& out_dims, + const framework::ExecutionContext& context, Tensor* d_input, + const DDim& in_dims, const Tensor* d_out, const DDim& out_dims, const Eigen::array, D>& paddings) const { if (D <= 3) { // if dimension less than 3, cannot reduce dimension @@ -512,10 +432,8 @@ class SliceGradKernel : public framework::OpKernel { out_tore_shape[1] = out_dims[pad_dim]; // convert array from std::vector to DDim - framework::DDim reshaped_in_dims = - framework::make_ddim(in_tore_shape); - framework::DDim reshaped_out_dims = - framework::make_ddim(out_tore_shape); + DDim reshaped_in_dims = framework::make_ddim(in_tore_shape); + DDim reshaped_out_dims = framework::make_ddim(out_tore_shape); // after reshape: the first dimension do not need padding, // set padding[0] zero @@ -543,10 +461,8 @@ class SliceGradKernel : public framework::OpKernel { } // convert array from std::vector to DDim - framework::DDim reshaped_in_dims = - framework::make_ddim(in_tore_shape); - framework::DDim reshaped_out_dims = - framework::make_ddim(out_tore_shape); + DDim reshaped_in_dims = framework::make_ddim(in_tore_shape); + DDim reshaped_out_dims = framework::make_ddim(out_tore_shape); // after reshape: // the first dimension is the previous padding dimension @@ -579,10 +495,8 @@ class SliceGradKernel : public framework::OpKernel { } // convert array from std::vector to DDim - framework::DDim reshaped_in_dims = - framework::make_ddim(in_tore_shape); - framework::DDim reshaped_out_dims = - framework::make_ddim(out_tore_shape); + DDim reshaped_in_dims = framework::make_ddim(in_tore_shape); + DDim reshaped_out_dims = framework::make_ddim(out_tore_shape); // after reshape: // the first dimension do not need padding, set padding[0] zero @@ -606,9 +520,8 @@ class SliceGradKernel : public framework::OpKernel { template void LaunchEigenPadding( - const framework::ExecutionContext& context, framework::Tensor* d_input, - const framework::DDim& in_dims, const framework::Tensor* d_out, - const framework::DDim& out_dims, + const framework::ExecutionContext& context, Tensor* d_input, + const DDim& in_dims, const Tensor* d_out, const DDim& out_dims, const Eigen::array, D>& paddings) const { auto& place = *context.template device_context().eigen_device(); diff --git a/paddle/fluid/operators/slice_utils.h b/paddle/fluid/operators/slice_utils.h new file mode 100644 index 0000000000..60782a9a92 --- /dev/null +++ b/paddle/fluid/operators/slice_utils.h @@ -0,0 +1,143 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, + const std::vector& axes, + std::vector* starts, + std::vector* ends, + std::vector* steps = nullptr, + std::vector* infer_flags = nullptr) { + for (size_t i = 0; i < axes.size(); ++i) { + T axis = axes[i]; + T dim_value = in_dims[axis]; + + if (dim_value > 0) { + if (infer_flags != nullptr && (*infer_flags)[i] == -1) { + continue; + } + T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i]; + start = std::max(start, static_cast(0)); + + T end = (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i]; + end = std::min(end, dim_value); + + T step = steps == nullptr ? 1 : (*steps)[i]; + PADDLE_ENFORCE_NE( + step, 0, platform::errors::InvalidArgument( + "Step should not be 0, but received step = %d.", step)); + + if (step > 0) { + start = std::min(start, dim_value); + end = std::max(end, static_cast(0)); + PADDLE_ENFORCE_GT( + end, start, + platform::errors::InvalidArgument( + "When step > 0, end should be greater than start, but " + "received end = %d, start = %d.", + end, start)); + } else { + // NOTE(liym27): When step < 0, start should less and equal to + // dim_value-1 + // "end is -1" means contain the 0-th element of this axis. + start = std::min(start, dim_value - 1); + end = std::max(end, static_cast(-1)); + PADDLE_ENFORCE_GT( + start, end, + platform::errors::InvalidArgument( + "When step < 0, start should be greater than end, but " + "received start = %d, end = %d.", + start, end)); + } + + (*starts)[i] = start; + (*ends)[i] = end; + } + } +} + +template +inline framework::DDim GetSliceDims(const framework::DDim in_dims, + const std::vector& axes, + const std::vector& starts, + const std::vector& ends, + std::vector* steps = nullptr, + std::vector* infer_flags = nullptr) { + framework::DDim slice_dims(in_dims); + + for (size_t i = 0; i < axes.size(); ++i) { + T axis = axes[i]; + if (infer_flags != nullptr && (*infer_flags)[i] == -1) { + slice_dims[axis] = -1; + continue; + } + + T start = starts[i]; + T end = ends[i]; + T step = steps == nullptr ? 1 : (*steps)[i]; + + if (step > 0) { + slice_dims[axis] = (end - start + step - 1) / step; + } else { + slice_dims[axis] = (end - start + step + 1) / step; + } + } + return slice_dims; +} + +template +inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims, + const std::vector& decrease_axes, + std::vector* infer_flags = nullptr) { + framework::DDim decreased_dims(slice_dims); + if (decrease_axes.size() > 0) { + for (size_t i = 0; i < decrease_axes.size(); ++i) { + T axis = decrease_axes[i]; + if (infer_flags && (*infer_flags)[i] != -1) { + PADDLE_ENFORCE_EQ( + decreased_dims[axis], 1, + platform::errors::InvalidArgument("decrease dim should be 1")); + } + decreased_dims[axis] = 0; + } + + std::vector new_shape; + for (int i = 0; i < decreased_dims.size(); ++i) { + if (decreased_dims[i] != 0) { + new_shape.push_back(decreased_dims[i]); + } + } + + // NOTE(liym27): Paddle does not support that the rank of Tensor is 0, and + // uses [1] instead. + if (new_shape.size() == 0) { + new_shape.push_back(1); + } + + decreased_dims = framework::make_ddim(new_shape); + } + return decreased_dims; +} + +} // namespace operators +} // namespace paddle -- GitLab