未验证 提交 382d099d 编写于 作者: W wangchaochaohu 提交者: GitHub

add support tensor and tensorlist for strided_slice OP (#19929)

* add support tensor and tensorlist for strided_slice OP test=develop

* fix the commnet test=develop

* fix test=develop

* fix the bug test=develop

* delete log test=develop

* fix API.spec test=develop

* fix test=develop
上级 fe218df3
......@@ -251,7 +251,7 @@ paddle.fluid.layers.sampling_id (ArgSpec(args=['x', 'min', 'max', 'seed', 'dtype
paddle.fluid.layers.gaussian_random_batch_size_like (ArgSpec(args=['input', 'shape', 'input_dim_idx', 'output_dim_idx', 'mean', 'std', 'seed', 'dtype'], varargs=None, keywords=None, defaults=(0, 0, 0.0, 1.0, 0, 'float32')), ('document', 'b24d0b21361c4bb8ef2cec8c26fb12b2'))
paddle.fluid.layers.sum (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'f4b60847cb0f1ae00823ba6fb1b11310'))
paddle.fluid.layers.slice (ArgSpec(args=['input', 'axes', 'starts', 'ends'], varargs=None, keywords=None, defaults=None), ('document', '315b4870f294e33a27ecbdf440bed3ff'))
paddle.fluid.layers.strided_slice (ArgSpec(args=['input', 'axes', 'starts', 'ends', 'strides'], varargs=None, keywords=None, defaults=None), ('document', 'a2e5296d34c081f2a67890aaa5f02238'))
paddle.fluid.layers.strided_slice (ArgSpec(args=['input', 'axes', 'starts', 'ends', 'strides'], varargs=None, keywords=None, defaults=None), ('document', '340d8d656272ea396b441aab848429a2'))
paddle.fluid.layers.shape (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'bf61c8f79d795a8371bdb3b5468aa82b'))
paddle.fluid.layers.rank (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', '096df0e0273145ab80ed119a4c294db3'))
paddle.fluid.layers.size (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'cf2e156beae36378722666c4c33bebfe'))
......
......@@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/strided_slice_op.h"
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/operators/slice_op.h"
namespace paddle {
namespace operators {
......@@ -26,7 +28,7 @@ class StridedSliceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
"Input (Input) of slice op should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
......@@ -39,56 +41,56 @@ class StridedSliceOp : public framework::OperatorWithKernel {
auto ends = ctx->Attrs().Get<std::vector<int>>("ends");
auto strides = ctx->Attrs().Get<std::vector<int>>("strides");
auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
auto infer_flags = ctx->Attrs().Get<std::vector<int>>("infer_flags");
PADDLE_ENFORCE_EQ(starts.size(), ends.size(),
"starts and ends dim size must to be same");
PADDLE_ENFORCE_EQ(ends.size(), strides.size(),
"ends and strides dim size must to be same");
PADDLE_ENFORCE_EQ(ends.size(), axes.size(),
"axes, end and start dim size must to be same");
auto starts_size = starts.size();
auto ends_size = ends.size();
auto strides_size = strides.size();
if (ctx->HasInputs("StartsTensorList")) {
auto StartsTensorList = ctx->Inputs("StartsTensorList");
PADDLE_ENFORCE_GT(StartsTensorList.size(), 0,
"StartsTensorList size can't be zero");
starts_size = StartsTensorList.size();
}
if (ctx->HasInputs("EndsTensorList")) {
auto EndsTensorList = ctx->Inputs("EndsTensorList");
PADDLE_ENFORCE_GT(EndsTensorList.size(), 0,
"EndsTensorList size can't be zero");
ends_size = EndsTensorList.size();
}
if (ctx->HasInputs("StridesTensorList")) {
auto StridesTensorList = ctx->Inputs("StridesTensorList");
PADDLE_ENFORCE_GT(StridesTensorList.size(), 0,
"StridesTensorList size can't be zero");
strides_size = StridesTensorList.size();
}
auto tensor_input = false;
if (ctx->HasInput("EndsTensor") || ctx->HasInput("StartsTensor") ||
ctx->HasInput("StridesTensor")) {
tensor_input = true;
}
if (ctx->HasInput("EndsTensor") == false) {
PADDLE_ENFORCE_EQ(ends_size, axes.size(),
"The size of ends must be equal to the size of axes.");
}
if (ctx->HasInput("StartsTensor") == false) {
PADDLE_ENFORCE_EQ(
starts_size, axes.size(),
"The size of starts must be equal to the size of axes.");
}
if (ctx->HasInput("StridesTensor") == false) {
PADDLE_ENFORCE_EQ(
strides_size, axes.size(),
"The size of strides must be equal to the size of axes.");
}
// we need to analysis strided slice op is valid for
// the parameter that we get from python front
int stride_index, start_index, end_index;
std::vector<int> out_dims_vector(in_dims.size());
for (int i = 0; i < in_dims.size(); i++) {
out_dims_vector[i] = in_dims[i];
}
for (size_t i = 0; i < starts.size(); i++) {
PADDLE_ENFORCE_NE(strides[i], 0, "stride must not to be zero");
int axes_index = axes[i];
start_index = starts[i];
end_index = ends[i];
stride_index = strides[i];
int axis_size = in_dims[axes_index];
if (axis_size < 0) {
continue;
}
if (start_index < 0) {
start_index = start_index + axis_size;
}
if (end_index < 0) {
end_index = end_index + axis_size;
}
if (stride_index < 0) {
start_index = start_index + 1;
end_index = end_index + 1;
}
bool zero_dim_condition =
((stride_index < 0 && (start_index <= end_index)) ||
(stride_index > 0 && (start_index >= end_index)));
PADDLE_ENFORCE_EQ(zero_dim_condition, false,
"starts and end must meet requirement in different "
"stride conditiont");
int left = std::max(0, std::min(start_index, end_index));
int right = std::min(axis_size, std::max(start_index, end_index));
int step = std::abs(stride_index);
auto out_dims_index = (std::abs(right - left) + step - 1) / step;
out_dims_vector[axes_index] = out_dims_index;
std::vector<int> out_dims_vector(in_dims.size(), -1);
if (!tensor_input) {
StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims,
out_dims_vector.data(), axes.size(), true);
}
framework::DDim out_dims(framework::make_ddim(out_dims_vector));
......@@ -98,26 +100,83 @@ class StridedSliceOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.Input<Tensor>("Input")->place());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "StartsTensor" || var_name == "EndsTensor" ||
var_name == "StridesTensor") {
return expected_kernel_type;
}
if (var_name == "StartsTensorList" || var_name == "EndsTensorList" ||
var_name == "StridesTensorList") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class StridedSliceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "Tensor of data to extract slices from.");
AddOutput("Out", "Sliced data tensor.");
AddOutput("Out", "Strided Sliced data tensor.");
AddInput("StartsTensor",
"(Tensor<int32>, optional) If provided, slice will use this."
"It has the highest priority of StartsTensor, StartsTensorList "
"and attr(starts).")
.AsDispensable();
AddInput("EndsTensor",
"(Tensor<int32>, optional) If provided, slice will use this."
"It has the highest priority of EndsTensor, EndsTensorList and "
"attr(ends).")
.AsDispensable();
AddInput(
"StridesTensor",
"(Tensor<int32>, optional) If provided, slice will use this."
"It has the highest priority of StridesTensor, StridesTensorList and "
"attr(ends).")
.AsDispensable();
AddInput(
"StartsTensorList",
"(vector<Tensor<int32>>, optional) If provided, slice will use this."
"The shape of the tensor in vector MUST BE [1]."
"It has higher priority compare with attr(starts).")
.AsDuplicable()
.AsDispensable();
AddInput(
"EndsTensorList",
"(vector<Tensor<int32>>, optional) If provided, slice will use this."
"The shape of the tensor in vector MUST BE [1]."
"It has higher priority compare with attr(ends).")
.AsDuplicable()
.AsDispensable();
AddInput(
"StridesTensorList",
"(vector<Tensor<int32>>, optional) If provided, slice will use this."
"The shape of the tensor in vector MUST BE [1]."
"It has higher priority compare with attr(strides).")
.AsDuplicable()
.AsDispensable();
AddAttr<std::vector<int>>(
"axes", "(list<int> Axes stride from the start to the end)");
"axes", "(list<int>) Axes that `starts` and `ends` apply to.");
AddAttr<std::vector<int>>(
"starts", "(list<int>) start that the tensor slice start.");
"starts", "(list<int>) Start indices for the strided slice start.")
.SetDefault({});
AddAttr<std::vector<int>>("ends",
"(list<int>) end that the tensor slice end");
"(list<int>) End indices the tensor slice end")
.SetDefault({});
AddAttr<std::vector<int>>(
"strides", "(list<int> stride stride from the start to the end)");
"strides", "(list<int> Stride step from the start to the end)")
.SetDefault({});
AddAttr<std::vector<int>>(
"infer_flags", "(list<int>) Flags of inferring dims in attributes.")
.SetDefault({});
AddComment(R"DOC(
Strided Slice Operator.
Instead of calling this op directly most users will want to use the
......@@ -133,7 +192,7 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, "Input should not be null");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) should not be null");
......@@ -145,11 +204,23 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel {
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "StartsTensor" || var_name == "EndsTensor") {
return expected_kernel_type;
}
if (var_name == "StartsTensorList" || var_name == "EndsTensorList") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class StridedSliceOpGradMaker : public framework::SingleGradOpDescMaker {
......@@ -158,9 +229,15 @@ class StridedSliceOpGradMaker : public framework::SingleGradOpDescMaker {
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* bind = new framework::OpDesc();
auto *bind = new framework::OpDesc();
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetInput("Input", Input("Input"));
bind->SetInput("StartsTensor", Input("StartsTensor"));
bind->SetInput("EndsTensor", Input("EndsTensor"));
bind->SetInput("StridesTensor", Input("StridesTensor"));
bind->SetInput("StartsTensorList", Input("StartsTensorList"));
bind->SetInput("EndsTensorList", Input("EndsTensorList"));
bind->SetInput("StridesTensorList", Input("StridesTensorList"));
bind->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
bind->SetAttrMap(Attrs());
bind->SetType("strided_slice_grad");
......
......@@ -19,9 +19,62 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/slice_op.h"
namespace paddle {
namespace operators {
static void StridedSliceOutDims(
const std::vector<int>& starts, const std::vector<int>& ends,
const std::vector<int>& strides, const std::vector<int>& axes,
const std::vector<int>& infer_flags, const framework::DDim in_dims,
int* out_dims_vector, const size_t size, bool infer_shape) {
for (int i = 0; i < in_dims.size(); i++) {
out_dims_vector[i] = in_dims[i];
}
int stride_index, start_index, end_index;
for (size_t i = 0; i < size; i++) {
int axes_index = axes[i];
if (infer_shape && infer_flags[i] == -1) {
out_dims_vector[axes_index] = -1;
continue;
}
PADDLE_ENFORCE_NE(strides[i], 0, "stride must not to be zero");
start_index = starts[i];
end_index = ends[i];
stride_index = strides[i];
int axis_size = in_dims[axes_index];
if (axis_size < 0) {
continue;
}
if (start_index < 0) {
start_index = start_index + axis_size;
}
if (end_index < 0) {
end_index = end_index + axis_size;
}
if (stride_index < 0) {
start_index = start_index + 1;
end_index = end_index + 1;
}
bool zero_dim_condition =
((stride_index < 0 && (start_index <= end_index)) ||
(stride_index > 0 && (start_index >= end_index)));
PADDLE_ENFORCE_EQ(zero_dim_condition, false,
"starts and end must meet requirement in different "
"stride conditiont");
int left = std::max(0, std::min(start_index, end_index));
int right = std::min(axis_size, std::max(start_index, end_index));
int step = std::abs(stride_index);
auto out_dims_index = (std::abs(right - left) + step - 1) / step;
out_dims_vector[axes_index] = out_dims_index;
}
}
static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes,
int* reverse_axis, const framework::DDim dims,
const size_t size) {
......@@ -91,19 +144,52 @@ class StridedSliceKernel : public framework::OpKernel<T> {
*context.template device_context<DeviceContext>().eigen_device();
auto in = context.Input<framework::Tensor>("Input");
auto out = context.Output<framework::Tensor>("Out");
auto out_dims = out->dims();
auto in_dims = in->dims();
auto starts = context.Attr<std::vector<int>>("starts");
auto ends = context.Attr<std::vector<int>>("ends");
auto strides = context.Attr<std::vector<int>>("strides");
auto axes = context.Attr<std::vector<int>>("axes");
auto infer_flags = context.Attr<std::vector<int>>("infer_flags");
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto reverse_axis = Eigen::array<bool, D>();
auto list_new_ends_tensor =
context.MultiInput<framework::Tensor>("EndsTensorList");
auto list_new_starts_tensor =
context.MultiInput<framework::Tensor>("StartsTensorList");
auto list_new_strides_tensor =
context.MultiInput<framework::Tensor>("StridesTensorList");
if (list_new_starts_tensor.size() > 0) {
starts = get_new_data_from_tensorlist(list_new_starts_tensor);
} else if (context.HasInput("StartsTensor")) {
auto* starts_tensor = context.Input<framework::Tensor>("StartsTensor");
starts = get_new_data_from_tensor(starts_tensor);
}
if (list_new_ends_tensor.size() > 0) {
ends = get_new_data_from_tensorlist(list_new_ends_tensor);
} else if (context.HasInput("EndsTensor")) {
auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor");
ends = get_new_data_from_tensor(ends_tensor);
}
if (list_new_strides_tensor.size() > 0) {
strides = get_new_data_from_tensorlist(list_new_strides_tensor);
} else if (context.HasInput("StridesTensor")) {
auto* strides_tensor = context.Input<framework::Tensor>("StridesTensor");
strides = get_new_data_from_tensor(strides_tensor);
}
std::vector<int> out_dims_vector(in_dims.size(), -1);
StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims,
out_dims_vector.data(), axes.size(), false);
framework::DDim out_dims(framework::make_ddim(out_dims_vector));
std::vector<int> reverse_vector(starts.size(), 0);
StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(),
reverse_vector.data(), in_dims, starts.size());
......@@ -112,6 +198,7 @@ class StridedSliceKernel : public framework::OpKernel<T> {
starts_indices[axis] = 0;
ends_indices[axis] = out_dims[axis];
strides_indices[axis] = 1;
reverse_axis[axis] = false;
}
for (size_t axis = 0; axis < axes.size(); axis++) {
int axis_index = axes[axis];
......@@ -124,6 +211,7 @@ class StridedSliceKernel : public framework::OpKernel<T> {
framework::Tensor tmp;
tmp.mutable_data<T>(out_dims, context.GetPlace());
out->Resize(out_dims);
out->mutable_data<T>(context.GetPlace());
auto in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
......@@ -189,6 +277,34 @@ class StridedSliceGradKernel : public framework::OpKernel<T> {
auto strides = context.Attr<std::vector<int>>("strides");
auto axes = context.Attr<std::vector<int>>("axes");
auto list_new_ends_tensor =
context.MultiInput<framework::Tensor>("EndsTensorList");
auto list_new_starts_tensor =
context.MultiInput<framework::Tensor>("StartsTensorList");
auto list_new_strides_tensor =
context.MultiInput<framework::Tensor>("StridesTensorList");
if (list_new_starts_tensor.size() > 0) {
starts = get_new_data_from_tensorlist(list_new_starts_tensor);
} else if (context.HasInput("StartsTensor")) {
auto* starts_tensor = context.Input<framework::Tensor>("StartsTensor");
starts = get_new_data_from_tensor(starts_tensor);
}
if (list_new_ends_tensor.size() > 0) {
ends = get_new_data_from_tensorlist(list_new_ends_tensor);
} else if (context.HasInput("EndsTensor")) {
auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor");
ends = get_new_data_from_tensor(ends_tensor);
}
if (list_new_strides_tensor.size() > 0) {
strides = get_new_data_from_tensorlist(list_new_strides_tensor);
} else if (context.HasInput("StridesTensor")) {
auto* strides_tensor = context.Input<framework::Tensor>("StridesTensor");
strides = get_new_data_from_tensor(strides_tensor);
}
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
......
......@@ -11396,57 +11396,148 @@ def strided_slice(input, axes, starts, ends, strides):
axes = [0, 1]
starts = [1, 0]
ends = [2, 3]
strides = [1, 1]
strides=[1, 1]
Then:
result = [ [5, 6, 7] ]
result = [ [5, 6, 7], ]
Case2:
Given:
data = [ [1, 2, 3, 4], [5, 6, 7, 8], ]
axes = [0, 1]
starts = [0, -1]
ends = [-1, 0]
strides = [1, -1]
starts = [0, 1]
ends = [-1, 1000]
strides = [1, 3]
Then:
result = [ [4, 3, 2] ]
Atrgs:
input (Varibale): the input variable.
axes(List):axis we need to slice
starts (List): the start index in axis
ends (List): the end index in axis
strides (List): the stride length when we do slice operation
Returns
out(Variable): the result by strided_slice Op
result = [ [2], ]
Args:
input (Variable): ${input_comment}.
axes (List): ${axes_comment}
starts (List|Variable): ${starts_comment}
ends (List|Variable): ${ends_comment}
Returns:
out (Variable): ${out_comment}
Examples:
.. code-block:: python
import paddle.fluid as fluid
starts = [1, 0, 2]
ends = [3, 3, 4]
axes = [0, 1, 2]
strides= [1, 1, 1]
input = fluid.layers.data(
name="input", shape=[3, 4, 5, 6], dtype='float32')
out = fluid.layers.strided_slice(input, axes=axes, starts=starts, ends=ends, strides=strides)
# example 1:
# attr starts is a list which doesn't contain tensor Variable.
axes = [0, 1, 2]
starts = [-3, 0, 2]
ends = [3, 2, 4]
strides=[1, 1, 1]
sliced_1 = fluid.layers.strided_slice(input, axes=axes, starts=starts, ends=ends, strides=strides)
# example 2:
# attr starts is a list which contain tensor Variable.
minus_3 = fluid.layers.fill_constant([1], "int32", -3)
sliced_2 = fluid.layers.strided_slice(input, axes=axes, starts=[minus_3, 0, 2], ends=ends, strides=strides)
"""
if not isinstance(starts, (list, tuple, Variable)):
raise ValueError(
"Input starts must be an Variable, python list or tuple.")
if not isinstance(ends, (list, tuple, Variable)):
raise ValueError(
"Input ends must be an Variable, python list or tuple.")
if not isinstance(strides, (list, tuple, Variable)):
raise ValueError(
"Input strides must be an Variable, python list or tuple.")
helper = LayerHelper('strided_slice', **locals())
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype('input'))
helper.append_op(
type='strided_slice',
inputs={'Input': input},
outputs={'Out': out},
attrs={
def contain_var(one_list):
for ele in one_list:
if isinstance(ele, Variable):
return True
return False
def get_new_list_tensor(old_list):
new_list_tensor = []
for dim in old_list:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_list_tensor.append(dim)
else:
assert (isinstance(dim, int))
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out)
new_list_tensor.append(temp_out)
return new_list_tensor
inputs = {'Input': input}
attrs = {'axes': axes}
infer_flags = list(1 for i in range(len(axes)))
if in_dygraph_mode():
inputs = {'Input': input}
attrs = {
'axes': axes,
'starts': starts,
'ends': ends,
'strides': strides
})
'strides': strides,
'infer_flags': infer_flags
}
else:
# starts
if isinstance(starts, Variable):
starts.stop_gradient = True
inputs['StartsTensor'] = starts
elif isinstance(starts, (list, tuple)):
attrs['starts'] = []
if not contain_var(starts):
attrs['starts'] = starts
else:
inputs['StartsTensorList'] = get_new_list_tensor(starts)
for i, dim in enumerate(starts):
if isinstance(dim, Variable):
attrs['starts'].append(-1)
infer_flags[i] = -1
else:
attrs['starts'].append(dim)
# ends
if isinstance(ends, Variable):
ends.stop_gradient = True
inputs['EndsTensor'] = ends
elif isinstance(ends, (list, tuple)):
attrs['ends'] = []
if not contain_var(ends):
attrs['ends'] = ends
else:
inputs['EndsTensorList'] = get_new_list_tensor(ends)
for i, dim in enumerate(ends):
if isinstance(dim, Variable):
attrs['ends'].append(-1)
infer_flags[i] = -1
else:
attrs['ends'].append(dim)
# strides
if isinstance(strides, Variable):
strides.stop_gradient = True
inputs['StridesTensor'] = strides
elif isinstance(strides, (list, tuple)):
attrs['strides'] = []
if not contain_var(strides):
attrs['strides'] = strides
else:
inputs['StridesTensorList'] = get_new_list_tensor(strides)
for i, dim in enumerate(strides):
if isinstance(dim, Variable):
attrs['strides'].append(-1)
infer_flags[i] = -1
else:
attrs['strides'].append(dim)
attrs['infer_flags'] = infer_flags
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype('input'))
helper.append_op(
type='strided_slice', inputs=inputs, attrs=attrs, outputs={'Out': out})
return out
......
......@@ -15,6 +15,7 @@
from op_test import OpTest
import numpy as np
import unittest
import paddle.fluid as fluid
def strided_slice_native_forward(input, axes, starts, ends, strides):
......@@ -63,7 +64,8 @@ class TestStrideSliceOp(OpTest):
'axes': self.axes,
'starts': self.starts,
'ends': self.ends,
'strides': self.strides
'strides': self.strides,
'infer_flags': self.infer_flags
}
def test_check_output(self):
......@@ -78,6 +80,7 @@ class TestStrideSliceOp(OpTest):
self.starts = [-4]
self.ends = [-3]
self.strides = [1]
self.infer_flags = [1]
class TestStrideSliceOp1(TestStrideSliceOp):
......@@ -87,6 +90,7 @@ class TestStrideSliceOp1(TestStrideSliceOp):
self.starts = [3]
self.ends = [8]
self.strides = [1]
self.infer_flags = [1]
class TestStrideSliceOp2(TestStrideSliceOp):
......@@ -96,6 +100,7 @@ class TestStrideSliceOp2(TestStrideSliceOp):
self.starts = [5]
self.ends = [0]
self.strides = [-1]
self.infer_flags = [1]
class TestStrideSliceOp3(TestStrideSliceOp):
......@@ -105,6 +110,7 @@ class TestStrideSliceOp3(TestStrideSliceOp):
self.starts = [-1]
self.ends = [-3]
self.strides = [-1]
self.infer_flags = [1]
class TestStrideSliceOp4(TestStrideSliceOp):
......@@ -114,6 +120,7 @@ class TestStrideSliceOp4(TestStrideSliceOp):
self.starts = [0, -1, 0]
self.ends = [2, -3, 5]
self.strides = [1, -1, 1]
self.infer_flags = [1, 1, 1]
class TestStrideSliceOp5(TestStrideSliceOp):
......@@ -123,6 +130,7 @@ class TestStrideSliceOp5(TestStrideSliceOp):
self.starts = [1, 0, 0]
self.ends = [2, 1, 3]
self.strides = [1, 1, 1]
self.infer_flags = [1, 1, 1]
class TestStrideSliceOp6(TestStrideSliceOp):
......@@ -132,6 +140,7 @@ class TestStrideSliceOp6(TestStrideSliceOp):
self.starts = [1, -1, 0]
self.ends = [2, -3, 3]
self.strides = [1, -1, 1]
self.infer_flags = [1, 1, 1]
class TestStrideSliceOp7(TestStrideSliceOp):
......@@ -141,6 +150,7 @@ class TestStrideSliceOp7(TestStrideSliceOp):
self.starts = [1, 0, 0]
self.ends = [2, 2, 3]
self.strides = [1, 1, 1]
self.infer_flags = [1, 1, 1]
class TestStrideSliceOp8(TestStrideSliceOp):
......@@ -150,6 +160,7 @@ class TestStrideSliceOp8(TestStrideSliceOp):
self.starts = [1]
self.ends = [2]
self.strides = [1]
self.infer_flags = [1]
class TestStrideSliceOp9(TestStrideSliceOp):
......@@ -159,6 +170,7 @@ class TestStrideSliceOp9(TestStrideSliceOp):
self.starts = [-1]
self.ends = [-2]
self.strides = [-1]
self.infer_flags = [1]
class TestStrideSliceOp10(TestStrideSliceOp):
......@@ -168,6 +180,7 @@ class TestStrideSliceOp10(TestStrideSliceOp):
self.starts = [1, 0]
self.ends = [2, 2]
self.strides = [1, 1]
self.infer_flags = [1, 1]
class TestStrideSliceOp11(TestStrideSliceOp):
......@@ -177,6 +190,7 @@ class TestStrideSliceOp11(TestStrideSliceOp):
self.starts = [1, 0, 0, 0]
self.ends = [2, 2, 3, 4]
self.strides = [1, 1, 1, 2]
self.infer_flags = [1, 1, 1, 1]
class TestStrideSliceOp12(TestStrideSliceOp):
......@@ -186,6 +200,7 @@ class TestStrideSliceOp12(TestStrideSliceOp):
self.starts = [1, 0, 0, 0, 0]
self.ends = [2, 2, 3, 4, 4]
self.strides = [1, 1, 1, 1, 1]
self.infer_flags = [1, 1, 1, 1]
class TestStrideSliceOp13(TestStrideSliceOp):
......@@ -195,6 +210,295 @@ class TestStrideSliceOp13(TestStrideSliceOp):
self.starts = [1, 0, 0, 0, 1, 2]
self.ends = [2, 2, 3, 1, 2, 8]
self.strides = [1, 1, 1, 1, 1, 2]
self.infer_flags = [1, 1, 1, 1, 1]
class TestStridedSliceOp_starts_ListTensor(OpTest):
def setUp(self):
self.op_type = "strided_slice"
self.config()
starts_tensor = []
for index, ele in enumerate(self.starts):
starts_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {'Input': self.input, 'StartsTensorList': starts_tensor}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
'starts': self.starts_infer,
'ends': self.ends,
'strides': self.strides,
'infer_flags': self.infer_flags
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [1, 0, 2]
self.ends = [3, 3, 4]
self.axes = [0, 1, 2]
self.strides = [1, 1, 1]
self.infer_flags = [1, -1, 1]
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
self.starts_infer = [1, 10, 2]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestStridedSliceOp_ends_ListTensor(OpTest):
def setUp(self):
self.op_type = "strided_slice"
self.config()
ends_tensor = []
for index, ele in enumerate(self.ends):
ends_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {'Input': self.input, 'EndsTensorList': ends_tensor}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends_infer,
'strides': self.strides,
'infer_flags': self.infer_flags
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [1, 0, 0]
self.ends = [3, 3, 4]
self.axes = [0, 1, 2]
self.strides = [1, 1, 2]
self.infer_flags = [1, -1, 1]
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
self.ends_infer = [3, 1, 4]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestStridedSliceOp_starts_Tensor(OpTest):
def setUp(self):
self.op_type = "strided_slice"
self.config()
self.inputs = {
'Input': self.input,
"StartsTensor": np.array(
self.starts, dtype="int32")
}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
#'starts': self.starts,
'ends': self.ends,
'strides': self.strides,
'infer_flags': self.infer_flags,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [1, 0, 2]
self.ends = [2, 3, 4]
self.axes = [0, 1, 2]
self.strides = [1, 1, 1]
self.infer_flags = [-1, -1, -1]
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestStridedSliceOp_ends_Tensor(OpTest):
def setUp(self):
self.op_type = "strided_slice"
self.config()
self.inputs = {
'Input': self.input,
"EndsTensor": np.array(
self.ends, dtype="int32")
}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
#'ends': self.ends,
'strides': self.strides,
'infer_flags': self.infer_flags,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [1, 0, 2]
self.ends = [2, 3, 4]
self.axes = [0, 1, 2]
self.strides = [1, 1, 1]
self.infer_flags = [-1, -1, -1]
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestStridedSliceOp_listTensor_Tensor(OpTest):
def setUp(self):
self.config()
ends_tensor = []
for index, ele in enumerate(self.ends):
ends_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.op_type = "strided_slice"
self.inputs = {
'Input': self.input,
"StartsTensor": np.array(
self.starts, dtype="int32"),
"EndsTensorList": ends_tensor
}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
#'starts': self.starts,
#'ends': self.ends,
'strides': self.strides,
'infer_flags': self.infer_flags,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [1, 0, 2]
self.ends = [2, 3, 4]
self.axes = [0, 1, 2]
self.strides = [1, 1, 1]
self.infer_flags = [-1, -1, -1]
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestStridedSliceOp_strides_Tensor(OpTest):
def setUp(self):
self.op_type = "strided_slice"
self.config()
self.inputs = {
'Input': self.input,
"StridesTensor": np.array(
self.strides, dtype="int32")
}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends,
#'strides': self.strides,
'infer_flags': self.infer_flags,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [1, -1, 2]
self.ends = [2, 0, 4]
self.axes = [0, 1, 2]
self.strides = [1, -1, 1]
self.infer_flags = [-1, -1, -1]
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
# Test python API
class TestSliceAPI(OpTest):
def test_1(self):
input = np.random.random([3, 4, 5, 6]).astype("float32")
minus_1 = fluid.layers.fill_constant([1], "int32", -1)
minus_3 = fluid.layers.fill_constant([1], "int32", -3)
starts = fluid.layers.data(
name='starts', shape=[3], append_batch_size=False)
ends = fluid.layers.data(
name='ends', shape=[3], append_batch_size=False)
strides = fluid.layers.data(
name='strides', shape=[3], append_batch_size=False)
x = fluid.layers.data(
name="x",
shape=[3, 4, 5, 6],
append_batch_size=False,
dtype="float32")
out_1 = fluid.layers.strided_slice(
x,
axes=[0, 1, 2],
starts=[-3, 0, 2],
ends=[3, 100, -1],
strides=[1, 1, 1])
out_2 = fluid.layers.strided_slice(
x,
axes=[0, 1, 3],
starts=[minus_3, 0, 2],
ends=[3, 100, -1],
strides=[1, 1, 1])
out_3 = fluid.layers.strided_slice(
x,
axes=[0, 1, 3],
starts=[minus_3, 0, 2],
ends=[3, 100, minus_1],
strides=[1, 1, 1])
out_4 = fluid.layers.strided_slice(
x, axes=[0, 1, 2], starts=starts, ends=ends, strides=strides)
out_5 = x[-3:3, 0:100, 2:-1]
out_6 = x[minus_3:3, 0:100, :, 2:-1]
out_7 = x[minus_1, 0:100, :, 2:minus_1]
exe = fluid.Executor(place=fluid.CPUPlace())
res_1, res_2, res_3, res_4, res_5, res_6, res_7 = exe.run(
fluid.default_main_program(),
feed={
"x": input,
'starts': np.array([-3, 0, 2]).astype("int32"),
'ends': np.array([3, 100, -1]).astype("int32"),
'strides': np.array([1, 1, 1]).astype("int32")
},
fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7])
assert np.array_equal(res_1, input[-3:3, 0:100, 2:-1, :])
assert np.array_equal(res_2, input[-3:3, 0:100, :, 2:-1])
assert np.array_equal(res_3, input[-3:3, 0:100, :, 2:-1])
assert np.array_equal(res_4, input[-3:3, 0:100, 2:-1, :])
assert np.array_equal(res_5, input[-3:3, 0:100, 2:-1, :])
assert np.array_equal(res_6, input[-3:3, 0:100, :, 2:-1])
assert np.array_equal(res_7, input[-1, 0:100, :, 2:-1])
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册