From 47af618f70c37e81ae5f78c1549a556be6ecd88f Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Thu, 19 Sep 2019 07:16:42 +0800 Subject: [PATCH] Strided slice (#19642) * strided_slice op basic function test=develop * test=develop rewrite and fix * fix bug test=develop * fix for the PADDLE_ENFORCE usage * add some unit testw * fix for the aip test and copright and fix test=develop * fix API.spec test=develop * fix API.spec test=develop * add axis parameter test=develop * fix for the build error test=develop * fix python api test=develop * fix the build test=develop * fix build test=develop * fix API spec test=develop * test=develop add some comment and single op test * fix API spece test=develop * fix test=develop * fix test=develop * fix api test=develop * fix api test=develop * fix API.spec test=develop * fix typo test=develop * fix API.spec test=develop * fix API typo test=develop * fix doc and API.spec test=develop --- paddle/fluid/API.spec | 1 + paddle/fluid/operators/strided_slice_op.cc | 195 +++++++++++++++ paddle/fluid/operators/strided_slice_op.cu | 30 +++ paddle/fluid/operators/strided_slice_op.h | 234 ++++++++++++++++++ python/paddle/fluid/layers/nn.py | 80 ++++++ .../fluid/tests/unittests/test_layers.py | 11 + .../tests/unittests/test_strided_slice_op.py | 201 +++++++++++++++ 7 files changed, 752 insertions(+) create mode 100644 paddle/fluid/operators/strided_slice_op.cc create mode 100644 paddle/fluid/operators/strided_slice_op.cu create mode 100644 paddle/fluid/operators/strided_slice_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_strided_slice_op.py diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 5252ca5d10a..f84a87ab825 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -244,6 +244,7 @@ paddle.fluid.layers.sampling_id (ArgSpec(args=['x', 'min', 'max', 'seed', 'dtype paddle.fluid.layers.gaussian_random_batch_size_like (ArgSpec(args=['input', 'shape', 'input_dim_idx', 'output_dim_idx', 'mean', 'std', 'seed', 'dtype'], varargs=None, keywords=None, defaults=(0, 0, 0.0, 1.0, 0, 'float32')), ('document', 'b24d0b21361c4bb8ef2cec8c26fb12b2')) paddle.fluid.layers.sum (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'f4b60847cb0f1ae00823ba6fb1b11310')) paddle.fluid.layers.slice (ArgSpec(args=['input', 'axes', 'starts', 'ends'], varargs=None, keywords=None, defaults=None), ('document', '315b4870f294e33a27ecbdf440bed3ff')) +paddle.fluid.layers.strided_slice (ArgSpec(args=['input', 'axes', 'starts', 'ends', 'strides'], varargs=None, keywords=None, defaults=None), ('document', 'a2e5296d34c081f2a67890aaa5f02238')) paddle.fluid.layers.shape (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'bf61c8f79d795a8371bdb3b5468aa82b')) paddle.fluid.layers.rank (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', '096df0e0273145ab80ed119a4c294db3')) paddle.fluid.layers.size (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'cf2e156beae36378722666c4c33bebfe')) diff --git a/paddle/fluid/operators/strided_slice_op.cc b/paddle/fluid/operators/strided_slice_op.cc new file mode 100644 index 00000000000..7b0cc432f39 --- /dev/null +++ b/paddle/fluid/operators/strided_slice_op.cc @@ -0,0 +1,195 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/strided_slice_op.h" +#include +#include +#include + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class StridedSliceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, + "Input (Input) of slice op should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output (Out) of slice op should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE_LT(in_dims.size(), 7, + "The rank of input should be less than 7."); + auto starts = ctx->Attrs().Get>("starts"); + auto ends = ctx->Attrs().Get>("ends"); + auto strides = ctx->Attrs().Get>("strides"); + auto axes = ctx->Attrs().Get>("axes"); + + PADDLE_ENFORCE_EQ(starts.size(), ends.size(), + "starts and ends dim size must to be same"); + PADDLE_ENFORCE_EQ(ends.size(), strides.size(), + "ends and strides dim size must to be same"); + PADDLE_ENFORCE_EQ(ends.size(), axes.size(), + "axes, end and start dim size must to be same"); + + // we need to analysis strided slice op is valid for + // the parameter that we get from python front + int stride_index, start_index, end_index; + std::vector out_dims_vector(in_dims.size()); + for (int i = 0; i < in_dims.size(); i++) { + out_dims_vector[i] = in_dims[i]; + } + for (size_t i = 0; i < starts.size(); i++) { + PADDLE_ENFORCE_NE(strides[i], 0, "stride must not to be zero"); + int axes_index = axes[i]; + start_index = starts[i]; + end_index = ends[i]; + stride_index = strides[i]; + int axis_size = in_dims[axes_index]; + if (axis_size < 0) { + continue; + } + + if (start_index < 0) { + start_index = start_index + axis_size; + } + if (end_index < 0) { + end_index = end_index + axis_size; + } + + if (stride_index < 0) { + start_index = start_index + 1; + end_index = end_index + 1; + } + + bool zero_dim_condition = + ((stride_index < 0 && (start_index <= end_index)) || + (stride_index > 0 && (start_index >= end_index))); + PADDLE_ENFORCE_EQ(zero_dim_condition, false, + "starts and end must meet requirement in different " + "stride conditiont"); + int left = std::max(0, std::min(start_index, end_index)); + int right = std::min(axis_size, std::max(start_index, end_index)); + int step = std::abs(stride_index); + auto out_dims_index = (std::abs(right - left) + step - 1) / step; + + out_dims_vector[axes_index] = out_dims_index; + } + framework::DDim out_dims(framework::make_ddim(out_dims_vector)); + + ctx->SetOutputDim("Out", out_dims); + ctx->ShareLoD("Input", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("Input")->type(), + ctx.Input("Input")->place()); + } +}; + +class StridedSliceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", "Tensor of data to extract slices from."); + AddOutput("Out", "Sliced data tensor."); + + AddAttr>( + "axes", "(list Axes stride from the start to the end)"); + AddAttr>( + "starts", "(list) start that the tensor slice start."); + AddAttr>("ends", + "(list) end that the tensor slice end"); + AddAttr>( + "strides", "(list stride stride from the start to the end)"); + AddComment(R"DOC( +Strided Slice Operator. +Instead of calling this op directly most users will want to use the +NumPy-style slicing syntax. +For Example: +data = fluid.layers.fill_constant(shape=[3, 3], value=0, dtype='int64') +y = fluid.layers.strided_slice(data, [0, 1], [1,0], [2, 3], [1, 1]) +)DOC"); + } +}; + +class StridedSliceOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, "Input should not be null"); + PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx->GetInputDim("Input"); + auto x_grad_name = framework::GradVarName("Input"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.GetPlace()); + } +}; + +class StridedSliceOpGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* bind = new framework::OpDesc(); + bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + bind->SetInput("Input", Input("Input")); + bind->SetOutput(framework::GradVarName("Input"), InputGrad("Input")); + bind->SetAttrMap(Attrs()); + bind->SetType("strided_slice_grad"); + return std::unique_ptr(bind); + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE( + StridedSliceOpGradNoNeedBufferVarsInference, "Input"); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(strided_slice, ops::StridedSliceOp, ops::StridedSliceOpMaker, + ops::StridedSliceOpGradMaker); +REGISTER_OPERATOR(strided_slice_grad, ops::StridedSliceOpGrad, + ops::StridedSliceOpGradNoNeedBufferVarsInference); + +REGISTER_OP_CPU_KERNEL( + strided_slice, + ops::StridedSliceKernel, + ops::StridedSliceKernel, + ops::StridedSliceKernel, + ops::StridedSliceKernel); + +REGISTER_OP_CPU_KERNEL( + strided_slice_grad, + ops::StridedSliceGradKernel, + ops::StridedSliceGradKernel, + ops::StridedSliceGradKernel, + ops::StridedSliceGradKernel); diff --git a/paddle/fluid/operators/strided_slice_op.cu b/paddle/fluid/operators/strided_slice_op.cu new file mode 100644 index 00000000000..f0c9d557b9a --- /dev/null +++ b/paddle/fluid/operators/strided_slice_op.cu @@ -0,0 +1,30 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/strided_slice_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + strided_slice, + ops::StridedSliceKernel, + ops::StridedSliceKernel, + ops::StridedSliceKernel, + ops::StridedSliceKernel); + +REGISTER_OP_CUDA_KERNEL( + strided_slice_grad, + ops::StridedSliceGradKernel, + ops::StridedSliceGradKernel, + ops::StridedSliceGradKernel, + ops::StridedSliceGradKernel); diff --git a/paddle/fluid/operators/strided_slice_op.h b/paddle/fluid/operators/strided_slice_op.h new file mode 100644 index 00000000000..ac396869003 --- /dev/null +++ b/paddle/fluid/operators/strided_slice_op.h @@ -0,0 +1,234 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" +namespace paddle { +namespace operators { + +static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes, + int* reverse_axis, const framework::DDim dims, + const size_t size) { + for (size_t axis = 0; axis < size; axis++) { + int axis_size = dims[axes[axis]]; + int axis_index = axis; + if (axis_size < 0) { + starts[axis_index] = 0; + ends[axis_index] = 1; + strides[axis_index] = 1; + } + // stride must not be zero + if (starts[axis_index] < 0) { + starts[axis_index] = starts[axis_index] + axis_size; + } + + if (ends[axis_index] < 0) { + ends[axis_index] = ends[axis_index] + axis_size; + } + if (strides[axis_index] < 0) { + reverse_axis[axis_index] = 1; + strides[axis_index] = -strides[axis_index]; + if (starts[axis_index] > ends[axis_index]) { + // swap the reverse + starts[axis_index] = starts[axis_index] + 1; + ends[axis_index] = ends[axis_index] + 1; + } + std::swap(starts[axis_index], ends[axis_index]); + } else { + reverse_axis[axis_index] = 0; + strides[axis_index] = strides[axis_index]; + } + } +} + +template +class StridedSliceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + int rank = ctx.Input("Input")->dims().size(); + switch (rank) { + case 1: + StridedSliceCompute<1>(ctx); + break; + case 2: + StridedSliceCompute<2>(ctx); + break; + case 3: + StridedSliceCompute<3>(ctx); + break; + case 4: + StridedSliceCompute<4>(ctx); + break; + case 5: + StridedSliceCompute<5>(ctx); + break; + case 6: + StridedSliceCompute<6>(ctx); + break; + } + } + + private: + template + void StridedSliceCompute(const framework::ExecutionContext& context) const { + auto& place = + *context.template device_context().eigen_device(); + auto in = context.Input("Input"); + auto out = context.Output("Out"); + auto out_dims = out->dims(); + auto in_dims = in->dims(); + + auto starts = context.Attr>("starts"); + auto ends = context.Attr>("ends"); + auto strides = context.Attr>("strides"); + auto axes = context.Attr>("axes"); + + auto starts_indices = Eigen::DSizes(); + auto ends_indices = Eigen::DSizes(); + auto strides_indices = Eigen::DSizes(); + auto reverse_axis = Eigen::array(); + + std::vector reverse_vector(starts.size(), 0); + StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(), + reverse_vector.data(), in_dims, starts.size()); + + for (size_t axis = 0; axis < D; axis++) { + starts_indices[axis] = 0; + ends_indices[axis] = out_dims[axis]; + strides_indices[axis] = 1; + } + for (size_t axis = 0; axis < axes.size(); axis++) { + int axis_index = axes[axis]; + starts_indices[axis_index] = starts[axis]; + ends_indices[axis_index] = ends[axis]; + strides_indices[axis_index] = strides[axis]; + reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false; + } + + framework::Tensor tmp; + tmp.mutable_data(out_dims, context.GetPlace()); + + out->mutable_data(context.GetPlace()); + auto in_t = + framework::EigenTensor::From( + *in); + auto tmp_t = + framework::EigenTensor::From( + tmp); + auto out_t = + framework::EigenTensor::From( + *out, out_dims); + tmp_t.device(place) = + in_t.stridedSlice(starts_indices, ends_indices, strides_indices); + out_t.device(place) = tmp_t.reverse(reverse_axis); + } +}; + +template +class StridedSliceGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + size_t rank = ctx.Input("Input")->dims().size(); + switch (rank) { + case 1: + StridedSliceGradCompute<1>(ctx); + break; + case 2: + StridedSliceGradCompute<2>(ctx); + break; + case 3: + StridedSliceGradCompute<3>(ctx); + break; + case 4: + StridedSliceGradCompute<4>(ctx); + break; + case 5: + StridedSliceGradCompute<5>(ctx); + break; + case 6: + StridedSliceGradCompute<6>(ctx); + break; + } + } + + private: + template + void StridedSliceGradCompute( + const framework::ExecutionContext& context) const { + auto& place = + *context.template device_context().eigen_device(); + auto* d_input = + context.Input(framework::GradVarName("Out")); + auto* d_out = + context.Output(framework::GradVarName("Input")); + d_out->mutable_data(context.GetPlace()); + + auto& dev_ctx = context.template device_context(); + math::SetConstant set_zero; + set_zero(dev_ctx, d_out, static_cast(0)); + auto out_dims = d_out->dims(); + auto in_dims = d_input->dims(); + auto starts = context.Attr>("starts"); + auto ends = context.Attr>("ends"); + auto strides = context.Attr>("strides"); + auto axes = context.Attr>("axes"); + + auto starts_indices = Eigen::DSizes(); + auto ends_indices = Eigen::DSizes(); + auto strides_indices = Eigen::DSizes(); + + auto reverse_axis = Eigen::array(); + std::vector reverse_vector(starts.size(), 0); + + StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(), + reverse_vector.data(), out_dims, starts.size()); + + for (size_t axis = 0; axis < D; axis++) { + starts_indices[axis] = 0; + ends_indices[axis] = out_dims[axis]; + strides_indices[axis] = 1; + } + for (size_t axis = 0; axis < axes.size(); axis++) { + int axis_index = axes[axis]; + starts_indices[axis_index] = starts[axis]; + ends_indices[axis_index] = ends[axis]; + strides_indices[axis_index] = strides[axis]; + reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false; + } + + framework::Tensor reverse_input; + reverse_input.mutable_data(in_dims, context.GetPlace()); + + auto in_t = + framework::EigenTensor::From( + *d_input); + auto reverse_in_t = + framework::EigenTensor::From( + reverse_input); + auto out_t = + framework::EigenTensor::From( + *d_out, out_dims); + + reverse_in_t.device(place) = in_t.reverse(reverse_axis); + out_t.stridedSlice(starts_indices, ends_indices, strides_indices) + .device(place) = reverse_in_t; + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index a4bf1378316..15ab6610c6b 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -171,6 +171,7 @@ __all__ = [ 'gaussian_random_batch_size_like', 'sum', 'slice', + 'strided_slice', 'shape', 'rank', 'size', @@ -10792,6 +10793,85 @@ def slice(input, axes, starts, ends): return out +@templatedoc() +def strided_slice(input, axes, starts, ends, strides): + """ + Strided Slice OP + + The conceptualization that really helped me understand this was + that this function emulates the indexing behavior of numpy arrays. + If you're familiar with numpy arrays, you'll know that you can make + slices via input[start1:end1:step1, start2:end2:step2, ... startN:endN:stepN]. + Basically, a very succinct way of writing for loops to get certain elements of the array. + strided_slice just allows you to do this fancy indexing without the syntactic sugar. + The numpy (#input[start1:end1:step1, start2:end2:step2, ... startN:endN:stepN]) + example from above just becomes fluid.strided_slice(input,[0, 1, ..., N], + [start1, start2, ..., startN], [end1, end2, ..., endN], [strides1, strides2, ..., stridesN]), + the axes which controls the dimension you want to slice makes it more flexible. + + .. code-block:: text + + Case1: + Given: + data = [ [1, 2, 3, 4], [5, 6, 7, 8], ] + axes = [0, 1] + starts = [1, 0] + ends = [2, 3] + strides = [1, 1] + Then: + result = [ [5, 6, 7] ] + + Case2: + Given: + data = [ [1, 2, 3, 4], [5, 6, 7, 8], ] + axes = [0, 1] + starts = [0, -1] + ends = [-1, 0] + strides = [1, -1] + Then: + result = [ [4, 3, 2] ] + Atrgs: + input (Varibale): the input variable. + axes(List):axis we need to slice + starts (List): the start index in axis + ends (List): the end index in axis + strides (List): the stride length when we do slice operation + Returns + out(Variable): the result by strided_slice Op + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + starts = [1, 0, 2] + ends = [3, 3, 4] + axes = [0, 1, 2] + strides= [1, 1, 1] + + input = fluid.layers.data( + name="input", shape=[3, 4, 5, 6], dtype='float32') + + out = fluid.layers.strided_slice(input, axes=axes, starts=starts, ends=ends, strides=strides) + """ + helper = LayerHelper('strided_slice', **locals()) + out = helper.create_variable_for_type_inference( + dtype=helper.input_dtype('input')) + + helper.append_op( + type='strided_slice', + inputs={'Input': input}, + outputs={'Out': out}, + attrs={ + 'axes': axes, + 'starts': starts, + 'ends': ends, + 'strides': strides + }) + + return out + + def shape(input): """ **Shape Layer** diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index ad8a42700e3..5ed50db2c43 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -2101,6 +2101,17 @@ class TestBook(LayerTest): self.assertIsNotNone(data_0) self.assertIsNotNone(data_1) + def test_stridedslice(self): + axes = [0, 1, 2] + starts = [1, 0, 2] + ends = [3, 3, 4] + strides = [1, 1, 1] + with self.static_graph(): + x = layers.data(name="x", shape=[245, 30, 30], dtype="float32") + out = layers.strided_slice( + x, axes=axes, starts=starts, ends=ends, strides=strides) + return out + def test_psroi_pool(self): # TODO(minqiyang): dygraph do not support lod now with self.static_graph(): diff --git a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py new file mode 100644 index 00000000000..d7e79a91ed7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py @@ -0,0 +1,201 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from op_test import OpTest +import numpy as np +import unittest + + +def strided_slice_native_forward(input, axes, starts, ends, strides): + dim = input.ndim + start = [] + end = [] + stride = [] + for i in range(dim): + start.append(0) + end.append(input.shape[i]) + stride.append(1) + + for i in range(len(axes)): + start[axes[i]] = starts[i] + end[axes[i]] = ends[i] + stride[axes[i]] = strides[i] + + result = { + 1: lambda input, start, end, stride: input[start[0]:end[0]:stride[0]], + 2: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \ + start[1]:end[1]:stride[1]], + 3: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \ + start[1]:end[1]:stride[1], start[2]:end[2]:stride[2]], + 4: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \ + start[1]:end[1]:stride[1], start[2]:end[2]:stride[2], start[3]:end[3]:stride[3]], + 5: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \ + start[1]:end[1]:stride[1], start[2]:end[2]:stride[2], start[3]:end[3]:stride[3], start[4]:end[4]:stride[4]], + 6: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \ + start[1]:end[1]:stride[1], start[2]:end[2]:stride[2], start[3]:end[3]:stride[3], \ + start[4]:end[4]:stride[4], start[5]:end[5]:stride[5]] + }[dim](input, start, end, stride) + + return result + + +class TestStrideSliceOp(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = 'strided_slice' + self.output = strided_slice_native_forward( + self.input, self.axes, self.starts, self.ends, self.strides) + + self.inputs = {'Input': self.input} + self.outputs = {'Out': self.output} + self.attrs = { + 'axes': self.axes, + 'starts': self.starts, + 'ends': self.ends, + 'strides': self.strides + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(set(['Input']), 'Out') + + def initTestCase(self): + self.input = np.random.rand(6) + self.axes = [0] + self.starts = [-4] + self.ends = [-3] + self.strides = [1] + + +class TestStrideSliceOp1(TestStrideSliceOp): + def initTestCase(self): + self.input = np.random.rand(6) + self.axes = [0] + self.starts = [3] + self.ends = [8] + self.strides = [1] + + +class TestStrideSliceOp2(TestStrideSliceOp): + def initTestCase(self): + self.input = np.random.rand(6) + self.axes = [0] + self.starts = [5] + self.ends = [0] + self.strides = [-1] + + +class TestStrideSliceOp3(TestStrideSliceOp): + def initTestCase(self): + self.input = np.random.rand(6) + self.axes = [0] + self.starts = [-1] + self.ends = [-3] + self.strides = [-1] + + +class TestStrideSliceOp4(TestStrideSliceOp): + def initTestCase(self): + self.input = np.random.rand(3, 4, 6) + self.axes = [0, 1, 2] + self.starts = [0, -1, 0] + self.ends = [2, -3, 5] + self.strides = [1, -1, 1] + + +class TestStrideSliceOp5(TestStrideSliceOp): + def initTestCase(self): + self.input = np.random.rand(3, 3, 3) + self.axes = [0, 1, 2] + self.starts = [1, 0, 0] + self.ends = [2, 1, 3] + self.strides = [1, 1, 1] + + +class TestStrideSliceOp6(TestStrideSliceOp): + def initTestCase(self): + self.input = np.random.rand(3, 3, 3) + self.axes = [0, 1, 2] + self.starts = [1, -1, 0] + self.ends = [2, -3, 3] + self.strides = [1, -1, 1] + + +class TestStrideSliceOp7(TestStrideSliceOp): + def initTestCase(self): + self.input = np.random.rand(3, 3, 3) + self.axes = [0, 1, 2] + self.starts = [1, 0, 0] + self.ends = [2, 2, 3] + self.strides = [1, 1, 1] + + +class TestStrideSliceOp8(TestStrideSliceOp): + def initTestCase(self): + self.input = np.random.rand(1, 3, 1) + self.axes = [1] + self.starts = [1] + self.ends = [2] + self.strides = [1] + + +class TestStrideSliceOp9(TestStrideSliceOp): + def initTestCase(self): + self.input = np.random.rand(1, 3, 1) + self.axes = [1] + self.starts = [-1] + self.ends = [-2] + self.strides = [-1] + + +class TestStrideSliceOp10(TestStrideSliceOp): + def initTestCase(self): + self.input = np.random.rand(3, 3) + self.axes = [0, 1] + self.starts = [1, 0] + self.ends = [2, 2] + self.strides = [1, 1] + + +class TestStrideSliceOp11(TestStrideSliceOp): + def initTestCase(self): + self.input = np.random.rand(3, 3, 3, 4) + self.axes = [0, 1, 2, 3] + self.starts = [1, 0, 0, 0] + self.ends = [2, 2, 3, 4] + self.strides = [1, 1, 1, 2] + + +class TestStrideSliceOp12(TestStrideSliceOp): + def initTestCase(self): + self.input = np.random.rand(3, 3, 3, 4, 5) + self.axes = [0, 1, 2, 3, 4] + self.starts = [1, 0, 0, 0, 0] + self.ends = [2, 2, 3, 4, 4] + self.strides = [1, 1, 1, 1, 1] + + +class TestStrideSliceOp13(TestStrideSliceOp): + def initTestCase(self): + self.input = np.random.rand(3, 3, 3, 6, 7, 8) + self.axes = [0, 1, 2, 3, 4, 5] + self.starts = [1, 0, 0, 0, 1, 2] + self.ends = [2, 2, 3, 1, 2, 8] + self.strides = [1, 1, 1, 1, 1, 2] + + +if __name__ == "__main__": + unittest.main() -- GitLab