diff --git a/paddle/fluid/operators/strided_slice_op.cc b/paddle/fluid/operators/strided_slice_op.cc index b6bbb071acc9ad8cda52f55dc9ac7700eba8a34c..7c81d71562f01840c82171daf53acfa80d8a438e 100644 --- a/paddle/fluid/operators/strided_slice_op.cc +++ b/paddle/fluid/operators/strided_slice_op.cc @@ -42,6 +42,7 @@ class StridedSliceOp : public framework::OperatorWithKernel { auto strides = ctx->Attrs().Get>("strides"); auto axes = ctx->Attrs().Get>("axes"); auto infer_flags = ctx->Attrs().Get>("infer_flags"); + auto decrease_axis = ctx->Attrs().Get>("decrease_axis"); auto starts_size = starts.size(); auto ends_size = ends.size(); @@ -90,10 +91,32 @@ class StridedSliceOp : public framework::OperatorWithKernel { std::vector out_dims_vector(in_dims.size(), -1); if (!tensor_input) { StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims, - out_dims_vector.data(), axes.size(), true); + decrease_axis, out_dims_vector.data(), axes.size(), + true); } framework::DDim out_dims(framework::make_ddim(out_dims_vector)); + // generate new shape + if (decrease_axis.size() > 0) { + std::vector new_out_shape; + for (size_t i = 0; i < decrease_axis.size(); ++i) { + if (ctx->IsRuntime() && infer_flags[i] != -1) { + PADDLE_ENFORCE_EQ(out_dims[decrease_axis[i]], 1, + "decrease dim should be 1"); + } + out_dims[decrease_axis[i]] = 0; + } + for (int i = 0; i < out_dims.size(); ++i) { + if (out_dims[i] != 0) { + new_out_shape.push_back(out_dims[i]); + } + } + if (new_out_shape.size() == 0) { + new_out_shape.push_back(1); + } + + out_dims = framework::make_ddim(new_out_shape); + } ctx->SetOutputDim("Out", out_dims); ctx->ShareLoD("Input", /*->*/ "Out"); } @@ -177,6 +200,8 @@ class StridedSliceOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>( "infer_flags", "(list) Flags of inferring dims in attributes.") .SetDefault({}); + AddAttr>("decrease_axis", "(list) decrease_axis") + .SetDefault({}); AddComment(R"DOC( Strided Slice Operator. Instead of calling this op directly most users will want to use the @@ -212,10 +237,12 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const Tensor &tensor, const framework::OpKernelType &expected_kernel_type) const override { - if (var_name == "StartsTensor" || var_name == "EndsTensor") { + if (var_name == "StartsTensor" || var_name == "EndsTensor" || + var_name == "StridesTensor") { return expected_kernel_type; } - if (var_name == "StartsTensorList" || var_name == "EndsTensorList") { + if (var_name == "StartsTensorList" || var_name == "EndsTensorList" || + var_name == "StridesTensorList") { return expected_kernel_type; } return framework::OpKernelType(expected_kernel_type.data_type_, diff --git a/paddle/fluid/operators/strided_slice_op.h b/paddle/fluid/operators/strided_slice_op.h index 57d33f29d80fedc6bac06b60708268e41c725d26..5baacc7ea1350b9f5b7ba81ff5a31c3e75c46853 100644 --- a/paddle/fluid/operators/strided_slice_op.h +++ b/paddle/fluid/operators/strided_slice_op.h @@ -27,22 +27,34 @@ static void StridedSliceOutDims( const std::vector& starts, const std::vector& ends, const std::vector& strides, const std::vector& axes, const std::vector& infer_flags, const framework::DDim in_dims, - int* out_dims_vector, const size_t size, bool infer_shape) { + const std::vector& decrease_axis, int* out_dims_vector, + const size_t size, bool infer_shape) { for (int i = 0; i < in_dims.size(); i++) { out_dims_vector[i] = in_dims[i]; } int stride_index, start_index, end_index; for (size_t i = 0; i < size; i++) { int axes_index = axes[i]; + start_index = starts[i]; + end_index = ends[i]; + stride_index = strides[i]; + bool decrease_axis_affect = false; + if (start_index == -1 && end_index == 0 && infer_flags[i] == -1) { + auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]); + if (ret != decrease_axis.end()) { + decrease_axis_affect = true; + } + } + if (decrease_axis_affect) { + out_dims_vector[axes_index] = 1; + continue; + } if (infer_shape && infer_flags[i] == -1) { out_dims_vector[axes_index] = -1; continue; } - PADDLE_ENFORCE_NE(strides[i], 0, "stride must not to be zero"); - start_index = starts[i]; - end_index = ends[i]; - stride_index = strides[i]; + PADDLE_ENFORCE_NE(stride_index, 0, "stride must not to be zero"); int axis_size = in_dims[axes_index]; if (axis_size < 0) { continue; @@ -77,6 +89,8 @@ static void StridedSliceOutDims( static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes, int* reverse_axis, const framework::DDim dims, + const std::vector& infer_flags, + const std::vector& decrease_axis, const size_t size) { for (size_t axis = 0; axis < size; axis++) { int axis_size = dims[axes[axis]]; @@ -86,6 +100,15 @@ static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes, ends[axis_index] = 1; strides[axis_index] = 1; } + bool decrease_axis_affect = false; + if (starts[axis_index] == -1 && ends[axis_index] == 0 && + infer_flags[axis_index] == -1) { + auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), + axes[axis_index]); + if (ret != decrease_axis.end()) { + decrease_axis_affect = true; + } + } // stride must not be zero if (starts[axis_index] < 0) { starts[axis_index] = starts[axis_index] + axis_size; @@ -94,6 +117,13 @@ static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes, if (ends[axis_index] < 0) { ends[axis_index] = ends[axis_index] + axis_size; } + if (decrease_axis_affect) { + if (strides[axis_index] < 0) { + ends[axis_index] = starts[axis_index] - 1; + } else { + ends[axis_index] = starts[axis_index] + 1; + } + } if (strides[axis_index] < 0) { reverse_axis[axis_index] = 1; strides[axis_index] = -strides[axis_index]; @@ -151,6 +181,7 @@ class StridedSliceKernel : public framework::OpKernel { auto strides = context.Attr>("strides"); auto axes = context.Attr>("axes"); auto infer_flags = context.Attr>("infer_flags"); + auto decrease_axis = context.Attr>("decrease_axis"); auto starts_indices = Eigen::DSizes(); auto ends_indices = Eigen::DSizes(); @@ -187,12 +218,14 @@ class StridedSliceKernel : public framework::OpKernel { std::vector out_dims_vector(in_dims.size(), -1); StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims, - out_dims_vector.data(), axes.size(), false); + decrease_axis, out_dims_vector.data(), axes.size(), + false); framework::DDim out_dims(framework::make_ddim(out_dims_vector)); std::vector reverse_vector(starts.size(), 0); StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(), - reverse_vector.data(), in_dims, starts.size()); + reverse_vector.data(), in_dims, infer_flags, + decrease_axis, starts.size()); for (size_t axis = 0; axis < D; axis++) { starts_indices[axis] = 0; @@ -209,8 +242,28 @@ class StridedSliceKernel : public framework::OpKernel { } framework::Tensor tmp; - tmp.mutable_data(out_dims, context.GetPlace()); + auto out_dims_origin = out_dims; + if (decrease_axis.size() > 0) { + std::vector new_out_shape; + for (size_t i = 0; i < decrease_axis.size(); ++i) { + PADDLE_ENFORCE_EQ(out_dims[decrease_axis[i]], 1, + "decrease dim should be 1"); + out_dims_origin[decrease_axis[i]] = 0; + } + + for (int i = 0; i < out_dims_origin.size(); ++i) { + if (out_dims_origin[i] != 0) { + new_out_shape.push_back(out_dims_origin[i]); + } + } + if (new_out_shape.size() == 0) { + new_out_shape.push_back(1); + } + out_dims_origin = framework::make_ddim(new_out_shape); + } + + tmp.mutable_data(out_dims, context.GetPlace()); out->Resize(out_dims); out->mutable_data(context.GetPlace()); auto in_t = @@ -225,6 +278,10 @@ class StridedSliceKernel : public framework::OpKernel { tmp_t.device(place) = in_t.stridedSlice(starts_indices, ends_indices, strides_indices); out_t.device(place) = tmp_t.reverse(reverse_axis); + + if (decrease_axis.size() > 0) { + out->Resize(out_dims_origin); + } } }; @@ -276,6 +333,8 @@ class StridedSliceGradKernel : public framework::OpKernel { auto ends = context.Attr>("ends"); auto strides = context.Attr>("strides"); auto axes = context.Attr>("axes"); + auto infer_flags = context.Attr>("infer_flags"); + auto decrease_axis = context.Attr>("decrease_axis"); auto list_new_ends_tensor = context.MultiInput("EndsTensorList"); @@ -313,7 +372,8 @@ class StridedSliceGradKernel : public framework::OpKernel { std::vector reverse_vector(starts.size(), 0); StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(), - reverse_vector.data(), out_dims, starts.size()); + reverse_vector.data(), out_dims, infer_flags, + decrease_axis, starts.size()); for (size_t axis = 0; axis < D; axis++) { starts_indices[axis] = 0; diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 4191ea6ea0fb389133b9db334bc55ff2c776a817..5a68fe449ce54ff506c8833642dfe0599b168823 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1414,9 +1414,11 @@ class Variable(object): slice_axis = [] slice_start = [] slice_end = [] + slice_step = [] + use_strided_slice = False reverse_axis = [] - def fill_constant(shape, dtype, value, force_cpu=False, out=None): + def fill_constant(shape, value, force_cpu=False, out=None): self.block.append_op( type='fill_constant', inputs={}, @@ -1425,7 +1427,7 @@ class Variable(object): 'shape': shape, 'dtype': out.dtype, 'value': float(value), - 'force_cpu': force_cpu or force_init_on_cpu() + 'force_cpu': force_cpu }, stop_gradient=True) out.stop_gradient = True @@ -1435,15 +1437,17 @@ class Variable(object): if isinstance(slice_item, slice): start = slice_item.start end = slice_item.stop - step = slice_item.step if slice_item.step else 1 + step = slice_item.step - assert (step == 1 or step == -1) + if start is None and end is None and step is None: + continue - if step == -1: - reverse_axis.append(dim) - assert (start is None and end is None) + if step is None: + step = 1 if start is None and end is None: + assert (step == -1) + reverse_axis.append(dim) continue if start is None: @@ -1452,16 +1456,21 @@ class Variable(object): if end is None: end = 10000000 + if step != 1: + use_strided_slice = True + slice_axis.append(dim) slice_start.append(start) slice_end.append(end) + slice_step.append(step) else: decrease_axis.append(dim) slice_axis.append(dim) slice_start.append(slice_item) + slice_step.append(1) if isinstance(slice_item, Variable): temp_1 = self.block.create_var(dtype='int32') - fill_constant([1], 'int32', 1, force_cpu=True, out=temp_1) + fill_constant([1], 1, force_cpu=True, out=temp_1) temp_end = self.block.create_var(dtype='int32') self.block.append_op( type='elementwise_add', @@ -1489,8 +1498,7 @@ class Variable(object): else: assert (isinstance(dim, int)) temp_out = self.block.create_var(dtype='int32') - fill_constant( - [1], 'int32', dim, force_cpu=True, out=temp_out) + fill_constant([1], dim, force_cpu=True, out=temp_out) new_list_tensor.append(temp_out) return new_list_tensor @@ -1501,8 +1509,9 @@ class Variable(object): 'ends': [], 'decrease_axis': decrease_axis } + if (use_strided_slice == True): + attrs['strides'] = [] infer_flags = list(1 for i in range(len(slice_axis))) - # starts if not contain_var(slice_start): attrs['starts'] = slice_start @@ -1525,11 +1534,23 @@ class Variable(object): infer_flags[i] = -1 else: attrs['ends'].append(dim) + # strides + if use_strided_slice == True: + if not contain_var(slice_step): + attrs['strides'] = slice_step + else: + inputs['StridesTensorList'] = get_new_list_tensor(slice_step) + for i, dim in enumerate(slice_step): + if isinstance(dim, Variable): + attrs['strides'].append(-1) + infer_flags[i] = -1 + else: + attrs['strides'].append(dim) # infer_flags attrs['infer_flags'] = infer_flags out = self - if len(slice_axis) > 0: + if use_strided_slice == False and len(slice_axis) > 0: # append slice_op here slice_out_var = self.block.create_var( name=unique_name.generate_with_ignorable_key(self.name + @@ -1543,6 +1564,18 @@ class Variable(object): attrs=attrs) out = slice_out_var + elif use_strided_slice == True and len(slice_axis) > 0: + strided_slice_out_var = self.block.create_var( + name=unique_name.generate_with_ignorable_key(self.name + + "_strided_slice"), + dtype=self.dtype) + self.block.append_op( + type="strided_slice", + inputs=inputs, + outputs={'Out': [strided_slice_out_var]}, + attrs=attrs) + + out = strided_slice_out_var if len(reverse_axis) > 0: reverse_out_var = self.block.create_var( diff --git a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py index bb327a8bd7fe04aa1b6dde2ba3ed8fe03cfc854d..8012f2e80048cf772f23a51e49f59d2fed379b8b 100644 --- a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py @@ -438,7 +438,7 @@ class TestStridedSliceOp_strides_Tensor(OpTest): # Test python API -class TestSliceAPI(OpTest): +class TestStridedSliceAPI(OpTest): def test_1(self): input = np.random.random([3, 4, 5, 6]).astype("float32") minus_1 = fluid.layers.fill_constant([1], "int32", -1) @@ -455,7 +455,6 @@ class TestSliceAPI(OpTest): shape=[3, 4, 5, 6], append_batch_size=False, dtype="float32") - out_1 = fluid.layers.strided_slice( x, axes=[0, 1, 2], @@ -477,9 +476,9 @@ class TestSliceAPI(OpTest): out_4 = fluid.layers.strided_slice( x, axes=[0, 1, 2], starts=starts, ends=ends, strides=strides) - out_5 = x[-3:3, 0:100, 2:-1] - out_6 = x[minus_3:3, 0:100, :, 2:-1] - out_7 = x[minus_1, 0:100, :, 2:minus_1] + out_5 = x[-3:3, 0:100:2, -1:2:-1] + out_6 = x[minus_3:3:1, 0:100:2, :, minus_1:2:minus_1] + out_7 = x[minus_1, 0:100:2, :, -1:2:-1] exe = fluid.Executor(place=fluid.CPUPlace()) res_1, res_2, res_3, res_4, res_5, res_6, res_7 = exe.run( @@ -491,14 +490,13 @@ class TestSliceAPI(OpTest): 'strides': np.array([1, 1, 1]).astype("int32") }, fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7]) - assert np.array_equal(res_1, input[-3:3, 0:100, 2:-1, :]) assert np.array_equal(res_2, input[-3:3, 0:100, :, 2:-1]) assert np.array_equal(res_3, input[-3:3, 0:100, :, 2:-1]) assert np.array_equal(res_4, input[-3:3, 0:100, 2:-1, :]) - assert np.array_equal(res_5, input[-3:3, 0:100, 2:-1, :]) - assert np.array_equal(res_6, input[-3:3, 0:100, :, 2:-1]) - assert np.array_equal(res_7, input[-1, 0:100, :, 2:-1]) + assert np.array_equal(res_5, input[-3:3, 0:100:2, -1:2:-1, :]) + assert np.array_equal(res_6, input[-3:3, 0:100:2, :, -1:2:-1]) + assert np.array_equal(res_7, input[-1, 0:100:2, :, -1:2:-1]) if __name__ == "__main__":