未验证 提交 4cf01462 编写于 作者: L liym27 提交者: GitHub

Polish code for slice and set_value op (#32947)

上级 a039fd7b
......@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/assign_value_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/slice_utils.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -59,106 +60,6 @@ inline std::string GetValueName(framework::proto::VarType::Type data_type) {
return value_name;
}
inline void CheckAndUpdateSlice(const framework::DDim in_dims,
const std::vector<int64_t> axes,
std::vector<int64_t>* starts,
std::vector<int64_t>* ends,
std::vector<int64_t>* steps) {
for (size_t i = 0; i < axes.size(); ++i) {
int64_t axis = axes[i];
int64_t dim_value = in_dims[axis];
int64_t start =
(*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
int64_t end = (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i];
start = std::max(start, static_cast<int64_t>(0));
end = std::min(end, dim_value);
int64_t step = (*steps)[i];
PADDLE_ENFORCE_NE(
step, 0, platform::errors::InvalidArgument(
"Step should not be 0, but received step = %d.", step));
if (step > 0) {
start = std::min(start, dim_value);
end = std::max(end, static_cast<int64_t>(0));
PADDLE_ENFORCE_GT(
end, start,
platform::errors::InvalidArgument(
"When step > 0, end should be greater than start, but "
"received end = %d, start = %d.",
end, start));
} else {
// NOTE(liym27): When step < 0, start should less and equal to dim_value-1
// "end is -1" means contain the 0-th element of this axis.
start = std::min(start, dim_value - 1);
end = std::max(end, static_cast<int64_t>(-1));
PADDLE_ENFORCE_GT(
start, end,
platform::errors::InvalidArgument(
"When step < 0, start should be greater than end, but "
"received start = %d, end = %d.",
start, end));
}
(*starts)[i] = start;
(*ends)[i] = end;
}
}
inline framework::DDim GetSliceDims(const framework::DDim in_dims,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& steps) {
framework::DDim slice_dims(in_dims);
for (size_t i = 0; i < axes.size(); ++i) {
int64_t axis = axes[i];
int64_t start = starts[i];
int64_t end = ends[i];
int64_t step = steps[i];
if (step > 0) {
slice_dims[axis] = (end - start + step - 1) / step;
} else {
slice_dims[axis] = (end - start + step + 1) / step;
}
}
return slice_dims;
}
inline framework::DDim GetDecreasedDims(
const framework::DDim slice_dims,
const std::vector<int64_t>& decrease_axes) {
// Get dims after decreasing axes.
framework::DDim decreased_dims(slice_dims);
if (decrease_axes.size() > 0) {
for (size_t i = 0; i < decrease_axes.size(); ++i) {
int64_t axis = decrease_axes[i];
PADDLE_ENFORCE_EQ(
decreased_dims[axis], 1,
platform::errors::InvalidArgument("decrease dim should be 1"));
decreased_dims[axis] = 0;
}
std::vector<int64_t> new_shape;
for (int i = 0; i < decreased_dims.size(); ++i) {
if (decreased_dims[i] != 0) {
new_shape.push_back(decreased_dims[i]);
}
}
// NOTE(liym27): Paddle does not support that the rank of Tensor is 0, and
// uses [1] instead.
if (new_shape.size() == 0) {
new_shape.push_back(1);
}
decreased_dims = framework::make_ddim(new_shape);
}
return decreased_dims;
}
template <typename DeviceContext, typename T>
class SetValueKernel : public framework::OpKernel<T> {
public:
......@@ -225,8 +126,8 @@ class SetValueKernel : public framework::OpKernel<T> {
}
auto in_dims = in->dims();
CheckAndUpdateSlice(in_dims, axes, &starts, &ends, &steps);
auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, steps);
CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends, &steps);
auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, &steps);
auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes);
auto place = ctx.GetPlace();
......
......@@ -28,13 +28,10 @@ class SliceOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
platform::errors::InvalidArgument(
"Input (Input) of slice op should not be null."));
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "slice");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "slice");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output (Out) of slice op should not be null."));
// Case 1: Special treatment when input is a tensor array.
auto x_var_type = ctx->GetInputsVarType("Input")[0];
auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
if (x_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
......@@ -57,6 +54,8 @@ class SliceOp : public framework::OperatorWithKernel {
return;
}
}
// Case 2: input is a tensor.
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_LT(in_dims.size(), 7,
platform::errors::InvalidArgument(
......@@ -65,101 +64,54 @@ class SliceOp : public framework::OperatorWithKernel {
auto starts = ctx->Attrs().Get<std::vector<int>>("starts");
auto ends = ctx->Attrs().Get<std::vector<int>>("ends");
auto infer_flags = ctx->Attrs().Get<std::vector<int>>("infer_flags");
auto decrease_axis = ctx->Attrs().Get<std::vector<int>>("decrease_axis");
auto starts_size = starts.size();
auto ends_size = ends.size();
auto infer_flags = ctx->Attrs().Get<std::vector<int>>("infer_flags");
if (infer_flags.empty()) {
// Initialize infer_flags with 1.
// To be compatible with other op tests in which infer_flags is not set.
infer_flags = std::vector<int>(axes.size(), 1);
}
// 2.1 Check attrs.
auto starts_size = starts.size();
auto ends_size = ends.size();
if (ctx->HasInputs("StartsTensorList")) {
auto StartsTensorList = ctx->Inputs("StartsTensorList");
PADDLE_ENFORCE_GT(StartsTensorList.size(), 0,
starts_size = ctx->Inputs("StartsTensorList").size();
PADDLE_ENFORCE_GT(starts_size, 0,
platform::errors::InvalidArgument(
"StartsTensorList size can't be zero"));
starts_size = StartsTensorList.size();
}
if (ctx->HasInputs("EndsTensorList")) {
auto EndsTensorList = ctx->Inputs("EndsTensorList");
PADDLE_ENFORCE_GT(EndsTensorList.size(), 0,
platform::errors::InvalidArgument(
ends_size = ctx->Inputs("EndsTensorList").size();
PADDLE_ENFORCE_GT(ends_size, 0, platform::errors::InvalidArgument(
"EndsTensorList size can't be zero"));
ends_size = EndsTensorList.size();
}
if (ctx->HasInput("StartsTensor") == false) {
if (!ctx->HasInput("StartsTensor")) {
PADDLE_ENFORCE_EQ(
starts_size, axes.size(),
platform::errors::InvalidArgument(
"The size of starts must be equal to the size of axes."));
}
if (ctx->HasInput("EndsTensor") == false) {
if (!ctx->HasInput("EndsTensor")) {
PADDLE_ENFORCE_EQ(
ends_size, axes.size(),
platform::errors::InvalidArgument(
"The size of ends must be equal to the size of axes."));
}
int dim_value, start, end;
for (size_t i = 0; i < axes.size(); ++i) {
PADDLE_ENFORCE_LT(static_cast<int>(axes[i]), in_dims.size(),
platform::errors::InvalidArgument(
"The index of dimension in axes must be less "
"than the size of input shape."));
if (infer_flags[i] == -1) {
out_dims[axes[i]] = -1;
} else {
// infer out_dim shape
dim_value = out_dims[axes[i]];
if (dim_value > 0) {
start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
start = std::max(start, 0);
end = std::max(end, 0);
end = std::min(end, dim_value);
PADDLE_ENFORCE_LE(start, dim_value,
platform::errors::InvalidArgument(
"start should be less than or equal to the "
"dimension value, but received "
"start = %d, shape[%d] = %d.",
starts[i], axes[i], out_dims[axes[i]]));
PADDLE_ENFORCE_GT(end, start,
platform::errors::InvalidArgument(
"end should greater than start, but received "
"end = %d, start = %d.",
ends[i], starts[i]));
out_dims[axes[i]] = end - start;
}
}
}
// generate new shape
if (decrease_axis.size() > 0) {
std::vector<int> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
if (ctx->IsRuntime() && infer_flags[i] != -1) {
PADDLE_ENFORCE_EQ(
out_dims[decrease_axis[i]], 1,
platform::errors::InvalidArgument("decrease dim should be 1"));
}
out_dims[decrease_axis[i]] = 0;
}
CheckAndUpdateSliceAttrs<int>(in_dims, axes, &starts, &ends, nullptr,
&infer_flags);
for (int i = 0; i < out_dims.size(); ++i) {
if (out_dims[i] != 0) {
new_out_shape.push_back(out_dims[i]);
}
}
if (new_out_shape.size() == 0) {
new_out_shape.push_back(1);
auto slice_dims =
GetSliceDims<int>(in_dims, axes, starts, ends, nullptr, &infer_flags);
if (ctx->IsRuntime()) {
out_dims = GetDecreasedDims<int>(slice_dims, decrease_axis, &infer_flags);
} else {
out_dims = GetDecreasedDims<int>(slice_dims, decrease_axis, nullptr);
}
out_dims = framework::make_ddim(new_out_shape);
}
ctx->SetOutputDim("Out", out_dims);
if (axes[0] != 0) {
ctx->ShareLoD("Input", /*->*/ "Out");
......@@ -185,6 +137,7 @@ class SliceOp : public framework::OperatorWithKernel {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
......
......@@ -19,21 +19,67 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/slice_utils.h"
#include "paddle/fluid/operators/utils.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using Variable = framework::Variable;
using LoDTensorArray = framework::LoDTensorArray;
using DDim = framework::DDim;
inline void DealTensorArray(const framework::ExecutionContext& ctx,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
bool out_is_array) {
auto in_array = ctx.Input<LoDTensorArray>("Input");
// If the input is LoDTensorArray, the rank of input is 1.
int64_t in_size = in_array->size();
int64_t start = starts[0] < 0 ? (starts[0] + in_size) : starts[0];
int64_t end = ends[0] < 0 ? (ends[0] + in_size) : ends[0];
start = std::max(start, static_cast<int64_t>(0));
end = std::max(end, static_cast<int64_t>(0));
end = std::min(end, in_size);
PADDLE_ENFORCE_GT(end, start,
platform::errors::InvalidArgument(
"Attr(ends) should be greater than attr(starts) in "
"slice op. But received end = %d, start = %d.",
ends[0], starts[0]));
int64_t out_size = end - start;
if (out_is_array) {
auto out_array = ctx.Output<LoDTensorArray>("Out");
out_array->resize(out_size);
for (int i = 0; i < out_size; ++i) {
auto* out_tensor = &out_array->at(i);
auto in_tensor = in_array->at(i + start);
out_tensor->set_lod(in_tensor.lod());
if (in_tensor.memory_size() > 0) {
TensorCopy(in_tensor, ctx.GetPlace(), out_tensor);
} else {
VLOG(10) << "WARNING: The input tensor 'x_tensor' holds no memory, so "
"nothing has been written to output array["
<< i << "].";
}
}
} else {
auto out = ctx.Output<Tensor>("Out");
auto in_tensor = in_array->at(start);
TensorCopy(in_tensor, ctx.GetPlace(), out);
}
}
template <typename DeviceContext, typename T>
class SliceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Variable* input_var = ctx.InputVar("Input");
bool is_tensor_array = input_var->IsType<framework::LoDTensorArray>();
int rank = is_tensor_array
? 1
: ctx.Input<framework::Tensor>("Input")->dims().size();
const Variable* input_var = ctx.InputVar("Input");
bool is_tensor_array = input_var->IsType<LoDTensorArray>();
int rank = is_tensor_array ? 1 : ctx.Input<Tensor>("Input")->dims().size();
switch (rank) {
case 1:
......@@ -54,53 +100,45 @@ class SliceKernel : public framework::OpKernel<T> {
case 6:
SliceCompute<6>(ctx);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", rank));
}
}
private:
template <size_t D>
void SliceCompute(const framework::ExecutionContext& context) const {
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
const framework::Variable* input_var = context.InputVar("Input");
framework::Variable* out_var = context.OutputVar("Out");
bool input_is_tensor_array = input_var->IsType<framework::LoDTensorArray>();
bool out_is_tensor_array = out_var->IsType<framework::LoDTensorArray>();
auto axes = context.Attr<std::vector<int>>("axes");
auto starts_int = context.Attr<std::vector<int>>("starts");
void SliceCompute(const framework::ExecutionContext& ctx) const {
const Variable* input_var = ctx.InputVar("Input");
Variable* out_var = ctx.OutputVar("Out");
bool input_is_array = input_var->IsType<LoDTensorArray>();
bool out_is_array = out_var->IsType<LoDTensorArray>();
auto axes_int = ctx.Attr<std::vector<int>>("axes");
auto starts_int = ctx.Attr<std::vector<int>>("starts");
auto ends_int = ctx.Attr<std::vector<int>>("ends");
std::vector<int64_t> axes(axes_int.begin(), axes_int.end());
std::vector<int64_t> starts(starts_int.begin(), starts_int.end());
auto ends_int = context.Attr<std::vector<int>>("ends");
std::vector<int64_t> ends(ends_int.begin(), ends_int.end());
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis");
auto infer_flags = context.Attr<std::vector<int>>("infer_flags");
auto list_new_ends_tensor =
context.MultiInput<framework::Tensor>("EndsTensorList");
auto list_new_starts_tensor =
context.MultiInput<framework::Tensor>("StartsTensorList");
bool need_infer = false;
if (context.HasInput("StartsTensor") || context.HasInput("EndsTensor")) {
need_infer = true;
}
if (list_new_starts_tensor.size() > 0 || list_new_ends_tensor.size() > 0) {
need_infer = true;
}
if (need_infer) {
if (context.HasInput("StartsTensor")) {
auto* starts_tensor = context.Input<framework::Tensor>("StartsTensor");
starts = GetDataFromTensor<int64_t>(starts_tensor);
} else if (list_new_starts_tensor.size() > 0) {
starts = GetDataFromTensorList<int64_t>(list_new_starts_tensor);
}
if (context.HasInput("EndsTensor")) {
auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor");
ends = GetDataFromTensor<int64_t>(ends_tensor);
} else if (list_new_ends_tensor.size() > 0) {
ends = GetDataFromTensorList<int64_t>(list_new_ends_tensor);
auto decrease_axis = ctx.Attr<std::vector<int>>("decrease_axis");
auto infer_flags = ctx.Attr<std::vector<int>>("infer_flags");
// Step 1: Get the accurate attribute value of starts and ends
auto starts_tensor_list = ctx.MultiInput<Tensor>("StartsTensorList");
if (ctx.HasInput("StartsTensor")) {
starts = GetDataFromTensor<int64_t>(ctx.Input<Tensor>("StartsTensor"));
} else if (starts_tensor_list.size() > 0) {
starts = GetDataFromTensorList<int64_t>(starts_tensor_list);
}
auto ends_tensor_list = ctx.MultiInput<Tensor>("EndsTensorList");
if (ctx.HasInput("EndsTensor")) {
ends = GetDataFromTensor<int64_t>(ctx.Input<Tensor>("EndsTensor"));
} else if (ends_tensor_list.size() > 0) {
ends = GetDataFromTensorList<int64_t>(ends_tensor_list);
}
PADDLE_ENFORCE_EQ(
starts.size(), axes.size(),
platform::errors::InvalidArgument(
......@@ -109,157 +147,55 @@ class SliceKernel : public framework::OpKernel<T> {
ends.size(), axes.size(),
platform::errors::InvalidArgument(
"The size of ends must be equal to the size of axes."));
if (input_is_tensor_array) {
auto in_array = context.Input<framework::LoDTensorArray>("Input");
// If the input is LoDTensorArray, the rank of input is 1.
int64_t in_size = in_array->size();
int64_t start = starts[0] < 0 ? (starts[0] + in_size) : starts[0];
int64_t end = ends[0] < 0 ? (ends[0] + in_size) : ends[0];
start = std::max(start, static_cast<int64_t>(0));
end = std::max(end, static_cast<int64_t>(0));
end = std::min(end, in_size);
PADDLE_ENFORCE_GT(end, start,
platform::errors::InvalidArgument(
"Attr(ends) should be greater than attr(starts) in "
"slice op. But received end = %d, start = %d.",
ends[0], starts[0]));
int64_t out_size = end - start;
if (out_is_tensor_array) {
auto out_array = context.Output<framework::LoDTensorArray>("Out");
out_array->resize(out_size);
for (int i = 0; i < out_size; ++i) {
auto* out_tensor = &out_array->at(i);
auto in_tensor = in_array->at(i + start);
out_tensor->set_lod(in_tensor.lod());
if (in_tensor.memory_size() > 0) {
TensorCopy(in_tensor, context.GetPlace(), out_tensor);
} else {
VLOG(10)
<< "WARNING: The input tensor 'x_tensor' holds no memory, so "
"nothing has been written to output array["
<< i << "].";
}
}
} else {
auto out = context.Output<framework::Tensor>("Out");
auto in_tensor = in_array->at(start);
TensorCopy(in_tensor, context.GetPlace(), out);
}
// Step 2: Compute output
if (input_is_array) {
DealTensorArray(ctx, starts, ends, out_is_array);
return;
}
auto in = context.Input<framework::Tensor>("Input");
auto out = context.Output<framework::Tensor>("Out");
} else {
auto in = ctx.Input<Tensor>("Input");
auto out = ctx.Output<Tensor>("Out");
auto out_dims = out->dims();
auto in_dims = in->dims();
if (need_infer) {
out_dims = in_dims;
int64_t dim_value, start, end;
auto out_dims = out->dims();
auto slice_dims = out_dims;
// 2.1 Infer output dims
for (size_t i = 0; i < axes.size(); ++i) {
dim_value = out_dims[axes[i]];
if (dim_value > 0) {
// when end = start+1 and start == -1
// when start == -1 && end == start+1
if (starts[i] == -1 && ends[i] == 0 && infer_flags[i] == -1) {
auto ret =
std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]);
if (ret != decrease_axis.end()) {
ends[i] = 10000000;
}
}
start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
start = std::max(start, static_cast<int64_t>(0));
end = std::max(end, static_cast<int64_t>(0));
end = std::min(end, dim_value);
PADDLE_ENFORCE_GT(
end, start,
platform::errors::InvalidArgument(
"Attr(ends) should be greater than attr(starts) in "
"slice op. But received end = %d, start = %d.",
ends[i], starts[i]));
out_dims[axes[i]] = end - start;
}
}
out->Resize(out_dims);
// generate new shape
if (decrease_axis.size() > 0) {
std::vector<int64_t> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
PADDLE_ENFORCE_EQ(
out_dims[decrease_axis[i]], 1,
platform::errors::InvalidArgument("decrease dim should be 1"));
out_dims[decrease_axis[i]] = 0;
ends[i] = in_dims[axes[i]];
}
for (int i = 0; i < out_dims.size(); ++i) {
if (out_dims[i] != 0) {
new_out_shape.push_back(out_dims[i]);
}
}
if (new_out_shape.size() == 0) {
new_out_shape.push_back(1);
}
out_dims = framework::make_ddim(new_out_shape);
}
}
// resize out_dims
if (decrease_axis.size() > 0) {
if (decrease_axis.size() == (size_t)in_dims.size()) {
std::vector<int> vec_origin_out_shape(decrease_axis.size(), 1);
out->Resize(framework::make_ddim(vec_origin_out_shape));
} else {
std::vector<int> vec_origin_out_shape(
out_dims.size() + decrease_axis.size(), -1);
for (size_t i = 0; i < decrease_axis.size(); ++i) {
vec_origin_out_shape[decrease_axis[i]] = 1;
}
CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends);
slice_dims =
GetSliceDims<int64_t>(in_dims, axes, starts, ends, nullptr, nullptr);
out_dims = GetDecreasedDims(slice_dims, decrease_axis);
int index = 0;
for (size_t i = 0; i < vec_origin_out_shape.size(); ++i) {
if (vec_origin_out_shape[i] == -1) {
vec_origin_out_shape[i] = out_dims[index];
++index;
}
}
out->Resize(framework::make_ddim(vec_origin_out_shape));
}
}
out->mutable_data<T>(context.GetPlace());
auto new_out_dims = out->dims();
// 2.2 Get output
auto offsets = Eigen::DSizes<Eigen::DenseIndex, D>();
auto extents = Eigen::DSizes<Eigen::DenseIndex, D>();
for (size_t i = 0; i < D; ++i) {
offsets[i] = 0;
extents[i] = new_out_dims[i];
extents[i] = slice_dims[i];
}
int64_t start;
for (size_t i = 0; i < axes.size(); ++i) {
start = starts[i];
if (start < 0) {
start = (start + in_dims[axes[i]]);
}
start = std::max(start, static_cast<int64_t>(0));
offsets[axes[i]] = start;
offsets[axes[i]] = starts[i];
}
auto in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*in);
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*out, new_out_dims);
out->Resize(slice_dims);
out->mutable_data<T>(ctx.GetPlace());
auto in_t = framework::EigenTensor<T, D>::From(*in, in_dims);
auto out_t = framework::EigenTensor<T, D>::From(*out, slice_dims);
auto& eigen_place =
*ctx.template device_context<DeviceContext>().eigen_device();
if (in->numel() <= Eigen::NumTraits<int>::highest()) {
// similar to tf.slice:
......@@ -269,27 +205,26 @@ class SliceKernel : public framework::OpKernel<T> {
offsets_32bit[i] = offsets[i];
extents_32bit[i] = extents[i];
}
EigenSlice<std::decay_t<decltype(place)>, T, D>::Eval(
place, framework::To32BitIndex(out_t), framework::To32BitIndex(in_t),
offsets_32bit, extents_32bit);
EigenSlice<std::decay_t<decltype(eigen_place)>, T, D>::Eval(
eigen_place, framework::To32BitIndex(out_t),
framework::To32BitIndex(in_t), offsets_32bit, extents_32bit);
} else {
EigenSlice<std::decay_t<decltype(place)>, T, D>::Eval(place, out_t, in_t,
offsets, extents);
EigenSlice<std::decay_t<decltype(eigen_place)>, T, D>::Eval(
eigen_place, out_t, in_t, offsets, extents);
}
out->Resize(out_dims);
}
}
};
template <typename DeviceContext, typename T>
class SliceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Variable* input_var = ctx.InputVar("Input");
bool is_tensor_array = input_var->IsType<framework::LoDTensorArray>();
size_t rank = is_tensor_array
? 1
: ctx.Input<framework::Tensor>("Input")->dims().size();
const Variable* input_var = ctx.InputVar("Input");
bool is_array = input_var->IsType<LoDTensorArray>();
size_t rank = is_array ? 1 : ctx.Input<Tensor>("Input")->dims().size();
switch (rank) {
case 1:
......@@ -310,53 +245,48 @@ class SliceGradKernel : public framework::OpKernel<T> {
case 6:
SliceCompute<6>(ctx);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", rank));
}
}
private:
template <size_t D>
void SliceCompute(const framework::ExecutionContext& context) const {
auto axes = context.Attr<std::vector<int>>("axes");
auto starts_int = context.Attr<std::vector<int>>("starts");
void SliceCompute(const framework::ExecutionContext& ctx) const {
auto axes = ctx.Attr<std::vector<int>>("axes");
auto starts_int = ctx.Attr<std::vector<int>>("starts");
auto ends_int = ctx.Attr<std::vector<int>>("ends");
std::vector<int64_t> starts(starts_int.begin(), starts_int.end());
auto ends_int = context.Attr<std::vector<int>>("ends");
std::vector<int64_t> ends(ends_int.begin(), ends_int.end());
auto list_new_ends_tensor =
context.MultiInput<framework::Tensor>("EndsTensorList");
auto list_new_starts_tensor =
context.MultiInput<framework::Tensor>("StartsTensorList");
if (list_new_starts_tensor.size() > 0) {
starts = GetDataFromTensorList<int64_t>(list_new_starts_tensor);
} else if (context.HasInput("StartsTensor")) {
auto* starts_tensor = context.Input<framework::Tensor>("StartsTensor");
starts = GetDataFromTensor<int64_t>(starts_tensor);
}
if (list_new_ends_tensor.size() > 0) {
ends = GetDataFromTensorList<int64_t>(list_new_ends_tensor);
} else if (context.HasInput("EndsTensor")) {
auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor");
ends = GetDataFromTensor<int64_t>(ends_tensor);
}
framework::Variable* d_input_var =
context.OutputVar(framework::GradVarName("Input"));
const framework::Variable* d_out_var =
context.InputVar(framework::GradVarName("Out"));
bool d_input_is_tensor_array =
d_input_var->IsType<framework::LoDTensorArray>();
bool d_out_is_tensor_array = d_out_var->IsType<framework::LoDTensorArray>();
if (d_input_is_tensor_array) {
auto* input_array = context.Input<framework::LoDTensorArray>("Input");
auto* d_input_array = context.Output<framework::LoDTensorArray>(
framework::GradVarName("Input"));
// Get the accurate attribute value of starts and ends
auto starts_tensor_list = ctx.MultiInput<Tensor>("StartsTensorList");
if (ctx.HasInput("StartsTensor")) {
starts = GetDataFromTensor<int64_t>(ctx.Input<Tensor>("StartsTensor"));
} else if (starts_tensor_list.size() > 0) {
starts = GetDataFromTensorList<int64_t>(starts_tensor_list);
}
auto ends_tensor_list = ctx.MultiInput<Tensor>("EndsTensorList");
if (ctx.HasInput("EndsTensor")) {
ends = GetDataFromTensor<int64_t>(ctx.Input<Tensor>("EndsTensor"));
} else if (ends_tensor_list.size() > 0) {
ends = GetDataFromTensorList<int64_t>(ends_tensor_list);
}
Variable* d_input_var = ctx.OutputVar(framework::GradVarName("Input"));
const Variable* d_out_var = ctx.InputVar(framework::GradVarName("Out"));
bool d_input_is_array = d_input_var->IsType<LoDTensorArray>();
bool d_out_is_array = d_out_var->IsType<LoDTensorArray>();
if (d_input_is_array) {
auto* input_array = ctx.Input<LoDTensorArray>("Input");
auto* d_in_arr =
ctx.Output<LoDTensorArray>(framework::GradVarName("Input"));
int64_t d_in_size = input_array->size();
d_input_array->resize(d_in_size);
d_in_arr->resize(d_in_size);
// If the input is LoDTensorArray, the rank of input is 1.
// So only use the 0th element of starts.
int64_t start = starts[0] < 0 ? (starts[0] + d_in_size) : starts[0];
......@@ -364,68 +294,60 @@ class SliceGradKernel : public framework::OpKernel<T> {
// set zero
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(context.GetPlace());
T value = T(0);
auto& dev_ctx = *pool.Get(ctx.GetPlace());
math::SetConstant<DeviceContext, T> functor;
for (int i = 0; i < d_in_size; ++i) {
auto dim = input_array->at(i).dims();
d_input_array->at(i).Resize(dim);
d_input_array->at(i).mutable_data<T>(context.GetPlace());
d_in_arr->at(i).Resize(dim);
d_in_arr->at(i).mutable_data<T>(ctx.GetPlace());
functor(reinterpret_cast<const DeviceContext&>(dev_ctx),
&d_input_array->at(i), static_cast<T>(value));
&d_in_arr->at(i), static_cast<T>(0));
}
if (d_out_is_tensor_array) {
auto* d_out_array = context.Input<framework::LoDTensorArray>(
framework::GradVarName("Out"));
int d_out_size = d_out_array->size();
if (d_out_is_array) {
auto* d_out_arr =
ctx.Input<LoDTensorArray>(framework::GradVarName("Out"));
int d_out_size = d_out_arr->size();
for (int i = 0; i < d_out_size; ++i) {
TensorCopy(d_out_array->at(i), context.GetPlace(),
&(d_input_array->at(start + i)));
TensorCopy(d_out_arr->at(i), ctx.GetPlace(),
&(d_in_arr->at(start + i)));
}
} else {
auto* d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
TensorCopy(*d_out, context.GetPlace(), &(d_input_array->at(start)));
auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
TensorCopy(*d_out, ctx.GetPlace(), &(d_in_arr->at(start)));
}
return;
}
auto* d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_input =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
d_input->mutable_data<T>(context.GetPlace());
auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* d_input = ctx.Output<Tensor>(framework::GradVarName("Input"));
d_input->mutable_data<T>(ctx.GetPlace());
auto out_dims = d_out->dims();
auto in_dims = d_input->dims();
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis");
if (decrease_axis.size() > 0) {
if (decrease_axis.size() == (size_t)in_dims.size()) {
auto decrease_axis = ctx.Attr<std::vector<int>>("decrease_axis");
auto decrease_size = decrease_axis.size();
if (decrease_size > 0) {
if (decrease_size == (size_t)in_dims.size()) {
// all dims decrease
std::vector<int> vec_origin_out_shape(decrease_axis.size(), 1);
out_dims = framework::make_ddim(vec_origin_out_shape);
std::vector<int> origin_out_shape(decrease_size, 1);
out_dims = framework::make_ddim(std::vector<int>(decrease_size, 1));
} else {
std::vector<int> vec_origin_out_shape(
out_dims.size() + decrease_axis.size(), -1);
for (size_t i = 0; i < decrease_axis.size(); ++i) {
vec_origin_out_shape[decrease_axis[i]] = 1;
std::vector<int> origin_out_shape(out_dims.size() + decrease_size, -1);
for (size_t i = 0; i < decrease_size; ++i) {
origin_out_shape[decrease_axis[i]] = 1;
}
int index = 0;
for (size_t i = 0; i < vec_origin_out_shape.size(); ++i) {
if (vec_origin_out_shape[i] == -1) {
vec_origin_out_shape[i] = out_dims[index];
for (size_t i = 0; i < origin_out_shape.size(); ++i) {
if (origin_out_shape[i] == -1) {
origin_out_shape[i] = out_dims[index];
++index;
}
}
out_dims = framework::make_ddim(vec_origin_out_shape);
out_dims = framework::make_ddim(origin_out_shape);
}
}
......@@ -435,28 +357,26 @@ class SliceGradKernel : public framework::OpKernel<T> {
offsets[i] = 0;
extents[i] = out_dims[i];
}
int64_t start;
for (size_t i = 0; i < axes.size(); ++i) {
start = starts[i];
if (start < 0) {
start = (start + in_dims[axes[i]]);
}
int axis = axes[i];
int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i];
start = std::max(start, static_cast<int64_t>(0));
offsets[axes[i]] = start;
offsets[axis] = start;
}
Eigen::array<std::pair<int64_t, int64_t>, D> paddings;
for (size_t i = 0; i < paddings.size(); ++i) {
paddings[i].first = offsets[i];
paddings[i].second = (in_dims[i] - out_dims[i]) - offsets[i];
}
EigenPaddingCompute(context, d_input, in_dims, d_out, out_dims, paddings);
EigenPaddingCompute(ctx, d_input, in_dims, d_out, out_dims, paddings);
}
template <size_t D>
void EigenPaddingCompute(
const framework::ExecutionContext& context, framework::Tensor* d_input,
const framework::DDim& in_dims, const framework::Tensor* d_out,
const framework::DDim& out_dims,
const framework::ExecutionContext& context, Tensor* d_input,
const DDim& in_dims, const Tensor* d_out, const DDim& out_dims,
const Eigen::array<std::pair<int64_t, int64_t>, D>& paddings) const {
if (D <= 3) {
// if dimension less than 3, cannot reduce dimension
......@@ -512,10 +432,8 @@ class SliceGradKernel : public framework::OpKernel<T> {
out_tore_shape[1] = out_dims[pad_dim];
// convert array from std::vector to DDim
framework::DDim reshaped_in_dims =
framework::make_ddim(in_tore_shape);
framework::DDim reshaped_out_dims =
framework::make_ddim(out_tore_shape);
DDim reshaped_in_dims = framework::make_ddim(in_tore_shape);
DDim reshaped_out_dims = framework::make_ddim(out_tore_shape);
// after reshape: the first dimension do not need padding,
// set padding[0] zero
......@@ -543,10 +461,8 @@ class SliceGradKernel : public framework::OpKernel<T> {
}
// convert array from std::vector to DDim
framework::DDim reshaped_in_dims =
framework::make_ddim(in_tore_shape);
framework::DDim reshaped_out_dims =
framework::make_ddim(out_tore_shape);
DDim reshaped_in_dims = framework::make_ddim(in_tore_shape);
DDim reshaped_out_dims = framework::make_ddim(out_tore_shape);
// after reshape:
// the first dimension is the previous padding dimension
......@@ -579,10 +495,8 @@ class SliceGradKernel : public framework::OpKernel<T> {
}
// convert array from std::vector to DDim
framework::DDim reshaped_in_dims =
framework::make_ddim(in_tore_shape);
framework::DDim reshaped_out_dims =
framework::make_ddim(out_tore_shape);
DDim reshaped_in_dims = framework::make_ddim(in_tore_shape);
DDim reshaped_out_dims = framework::make_ddim(out_tore_shape);
// after reshape:
// the first dimension do not need padding, set padding[0] zero
......@@ -606,9 +520,8 @@ class SliceGradKernel : public framework::OpKernel<T> {
template <size_t D>
void LaunchEigenPadding(
const framework::ExecutionContext& context, framework::Tensor* d_input,
const framework::DDim& in_dims, const framework::Tensor* d_out,
const framework::DDim& out_dims,
const framework::ExecutionContext& context, Tensor* d_input,
const DDim& in_dims, const Tensor* d_out, const DDim& out_dims,
const Eigen::array<std::pair<int64_t, int64_t>, D>& paddings) const {
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <paddle/fluid/framework/operator.h>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T = int64_t>
inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
const std::vector<T>& axes,
std::vector<T>* starts,
std::vector<T>* ends,
std::vector<int64_t>* steps = nullptr,
std::vector<T>* infer_flags = nullptr) {
for (size_t i = 0; i < axes.size(); ++i) {
T axis = axes[i];
T dim_value = in_dims[axis];
if (dim_value > 0) {
if (infer_flags != nullptr && (*infer_flags)[i] == -1) {
continue;
}
T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
start = std::max(start, static_cast<T>(0));
T end = (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i];
end = std::min(end, dim_value);
T step = steps == nullptr ? 1 : (*steps)[i];
PADDLE_ENFORCE_NE(
step, 0, platform::errors::InvalidArgument(
"Step should not be 0, but received step = %d.", step));
if (step > 0) {
start = std::min(start, dim_value);
end = std::max(end, static_cast<T>(0));
PADDLE_ENFORCE_GT(
end, start,
platform::errors::InvalidArgument(
"When step > 0, end should be greater than start, but "
"received end = %d, start = %d.",
end, start));
} else {
// NOTE(liym27): When step < 0, start should less and equal to
// dim_value-1
// "end is -1" means contain the 0-th element of this axis.
start = std::min(start, dim_value - 1);
end = std::max(end, static_cast<T>(-1));
PADDLE_ENFORCE_GT(
start, end,
platform::errors::InvalidArgument(
"When step < 0, start should be greater than end, but "
"received start = %d, end = %d.",
start, end));
}
(*starts)[i] = start;
(*ends)[i] = end;
}
}
}
template <typename T = int64_t>
inline framework::DDim GetSliceDims(const framework::DDim in_dims,
const std::vector<T>& axes,
const std::vector<T>& starts,
const std::vector<T>& ends,
std::vector<T>* steps = nullptr,
std::vector<T>* infer_flags = nullptr) {
framework::DDim slice_dims(in_dims);
for (size_t i = 0; i < axes.size(); ++i) {
T axis = axes[i];
if (infer_flags != nullptr && (*infer_flags)[i] == -1) {
slice_dims[axis] = -1;
continue;
}
T start = starts[i];
T end = ends[i];
T step = steps == nullptr ? 1 : (*steps)[i];
if (step > 0) {
slice_dims[axis] = (end - start + step - 1) / step;
} else {
slice_dims[axis] = (end - start + step + 1) / step;
}
}
return slice_dims;
}
template <typename T = int64_t>
inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
const std::vector<T>& decrease_axes,
std::vector<T>* infer_flags = nullptr) {
framework::DDim decreased_dims(slice_dims);
if (decrease_axes.size() > 0) {
for (size_t i = 0; i < decrease_axes.size(); ++i) {
T axis = decrease_axes[i];
if (infer_flags && (*infer_flags)[i] != -1) {
PADDLE_ENFORCE_EQ(
decreased_dims[axis], 1,
platform::errors::InvalidArgument("decrease dim should be 1"));
}
decreased_dims[axis] = 0;
}
std::vector<T> new_shape;
for (int i = 0; i < decreased_dims.size(); ++i) {
if (decreased_dims[i] != 0) {
new_shape.push_back(decreased_dims[i]);
}
}
// NOTE(liym27): Paddle does not support that the rank of Tensor is 0, and
// uses [1] instead.
if (new_shape.size() == 0) {
new_shape.push_back(1);
}
decreased_dims = framework::make_ddim(new_shape);
}
return decreased_dims;
}
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册