未验证 提交 47af618f 编写于 作者: W wangchaochaohu 提交者: GitHub

Strided slice (#19642)

* strided_slice op basic function test=develop

* test=develop rewrite and fix

* fix bug test=develop

* fix for the PADDLE_ENFORCE usage

* add some unit testw

* fix for the aip  test and copright and fix test=develop

* fix API.spec test=develop

* fix API.spec test=develop

* add axis parameter test=develop

* fix for the build error test=develop

* fix python api  test=develop

* fix the build test=develop

* fix build test=develop

* fix API spec test=develop

* test=develop add some comment and single op test

* fix API spece test=develop

* fix test=develop

* fix test=develop

* fix api test=develop

* fix api test=develop

* fix API.spec test=develop

* fix typo test=develop

* fix API.spec test=develop

* fix API typo test=develop

* fix doc and API.spec test=develop
上级 13ca364c
......@@ -244,6 +244,7 @@ paddle.fluid.layers.sampling_id (ArgSpec(args=['x', 'min', 'max', 'seed', 'dtype
paddle.fluid.layers.gaussian_random_batch_size_like (ArgSpec(args=['input', 'shape', 'input_dim_idx', 'output_dim_idx', 'mean', 'std', 'seed', 'dtype'], varargs=None, keywords=None, defaults=(0, 0, 0.0, 1.0, 0, 'float32')), ('document', 'b24d0b21361c4bb8ef2cec8c26fb12b2'))
paddle.fluid.layers.sum (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'f4b60847cb0f1ae00823ba6fb1b11310'))
paddle.fluid.layers.slice (ArgSpec(args=['input', 'axes', 'starts', 'ends'], varargs=None, keywords=None, defaults=None), ('document', '315b4870f294e33a27ecbdf440bed3ff'))
paddle.fluid.layers.strided_slice (ArgSpec(args=['input', 'axes', 'starts', 'ends', 'strides'], varargs=None, keywords=None, defaults=None), ('document', 'a2e5296d34c081f2a67890aaa5f02238'))
paddle.fluid.layers.shape (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'bf61c8f79d795a8371bdb3b5468aa82b'))
paddle.fluid.layers.rank (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', '096df0e0273145ab80ed119a4c294db3'))
paddle.fluid.layers.size (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'cf2e156beae36378722666c4c33bebfe'))
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/strided_slice_op.h"
#include <algorithm>
#include <memory>
#include <vector>
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class StridedSliceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
"Input (Input) of slice op should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output (Out) of slice op should not be null.");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_LT(in_dims.size(), 7,
"The rank of input should be less than 7.");
auto starts = ctx->Attrs().Get<std::vector<int>>("starts");
auto ends = ctx->Attrs().Get<std::vector<int>>("ends");
auto strides = ctx->Attrs().Get<std::vector<int>>("strides");
auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
PADDLE_ENFORCE_EQ(starts.size(), ends.size(),
"starts and ends dim size must to be same");
PADDLE_ENFORCE_EQ(ends.size(), strides.size(),
"ends and strides dim size must to be same");
PADDLE_ENFORCE_EQ(ends.size(), axes.size(),
"axes, end and start dim size must to be same");
// we need to analysis strided slice op is valid for
// the parameter that we get from python front
int stride_index, start_index, end_index;
std::vector<int> out_dims_vector(in_dims.size());
for (int i = 0; i < in_dims.size(); i++) {
out_dims_vector[i] = in_dims[i];
}
for (size_t i = 0; i < starts.size(); i++) {
PADDLE_ENFORCE_NE(strides[i], 0, "stride must not to be zero");
int axes_index = axes[i];
start_index = starts[i];
end_index = ends[i];
stride_index = strides[i];
int axis_size = in_dims[axes_index];
if (axis_size < 0) {
continue;
}
if (start_index < 0) {
start_index = start_index + axis_size;
}
if (end_index < 0) {
end_index = end_index + axis_size;
}
if (stride_index < 0) {
start_index = start_index + 1;
end_index = end_index + 1;
}
bool zero_dim_condition =
((stride_index < 0 && (start_index <= end_index)) ||
(stride_index > 0 && (start_index >= end_index)));
PADDLE_ENFORCE_EQ(zero_dim_condition, false,
"starts and end must meet requirement in different "
"stride conditiont");
int left = std::max(0, std::min(start_index, end_index));
int right = std::min(axis_size, std::max(start_index, end_index));
int step = std::abs(stride_index);
auto out_dims_index = (std::abs(right - left) + step - 1) / step;
out_dims_vector[axes_index] = out_dims_index;
}
framework::DDim out_dims(framework::make_ddim(out_dims_vector));
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("Input", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.Input<Tensor>("Input")->place());
}
};
class StridedSliceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "Tensor of data to extract slices from.");
AddOutput("Out", "Sliced data tensor.");
AddAttr<std::vector<int>>(
"axes", "(list<int> Axes stride from the start to the end)");
AddAttr<std::vector<int>>(
"starts", "(list<int>) start that the tensor slice start.");
AddAttr<std::vector<int>>("ends",
"(list<int>) end that the tensor slice end");
AddAttr<std::vector<int>>(
"strides", "(list<int> stride stride from the start to the end)");
AddComment(R"DOC(
Strided Slice Operator.
Instead of calling this op directly most users will want to use the
NumPy-style slicing syntax.
For Example:
data = fluid.layers.fill_constant(shape=[3, 3], value=0, dtype='int64')
y = fluid.layers.strided_slice(data, [0, 1], [1,0], [2, 3], [1, 1])
)DOC");
}
};
class StridedSliceOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, "Input should not be null");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("Input");
auto x_grad_name = framework::GradVarName("Input");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
};
class StridedSliceOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* bind = new framework::OpDesc();
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetInput("Input", Input("Input"));
bind->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
bind->SetAttrMap(Attrs());
bind->SetType("strided_slice_grad");
return std::unique_ptr<framework::OpDesc>(bind);
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
StridedSliceOpGradNoNeedBufferVarsInference, "Input");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(strided_slice, ops::StridedSliceOp, ops::StridedSliceOpMaker,
ops::StridedSliceOpGradMaker);
REGISTER_OPERATOR(strided_slice_grad, ops::StridedSliceOpGrad,
ops::StridedSliceOpGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(
strided_slice,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, int>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, float>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
strided_slice_grad,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/strided_slice_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
strided_slice,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, int>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, float>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
strided_slice_grad,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <cstdlib>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes,
int* reverse_axis, const framework::DDim dims,
const size_t size) {
for (size_t axis = 0; axis < size; axis++) {
int axis_size = dims[axes[axis]];
int axis_index = axis;
if (axis_size < 0) {
starts[axis_index] = 0;
ends[axis_index] = 1;
strides[axis_index] = 1;
}
// stride must not be zero
if (starts[axis_index] < 0) {
starts[axis_index] = starts[axis_index] + axis_size;
}
if (ends[axis_index] < 0) {
ends[axis_index] = ends[axis_index] + axis_size;
}
if (strides[axis_index] < 0) {
reverse_axis[axis_index] = 1;
strides[axis_index] = -strides[axis_index];
if (starts[axis_index] > ends[axis_index]) {
// swap the reverse
starts[axis_index] = starts[axis_index] + 1;
ends[axis_index] = ends[axis_index] + 1;
}
std::swap(starts[axis_index], ends[axis_index]);
} else {
reverse_axis[axis_index] = 0;
strides[axis_index] = strides[axis_index];
}
}
}
template <typename DeviceContext, typename T>
class StridedSliceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int rank = ctx.Input<framework::Tensor>("Input")->dims().size();
switch (rank) {
case 1:
StridedSliceCompute<1>(ctx);
break;
case 2:
StridedSliceCompute<2>(ctx);
break;
case 3:
StridedSliceCompute<3>(ctx);
break;
case 4:
StridedSliceCompute<4>(ctx);
break;
case 5:
StridedSliceCompute<5>(ctx);
break;
case 6:
StridedSliceCompute<6>(ctx);
break;
}
}
private:
template <size_t D>
void StridedSliceCompute(const framework::ExecutionContext& context) const {
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto in = context.Input<framework::Tensor>("Input");
auto out = context.Output<framework::Tensor>("Out");
auto out_dims = out->dims();
auto in_dims = in->dims();
auto starts = context.Attr<std::vector<int>>("starts");
auto ends = context.Attr<std::vector<int>>("ends");
auto strides = context.Attr<std::vector<int>>("strides");
auto axes = context.Attr<std::vector<int>>("axes");
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto reverse_axis = Eigen::array<bool, D>();
std::vector<int> reverse_vector(starts.size(), 0);
StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(),
reverse_vector.data(), in_dims, starts.size());
for (size_t axis = 0; axis < D; axis++) {
starts_indices[axis] = 0;
ends_indices[axis] = out_dims[axis];
strides_indices[axis] = 1;
}
for (size_t axis = 0; axis < axes.size(); axis++) {
int axis_index = axes[axis];
starts_indices[axis_index] = starts[axis];
ends_indices[axis_index] = ends[axis];
strides_indices[axis_index] = strides[axis];
reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false;
}
framework::Tensor tmp;
tmp.mutable_data<T>(out_dims, context.GetPlace());
out->mutable_data<T>(context.GetPlace());
auto in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*in);
auto tmp_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
tmp);
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*out, out_dims);
tmp_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
out_t.device(place) = tmp_t.reverse(reverse_axis);
}
};
template <typename DeviceContext, typename T>
class StridedSliceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
size_t rank = ctx.Input<framework::Tensor>("Input")->dims().size();
switch (rank) {
case 1:
StridedSliceGradCompute<1>(ctx);
break;
case 2:
StridedSliceGradCompute<2>(ctx);
break;
case 3:
StridedSliceGradCompute<3>(ctx);
break;
case 4:
StridedSliceGradCompute<4>(ctx);
break;
case 5:
StridedSliceGradCompute<5>(ctx);
break;
case 6:
StridedSliceGradCompute<6>(ctx);
break;
}
}
private:
template <size_t D>
void StridedSliceGradCompute(
const framework::ExecutionContext& context) const {
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto* d_input =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_out =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
d_out->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, d_out, static_cast<T>(0));
auto out_dims = d_out->dims();
auto in_dims = d_input->dims();
auto starts = context.Attr<std::vector<int>>("starts");
auto ends = context.Attr<std::vector<int>>("ends");
auto strides = context.Attr<std::vector<int>>("strides");
auto axes = context.Attr<std::vector<int>>("axes");
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto reverse_axis = Eigen::array<bool, D>();
std::vector<int> reverse_vector(starts.size(), 0);
StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(),
reverse_vector.data(), out_dims, starts.size());
for (size_t axis = 0; axis < D; axis++) {
starts_indices[axis] = 0;
ends_indices[axis] = out_dims[axis];
strides_indices[axis] = 1;
}
for (size_t axis = 0; axis < axes.size(); axis++) {
int axis_index = axes[axis];
starts_indices[axis_index] = starts[axis];
ends_indices[axis_index] = ends[axis];
strides_indices[axis_index] = strides[axis];
reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false;
}
framework::Tensor reverse_input;
reverse_input.mutable_data<T>(in_dims, context.GetPlace());
auto in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_input);
auto reverse_in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
reverse_input);
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_out, out_dims);
reverse_in_t.device(place) = in_t.reverse(reverse_axis);
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(place) = reverse_in_t;
}
};
} // namespace operators
} // namespace paddle
......@@ -171,6 +171,7 @@ __all__ = [
'gaussian_random_batch_size_like',
'sum',
'slice',
'strided_slice',
'shape',
'rank',
'size',
......@@ -10792,6 +10793,85 @@ def slice(input, axes, starts, ends):
return out
@templatedoc()
def strided_slice(input, axes, starts, ends, strides):
"""
Strided Slice OP
The conceptualization that really helped me understand this was
that this function emulates the indexing behavior of numpy arrays.
If you're familiar with numpy arrays, you'll know that you can make
slices via input[start1:end1:step1, start2:end2:step2, ... startN:endN:stepN].
Basically, a very succinct way of writing for loops to get certain elements of the array.
strided_slice just allows you to do this fancy indexing without the syntactic sugar.
The numpy (#input[start1:end1:step1, start2:end2:step2, ... startN:endN:stepN])
example from above just becomes fluid.strided_slice(input,[0, 1, ..., N],
[start1, start2, ..., startN], [end1, end2, ..., endN], [strides1, strides2, ..., stridesN]),
the axes which controls the dimension you want to slice makes it more flexible.
.. code-block:: text
Case1:
Given:
data = [ [1, 2, 3, 4], [5, 6, 7, 8], ]
axes = [0, 1]
starts = [1, 0]
ends = [2, 3]
strides = [1, 1]
Then:
result = [ [5, 6, 7] ]
Case2:
Given:
data = [ [1, 2, 3, 4], [5, 6, 7, 8], ]
axes = [0, 1]
starts = [0, -1]
ends = [-1, 0]
strides = [1, -1]
Then:
result = [ [4, 3, 2] ]
Atrgs:
input (Varibale): the input variable.
axes(List):axis we need to slice
starts (List): the start index in axis
ends (List): the end index in axis
strides (List): the stride length when we do slice operation
Returns
out(Variable): the result by strided_slice Op
Examples:
.. code-block:: python
import paddle.fluid as fluid
starts = [1, 0, 2]
ends = [3, 3, 4]
axes = [0, 1, 2]
strides= [1, 1, 1]
input = fluid.layers.data(
name="input", shape=[3, 4, 5, 6], dtype='float32')
out = fluid.layers.strided_slice(input, axes=axes, starts=starts, ends=ends, strides=strides)
"""
helper = LayerHelper('strided_slice', **locals())
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype('input'))
helper.append_op(
type='strided_slice',
inputs={'Input': input},
outputs={'Out': out},
attrs={
'axes': axes,
'starts': starts,
'ends': ends,
'strides': strides
})
return out
def shape(input):
"""
**Shape Layer**
......
......@@ -2101,6 +2101,17 @@ class TestBook(LayerTest):
self.assertIsNotNone(data_0)
self.assertIsNotNone(data_1)
def test_stridedslice(self):
axes = [0, 1, 2]
starts = [1, 0, 2]
ends = [3, 3, 4]
strides = [1, 1, 1]
with self.static_graph():
x = layers.data(name="x", shape=[245, 30, 30], dtype="float32")
out = layers.strided_slice(
x, axes=axes, starts=starts, ends=ends, strides=strides)
return out
def test_psroi_pool(self):
# TODO(minqiyang): dygraph do not support lod now
with self.static_graph():
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from op_test import OpTest
import numpy as np
import unittest
def strided_slice_native_forward(input, axes, starts, ends, strides):
dim = input.ndim
start = []
end = []
stride = []
for i in range(dim):
start.append(0)
end.append(input.shape[i])
stride.append(1)
for i in range(len(axes)):
start[axes[i]] = starts[i]
end[axes[i]] = ends[i]
stride[axes[i]] = strides[i]
result = {
1: lambda input, start, end, stride: input[start[0]:end[0]:stride[0]],
2: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \
start[1]:end[1]:stride[1]],
3: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \
start[1]:end[1]:stride[1], start[2]:end[2]:stride[2]],
4: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \
start[1]:end[1]:stride[1], start[2]:end[2]:stride[2], start[3]:end[3]:stride[3]],
5: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \
start[1]:end[1]:stride[1], start[2]:end[2]:stride[2], start[3]:end[3]:stride[3], start[4]:end[4]:stride[4]],
6: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \
start[1]:end[1]:stride[1], start[2]:end[2]:stride[2], start[3]:end[3]:stride[3], \
start[4]:end[4]:stride[4], start[5]:end[5]:stride[5]]
}[dim](input, start, end, stride)
return result
class TestStrideSliceOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'strided_slice'
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
self.inputs = {'Input': self.input}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends,
'strides': self.strides
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(set(['Input']), 'Out')
def initTestCase(self):
self.input = np.random.rand(6)
self.axes = [0]
self.starts = [-4]
self.ends = [-3]
self.strides = [1]
class TestStrideSliceOp1(TestStrideSliceOp):
def initTestCase(self):
self.input = np.random.rand(6)
self.axes = [0]
self.starts = [3]
self.ends = [8]
self.strides = [1]
class TestStrideSliceOp2(TestStrideSliceOp):
def initTestCase(self):
self.input = np.random.rand(6)
self.axes = [0]
self.starts = [5]
self.ends = [0]
self.strides = [-1]
class TestStrideSliceOp3(TestStrideSliceOp):
def initTestCase(self):
self.input = np.random.rand(6)
self.axes = [0]
self.starts = [-1]
self.ends = [-3]
self.strides = [-1]
class TestStrideSliceOp4(TestStrideSliceOp):
def initTestCase(self):
self.input = np.random.rand(3, 4, 6)
self.axes = [0, 1, 2]
self.starts = [0, -1, 0]
self.ends = [2, -3, 5]
self.strides = [1, -1, 1]
class TestStrideSliceOp5(TestStrideSliceOp):
def initTestCase(self):
self.input = np.random.rand(3, 3, 3)
self.axes = [0, 1, 2]
self.starts = [1, 0, 0]
self.ends = [2, 1, 3]
self.strides = [1, 1, 1]
class TestStrideSliceOp6(TestStrideSliceOp):
def initTestCase(self):
self.input = np.random.rand(3, 3, 3)
self.axes = [0, 1, 2]
self.starts = [1, -1, 0]
self.ends = [2, -3, 3]
self.strides = [1, -1, 1]
class TestStrideSliceOp7(TestStrideSliceOp):
def initTestCase(self):
self.input = np.random.rand(3, 3, 3)
self.axes = [0, 1, 2]
self.starts = [1, 0, 0]
self.ends = [2, 2, 3]
self.strides = [1, 1, 1]
class TestStrideSliceOp8(TestStrideSliceOp):
def initTestCase(self):
self.input = np.random.rand(1, 3, 1)
self.axes = [1]
self.starts = [1]
self.ends = [2]
self.strides = [1]
class TestStrideSliceOp9(TestStrideSliceOp):
def initTestCase(self):
self.input = np.random.rand(1, 3, 1)
self.axes = [1]
self.starts = [-1]
self.ends = [-2]
self.strides = [-1]
class TestStrideSliceOp10(TestStrideSliceOp):
def initTestCase(self):
self.input = np.random.rand(3, 3)
self.axes = [0, 1]
self.starts = [1, 0]
self.ends = [2, 2]
self.strides = [1, 1]
class TestStrideSliceOp11(TestStrideSliceOp):
def initTestCase(self):
self.input = np.random.rand(3, 3, 3, 4)
self.axes = [0, 1, 2, 3]
self.starts = [1, 0, 0, 0]
self.ends = [2, 2, 3, 4]
self.strides = [1, 1, 1, 2]
class TestStrideSliceOp12(TestStrideSliceOp):
def initTestCase(self):
self.input = np.random.rand(3, 3, 3, 4, 5)
self.axes = [0, 1, 2, 3, 4]
self.starts = [1, 0, 0, 0, 0]
self.ends = [2, 2, 3, 4, 4]
self.strides = [1, 1, 1, 1, 1]
class TestStrideSliceOp13(TestStrideSliceOp):
def initTestCase(self):
self.input = np.random.rand(3, 3, 3, 6, 7, 8)
self.axes = [0, 1, 2, 3, 4, 5]
self.starts = [1, 0, 0, 0, 1, 2]
self.ends = [2, 2, 3, 1, 2, 8]
self.strides = [1, 1, 1, 1, 1, 2]
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册