From 382d099dcb16d23d2829b1797ba4a4f5b92ce516 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Wed, 25 Sep 2019 14:29:28 +0800 Subject: [PATCH] add support tensor and tensorlist for strided_slice OP (#19929) * add support tensor and tensorlist for strided_slice OP test=develop * fix the commnet test=develop * fix test=develop * fix the bug test=develop * delete log test=develop * fix API.spec test=develop * fix test=develop --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/strided_slice_op.cc | 189 +++++++---- paddle/fluid/operators/strided_slice_op.h | 118 ++++++- python/paddle/fluid/layers/nn.py | 151 +++++++-- .../tests/unittests/test_strided_slice_op.py | 306 +++++++++++++++++- 5 files changed, 677 insertions(+), 89 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 512ffbff5e..9080a53ab3 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -251,7 +251,7 @@ paddle.fluid.layers.sampling_id (ArgSpec(args=['x', 'min', 'max', 'seed', 'dtype paddle.fluid.layers.gaussian_random_batch_size_like (ArgSpec(args=['input', 'shape', 'input_dim_idx', 'output_dim_idx', 'mean', 'std', 'seed', 'dtype'], varargs=None, keywords=None, defaults=(0, 0, 0.0, 1.0, 0, 'float32')), ('document', 'b24d0b21361c4bb8ef2cec8c26fb12b2')) paddle.fluid.layers.sum (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'f4b60847cb0f1ae00823ba6fb1b11310')) paddle.fluid.layers.slice (ArgSpec(args=['input', 'axes', 'starts', 'ends'], varargs=None, keywords=None, defaults=None), ('document', '315b4870f294e33a27ecbdf440bed3ff')) -paddle.fluid.layers.strided_slice (ArgSpec(args=['input', 'axes', 'starts', 'ends', 'strides'], varargs=None, keywords=None, defaults=None), ('document', 'a2e5296d34c081f2a67890aaa5f02238')) +paddle.fluid.layers.strided_slice (ArgSpec(args=['input', 'axes', 'starts', 'ends', 'strides'], varargs=None, keywords=None, defaults=None), ('document', '340d8d656272ea396b441aab848429a2')) paddle.fluid.layers.shape (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'bf61c8f79d795a8371bdb3b5468aa82b')) paddle.fluid.layers.rank (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', '096df0e0273145ab80ed119a4c294db3')) paddle.fluid.layers.size (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'cf2e156beae36378722666c4c33bebfe')) diff --git a/paddle/fluid/operators/strided_slice_op.cc b/paddle/fluid/operators/strided_slice_op.cc index 7b0cc432f3..b6bbb071ac 100644 --- a/paddle/fluid/operators/strided_slice_op.cc +++ b/paddle/fluid/operators/strided_slice_op.cc @@ -15,7 +15,9 @@ limitations under the License. */ #include "paddle/fluid/operators/strided_slice_op.h" #include #include +#include #include +#include "paddle/fluid/operators/slice_op.h" namespace paddle { namespace operators { @@ -26,7 +28,7 @@ class StridedSliceOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, "Input (Input) of slice op should not be null."); PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, @@ -39,56 +41,56 @@ class StridedSliceOp : public framework::OperatorWithKernel { auto ends = ctx->Attrs().Get>("ends"); auto strides = ctx->Attrs().Get>("strides"); auto axes = ctx->Attrs().Get>("axes"); + auto infer_flags = ctx->Attrs().Get>("infer_flags"); - PADDLE_ENFORCE_EQ(starts.size(), ends.size(), - "starts and ends dim size must to be same"); - PADDLE_ENFORCE_EQ(ends.size(), strides.size(), - "ends and strides dim size must to be same"); - PADDLE_ENFORCE_EQ(ends.size(), axes.size(), - "axes, end and start dim size must to be same"); + auto starts_size = starts.size(); + auto ends_size = ends.size(); + auto strides_size = strides.size(); + if (ctx->HasInputs("StartsTensorList")) { + auto StartsTensorList = ctx->Inputs("StartsTensorList"); + PADDLE_ENFORCE_GT(StartsTensorList.size(), 0, + "StartsTensorList size can't be zero"); + starts_size = StartsTensorList.size(); + } + if (ctx->HasInputs("EndsTensorList")) { + auto EndsTensorList = ctx->Inputs("EndsTensorList"); + PADDLE_ENFORCE_GT(EndsTensorList.size(), 0, + "EndsTensorList size can't be zero"); + ends_size = EndsTensorList.size(); + } + if (ctx->HasInputs("StridesTensorList")) { + auto StridesTensorList = ctx->Inputs("StridesTensorList"); + PADDLE_ENFORCE_GT(StridesTensorList.size(), 0, + "StridesTensorList size can't be zero"); + strides_size = StridesTensorList.size(); + } + + auto tensor_input = false; + if (ctx->HasInput("EndsTensor") || ctx->HasInput("StartsTensor") || + ctx->HasInput("StridesTensor")) { + tensor_input = true; + } + if (ctx->HasInput("EndsTensor") == false) { + PADDLE_ENFORCE_EQ(ends_size, axes.size(), + "The size of ends must be equal to the size of axes."); + } + if (ctx->HasInput("StartsTensor") == false) { + PADDLE_ENFORCE_EQ( + starts_size, axes.size(), + "The size of starts must be equal to the size of axes."); + } + if (ctx->HasInput("StridesTensor") == false) { + PADDLE_ENFORCE_EQ( + strides_size, axes.size(), + "The size of strides must be equal to the size of axes."); + } // we need to analysis strided slice op is valid for // the parameter that we get from python front - int stride_index, start_index, end_index; - std::vector out_dims_vector(in_dims.size()); - for (int i = 0; i < in_dims.size(); i++) { - out_dims_vector[i] = in_dims[i]; - } - for (size_t i = 0; i < starts.size(); i++) { - PADDLE_ENFORCE_NE(strides[i], 0, "stride must not to be zero"); - int axes_index = axes[i]; - start_index = starts[i]; - end_index = ends[i]; - stride_index = strides[i]; - int axis_size = in_dims[axes_index]; - if (axis_size < 0) { - continue; - } - - if (start_index < 0) { - start_index = start_index + axis_size; - } - if (end_index < 0) { - end_index = end_index + axis_size; - } - - if (stride_index < 0) { - start_index = start_index + 1; - end_index = end_index + 1; - } - - bool zero_dim_condition = - ((stride_index < 0 && (start_index <= end_index)) || - (stride_index > 0 && (start_index >= end_index))); - PADDLE_ENFORCE_EQ(zero_dim_condition, false, - "starts and end must meet requirement in different " - "stride conditiont"); - int left = std::max(0, std::min(start_index, end_index)); - int right = std::min(axis_size, std::max(start_index, end_index)); - int step = std::abs(stride_index); - auto out_dims_index = (std::abs(right - left) + step - 1) / step; - - out_dims_vector[axes_index] = out_dims_index; + std::vector out_dims_vector(in_dims.size(), -1); + if (!tensor_input) { + StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims, + out_dims_vector.data(), axes.size(), true); } framework::DDim out_dims(framework::make_ddim(out_dims_vector)); @@ -98,26 +100,83 @@ class StridedSliceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { return framework::OpKernelType(ctx.Input("Input")->type(), ctx.Input("Input")->place()); } + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (var_name == "StartsTensor" || var_name == "EndsTensor" || + var_name == "StridesTensor") { + return expected_kernel_type; + } + if (var_name == "StartsTensorList" || var_name == "EndsTensorList" || + var_name == "StridesTensorList") { + return expected_kernel_type; + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } }; class StridedSliceOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("Input", "Tensor of data to extract slices from."); - AddOutput("Out", "Sliced data tensor."); + AddOutput("Out", "Strided Sliced data tensor."); + AddInput("StartsTensor", + "(Tensor, optional) If provided, slice will use this." + "It has the highest priority of StartsTensor, StartsTensorList " + "and attr(starts).") + .AsDispensable(); + AddInput("EndsTensor", + "(Tensor, optional) If provided, slice will use this." + "It has the highest priority of EndsTensor, EndsTensorList and " + "attr(ends).") + .AsDispensable(); + AddInput( + "StridesTensor", + "(Tensor, optional) If provided, slice will use this." + "It has the highest priority of StridesTensor, StridesTensorList and " + "attr(ends).") + .AsDispensable(); + AddInput( + "StartsTensorList", + "(vector>, optional) If provided, slice will use this." + "The shape of the tensor in vector MUST BE [1]." + "It has higher priority compare with attr(starts).") + .AsDuplicable() + .AsDispensable(); + AddInput( + "EndsTensorList", + "(vector>, optional) If provided, slice will use this." + "The shape of the tensor in vector MUST BE [1]." + "It has higher priority compare with attr(ends).") + .AsDuplicable() + .AsDispensable(); + AddInput( + "StridesTensorList", + "(vector>, optional) If provided, slice will use this." + "The shape of the tensor in vector MUST BE [1]." + "It has higher priority compare with attr(strides).") + .AsDuplicable() + .AsDispensable(); AddAttr>( - "axes", "(list Axes stride from the start to the end)"); + "axes", "(list) Axes that `starts` and `ends` apply to."); AddAttr>( - "starts", "(list) start that the tensor slice start."); + "starts", "(list) Start indices for the strided slice start.") + .SetDefault({}); AddAttr>("ends", - "(list) end that the tensor slice end"); + "(list) End indices the tensor slice end") + .SetDefault({}); AddAttr>( - "strides", "(list stride stride from the start to the end)"); + "strides", "(list Stride step from the start to the end)") + .SetDefault({}); + AddAttr>( + "infer_flags", "(list) Flags of inferring dims in attributes.") + .SetDefault({}); AddComment(R"DOC( Strided Slice Operator. Instead of calling this op directly most users will want to use the @@ -133,7 +192,7 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, "Input should not be null"); PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, "Input(Out@GRAD) should not be null"); @@ -145,11 +204,23 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( ctx.Input(framework::GradVarName("Out"))->type(), ctx.GetPlace()); } + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (var_name == "StartsTensor" || var_name == "EndsTensor") { + return expected_kernel_type; + } + if (var_name == "StartsTensorList" || var_name == "EndsTensorList") { + return expected_kernel_type; + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } }; class StridedSliceOpGradMaker : public framework::SingleGradOpDescMaker { @@ -158,9 +229,15 @@ class StridedSliceOpGradMaker : public framework::SingleGradOpDescMaker { protected: std::unique_ptr Apply() const override { - auto* bind = new framework::OpDesc(); + auto *bind = new framework::OpDesc(); bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); bind->SetInput("Input", Input("Input")); + bind->SetInput("StartsTensor", Input("StartsTensor")); + bind->SetInput("EndsTensor", Input("EndsTensor")); + bind->SetInput("StridesTensor", Input("StridesTensor")); + bind->SetInput("StartsTensorList", Input("StartsTensorList")); + bind->SetInput("EndsTensorList", Input("EndsTensorList")); + bind->SetInput("StridesTensorList", Input("StridesTensorList")); bind->SetOutput(framework::GradVarName("Input"), InputGrad("Input")); bind->SetAttrMap(Attrs()); bind->SetType("strided_slice_grad"); diff --git a/paddle/fluid/operators/strided_slice_op.h b/paddle/fluid/operators/strided_slice_op.h index ac39686900..57d33f29d8 100644 --- a/paddle/fluid/operators/strided_slice_op.h +++ b/paddle/fluid/operators/strided_slice_op.h @@ -19,9 +19,62 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/slice_op.h" namespace paddle { namespace operators { +static void StridedSliceOutDims( + const std::vector& starts, const std::vector& ends, + const std::vector& strides, const std::vector& axes, + const std::vector& infer_flags, const framework::DDim in_dims, + int* out_dims_vector, const size_t size, bool infer_shape) { + for (int i = 0; i < in_dims.size(); i++) { + out_dims_vector[i] = in_dims[i]; + } + int stride_index, start_index, end_index; + for (size_t i = 0; i < size; i++) { + int axes_index = axes[i]; + if (infer_shape && infer_flags[i] == -1) { + out_dims_vector[axes_index] = -1; + continue; + } + + PADDLE_ENFORCE_NE(strides[i], 0, "stride must not to be zero"); + start_index = starts[i]; + end_index = ends[i]; + stride_index = strides[i]; + int axis_size = in_dims[axes_index]; + if (axis_size < 0) { + continue; + } + + if (start_index < 0) { + start_index = start_index + axis_size; + } + if (end_index < 0) { + end_index = end_index + axis_size; + } + + if (stride_index < 0) { + start_index = start_index + 1; + end_index = end_index + 1; + } + + bool zero_dim_condition = + ((stride_index < 0 && (start_index <= end_index)) || + (stride_index > 0 && (start_index >= end_index))); + PADDLE_ENFORCE_EQ(zero_dim_condition, false, + "starts and end must meet requirement in different " + "stride conditiont"); + int left = std::max(0, std::min(start_index, end_index)); + int right = std::min(axis_size, std::max(start_index, end_index)); + int step = std::abs(stride_index); + auto out_dims_index = (std::abs(right - left) + step - 1) / step; + + out_dims_vector[axes_index] = out_dims_index; + } +} + static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes, int* reverse_axis, const framework::DDim dims, const size_t size) { @@ -91,19 +144,52 @@ class StridedSliceKernel : public framework::OpKernel { *context.template device_context().eigen_device(); auto in = context.Input("Input"); auto out = context.Output("Out"); - auto out_dims = out->dims(); auto in_dims = in->dims(); auto starts = context.Attr>("starts"); auto ends = context.Attr>("ends"); auto strides = context.Attr>("strides"); auto axes = context.Attr>("axes"); + auto infer_flags = context.Attr>("infer_flags"); auto starts_indices = Eigen::DSizes(); auto ends_indices = Eigen::DSizes(); auto strides_indices = Eigen::DSizes(); auto reverse_axis = Eigen::array(); + auto list_new_ends_tensor = + context.MultiInput("EndsTensorList"); + auto list_new_starts_tensor = + context.MultiInput("StartsTensorList"); + auto list_new_strides_tensor = + context.MultiInput("StridesTensorList"); + + if (list_new_starts_tensor.size() > 0) { + starts = get_new_data_from_tensorlist(list_new_starts_tensor); + } else if (context.HasInput("StartsTensor")) { + auto* starts_tensor = context.Input("StartsTensor"); + starts = get_new_data_from_tensor(starts_tensor); + } + + if (list_new_ends_tensor.size() > 0) { + ends = get_new_data_from_tensorlist(list_new_ends_tensor); + } else if (context.HasInput("EndsTensor")) { + auto* ends_tensor = context.Input("EndsTensor"); + ends = get_new_data_from_tensor(ends_tensor); + } + + if (list_new_strides_tensor.size() > 0) { + strides = get_new_data_from_tensorlist(list_new_strides_tensor); + } else if (context.HasInput("StridesTensor")) { + auto* strides_tensor = context.Input("StridesTensor"); + strides = get_new_data_from_tensor(strides_tensor); + } + + std::vector out_dims_vector(in_dims.size(), -1); + StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims, + out_dims_vector.data(), axes.size(), false); + framework::DDim out_dims(framework::make_ddim(out_dims_vector)); + std::vector reverse_vector(starts.size(), 0); StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(), reverse_vector.data(), in_dims, starts.size()); @@ -112,6 +198,7 @@ class StridedSliceKernel : public framework::OpKernel { starts_indices[axis] = 0; ends_indices[axis] = out_dims[axis]; strides_indices[axis] = 1; + reverse_axis[axis] = false; } for (size_t axis = 0; axis < axes.size(); axis++) { int axis_index = axes[axis]; @@ -124,6 +211,7 @@ class StridedSliceKernel : public framework::OpKernel { framework::Tensor tmp; tmp.mutable_data(out_dims, context.GetPlace()); + out->Resize(out_dims); out->mutable_data(context.GetPlace()); auto in_t = framework::EigenTensor::From( @@ -189,6 +277,34 @@ class StridedSliceGradKernel : public framework::OpKernel { auto strides = context.Attr>("strides"); auto axes = context.Attr>("axes"); + auto list_new_ends_tensor = + context.MultiInput("EndsTensorList"); + auto list_new_starts_tensor = + context.MultiInput("StartsTensorList"); + auto list_new_strides_tensor = + context.MultiInput("StridesTensorList"); + + if (list_new_starts_tensor.size() > 0) { + starts = get_new_data_from_tensorlist(list_new_starts_tensor); + } else if (context.HasInput("StartsTensor")) { + auto* starts_tensor = context.Input("StartsTensor"); + starts = get_new_data_from_tensor(starts_tensor); + } + + if (list_new_ends_tensor.size() > 0) { + ends = get_new_data_from_tensorlist(list_new_ends_tensor); + } else if (context.HasInput("EndsTensor")) { + auto* ends_tensor = context.Input("EndsTensor"); + ends = get_new_data_from_tensor(ends_tensor); + } + + if (list_new_strides_tensor.size() > 0) { + strides = get_new_data_from_tensorlist(list_new_strides_tensor); + } else if (context.HasInput("StridesTensor")) { + auto* strides_tensor = context.Input("StridesTensor"); + strides = get_new_data_from_tensor(strides_tensor); + } + auto starts_indices = Eigen::DSizes(); auto ends_indices = Eigen::DSizes(); auto strides_indices = Eigen::DSizes(); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 270924dd77..ebd81515ad 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -11396,57 +11396,148 @@ def strided_slice(input, axes, starts, ends, strides): axes = [0, 1] starts = [1, 0] ends = [2, 3] - strides = [1, 1] + strides=[1, 1] Then: - result = [ [5, 6, 7] ] + result = [ [5, 6, 7], ] Case2: Given: data = [ [1, 2, 3, 4], [5, 6, 7, 8], ] axes = [0, 1] - starts = [0, -1] - ends = [-1, 0] - strides = [1, -1] + starts = [0, 1] + ends = [-1, 1000] + strides = [1, 3] Then: - result = [ [4, 3, 2] ] - Atrgs: - input (Varibale): the input variable. - axes(List):axis we need to slice - starts (List): the start index in axis - ends (List): the end index in axis - strides (List): the stride length when we do slice operation - Returns - out(Variable): the result by strided_slice Op - + result = [ [2], ] + Args: + input (Variable): ${input_comment}. + axes (List): ${axes_comment} + starts (List|Variable): ${starts_comment} + ends (List|Variable): ${ends_comment} + + Returns: + out (Variable): ${out_comment} + Examples: .. code-block:: python import paddle.fluid as fluid - - starts = [1, 0, 2] - ends = [3, 3, 4] - axes = [0, 1, 2] - strides= [1, 1, 1] input = fluid.layers.data( name="input", shape=[3, 4, 5, 6], dtype='float32') - out = fluid.layers.strided_slice(input, axes=axes, starts=starts, ends=ends, strides=strides) + # example 1: + # attr starts is a list which doesn't contain tensor Variable. + axes = [0, 1, 2] + starts = [-3, 0, 2] + ends = [3, 2, 4] + strides=[1, 1, 1] + sliced_1 = fluid.layers.strided_slice(input, axes=axes, starts=starts, ends=ends, strides=strides) + + # example 2: + # attr starts is a list which contain tensor Variable. + minus_3 = fluid.layers.fill_constant([1], "int32", -3) + sliced_2 = fluid.layers.strided_slice(input, axes=axes, starts=[minus_3, 0, 2], ends=ends, strides=strides) """ + if not isinstance(starts, (list, tuple, Variable)): + raise ValueError( + "Input starts must be an Variable, python list or tuple.") + if not isinstance(ends, (list, tuple, Variable)): + raise ValueError( + "Input ends must be an Variable, python list or tuple.") + if not isinstance(strides, (list, tuple, Variable)): + raise ValueError( + "Input strides must be an Variable, python list or tuple.") + helper = LayerHelper('strided_slice', **locals()) - out = helper.create_variable_for_type_inference( - dtype=helper.input_dtype('input')) - helper.append_op( - type='strided_slice', - inputs={'Input': input}, - outputs={'Out': out}, - attrs={ + def contain_var(one_list): + for ele in one_list: + if isinstance(ele, Variable): + return True + return False + + def get_new_list_tensor(old_list): + new_list_tensor = [] + for dim in old_list: + if isinstance(dim, Variable): + dim.stop_gradient = True + new_list_tensor.append(dim) + else: + assert (isinstance(dim, int)) + temp_out = helper.create_variable_for_type_inference('int32') + fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out) + new_list_tensor.append(temp_out) + return new_list_tensor + + inputs = {'Input': input} + attrs = {'axes': axes} + infer_flags = list(1 for i in range(len(axes))) + + if in_dygraph_mode(): + inputs = {'Input': input} + attrs = { 'axes': axes, 'starts': starts, 'ends': ends, - 'strides': strides - }) + 'strides': strides, + 'infer_flags': infer_flags + } + else: + # starts + if isinstance(starts, Variable): + starts.stop_gradient = True + inputs['StartsTensor'] = starts + elif isinstance(starts, (list, tuple)): + attrs['starts'] = [] + if not contain_var(starts): + attrs['starts'] = starts + else: + inputs['StartsTensorList'] = get_new_list_tensor(starts) + for i, dim in enumerate(starts): + if isinstance(dim, Variable): + attrs['starts'].append(-1) + infer_flags[i] = -1 + else: + attrs['starts'].append(dim) + + # ends + if isinstance(ends, Variable): + ends.stop_gradient = True + inputs['EndsTensor'] = ends + elif isinstance(ends, (list, tuple)): + attrs['ends'] = [] + if not contain_var(ends): + attrs['ends'] = ends + else: + inputs['EndsTensorList'] = get_new_list_tensor(ends) + for i, dim in enumerate(ends): + if isinstance(dim, Variable): + attrs['ends'].append(-1) + infer_flags[i] = -1 + else: + attrs['ends'].append(dim) + # strides + if isinstance(strides, Variable): + strides.stop_gradient = True + inputs['StridesTensor'] = strides + elif isinstance(strides, (list, tuple)): + attrs['strides'] = [] + if not contain_var(strides): + attrs['strides'] = strides + else: + inputs['StridesTensorList'] = get_new_list_tensor(strides) + for i, dim in enumerate(strides): + if isinstance(dim, Variable): + attrs['strides'].append(-1) + infer_flags[i] = -1 + else: + attrs['strides'].append(dim) + attrs['infer_flags'] = infer_flags + out = helper.create_variable_for_type_inference( + dtype=helper.input_dtype('input')) + helper.append_op( + type='strided_slice', inputs=inputs, attrs=attrs, outputs={'Out': out}) return out diff --git a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py index d7e79a91ed..bb327a8bd7 100644 --- a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py @@ -15,6 +15,7 @@ from op_test import OpTest import numpy as np import unittest +import paddle.fluid as fluid def strided_slice_native_forward(input, axes, starts, ends, strides): @@ -63,7 +64,8 @@ class TestStrideSliceOp(OpTest): 'axes': self.axes, 'starts': self.starts, 'ends': self.ends, - 'strides': self.strides + 'strides': self.strides, + 'infer_flags': self.infer_flags } def test_check_output(self): @@ -78,6 +80,7 @@ class TestStrideSliceOp(OpTest): self.starts = [-4] self.ends = [-3] self.strides = [1] + self.infer_flags = [1] class TestStrideSliceOp1(TestStrideSliceOp): @@ -87,6 +90,7 @@ class TestStrideSliceOp1(TestStrideSliceOp): self.starts = [3] self.ends = [8] self.strides = [1] + self.infer_flags = [1] class TestStrideSliceOp2(TestStrideSliceOp): @@ -96,6 +100,7 @@ class TestStrideSliceOp2(TestStrideSliceOp): self.starts = [5] self.ends = [0] self.strides = [-1] + self.infer_flags = [1] class TestStrideSliceOp3(TestStrideSliceOp): @@ -105,6 +110,7 @@ class TestStrideSliceOp3(TestStrideSliceOp): self.starts = [-1] self.ends = [-3] self.strides = [-1] + self.infer_flags = [1] class TestStrideSliceOp4(TestStrideSliceOp): @@ -114,6 +120,7 @@ class TestStrideSliceOp4(TestStrideSliceOp): self.starts = [0, -1, 0] self.ends = [2, -3, 5] self.strides = [1, -1, 1] + self.infer_flags = [1, 1, 1] class TestStrideSliceOp5(TestStrideSliceOp): @@ -123,6 +130,7 @@ class TestStrideSliceOp5(TestStrideSliceOp): self.starts = [1, 0, 0] self.ends = [2, 1, 3] self.strides = [1, 1, 1] + self.infer_flags = [1, 1, 1] class TestStrideSliceOp6(TestStrideSliceOp): @@ -132,6 +140,7 @@ class TestStrideSliceOp6(TestStrideSliceOp): self.starts = [1, -1, 0] self.ends = [2, -3, 3] self.strides = [1, -1, 1] + self.infer_flags = [1, 1, 1] class TestStrideSliceOp7(TestStrideSliceOp): @@ -141,6 +150,7 @@ class TestStrideSliceOp7(TestStrideSliceOp): self.starts = [1, 0, 0] self.ends = [2, 2, 3] self.strides = [1, 1, 1] + self.infer_flags = [1, 1, 1] class TestStrideSliceOp8(TestStrideSliceOp): @@ -150,6 +160,7 @@ class TestStrideSliceOp8(TestStrideSliceOp): self.starts = [1] self.ends = [2] self.strides = [1] + self.infer_flags = [1] class TestStrideSliceOp9(TestStrideSliceOp): @@ -159,6 +170,7 @@ class TestStrideSliceOp9(TestStrideSliceOp): self.starts = [-1] self.ends = [-2] self.strides = [-1] + self.infer_flags = [1] class TestStrideSliceOp10(TestStrideSliceOp): @@ -168,6 +180,7 @@ class TestStrideSliceOp10(TestStrideSliceOp): self.starts = [1, 0] self.ends = [2, 2] self.strides = [1, 1] + self.infer_flags = [1, 1] class TestStrideSliceOp11(TestStrideSliceOp): @@ -177,6 +190,7 @@ class TestStrideSliceOp11(TestStrideSliceOp): self.starts = [1, 0, 0, 0] self.ends = [2, 2, 3, 4] self.strides = [1, 1, 1, 2] + self.infer_flags = [1, 1, 1, 1] class TestStrideSliceOp12(TestStrideSliceOp): @@ -186,6 +200,7 @@ class TestStrideSliceOp12(TestStrideSliceOp): self.starts = [1, 0, 0, 0, 0] self.ends = [2, 2, 3, 4, 4] self.strides = [1, 1, 1, 1, 1] + self.infer_flags = [1, 1, 1, 1] class TestStrideSliceOp13(TestStrideSliceOp): @@ -195,6 +210,295 @@ class TestStrideSliceOp13(TestStrideSliceOp): self.starts = [1, 0, 0, 0, 1, 2] self.ends = [2, 2, 3, 1, 2, 8] self.strides = [1, 1, 1, 1, 1, 2] + self.infer_flags = [1, 1, 1, 1, 1] + + +class TestStridedSliceOp_starts_ListTensor(OpTest): + def setUp(self): + self.op_type = "strided_slice" + self.config() + + starts_tensor = [] + for index, ele in enumerate(self.starts): + starts_tensor.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + + self.inputs = {'Input': self.input, 'StartsTensorList': starts_tensor} + self.outputs = {'Out': self.output} + self.attrs = { + 'axes': self.axes, + 'starts': self.starts_infer, + 'ends': self.ends, + 'strides': self.strides, + 'infer_flags': self.infer_flags + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float32") + self.starts = [1, 0, 2] + self.ends = [3, 3, 4] + self.axes = [0, 1, 2] + self.strides = [1, 1, 1] + self.infer_flags = [1, -1, 1] + self.output = strided_slice_native_forward( + self.input, self.axes, self.starts, self.ends, self.strides) + + self.starts_infer = [1, 10, 2] + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['Input'], 'Out', max_relative_error=0.006) + + +class TestStridedSliceOp_ends_ListTensor(OpTest): + def setUp(self): + self.op_type = "strided_slice" + self.config() + + ends_tensor = [] + for index, ele in enumerate(self.ends): + ends_tensor.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + + self.inputs = {'Input': self.input, 'EndsTensorList': ends_tensor} + self.outputs = {'Out': self.output} + self.attrs = { + 'axes': self.axes, + 'starts': self.starts, + 'ends': self.ends_infer, + 'strides': self.strides, + 'infer_flags': self.infer_flags + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float32") + self.starts = [1, 0, 0] + self.ends = [3, 3, 4] + self.axes = [0, 1, 2] + self.strides = [1, 1, 2] + self.infer_flags = [1, -1, 1] + self.output = strided_slice_native_forward( + self.input, self.axes, self.starts, self.ends, self.strides) + + self.ends_infer = [3, 1, 4] + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['Input'], 'Out', max_relative_error=0.006) + + +class TestStridedSliceOp_starts_Tensor(OpTest): + def setUp(self): + self.op_type = "strided_slice" + self.config() + self.inputs = { + 'Input': self.input, + "StartsTensor": np.array( + self.starts, dtype="int32") + } + self.outputs = {'Out': self.output} + self.attrs = { + 'axes': self.axes, + #'starts': self.starts, + 'ends': self.ends, + 'strides': self.strides, + 'infer_flags': self.infer_flags, + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float32") + self.starts = [1, 0, 2] + self.ends = [2, 3, 4] + self.axes = [0, 1, 2] + self.strides = [1, 1, 1] + self.infer_flags = [-1, -1, -1] + self.output = strided_slice_native_forward( + self.input, self.axes, self.starts, self.ends, self.strides) + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['Input'], 'Out', max_relative_error=0.006) + + +class TestStridedSliceOp_ends_Tensor(OpTest): + def setUp(self): + self.op_type = "strided_slice" + self.config() + self.inputs = { + 'Input': self.input, + "EndsTensor": np.array( + self.ends, dtype="int32") + } + self.outputs = {'Out': self.output} + self.attrs = { + 'axes': self.axes, + 'starts': self.starts, + #'ends': self.ends, + 'strides': self.strides, + 'infer_flags': self.infer_flags, + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float32") + self.starts = [1, 0, 2] + self.ends = [2, 3, 4] + self.axes = [0, 1, 2] + self.strides = [1, 1, 1] + self.infer_flags = [-1, -1, -1] + self.output = strided_slice_native_forward( + self.input, self.axes, self.starts, self.ends, self.strides) + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['Input'], 'Out', max_relative_error=0.006) + + +class TestStridedSliceOp_listTensor_Tensor(OpTest): + def setUp(self): + self.config() + ends_tensor = [] + for index, ele in enumerate(self.ends): + ends_tensor.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + self.op_type = "strided_slice" + + self.inputs = { + 'Input': self.input, + "StartsTensor": np.array( + self.starts, dtype="int32"), + "EndsTensorList": ends_tensor + } + self.outputs = {'Out': self.output} + self.attrs = { + 'axes': self.axes, + #'starts': self.starts, + #'ends': self.ends, + 'strides': self.strides, + 'infer_flags': self.infer_flags, + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float32") + self.starts = [1, 0, 2] + self.ends = [2, 3, 4] + self.axes = [0, 1, 2] + self.strides = [1, 1, 1] + self.infer_flags = [-1, -1, -1] + self.output = strided_slice_native_forward( + self.input, self.axes, self.starts, self.ends, self.strides) + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['Input'], 'Out', max_relative_error=0.006) + + +class TestStridedSliceOp_strides_Tensor(OpTest): + def setUp(self): + self.op_type = "strided_slice" + self.config() + self.inputs = { + 'Input': self.input, + "StridesTensor": np.array( + self.strides, dtype="int32") + } + self.outputs = {'Out': self.output} + self.attrs = { + 'axes': self.axes, + 'starts': self.starts, + 'ends': self.ends, + #'strides': self.strides, + 'infer_flags': self.infer_flags, + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float32") + self.starts = [1, -1, 2] + self.ends = [2, 0, 4] + self.axes = [0, 1, 2] + self.strides = [1, -1, 1] + self.infer_flags = [-1, -1, -1] + self.output = strided_slice_native_forward( + self.input, self.axes, self.starts, self.ends, self.strides) + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['Input'], 'Out', max_relative_error=0.006) + + +# Test python API +class TestSliceAPI(OpTest): + def test_1(self): + input = np.random.random([3, 4, 5, 6]).astype("float32") + minus_1 = fluid.layers.fill_constant([1], "int32", -1) + minus_3 = fluid.layers.fill_constant([1], "int32", -3) + starts = fluid.layers.data( + name='starts', shape=[3], append_batch_size=False) + ends = fluid.layers.data( + name='ends', shape=[3], append_batch_size=False) + strides = fluid.layers.data( + name='strides', shape=[3], append_batch_size=False) + + x = fluid.layers.data( + name="x", + shape=[3, 4, 5, 6], + append_batch_size=False, + dtype="float32") + + out_1 = fluid.layers.strided_slice( + x, + axes=[0, 1, 2], + starts=[-3, 0, 2], + ends=[3, 100, -1], + strides=[1, 1, 1]) + out_2 = fluid.layers.strided_slice( + x, + axes=[0, 1, 3], + starts=[minus_3, 0, 2], + ends=[3, 100, -1], + strides=[1, 1, 1]) + out_3 = fluid.layers.strided_slice( + x, + axes=[0, 1, 3], + starts=[minus_3, 0, 2], + ends=[3, 100, minus_1], + strides=[1, 1, 1]) + out_4 = fluid.layers.strided_slice( + x, axes=[0, 1, 2], starts=starts, ends=ends, strides=strides) + + out_5 = x[-3:3, 0:100, 2:-1] + out_6 = x[minus_3:3, 0:100, :, 2:-1] + out_7 = x[minus_1, 0:100, :, 2:minus_1] + + exe = fluid.Executor(place=fluid.CPUPlace()) + res_1, res_2, res_3, res_4, res_5, res_6, res_7 = exe.run( + fluid.default_main_program(), + feed={ + "x": input, + 'starts': np.array([-3, 0, 2]).astype("int32"), + 'ends': np.array([3, 100, -1]).astype("int32"), + 'strides': np.array([1, 1, 1]).astype("int32") + }, + fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7]) + + assert np.array_equal(res_1, input[-3:3, 0:100, 2:-1, :]) + assert np.array_equal(res_2, input[-3:3, 0:100, :, 2:-1]) + assert np.array_equal(res_3, input[-3:3, 0:100, :, 2:-1]) + assert np.array_equal(res_4, input[-3:3, 0:100, 2:-1, :]) + assert np.array_equal(res_5, input[-3:3, 0:100, 2:-1, :]) + assert np.array_equal(res_6, input[-3:3, 0:100, :, 2:-1]) + assert np.array_equal(res_7, input[-1, 0:100, :, 2:-1]) if __name__ == "__main__": -- GitLab