未验证 提交 4cf01462 编写于 作者: L liym27 提交者: GitHub

Polish code for slice and set_value op (#32947)

上级 a039fd7b
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/assign_value_op.h" #include "paddle/fluid/operators/assign_value_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/slice_utils.h"
#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -59,106 +60,6 @@ inline std::string GetValueName(framework::proto::VarType::Type data_type) { ...@@ -59,106 +60,6 @@ inline std::string GetValueName(framework::proto::VarType::Type data_type) {
return value_name; return value_name;
} }
inline void CheckAndUpdateSlice(const framework::DDim in_dims,
const std::vector<int64_t> axes,
std::vector<int64_t>* starts,
std::vector<int64_t>* ends,
std::vector<int64_t>* steps) {
for (size_t i = 0; i < axes.size(); ++i) {
int64_t axis = axes[i];
int64_t dim_value = in_dims[axis];
int64_t start =
(*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
int64_t end = (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i];
start = std::max(start, static_cast<int64_t>(0));
end = std::min(end, dim_value);
int64_t step = (*steps)[i];
PADDLE_ENFORCE_NE(
step, 0, platform::errors::InvalidArgument(
"Step should not be 0, but received step = %d.", step));
if (step > 0) {
start = std::min(start, dim_value);
end = std::max(end, static_cast<int64_t>(0));
PADDLE_ENFORCE_GT(
end, start,
platform::errors::InvalidArgument(
"When step > 0, end should be greater than start, but "
"received end = %d, start = %d.",
end, start));
} else {
// NOTE(liym27): When step < 0, start should less and equal to dim_value-1
// "end is -1" means contain the 0-th element of this axis.
start = std::min(start, dim_value - 1);
end = std::max(end, static_cast<int64_t>(-1));
PADDLE_ENFORCE_GT(
start, end,
platform::errors::InvalidArgument(
"When step < 0, start should be greater than end, but "
"received start = %d, end = %d.",
start, end));
}
(*starts)[i] = start;
(*ends)[i] = end;
}
}
inline framework::DDim GetSliceDims(const framework::DDim in_dims,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& steps) {
framework::DDim slice_dims(in_dims);
for (size_t i = 0; i < axes.size(); ++i) {
int64_t axis = axes[i];
int64_t start = starts[i];
int64_t end = ends[i];
int64_t step = steps[i];
if (step > 0) {
slice_dims[axis] = (end - start + step - 1) / step;
} else {
slice_dims[axis] = (end - start + step + 1) / step;
}
}
return slice_dims;
}
inline framework::DDim GetDecreasedDims(
const framework::DDim slice_dims,
const std::vector<int64_t>& decrease_axes) {
// Get dims after decreasing axes.
framework::DDim decreased_dims(slice_dims);
if (decrease_axes.size() > 0) {
for (size_t i = 0; i < decrease_axes.size(); ++i) {
int64_t axis = decrease_axes[i];
PADDLE_ENFORCE_EQ(
decreased_dims[axis], 1,
platform::errors::InvalidArgument("decrease dim should be 1"));
decreased_dims[axis] = 0;
}
std::vector<int64_t> new_shape;
for (int i = 0; i < decreased_dims.size(); ++i) {
if (decreased_dims[i] != 0) {
new_shape.push_back(decreased_dims[i]);
}
}
// NOTE(liym27): Paddle does not support that the rank of Tensor is 0, and
// uses [1] instead.
if (new_shape.size() == 0) {
new_shape.push_back(1);
}
decreased_dims = framework::make_ddim(new_shape);
}
return decreased_dims;
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SetValueKernel : public framework::OpKernel<T> { class SetValueKernel : public framework::OpKernel<T> {
public: public:
...@@ -225,8 +126,8 @@ class SetValueKernel : public framework::OpKernel<T> { ...@@ -225,8 +126,8 @@ class SetValueKernel : public framework::OpKernel<T> {
} }
auto in_dims = in->dims(); auto in_dims = in->dims();
CheckAndUpdateSlice(in_dims, axes, &starts, &ends, &steps); CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends, &steps);
auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, steps); auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, &steps);
auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes); auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
......
...@@ -28,13 +28,10 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -28,13 +28,10 @@ class SliceOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "slice");
platform::errors::InvalidArgument( OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "slice");
"Input (Input) of slice op should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, // Case 1: Special treatment when input is a tensor array.
platform::errors::InvalidArgument(
"Output (Out) of slice op should not be null."));
auto x_var_type = ctx->GetInputsVarType("Input")[0]; auto x_var_type = ctx->GetInputsVarType("Input")[0];
auto axes = ctx->Attrs().Get<std::vector<int>>("axes"); auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
if (x_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) { if (x_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
...@@ -57,6 +54,8 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -57,6 +54,8 @@ class SliceOp : public framework::OperatorWithKernel {
return; return;
} }
} }
// Case 2: input is a tensor.
auto in_dims = ctx->GetInputDim("Input"); auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_LT(in_dims.size(), 7, PADDLE_ENFORCE_LT(in_dims.size(), 7,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -65,101 +64,54 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -65,101 +64,54 @@ class SliceOp : public framework::OperatorWithKernel {
auto starts = ctx->Attrs().Get<std::vector<int>>("starts"); auto starts = ctx->Attrs().Get<std::vector<int>>("starts");
auto ends = ctx->Attrs().Get<std::vector<int>>("ends"); auto ends = ctx->Attrs().Get<std::vector<int>>("ends");
auto infer_flags = ctx->Attrs().Get<std::vector<int>>("infer_flags");
auto decrease_axis = ctx->Attrs().Get<std::vector<int>>("decrease_axis"); auto decrease_axis = ctx->Attrs().Get<std::vector<int>>("decrease_axis");
auto infer_flags = ctx->Attrs().Get<std::vector<int>>("infer_flags");
auto starts_size = starts.size();
auto ends_size = ends.size();
if (infer_flags.empty()) { if (infer_flags.empty()) {
// Initialize infer_flags with 1. // Initialize infer_flags with 1.
// To be compatible with other op tests in which infer_flags is not set. // To be compatible with other op tests in which infer_flags is not set.
infer_flags = std::vector<int>(axes.size(), 1); infer_flags = std::vector<int>(axes.size(), 1);
} }
// 2.1 Check attrs.
auto starts_size = starts.size();
auto ends_size = ends.size();
if (ctx->HasInputs("StartsTensorList")) { if (ctx->HasInputs("StartsTensorList")) {
auto StartsTensorList = ctx->Inputs("StartsTensorList"); starts_size = ctx->Inputs("StartsTensorList").size();
PADDLE_ENFORCE_GT(StartsTensorList.size(), 0, PADDLE_ENFORCE_GT(starts_size, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"StartsTensorList size can't be zero")); "StartsTensorList size can't be zero"));
starts_size = StartsTensorList.size();
} }
if (ctx->HasInputs("EndsTensorList")) { if (ctx->HasInputs("EndsTensorList")) {
auto EndsTensorList = ctx->Inputs("EndsTensorList"); ends_size = ctx->Inputs("EndsTensorList").size();
PADDLE_ENFORCE_GT(EndsTensorList.size(), 0, PADDLE_ENFORCE_GT(ends_size, 0, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "EndsTensorList size can't be zero"));
"EndsTensorList size can't be zero"));
ends_size = EndsTensorList.size();
} }
if (ctx->HasInput("StartsTensor") == false) { if (!ctx->HasInput("StartsTensor")) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
starts_size, axes.size(), starts_size, axes.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The size of starts must be equal to the size of axes.")); "The size of starts must be equal to the size of axes."));
} }
if (ctx->HasInput("EndsTensor") == false) { if (!ctx->HasInput("EndsTensor")) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ends_size, axes.size(), ends_size, axes.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The size of ends must be equal to the size of axes.")); "The size of ends must be equal to the size of axes."));
} }
int dim_value, start, end; CheckAndUpdateSliceAttrs<int>(in_dims, axes, &starts, &ends, nullptr,
for (size_t i = 0; i < axes.size(); ++i) { &infer_flags);
PADDLE_ENFORCE_LT(static_cast<int>(axes[i]), in_dims.size(),
platform::errors::InvalidArgument(
"The index of dimension in axes must be less "
"than the size of input shape."));
if (infer_flags[i] == -1) {
out_dims[axes[i]] = -1;
} else {
// infer out_dim shape
dim_value = out_dims[axes[i]];
if (dim_value > 0) {
start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
start = std::max(start, 0);
end = std::max(end, 0);
end = std::min(end, dim_value);
PADDLE_ENFORCE_LE(start, dim_value,
platform::errors::InvalidArgument(
"start should be less than or equal to the "
"dimension value, but received "
"start = %d, shape[%d] = %d.",
starts[i], axes[i], out_dims[axes[i]]));
PADDLE_ENFORCE_GT(end, start,
platform::errors::InvalidArgument(
"end should greater than start, but received "
"end = %d, start = %d.",
ends[i], starts[i]));
out_dims[axes[i]] = end - start;
}
}
}
// generate new shape
if (decrease_axis.size() > 0) {
std::vector<int> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
if (ctx->IsRuntime() && infer_flags[i] != -1) {
PADDLE_ENFORCE_EQ(
out_dims[decrease_axis[i]], 1,
platform::errors::InvalidArgument("decrease dim should be 1"));
}
out_dims[decrease_axis[i]] = 0;
}
for (int i = 0; i < out_dims.size(); ++i) { auto slice_dims =
if (out_dims[i] != 0) { GetSliceDims<int>(in_dims, axes, starts, ends, nullptr, &infer_flags);
new_out_shape.push_back(out_dims[i]); if (ctx->IsRuntime()) {
} out_dims = GetDecreasedDims<int>(slice_dims, decrease_axis, &infer_flags);
} } else {
if (new_out_shape.size() == 0) { out_dims = GetDecreasedDims<int>(slice_dims, decrease_axis, nullptr);
new_out_shape.push_back(1);
}
out_dims = framework::make_ddim(new_out_shape);
} }
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
if (axes[0] != 0) { if (axes[0] != 0) {
ctx->ShareLoD("Input", /*->*/ "Out"); ctx->ShareLoD("Input", /*->*/ "Out");
...@@ -185,6 +137,7 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -185,6 +137,7 @@ class SliceOp : public framework::OperatorWithKernel {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor, const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override { const framework::OpKernelType &expected_kernel_type) const override {
......
此差异已折叠。
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <paddle/fluid/framework/operator.h>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T = int64_t>
inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
const std::vector<T>& axes,
std::vector<T>* starts,
std::vector<T>* ends,
std::vector<int64_t>* steps = nullptr,
std::vector<T>* infer_flags = nullptr) {
for (size_t i = 0; i < axes.size(); ++i) {
T axis = axes[i];
T dim_value = in_dims[axis];
if (dim_value > 0) {
if (infer_flags != nullptr && (*infer_flags)[i] == -1) {
continue;
}
T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
start = std::max(start, static_cast<T>(0));
T end = (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i];
end = std::min(end, dim_value);
T step = steps == nullptr ? 1 : (*steps)[i];
PADDLE_ENFORCE_NE(
step, 0, platform::errors::InvalidArgument(
"Step should not be 0, but received step = %d.", step));
if (step > 0) {
start = std::min(start, dim_value);
end = std::max(end, static_cast<T>(0));
PADDLE_ENFORCE_GT(
end, start,
platform::errors::InvalidArgument(
"When step > 0, end should be greater than start, but "
"received end = %d, start = %d.",
end, start));
} else {
// NOTE(liym27): When step < 0, start should less and equal to
// dim_value-1
// "end is -1" means contain the 0-th element of this axis.
start = std::min(start, dim_value - 1);
end = std::max(end, static_cast<T>(-1));
PADDLE_ENFORCE_GT(
start, end,
platform::errors::InvalidArgument(
"When step < 0, start should be greater than end, but "
"received start = %d, end = %d.",
start, end));
}
(*starts)[i] = start;
(*ends)[i] = end;
}
}
}
template <typename T = int64_t>
inline framework::DDim GetSliceDims(const framework::DDim in_dims,
const std::vector<T>& axes,
const std::vector<T>& starts,
const std::vector<T>& ends,
std::vector<T>* steps = nullptr,
std::vector<T>* infer_flags = nullptr) {
framework::DDim slice_dims(in_dims);
for (size_t i = 0; i < axes.size(); ++i) {
T axis = axes[i];
if (infer_flags != nullptr && (*infer_flags)[i] == -1) {
slice_dims[axis] = -1;
continue;
}
T start = starts[i];
T end = ends[i];
T step = steps == nullptr ? 1 : (*steps)[i];
if (step > 0) {
slice_dims[axis] = (end - start + step - 1) / step;
} else {
slice_dims[axis] = (end - start + step + 1) / step;
}
}
return slice_dims;
}
template <typename T = int64_t>
inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
const std::vector<T>& decrease_axes,
std::vector<T>* infer_flags = nullptr) {
framework::DDim decreased_dims(slice_dims);
if (decrease_axes.size() > 0) {
for (size_t i = 0; i < decrease_axes.size(); ++i) {
T axis = decrease_axes[i];
if (infer_flags && (*infer_flags)[i] != -1) {
PADDLE_ENFORCE_EQ(
decreased_dims[axis], 1,
platform::errors::InvalidArgument("decrease dim should be 1"));
}
decreased_dims[axis] = 0;
}
std::vector<T> new_shape;
for (int i = 0; i < decreased_dims.size(); ++i) {
if (decreased_dims[i] != 0) {
new_shape.push_back(decreased_dims[i]);
}
}
// NOTE(liym27): Paddle does not support that the rank of Tensor is 0, and
// uses [1] instead.
if (new_shape.size() == 0) {
new_shape.push_back(1);
}
decreased_dims = framework::make_ddim(new_shape);
}
return decreased_dims;
}
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册