未验证 提交 4cf01462 编写于 作者: L liym27 提交者: GitHub

Polish code for slice and set_value op (#32947)

上级 a039fd7b
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/assign_value_op.h" #include "paddle/fluid/operators/assign_value_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/slice_utils.h"
#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -59,106 +60,6 @@ inline std::string GetValueName(framework::proto::VarType::Type data_type) { ...@@ -59,106 +60,6 @@ inline std::string GetValueName(framework::proto::VarType::Type data_type) {
return value_name; return value_name;
} }
inline void CheckAndUpdateSlice(const framework::DDim in_dims,
const std::vector<int64_t> axes,
std::vector<int64_t>* starts,
std::vector<int64_t>* ends,
std::vector<int64_t>* steps) {
for (size_t i = 0; i < axes.size(); ++i) {
int64_t axis = axes[i];
int64_t dim_value = in_dims[axis];
int64_t start =
(*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
int64_t end = (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i];
start = std::max(start, static_cast<int64_t>(0));
end = std::min(end, dim_value);
int64_t step = (*steps)[i];
PADDLE_ENFORCE_NE(
step, 0, platform::errors::InvalidArgument(
"Step should not be 0, but received step = %d.", step));
if (step > 0) {
start = std::min(start, dim_value);
end = std::max(end, static_cast<int64_t>(0));
PADDLE_ENFORCE_GT(
end, start,
platform::errors::InvalidArgument(
"When step > 0, end should be greater than start, but "
"received end = %d, start = %d.",
end, start));
} else {
// NOTE(liym27): When step < 0, start should less and equal to dim_value-1
// "end is -1" means contain the 0-th element of this axis.
start = std::min(start, dim_value - 1);
end = std::max(end, static_cast<int64_t>(-1));
PADDLE_ENFORCE_GT(
start, end,
platform::errors::InvalidArgument(
"When step < 0, start should be greater than end, but "
"received start = %d, end = %d.",
start, end));
}
(*starts)[i] = start;
(*ends)[i] = end;
}
}
inline framework::DDim GetSliceDims(const framework::DDim in_dims,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& steps) {
framework::DDim slice_dims(in_dims);
for (size_t i = 0; i < axes.size(); ++i) {
int64_t axis = axes[i];
int64_t start = starts[i];
int64_t end = ends[i];
int64_t step = steps[i];
if (step > 0) {
slice_dims[axis] = (end - start + step - 1) / step;
} else {
slice_dims[axis] = (end - start + step + 1) / step;
}
}
return slice_dims;
}
inline framework::DDim GetDecreasedDims(
const framework::DDim slice_dims,
const std::vector<int64_t>& decrease_axes) {
// Get dims after decreasing axes.
framework::DDim decreased_dims(slice_dims);
if (decrease_axes.size() > 0) {
for (size_t i = 0; i < decrease_axes.size(); ++i) {
int64_t axis = decrease_axes[i];
PADDLE_ENFORCE_EQ(
decreased_dims[axis], 1,
platform::errors::InvalidArgument("decrease dim should be 1"));
decreased_dims[axis] = 0;
}
std::vector<int64_t> new_shape;
for (int i = 0; i < decreased_dims.size(); ++i) {
if (decreased_dims[i] != 0) {
new_shape.push_back(decreased_dims[i]);
}
}
// NOTE(liym27): Paddle does not support that the rank of Tensor is 0, and
// uses [1] instead.
if (new_shape.size() == 0) {
new_shape.push_back(1);
}
decreased_dims = framework::make_ddim(new_shape);
}
return decreased_dims;
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SetValueKernel : public framework::OpKernel<T> { class SetValueKernel : public framework::OpKernel<T> {
public: public:
...@@ -225,8 +126,8 @@ class SetValueKernel : public framework::OpKernel<T> { ...@@ -225,8 +126,8 @@ class SetValueKernel : public framework::OpKernel<T> {
} }
auto in_dims = in->dims(); auto in_dims = in->dims();
CheckAndUpdateSlice(in_dims, axes, &starts, &ends, &steps); CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends, &steps);
auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, steps); auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, &steps);
auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes); auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
......
...@@ -28,13 +28,10 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -28,13 +28,10 @@ class SliceOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "slice");
platform::errors::InvalidArgument( OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "slice");
"Input (Input) of slice op should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, // Case 1: Special treatment when input is a tensor array.
platform::errors::InvalidArgument(
"Output (Out) of slice op should not be null."));
auto x_var_type = ctx->GetInputsVarType("Input")[0]; auto x_var_type = ctx->GetInputsVarType("Input")[0];
auto axes = ctx->Attrs().Get<std::vector<int>>("axes"); auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
if (x_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) { if (x_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
...@@ -57,6 +54,8 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -57,6 +54,8 @@ class SliceOp : public framework::OperatorWithKernel {
return; return;
} }
} }
// Case 2: input is a tensor.
auto in_dims = ctx->GetInputDim("Input"); auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_LT(in_dims.size(), 7, PADDLE_ENFORCE_LT(in_dims.size(), 7,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -65,101 +64,54 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -65,101 +64,54 @@ class SliceOp : public framework::OperatorWithKernel {
auto starts = ctx->Attrs().Get<std::vector<int>>("starts"); auto starts = ctx->Attrs().Get<std::vector<int>>("starts");
auto ends = ctx->Attrs().Get<std::vector<int>>("ends"); auto ends = ctx->Attrs().Get<std::vector<int>>("ends");
auto infer_flags = ctx->Attrs().Get<std::vector<int>>("infer_flags");
auto decrease_axis = ctx->Attrs().Get<std::vector<int>>("decrease_axis"); auto decrease_axis = ctx->Attrs().Get<std::vector<int>>("decrease_axis");
auto infer_flags = ctx->Attrs().Get<std::vector<int>>("infer_flags");
auto starts_size = starts.size();
auto ends_size = ends.size();
if (infer_flags.empty()) { if (infer_flags.empty()) {
// Initialize infer_flags with 1. // Initialize infer_flags with 1.
// To be compatible with other op tests in which infer_flags is not set. // To be compatible with other op tests in which infer_flags is not set.
infer_flags = std::vector<int>(axes.size(), 1); infer_flags = std::vector<int>(axes.size(), 1);
} }
// 2.1 Check attrs.
auto starts_size = starts.size();
auto ends_size = ends.size();
if (ctx->HasInputs("StartsTensorList")) { if (ctx->HasInputs("StartsTensorList")) {
auto StartsTensorList = ctx->Inputs("StartsTensorList"); starts_size = ctx->Inputs("StartsTensorList").size();
PADDLE_ENFORCE_GT(StartsTensorList.size(), 0, PADDLE_ENFORCE_GT(starts_size, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"StartsTensorList size can't be zero")); "StartsTensorList size can't be zero"));
starts_size = StartsTensorList.size();
} }
if (ctx->HasInputs("EndsTensorList")) { if (ctx->HasInputs("EndsTensorList")) {
auto EndsTensorList = ctx->Inputs("EndsTensorList"); ends_size = ctx->Inputs("EndsTensorList").size();
PADDLE_ENFORCE_GT(EndsTensorList.size(), 0, PADDLE_ENFORCE_GT(ends_size, 0, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "EndsTensorList size can't be zero"));
"EndsTensorList size can't be zero"));
ends_size = EndsTensorList.size();
} }
if (ctx->HasInput("StartsTensor") == false) { if (!ctx->HasInput("StartsTensor")) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
starts_size, axes.size(), starts_size, axes.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The size of starts must be equal to the size of axes.")); "The size of starts must be equal to the size of axes."));
} }
if (ctx->HasInput("EndsTensor") == false) { if (!ctx->HasInput("EndsTensor")) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ends_size, axes.size(), ends_size, axes.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The size of ends must be equal to the size of axes.")); "The size of ends must be equal to the size of axes."));
} }
int dim_value, start, end; CheckAndUpdateSliceAttrs<int>(in_dims, axes, &starts, &ends, nullptr,
for (size_t i = 0; i < axes.size(); ++i) { &infer_flags);
PADDLE_ENFORCE_LT(static_cast<int>(axes[i]), in_dims.size(),
platform::errors::InvalidArgument(
"The index of dimension in axes must be less "
"than the size of input shape."));
if (infer_flags[i] == -1) {
out_dims[axes[i]] = -1;
} else {
// infer out_dim shape
dim_value = out_dims[axes[i]];
if (dim_value > 0) {
start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
start = std::max(start, 0);
end = std::max(end, 0);
end = std::min(end, dim_value);
PADDLE_ENFORCE_LE(start, dim_value,
platform::errors::InvalidArgument(
"start should be less than or equal to the "
"dimension value, but received "
"start = %d, shape[%d] = %d.",
starts[i], axes[i], out_dims[axes[i]]));
PADDLE_ENFORCE_GT(end, start,
platform::errors::InvalidArgument(
"end should greater than start, but received "
"end = %d, start = %d.",
ends[i], starts[i]));
out_dims[axes[i]] = end - start;
}
}
}
// generate new shape
if (decrease_axis.size() > 0) {
std::vector<int> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
if (ctx->IsRuntime() && infer_flags[i] != -1) {
PADDLE_ENFORCE_EQ(
out_dims[decrease_axis[i]], 1,
platform::errors::InvalidArgument("decrease dim should be 1"));
}
out_dims[decrease_axis[i]] = 0;
}
for (int i = 0; i < out_dims.size(); ++i) { auto slice_dims =
if (out_dims[i] != 0) { GetSliceDims<int>(in_dims, axes, starts, ends, nullptr, &infer_flags);
new_out_shape.push_back(out_dims[i]); if (ctx->IsRuntime()) {
} out_dims = GetDecreasedDims<int>(slice_dims, decrease_axis, &infer_flags);
} } else {
if (new_out_shape.size() == 0) { out_dims = GetDecreasedDims<int>(slice_dims, decrease_axis, nullptr);
new_out_shape.push_back(1);
}
out_dims = framework::make_ddim(new_out_shape);
} }
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
if (axes[0] != 0) { if (axes[0] != 0) {
ctx->ShareLoD("Input", /*->*/ "Out"); ctx->ShareLoD("Input", /*->*/ "Out");
...@@ -185,6 +137,7 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -185,6 +137,7 @@ class SliceOp : public framework::OperatorWithKernel {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor, const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override { const framework::OpKernelType &expected_kernel_type) const override {
......
...@@ -19,21 +19,67 @@ limitations under the License. */ ...@@ -19,21 +19,67 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/slice_utils.h"
#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/operators/utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using Variable = framework::Variable;
using LoDTensorArray = framework::LoDTensorArray;
using DDim = framework::DDim;
inline void DealTensorArray(const framework::ExecutionContext& ctx,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
bool out_is_array) {
auto in_array = ctx.Input<LoDTensorArray>("Input");
// If the input is LoDTensorArray, the rank of input is 1.
int64_t in_size = in_array->size();
int64_t start = starts[0] < 0 ? (starts[0] + in_size) : starts[0];
int64_t end = ends[0] < 0 ? (ends[0] + in_size) : ends[0];
start = std::max(start, static_cast<int64_t>(0));
end = std::max(end, static_cast<int64_t>(0));
end = std::min(end, in_size);
PADDLE_ENFORCE_GT(end, start,
platform::errors::InvalidArgument(
"Attr(ends) should be greater than attr(starts) in "
"slice op. But received end = %d, start = %d.",
ends[0], starts[0]));
int64_t out_size = end - start;
if (out_is_array) {
auto out_array = ctx.Output<LoDTensorArray>("Out");
out_array->resize(out_size);
for (int i = 0; i < out_size; ++i) {
auto* out_tensor = &out_array->at(i);
auto in_tensor = in_array->at(i + start);
out_tensor->set_lod(in_tensor.lod());
if (in_tensor.memory_size() > 0) {
TensorCopy(in_tensor, ctx.GetPlace(), out_tensor);
} else {
VLOG(10) << "WARNING: The input tensor 'x_tensor' holds no memory, so "
"nothing has been written to output array["
<< i << "].";
}
}
} else {
auto out = ctx.Output<Tensor>("Out");
auto in_tensor = in_array->at(start);
TensorCopy(in_tensor, ctx.GetPlace(), out);
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SliceKernel : public framework::OpKernel<T> { class SliceKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Variable* input_var = ctx.InputVar("Input"); const Variable* input_var = ctx.InputVar("Input");
bool is_tensor_array = input_var->IsType<framework::LoDTensorArray>(); bool is_tensor_array = input_var->IsType<LoDTensorArray>();
int rank = is_tensor_array int rank = is_tensor_array ? 1 : ctx.Input<Tensor>("Input")->dims().size();
? 1
: ctx.Input<framework::Tensor>("Input")->dims().size();
switch (rank) { switch (rank) {
case 1: case 1:
...@@ -54,53 +100,45 @@ class SliceKernel : public framework::OpKernel<T> { ...@@ -54,53 +100,45 @@ class SliceKernel : public framework::OpKernel<T> {
case 6: case 6:
SliceCompute<6>(ctx); SliceCompute<6>(ctx);
break; break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", rank));
} }
} }
private: private:
template <size_t D> template <size_t D>
void SliceCompute(const framework::ExecutionContext& context) const { void SliceCompute(const framework::ExecutionContext& ctx) const {
auto& place = const Variable* input_var = ctx.InputVar("Input");
*context.template device_context<DeviceContext>().eigen_device(); Variable* out_var = ctx.OutputVar("Out");
const framework::Variable* input_var = context.InputVar("Input"); bool input_is_array = input_var->IsType<LoDTensorArray>();
framework::Variable* out_var = context.OutputVar("Out"); bool out_is_array = out_var->IsType<LoDTensorArray>();
bool input_is_tensor_array = input_var->IsType<framework::LoDTensorArray>();
bool out_is_tensor_array = out_var->IsType<framework::LoDTensorArray>(); auto axes_int = ctx.Attr<std::vector<int>>("axes");
auto starts_int = ctx.Attr<std::vector<int>>("starts");
auto axes = context.Attr<std::vector<int>>("axes"); auto ends_int = ctx.Attr<std::vector<int>>("ends");
std::vector<int64_t> axes(axes_int.begin(), axes_int.end());
auto starts_int = context.Attr<std::vector<int>>("starts");
std::vector<int64_t> starts(starts_int.begin(), starts_int.end()); std::vector<int64_t> starts(starts_int.begin(), starts_int.end());
auto ends_int = context.Attr<std::vector<int>>("ends");
std::vector<int64_t> ends(ends_int.begin(), ends_int.end()); std::vector<int64_t> ends(ends_int.begin(), ends_int.end());
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis");
auto infer_flags = context.Attr<std::vector<int>>("infer_flags"); auto decrease_axis = ctx.Attr<std::vector<int>>("decrease_axis");
auto list_new_ends_tensor = auto infer_flags = ctx.Attr<std::vector<int>>("infer_flags");
context.MultiInput<framework::Tensor>("EndsTensorList");
auto list_new_starts_tensor = // Step 1: Get the accurate attribute value of starts and ends
context.MultiInput<framework::Tensor>("StartsTensorList"); auto starts_tensor_list = ctx.MultiInput<Tensor>("StartsTensorList");
if (ctx.HasInput("StartsTensor")) {
bool need_infer = false; starts = GetDataFromTensor<int64_t>(ctx.Input<Tensor>("StartsTensor"));
if (context.HasInput("StartsTensor") || context.HasInput("EndsTensor")) { } else if (starts_tensor_list.size() > 0) {
need_infer = true; starts = GetDataFromTensorList<int64_t>(starts_tensor_list);
}
if (list_new_starts_tensor.size() > 0 || list_new_ends_tensor.size() > 0) {
need_infer = true;
} }
if (need_infer) {
if (context.HasInput("StartsTensor")) { auto ends_tensor_list = ctx.MultiInput<Tensor>("EndsTensorList");
auto* starts_tensor = context.Input<framework::Tensor>("StartsTensor"); if (ctx.HasInput("EndsTensor")) {
starts = GetDataFromTensor<int64_t>(starts_tensor); ends = GetDataFromTensor<int64_t>(ctx.Input<Tensor>("EndsTensor"));
} else if (list_new_starts_tensor.size() > 0) { } else if (ends_tensor_list.size() > 0) {
starts = GetDataFromTensorList<int64_t>(list_new_starts_tensor); ends = GetDataFromTensorList<int64_t>(ends_tensor_list);
}
if (context.HasInput("EndsTensor")) {
auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor");
ends = GetDataFromTensor<int64_t>(ends_tensor);
} else if (list_new_ends_tensor.size() > 0) {
ends = GetDataFromTensorList<int64_t>(list_new_ends_tensor);
}
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
starts.size(), axes.size(), starts.size(), axes.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -109,175 +147,74 @@ class SliceKernel : public framework::OpKernel<T> { ...@@ -109,175 +147,74 @@ class SliceKernel : public framework::OpKernel<T> {
ends.size(), axes.size(), ends.size(), axes.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The size of ends must be equal to the size of axes.")); "The size of ends must be equal to the size of axes."));
if (input_is_tensor_array) {
auto in_array = context.Input<framework::LoDTensorArray>("Input");
// If the input is LoDTensorArray, the rank of input is 1.
int64_t in_size = in_array->size();
int64_t start = starts[0] < 0 ? (starts[0] + in_size) : starts[0];
int64_t end = ends[0] < 0 ? (ends[0] + in_size) : ends[0];
start = std::max(start, static_cast<int64_t>(0));
end = std::max(end, static_cast<int64_t>(0));
end = std::min(end, in_size);
PADDLE_ENFORCE_GT(end, start,
platform::errors::InvalidArgument(
"Attr(ends) should be greater than attr(starts) in "
"slice op. But received end = %d, start = %d.",
ends[0], starts[0]));
int64_t out_size = end - start;
if (out_is_tensor_array) {
auto out_array = context.Output<framework::LoDTensorArray>("Out");
out_array->resize(out_size);
for (int i = 0; i < out_size; ++i) {
auto* out_tensor = &out_array->at(i);
auto in_tensor = in_array->at(i + start);
out_tensor->set_lod(in_tensor.lod());
if (in_tensor.memory_size() > 0) {
TensorCopy(in_tensor, context.GetPlace(), out_tensor);
} else {
VLOG(10)
<< "WARNING: The input tensor 'x_tensor' holds no memory, so "
"nothing has been written to output array["
<< i << "].";
}
}
} else {
auto out = context.Output<framework::Tensor>("Out");
auto in_tensor = in_array->at(start);
TensorCopy(in_tensor, context.GetPlace(), out);
}
// Step 2: Compute output
if (input_is_array) {
DealTensorArray(ctx, starts, ends, out_is_array);
return; return;
} } else {
auto in = ctx.Input<Tensor>("Input");
auto out = ctx.Output<Tensor>("Out");
auto in = context.Input<framework::Tensor>("Input"); auto in_dims = in->dims();
auto out = context.Output<framework::Tensor>("Out"); auto out_dims = out->dims();
auto slice_dims = out_dims;
auto out_dims = out->dims(); // 2.1 Infer output dims
auto in_dims = in->dims();
if (need_infer) {
out_dims = in_dims;
int64_t dim_value, start, end;
for (size_t i = 0; i < axes.size(); ++i) { for (size_t i = 0; i < axes.size(); ++i) {
dim_value = out_dims[axes[i]]; // when start == -1 && end == start+1
if (dim_value > 0) { if (starts[i] == -1 && ends[i] == 0 && infer_flags[i] == -1) {
// when end = start+1 and start == -1 auto ret =
if (starts[i] == -1 && ends[i] == 0 && infer_flags[i] == -1) { std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]);
auto ret = if (ret != decrease_axis.end()) {
std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]); ends[i] = in_dims[axes[i]];
if (ret != decrease_axis.end()) {
ends[i] = 10000000;
}
}
start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
start = std::max(start, static_cast<int64_t>(0));
end = std::max(end, static_cast<int64_t>(0));
end = std::min(end, dim_value);
PADDLE_ENFORCE_GT(
end, start,
platform::errors::InvalidArgument(
"Attr(ends) should be greater than attr(starts) in "
"slice op. But received end = %d, start = %d.",
ends[i], starts[i]));
out_dims[axes[i]] = end - start;
}
}
out->Resize(out_dims);
// generate new shape
if (decrease_axis.size() > 0) {
std::vector<int64_t> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
PADDLE_ENFORCE_EQ(
out_dims[decrease_axis[i]], 1,
platform::errors::InvalidArgument("decrease dim should be 1"));
out_dims[decrease_axis[i]] = 0;
}
for (int i = 0; i < out_dims.size(); ++i) {
if (out_dims[i] != 0) {
new_out_shape.push_back(out_dims[i]);
} }
} }
if (new_out_shape.size() == 0) {
new_out_shape.push_back(1);
}
out_dims = framework::make_ddim(new_out_shape);
} }
}
// resize out_dims CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends);
if (decrease_axis.size() > 0) { slice_dims =
if (decrease_axis.size() == (size_t)in_dims.size()) { GetSliceDims<int64_t>(in_dims, axes, starts, ends, nullptr, nullptr);
std::vector<int> vec_origin_out_shape(decrease_axis.size(), 1); out_dims = GetDecreasedDims(slice_dims, decrease_axis);
out->Resize(framework::make_ddim(vec_origin_out_shape));
} else {
std::vector<int> vec_origin_out_shape(
out_dims.size() + decrease_axis.size(), -1);
for (size_t i = 0; i < decrease_axis.size(); ++i) { // 2.2 Get output
vec_origin_out_shape[decrease_axis[i]] = 1; auto offsets = Eigen::DSizes<Eigen::DenseIndex, D>();
} auto extents = Eigen::DSizes<Eigen::DenseIndex, D>();
int index = 0; for (size_t i = 0; i < D; ++i) {
for (size_t i = 0; i < vec_origin_out_shape.size(); ++i) { offsets[i] = 0;
if (vec_origin_out_shape[i] == -1) { extents[i] = slice_dims[i];
vec_origin_out_shape[i] = out_dims[index];
++index;
}
}
out->Resize(framework::make_ddim(vec_origin_out_shape));
} }
} for (size_t i = 0; i < axes.size(); ++i) {
offsets[axes[i]] = starts[i];
out->mutable_data<T>(context.GetPlace());
auto new_out_dims = out->dims();
auto offsets = Eigen::DSizes<Eigen::DenseIndex, D>();
auto extents = Eigen::DSizes<Eigen::DenseIndex, D>();
for (size_t i = 0; i < D; ++i) {
offsets[i] = 0;
extents[i] = new_out_dims[i];
}
int64_t start;
for (size_t i = 0; i < axes.size(); ++i) {
start = starts[i];
if (start < 0) {
start = (start + in_dims[axes[i]]);
} }
start = std::max(start, static_cast<int64_t>(0));
offsets[axes[i]] = start;
}
auto in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*in);
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*out, new_out_dims);
if (in->numel() <= Eigen::NumTraits<int>::highest()) { out->Resize(slice_dims);
// similar to tf.slice: out->mutable_data<T>(ctx.GetPlace());
// if element number less than INT_MAX, change the type of index to int
Eigen::DSizes<int, D> offsets_32bit, extents_32bit; auto in_t = framework::EigenTensor<T, D>::From(*in, in_dims);
for (size_t i = 0; i < D; i++) { auto out_t = framework::EigenTensor<T, D>::From(*out, slice_dims);
offsets_32bit[i] = offsets[i]; auto& eigen_place =
extents_32bit[i] = extents[i]; *ctx.template device_context<DeviceContext>().eigen_device();
if (in->numel() <= Eigen::NumTraits<int>::highest()) {
// similar to tf.slice:
// if element number less than INT_MAX, change the type of index to int
Eigen::DSizes<int, D> offsets_32bit, extents_32bit;
for (size_t i = 0; i < D; i++) {
offsets_32bit[i] = offsets[i];
extents_32bit[i] = extents[i];
}
EigenSlice<std::decay_t<decltype(eigen_place)>, T, D>::Eval(
eigen_place, framework::To32BitIndex(out_t),
framework::To32BitIndex(in_t), offsets_32bit, extents_32bit);
} else {
EigenSlice<std::decay_t<decltype(eigen_place)>, T, D>::Eval(
eigen_place, out_t, in_t, offsets, extents);
} }
EigenSlice<std::decay_t<decltype(place)>, T, D>::Eval(
place, framework::To32BitIndex(out_t), framework::To32BitIndex(in_t),
offsets_32bit, extents_32bit);
} else {
EigenSlice<std::decay_t<decltype(place)>, T, D>::Eval(place, out_t, in_t,
offsets, extents);
}
out->Resize(out_dims); out->Resize(out_dims);
}
} }
}; };
...@@ -285,11 +222,9 @@ template <typename DeviceContext, typename T> ...@@ -285,11 +222,9 @@ template <typename DeviceContext, typename T>
class SliceGradKernel : public framework::OpKernel<T> { class SliceGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Variable* input_var = ctx.InputVar("Input"); const Variable* input_var = ctx.InputVar("Input");
bool is_tensor_array = input_var->IsType<framework::LoDTensorArray>(); bool is_array = input_var->IsType<LoDTensorArray>();
size_t rank = is_tensor_array size_t rank = is_array ? 1 : ctx.Input<Tensor>("Input")->dims().size();
? 1
: ctx.Input<framework::Tensor>("Input")->dims().size();
switch (rank) { switch (rank) {
case 1: case 1:
...@@ -310,53 +245,48 @@ class SliceGradKernel : public framework::OpKernel<T> { ...@@ -310,53 +245,48 @@ class SliceGradKernel : public framework::OpKernel<T> {
case 6: case 6:
SliceCompute<6>(ctx); SliceCompute<6>(ctx);
break; break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", rank));
} }
} }
private: private:
template <size_t D> template <size_t D>
void SliceCompute(const framework::ExecutionContext& context) const { void SliceCompute(const framework::ExecutionContext& ctx) const {
auto axes = context.Attr<std::vector<int>>("axes"); auto axes = ctx.Attr<std::vector<int>>("axes");
auto starts_int = ctx.Attr<std::vector<int>>("starts");
auto starts_int = context.Attr<std::vector<int>>("starts"); auto ends_int = ctx.Attr<std::vector<int>>("ends");
std::vector<int64_t> starts(starts_int.begin(), starts_int.end()); std::vector<int64_t> starts(starts_int.begin(), starts_int.end());
auto ends_int = context.Attr<std::vector<int>>("ends");
std::vector<int64_t> ends(ends_int.begin(), ends_int.end()); std::vector<int64_t> ends(ends_int.begin(), ends_int.end());
auto list_new_ends_tensor = // Get the accurate attribute value of starts and ends
context.MultiInput<framework::Tensor>("EndsTensorList"); auto starts_tensor_list = ctx.MultiInput<Tensor>("StartsTensorList");
auto list_new_starts_tensor = if (ctx.HasInput("StartsTensor")) {
context.MultiInput<framework::Tensor>("StartsTensorList"); starts = GetDataFromTensor<int64_t>(ctx.Input<Tensor>("StartsTensor"));
} else if (starts_tensor_list.size() > 0) {
if (list_new_starts_tensor.size() > 0) { starts = GetDataFromTensorList<int64_t>(starts_tensor_list);
starts = GetDataFromTensorList<int64_t>(list_new_starts_tensor);
} else if (context.HasInput("StartsTensor")) {
auto* starts_tensor = context.Input<framework::Tensor>("StartsTensor");
starts = GetDataFromTensor<int64_t>(starts_tensor);
} }
if (list_new_ends_tensor.size() > 0) { auto ends_tensor_list = ctx.MultiInput<Tensor>("EndsTensorList");
ends = GetDataFromTensorList<int64_t>(list_new_ends_tensor); if (ctx.HasInput("EndsTensor")) {
} else if (context.HasInput("EndsTensor")) { ends = GetDataFromTensor<int64_t>(ctx.Input<Tensor>("EndsTensor"));
auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor"); } else if (ends_tensor_list.size() > 0) {
ends = GetDataFromTensor<int64_t>(ends_tensor); ends = GetDataFromTensorList<int64_t>(ends_tensor_list);
} }
framework::Variable* d_input_var =
context.OutputVar(framework::GradVarName("Input")); Variable* d_input_var = ctx.OutputVar(framework::GradVarName("Input"));
const framework::Variable* d_out_var = const Variable* d_out_var = ctx.InputVar(framework::GradVarName("Out"));
context.InputVar(framework::GradVarName("Out")); bool d_input_is_array = d_input_var->IsType<LoDTensorArray>();
bool d_input_is_tensor_array = bool d_out_is_array = d_out_var->IsType<LoDTensorArray>();
d_input_var->IsType<framework::LoDTensorArray>();
bool d_out_is_tensor_array = d_out_var->IsType<framework::LoDTensorArray>(); if (d_input_is_array) {
auto* input_array = ctx.Input<LoDTensorArray>("Input");
if (d_input_is_tensor_array) { auto* d_in_arr =
auto* input_array = context.Input<framework::LoDTensorArray>("Input"); ctx.Output<LoDTensorArray>(framework::GradVarName("Input"));
auto* d_input_array = context.Output<framework::LoDTensorArray>(
framework::GradVarName("Input"));
int64_t d_in_size = input_array->size(); int64_t d_in_size = input_array->size();
d_input_array->resize(d_in_size); d_in_arr->resize(d_in_size);
// If the input is LoDTensorArray, the rank of input is 1. // If the input is LoDTensorArray, the rank of input is 1.
// So only use the 0th element of starts. // So only use the 0th element of starts.
int64_t start = starts[0] < 0 ? (starts[0] + d_in_size) : starts[0]; int64_t start = starts[0] < 0 ? (starts[0] + d_in_size) : starts[0];
...@@ -364,68 +294,60 @@ class SliceGradKernel : public framework::OpKernel<T> { ...@@ -364,68 +294,60 @@ class SliceGradKernel : public framework::OpKernel<T> {
// set zero // set zero
platform::DeviceContextPool& pool = platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(context.GetPlace()); auto& dev_ctx = *pool.Get(ctx.GetPlace());
T value = T(0);
math::SetConstant<DeviceContext, T> functor; math::SetConstant<DeviceContext, T> functor;
for (int i = 0; i < d_in_size; ++i) { for (int i = 0; i < d_in_size; ++i) {
auto dim = input_array->at(i).dims(); auto dim = input_array->at(i).dims();
d_input_array->at(i).Resize(dim); d_in_arr->at(i).Resize(dim);
d_input_array->at(i).mutable_data<T>(context.GetPlace()); d_in_arr->at(i).mutable_data<T>(ctx.GetPlace());
functor(reinterpret_cast<const DeviceContext&>(dev_ctx), functor(reinterpret_cast<const DeviceContext&>(dev_ctx),
&d_input_array->at(i), static_cast<T>(value)); &d_in_arr->at(i), static_cast<T>(0));
} }
if (d_out_is_tensor_array) { if (d_out_is_array) {
auto* d_out_array = context.Input<framework::LoDTensorArray>( auto* d_out_arr =
framework::GradVarName("Out")); ctx.Input<LoDTensorArray>(framework::GradVarName("Out"));
int d_out_size = d_out_array->size(); int d_out_size = d_out_arr->size();
for (int i = 0; i < d_out_size; ++i) { for (int i = 0; i < d_out_size; ++i) {
TensorCopy(d_out_array->at(i), context.GetPlace(), TensorCopy(d_out_arr->at(i), ctx.GetPlace(),
&(d_input_array->at(start + i))); &(d_in_arr->at(start + i)));
} }
} else { } else {
auto* d_out = auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
context.Input<framework::Tensor>(framework::GradVarName("Out")); TensorCopy(*d_out, ctx.GetPlace(), &(d_in_arr->at(start)));
TensorCopy(*d_out, context.GetPlace(), &(d_input_array->at(start)));
} }
return; return;
} }
auto* d_out = auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
context.Input<framework::Tensor>(framework::GradVarName("Out")); auto* d_input = ctx.Output<Tensor>(framework::GradVarName("Input"));
d_input->mutable_data<T>(ctx.GetPlace());
auto* d_input =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
d_input->mutable_data<T>(context.GetPlace());
auto out_dims = d_out->dims(); auto out_dims = d_out->dims();
auto in_dims = d_input->dims(); auto in_dims = d_input->dims();
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis"); auto decrease_axis = ctx.Attr<std::vector<int>>("decrease_axis");
if (decrease_axis.size() > 0) { auto decrease_size = decrease_axis.size();
if (decrease_axis.size() == (size_t)in_dims.size()) { if (decrease_size > 0) {
if (decrease_size == (size_t)in_dims.size()) {
// all dims decrease // all dims decrease
std::vector<int> vec_origin_out_shape(decrease_axis.size(), 1); std::vector<int> origin_out_shape(decrease_size, 1);
out_dims = framework::make_ddim(vec_origin_out_shape); out_dims = framework::make_ddim(std::vector<int>(decrease_size, 1));
} else { } else {
std::vector<int> vec_origin_out_shape( std::vector<int> origin_out_shape(out_dims.size() + decrease_size, -1);
out_dims.size() + decrease_axis.size(), -1); for (size_t i = 0; i < decrease_size; ++i) {
origin_out_shape[decrease_axis[i]] = 1;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
vec_origin_out_shape[decrease_axis[i]] = 1;
} }
int index = 0; int index = 0;
for (size_t i = 0; i < vec_origin_out_shape.size(); ++i) { for (size_t i = 0; i < origin_out_shape.size(); ++i) {
if (vec_origin_out_shape[i] == -1) { if (origin_out_shape[i] == -1) {
vec_origin_out_shape[i] = out_dims[index]; origin_out_shape[i] = out_dims[index];
++index; ++index;
} }
} }
out_dims = framework::make_ddim(vec_origin_out_shape); out_dims = framework::make_ddim(origin_out_shape);
} }
} }
...@@ -435,28 +357,26 @@ class SliceGradKernel : public framework::OpKernel<T> { ...@@ -435,28 +357,26 @@ class SliceGradKernel : public framework::OpKernel<T> {
offsets[i] = 0; offsets[i] = 0;
extents[i] = out_dims[i]; extents[i] = out_dims[i];
} }
int64_t start;
for (size_t i = 0; i < axes.size(); ++i) { for (size_t i = 0; i < axes.size(); ++i) {
start = starts[i]; int axis = axes[i];
if (start < 0) { int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i];
start = (start + in_dims[axes[i]]);
}
start = std::max(start, static_cast<int64_t>(0)); start = std::max(start, static_cast<int64_t>(0));
offsets[axes[i]] = start; offsets[axis] = start;
} }
Eigen::array<std::pair<int64_t, int64_t>, D> paddings; Eigen::array<std::pair<int64_t, int64_t>, D> paddings;
for (size_t i = 0; i < paddings.size(); ++i) { for (size_t i = 0; i < paddings.size(); ++i) {
paddings[i].first = offsets[i]; paddings[i].first = offsets[i];
paddings[i].second = (in_dims[i] - out_dims[i]) - offsets[i]; paddings[i].second = (in_dims[i] - out_dims[i]) - offsets[i];
} }
EigenPaddingCompute(context, d_input, in_dims, d_out, out_dims, paddings); EigenPaddingCompute(ctx, d_input, in_dims, d_out, out_dims, paddings);
} }
template <size_t D> template <size_t D>
void EigenPaddingCompute( void EigenPaddingCompute(
const framework::ExecutionContext& context, framework::Tensor* d_input, const framework::ExecutionContext& context, Tensor* d_input,
const framework::DDim& in_dims, const framework::Tensor* d_out, const DDim& in_dims, const Tensor* d_out, const DDim& out_dims,
const framework::DDim& out_dims,
const Eigen::array<std::pair<int64_t, int64_t>, D>& paddings) const { const Eigen::array<std::pair<int64_t, int64_t>, D>& paddings) const {
if (D <= 3) { if (D <= 3) {
// if dimension less than 3, cannot reduce dimension // if dimension less than 3, cannot reduce dimension
...@@ -512,10 +432,8 @@ class SliceGradKernel : public framework::OpKernel<T> { ...@@ -512,10 +432,8 @@ class SliceGradKernel : public framework::OpKernel<T> {
out_tore_shape[1] = out_dims[pad_dim]; out_tore_shape[1] = out_dims[pad_dim];
// convert array from std::vector to DDim // convert array from std::vector to DDim
framework::DDim reshaped_in_dims = DDim reshaped_in_dims = framework::make_ddim(in_tore_shape);
framework::make_ddim(in_tore_shape); DDim reshaped_out_dims = framework::make_ddim(out_tore_shape);
framework::DDim reshaped_out_dims =
framework::make_ddim(out_tore_shape);
// after reshape: the first dimension do not need padding, // after reshape: the first dimension do not need padding,
// set padding[0] zero // set padding[0] zero
...@@ -543,10 +461,8 @@ class SliceGradKernel : public framework::OpKernel<T> { ...@@ -543,10 +461,8 @@ class SliceGradKernel : public framework::OpKernel<T> {
} }
// convert array from std::vector to DDim // convert array from std::vector to DDim
framework::DDim reshaped_in_dims = DDim reshaped_in_dims = framework::make_ddim(in_tore_shape);
framework::make_ddim(in_tore_shape); DDim reshaped_out_dims = framework::make_ddim(out_tore_shape);
framework::DDim reshaped_out_dims =
framework::make_ddim(out_tore_shape);
// after reshape: // after reshape:
// the first dimension is the previous padding dimension // the first dimension is the previous padding dimension
...@@ -579,10 +495,8 @@ class SliceGradKernel : public framework::OpKernel<T> { ...@@ -579,10 +495,8 @@ class SliceGradKernel : public framework::OpKernel<T> {
} }
// convert array from std::vector to DDim // convert array from std::vector to DDim
framework::DDim reshaped_in_dims = DDim reshaped_in_dims = framework::make_ddim(in_tore_shape);
framework::make_ddim(in_tore_shape); DDim reshaped_out_dims = framework::make_ddim(out_tore_shape);
framework::DDim reshaped_out_dims =
framework::make_ddim(out_tore_shape);
// after reshape: // after reshape:
// the first dimension do not need padding, set padding[0] zero // the first dimension do not need padding, set padding[0] zero
...@@ -606,9 +520,8 @@ class SliceGradKernel : public framework::OpKernel<T> { ...@@ -606,9 +520,8 @@ class SliceGradKernel : public framework::OpKernel<T> {
template <size_t D> template <size_t D>
void LaunchEigenPadding( void LaunchEigenPadding(
const framework::ExecutionContext& context, framework::Tensor* d_input, const framework::ExecutionContext& context, Tensor* d_input,
const framework::DDim& in_dims, const framework::Tensor* d_out, const DDim& in_dims, const Tensor* d_out, const DDim& out_dims,
const framework::DDim& out_dims,
const Eigen::array<std::pair<int64_t, int64_t>, D>& paddings) const { const Eigen::array<std::pair<int64_t, int64_t>, D>& paddings) const {
auto& place = auto& place =
*context.template device_context<DeviceContext>().eigen_device(); *context.template device_context<DeviceContext>().eigen_device();
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <paddle/fluid/framework/operator.h>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T = int64_t>
inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
const std::vector<T>& axes,
std::vector<T>* starts,
std::vector<T>* ends,
std::vector<int64_t>* steps = nullptr,
std::vector<T>* infer_flags = nullptr) {
for (size_t i = 0; i < axes.size(); ++i) {
T axis = axes[i];
T dim_value = in_dims[axis];
if (dim_value > 0) {
if (infer_flags != nullptr && (*infer_flags)[i] == -1) {
continue;
}
T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
start = std::max(start, static_cast<T>(0));
T end = (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i];
end = std::min(end, dim_value);
T step = steps == nullptr ? 1 : (*steps)[i];
PADDLE_ENFORCE_NE(
step, 0, platform::errors::InvalidArgument(
"Step should not be 0, but received step = %d.", step));
if (step > 0) {
start = std::min(start, dim_value);
end = std::max(end, static_cast<T>(0));
PADDLE_ENFORCE_GT(
end, start,
platform::errors::InvalidArgument(
"When step > 0, end should be greater than start, but "
"received end = %d, start = %d.",
end, start));
} else {
// NOTE(liym27): When step < 0, start should less and equal to
// dim_value-1
// "end is -1" means contain the 0-th element of this axis.
start = std::min(start, dim_value - 1);
end = std::max(end, static_cast<T>(-1));
PADDLE_ENFORCE_GT(
start, end,
platform::errors::InvalidArgument(
"When step < 0, start should be greater than end, but "
"received start = %d, end = %d.",
start, end));
}
(*starts)[i] = start;
(*ends)[i] = end;
}
}
}
template <typename T = int64_t>
inline framework::DDim GetSliceDims(const framework::DDim in_dims,
const std::vector<T>& axes,
const std::vector<T>& starts,
const std::vector<T>& ends,
std::vector<T>* steps = nullptr,
std::vector<T>* infer_flags = nullptr) {
framework::DDim slice_dims(in_dims);
for (size_t i = 0; i < axes.size(); ++i) {
T axis = axes[i];
if (infer_flags != nullptr && (*infer_flags)[i] == -1) {
slice_dims[axis] = -1;
continue;
}
T start = starts[i];
T end = ends[i];
T step = steps == nullptr ? 1 : (*steps)[i];
if (step > 0) {
slice_dims[axis] = (end - start + step - 1) / step;
} else {
slice_dims[axis] = (end - start + step + 1) / step;
}
}
return slice_dims;
}
template <typename T = int64_t>
inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
const std::vector<T>& decrease_axes,
std::vector<T>* infer_flags = nullptr) {
framework::DDim decreased_dims(slice_dims);
if (decrease_axes.size() > 0) {
for (size_t i = 0; i < decrease_axes.size(); ++i) {
T axis = decrease_axes[i];
if (infer_flags && (*infer_flags)[i] != -1) {
PADDLE_ENFORCE_EQ(
decreased_dims[axis], 1,
platform::errors::InvalidArgument("decrease dim should be 1"));
}
decreased_dims[axis] = 0;
}
std::vector<T> new_shape;
for (int i = 0; i < decreased_dims.size(); ++i) {
if (decreased_dims[i] != 0) {
new_shape.push_back(decreased_dims[i]);
}
}
// NOTE(liym27): Paddle does not support that the rank of Tensor is 0, and
// uses [1] instead.
if (new_shape.size() == 0) {
new_shape.push_back(1);
}
decreased_dims = framework::make_ddim(new_shape);
}
return decreased_dims;
}
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册