未验证 提交 fa6c59a4 编写于 作者: B Bo Liu 提交者: GitHub

[NPU] Support npu kernel for StridedSlice op without grad (#34601)

上级 ac33c0ca
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/strided_slice_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/operators/slice_op.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class StridedSliceNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Variable* input_var = ctx.InputVar("Input");
bool is_tensor_array = input_var->IsType<LoDTensorArray>();
PADDLE_ENFORCE_EQ(is_tensor_array, false,
platform::errors::InvalidArgument(
"Tensor array as input is not supported."));
int rank = ctx.Input<framework::Tensor>("Input")->dims().size();
switch (rank) {
case 1:
StridedSliceCompute<1>(ctx);
break;
case 2:
StridedSliceCompute<2>(ctx);
break;
case 3:
StridedSliceCompute<3>(ctx);
break;
case 4:
StridedSliceCompute<4>(ctx);
break;
case 5:
StridedSliceCompute<5>(ctx);
break;
case 6:
StridedSliceCompute<6>(ctx);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"The rank of input is supported up to 6."));
break;
}
}
private:
template <size_t D>
void StridedSliceCompute(const framework::ExecutionContext& ctx) const {
auto place = ctx.GetPlace();
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
auto in = ctx.Input<framework::Tensor>("Input");
auto out = ctx.Output<framework::Tensor>("Out");
auto in_dims = in->dims();
// list<int>
auto starts_int = ctx.Attr<std::vector<int>>("starts");
auto ends_int = ctx.Attr<std::vector<int>>("ends");
auto strides_int = ctx.Attr<std::vector<int>>("strides");
std::vector<int64_t> starts(starts_int.begin(), starts_int.end());
std::vector<int64_t> ends(ends_int.begin(), ends_int.end());
std::vector<int64_t> strides(strides_int.begin(), strides_int.end());
auto axes = ctx.Attr<std::vector<int>>("axes");
auto infer_flags = ctx.Attr<std::vector<int>>("infer_flags");
auto decrease_axis = ctx.Attr<std::vector<int>>("decrease_axis");
// vector<Tensor<int32>>
auto list_new_ends_tensor =
ctx.MultiInput<framework::Tensor>("EndsTensorList");
auto list_new_starts_tensor =
ctx.MultiInput<framework::Tensor>("StartsTensorList");
auto list_new_strides_tensor =
ctx.MultiInput<framework::Tensor>("StridesTensorList");
// Tensor<int32>
if (list_new_starts_tensor.size() > 0) {
starts = GetDataFromTensorList<int64_t>(list_new_starts_tensor);
} else if (ctx.HasInput("StartsTensor")) {
auto* starts_tensor = ctx.Input<framework::Tensor>("StartsTensor");
starts = GetDataFromTensor<int64_t>(starts_tensor);
}
if (list_new_ends_tensor.size() > 0) {
ends = GetDataFromTensorList<int64_t>(list_new_ends_tensor);
} else if (ctx.HasInput("EndsTensor")) {
auto* ends_tensor = ctx.Input<framework::Tensor>("EndsTensor");
ends = GetDataFromTensor<int64_t>(ends_tensor);
}
if (list_new_strides_tensor.size() > 0) {
strides = GetDataFromTensorList<int64_t>(list_new_strides_tensor);
} else if (ctx.HasInput("StridesTensor")) {
auto* strides_tensor = ctx.Input<framework::Tensor>("StridesTensor");
strides = GetDataFromTensor<int64_t>(strides_tensor);
}
// out dims calculation
std::vector<int64_t> out_dims_vector(in_dims.size(), -1);
StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims,
decrease_axis, out_dims_vector.data(), axes.size(),
false);
framework::DDim out_dims(framework::make_ddim(out_dims_vector));
// check whether need to reverse (false: stride > 0; true: stride < 0)
std::vector<int> reverse_vector(starts.size(), 0);
StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(),
reverse_vector.data(), in_dims, infer_flags,
decrease_axis, starts.size());
// construct the starts_indices, ends_indices and strides_indices tensor for
// calling StridedSlice op
std::vector<int64_t> starts_indices_vector(D, 0);
std::vector<int64_t> ends_indices_vector(out_dims_vector.begin(),
out_dims_vector.end());
std::vector<int64_t> strides_indices_vector(D, 1);
for (size_t axis = 0; axis < axes.size(); axis++) {
int axis_index = axes[axis];
starts_indices_vector[axis_index] = starts[axis];
ends_indices_vector[axis_index] = ends[axis];
strides_indices_vector[axis_index] = strides[axis];
}
Tensor starts_indices_tensor;
Tensor ends_indices_tensor;
Tensor strides_indices_tensor;
starts_indices_tensor.mutable_data<int64_t>({D}, place);
ends_indices_tensor.mutable_data<int64_t>({D}, place);
strides_indices_tensor.mutable_data<int64_t>({D}, place);
TensorFromVector(starts_indices_vector, ctx.device_context(),
&starts_indices_tensor);
TensorFromVector(ends_indices_vector, ctx.device_context(),
&ends_indices_tensor);
TensorFromVector(strides_indices_vector, ctx.device_context(),
&strides_indices_tensor);
auto out_dims_origin = out_dims;
if (decrease_axis.size() > 0) {
std::vector<int64_t> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
PADDLE_ENFORCE_EQ(
out_dims[decrease_axis[i]], 1,
platform::errors::InvalidArgument(
"the size of decrease dimension should be 1, but received %d.",
out_dims[decrease_axis[i]]));
out_dims_origin[decrease_axis[i]] = 0;
}
for (int i = 0; i < out_dims_origin.size(); ++i) {
if (out_dims_origin[i] != 0) {
new_out_shape.push_back(out_dims_origin[i]);
}
}
if (new_out_shape.size() == 0) {
new_out_shape.push_back(1);
}
out_dims_origin = framework::make_ddim(new_out_shape);
}
bool need_reverse = false;
for (size_t axis = 0; axis < axes.size(); axis++) {
if (reverse_vector[axis] == 1) {
need_reverse = true;
break;
}
}
out->Resize(out_dims);
out->mutable_data<T>(place);
const auto& runner = NpuOpRunner(
"StridedSlice", {*in, starts_indices_tensor, ends_indices_tensor,
strides_indices_tensor},
{*out}, {{"begin_mask", 0},
{"end_mask", 0},
{"ellipsis_mask", 0},
{"new_axis_mask", 0},
{"shrink_axis_mask", 0}});
runner.Run(stream);
if (need_reverse) {
Tensor out_tmp;
out_tmp.mutable_data<T>(out_dims, place);
TensorCopy(*out, place,
ctx.template device_context<platform::DeviceContext>(),
&out_tmp);
Tensor reverse_axis;
std::vector<int> reverse_axis_vector;
for (size_t axis = 0; axis < axes.size(); axis++) {
if (reverse_vector[axis] == 1) {
reverse_axis_vector.push_back(axes[axis]);
}
}
reverse_axis.mutable_data<int>(
{static_cast<int>(reverse_axis_vector.size())}, place);
TensorFromVector(reverse_axis_vector, ctx.device_context(),
&reverse_axis);
const auto& runner_reverse =
NpuOpRunner("ReverseV2", {out_tmp, reverse_axis}, {*out});
runner_reverse.Run(stream);
}
if (decrease_axis.size() > 0) {
out->Resize(out_dims_origin);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
strided_slice,
ops::StridedSliceNPUKernel<paddle::platform::NPUDeviceContext, bool>,
ops::StridedSliceNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::StridedSliceNPUKernel<paddle::platform::NPUDeviceContext, int64_t>,
ops::StridedSliceNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::StridedSliceNPUKernel<paddle::platform::NPUDeviceContext, double>);
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
sys.path.append("..")
from op_test import OpTest, skip_check_grad_ci
import unittest
import paddle.fluid as fluid
import paddle
paddle.enable_static()
def strided_slice_native_forward(input, axes, starts, ends, strides):
dim = input.ndim
start = []
end = []
stride = []
for i in range(dim):
start.append(0)
end.append(input.shape[i])
stride.append(1)
for i in range(len(axes)):
start[axes[i]] = starts[i]
end[axes[i]] = ends[i]
stride[axes[i]] = strides[i]
result = {
1: lambda input, start, end, stride: input[start[0]:end[0]:stride[0]],
2: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \
start[1]:end[1]:stride[1]],
3: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \
start[1]:end[1]:stride[1], start[2]:end[2]:stride[2]],
4: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \
start[1]:end[1]:stride[1], start[2]:end[2]:stride[2], start[3]:end[3]:stride[3]],
5: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \
start[1]:end[1]:stride[1], start[2]:end[2]:stride[2], start[3]:end[3]:stride[3], start[4]:end[4]:stride[4]],
6: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \
start[1]:end[1]:stride[1], start[2]:end[2]:stride[2], start[3]:end[3]:stride[3], \
start[4]:end[4]:stride[4], start[5]:end[5]:stride[5]]
}[dim](input, start, end, stride)
return result
@skip_check_grad_ci(
reason='''forward only, it doesn't need to call check_grad.''')
class TestStridedSliceOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'strided_slice'
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
self.inputs = {'Input': self.input}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends,
'strides': self.strides,
'infer_flags': self.infer_flags
}
def test_check_output(self):
place = paddle.NPUPlace(0)
self.check_output_with_place(place)
def initTestCase(self):
self.input = np.random.rand(10)
self.axes = [0]
self.starts = [2]
self.ends = [7]
self.strides = [1]
self.infer_flags = [1]
class TestStridedSliceOp1(TestStridedSliceOp):
def initTestCase(self):
self.input = np.random.rand(100)
self.axes = [0]
self.starts = [3]
self.ends = [8]
self.strides = [1]
self.infer_flags = [1]
class TestStridedSliceOp2(TestStridedSliceOp):
def initTestCase(self):
self.input = np.random.rand(100)
self.axes = [0]
self.starts = [5]
self.ends = [0]
self.strides = [-1]
self.infer_flags = [1]
class TestStridedSliceOp3(TestStridedSliceOp):
def initTestCase(self):
self.input = np.random.rand(100)
self.axes = [0]
self.starts = [-1]
self.ends = [-3]
self.strides = [-1]
self.infer_flags = [1]
class TestStridedSliceOp4(TestStridedSliceOp):
def initTestCase(self):
self.input = np.random.rand(3, 4, 10)
self.axes = [0, 1, 2]
self.starts = [0, -1, 0]
self.ends = [2, -3, 5]
self.strides = [1, -1, 1]
self.infer_flags = [1, 1, 1]
class TestStridedSliceOp5(TestStridedSliceOp):
def initTestCase(self):
self.input = np.random.rand(5, 5, 5)
self.axes = [0, 1, 2]
self.starts = [1, 0, 0]
self.ends = [2, 1, 3]
self.strides = [1, 1, 1]
self.infer_flags = [1, 1, 1]
class TestStridedSliceOp6(TestStridedSliceOp):
def initTestCase(self):
self.input = np.random.rand(5, 5, 5)
self.axes = [0, 1, 2]
self.starts = [1, -1, 0]
self.ends = [2, -3, 3]
self.strides = [1, -1, 1]
self.infer_flags = [1, 1, 1]
class TestStridedSliceOp7(TestStridedSliceOp):
def initTestCase(self):
self.input = np.random.rand(5, 5, 5)
self.axes = [0, 1, 2]
self.starts = [1, 0, 0]
self.ends = [2, 2, 3]
self.strides = [1, 1, 1]
self.infer_flags = [1, 1, 1]
class TestStridedSliceOp8(TestStridedSliceOp):
def initTestCase(self):
self.input = np.random.rand(1, 100, 1)
self.axes = [1]
self.starts = [1]
self.ends = [2]
self.strides = [1]
self.infer_flags = [1]
class TestStridedSliceOp9(TestStridedSliceOp):
def initTestCase(self):
self.input = np.random.rand(1, 100, 1)
self.axes = [1]
self.starts = [-1]
self.ends = [-2]
self.strides = [-1]
self.infer_flags = [1]
class TestStridedSliceOp10(TestStridedSliceOp):
def initTestCase(self):
self.input = np.random.rand(10, 10)
self.axes = [0, 1]
self.starts = [1, 0]
self.ends = [2, 2]
self.strides = [1, 1]
self.infer_flags = [1, 1]
class TestStridedSliceOp11(TestStridedSliceOp):
def initTestCase(self):
self.input = np.random.rand(3, 3, 3, 4)
self.axes = [0, 1, 2, 3]
self.starts = [1, 0, 0, 0]
self.ends = [2, 2, 3, 4]
self.strides = [1, 1, 1, 2]
self.infer_flags = [1, 1, 1, 1]
class TestStridedSliceOp12(TestStridedSliceOp):
def initTestCase(self):
self.input = np.random.rand(3, 3, 3, 4, 5)
self.axes = [0, 1, 2, 3, 4]
self.starts = [1, 0, 0, 0, 0]
self.ends = [2, 2, 3, 4, 4]
self.strides = [1, 1, 1, 1, 1]
self.infer_flags = [1, 1, 1, 1]
class TestStridedSliceOp13(TestStridedSliceOp):
def initTestCase(self):
self.input = np.random.rand(3, 3, 3, 6, 7, 8)
self.axes = [0, 1, 2, 3, 4, 5]
self.starts = [1, 0, 0, 0, 1, 2]
self.ends = [2, 2, 3, 1, 2, 8]
self.strides = [1, 1, 1, 1, 1, 2]
self.infer_flags = [1, 1, 1, 1, 1]
class TestStridedSliceOpBool(TestStridedSliceOp):
def test_check_grad(self):
pass
class TestStridedSliceOpBool1D(TestStridedSliceOpBool):
def initTestCase(self):
self.input = np.random.rand(100).astype("bool")
self.axes = [0]
self.starts = [3]
self.ends = [8]
self.strides = [1]
self.infer_flags = [1]
class TestStridedSliceOpBool2D(TestStridedSliceOpBool):
def initTestCase(self):
self.input = np.random.rand(10, 10).astype("bool")
self.axes = [0, 1]
self.starts = [1, 0]
self.ends = [2, 2]
self.strides = [1, 1]
self.infer_flags = [1, 1]
class TestStridedSliceOpBool3D(TestStridedSliceOpBool):
def initTestCase(self):
self.input = np.random.rand(3, 4, 10).astype("bool")
self.axes = [0, 1, 2]
self.starts = [0, -1, 0]
self.ends = [2, -3, 5]
self.strides = [1, -1, 1]
self.infer_flags = [1, 1, 1]
class TestStridedSliceOpBool4D(TestStridedSliceOpBool):
def initTestCase(self):
self.input = np.random.rand(3, 3, 3, 4).astype("bool")
self.axes = [0, 1, 2, 3]
self.starts = [1, 0, 0, 0]
self.ends = [2, 2, 3, 4]
self.strides = [1, 1, 1, 2]
self.infer_flags = [1, 1, 1, 1]
class TestStridedSliceOpBool5D(TestStridedSliceOpBool):
def initTestCase(self):
self.input = np.random.rand(3, 3, 3, 4, 5).astype("bool")
self.axes = [0, 1, 2, 3, 4]
self.starts = [1, 0, 0, 0, 0]
self.ends = [2, 2, 3, 4, 4]
self.strides = [1, 1, 1, 1, 1]
self.infer_flags = [1, 1, 1, 1]
class TestStridedSliceOpBool6D(TestStridedSliceOpBool):
def initTestCase(self):
self.input = np.random.rand(3, 3, 3, 6, 7, 8).astype("bool")
self.axes = [0, 1, 2, 3, 4, 5]
self.starts = [1, 0, 0, 0, 1, 2]
self.ends = [2, 2, 3, 1, 2, 8]
self.strides = [1, 1, 1, 1, 1, 2]
self.infer_flags = [1, 1, 1, 1, 1]
@skip_check_grad_ci(
reason='''forward only, it doesn't need to call check_grad.''')
class TestStridedSliceOp_starts_ListTensor(OpTest):
def setUp(self):
self.op_type = "strided_slice"
self.config()
starts_tensor = []
for index, ele in enumerate(self.starts):
starts_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {'Input': self.input, 'StartsTensorList': starts_tensor}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
'starts': self.starts_infer,
'ends': self.ends,
'strides': self.strides,
'infer_flags': self.infer_flags
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float64")
self.starts = [1, 0, 2]
self.ends = [3, 3, 4]
self.axes = [0, 1, 2]
self.strides = [1, 1, 1]
self.infer_flags = [1, -1, 1]
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
self.starts_infer = [1, 10, 2]
def test_check_output(self):
place = paddle.NPUPlace(0)
self.check_output_with_place(place)
@skip_check_grad_ci(
reason='''forward only, it doesn't need to call check_grad.''')
class TestStridedSliceOp_ends_ListTensor(OpTest):
def setUp(self):
self.op_type = "strided_slice"
self.config()
ends_tensor = []
for index, ele in enumerate(self.ends):
ends_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {'Input': self.input, 'EndsTensorList': ends_tensor}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends_infer,
'strides': self.strides,
'infer_flags': self.infer_flags
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float64")
self.starts = [1, 0, 0]
self.ends = [3, 3, 4]
self.axes = [0, 1, 2]
self.strides = [1, 1, 2]
self.infer_flags = [1, -1, 1]
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
self.ends_infer = [3, 1, 4]
def test_check_output(self):
place = paddle.NPUPlace(0)
self.check_output_with_place(place)
@skip_check_grad_ci(
reason='''forward only, it doesn't need to call check_grad.''')
class TestStridedSliceOp_starts_Tensor(OpTest):
def setUp(self):
self.op_type = "strided_slice"
self.config()
self.inputs = {
'Input': self.input,
"StartsTensor": np.array(
self.starts, dtype="int32")
}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
#'starts': self.starts,
'ends': self.ends,
'strides': self.strides,
'infer_flags': self.infer_flags,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float64")
self.starts = [1, 0, 2]
self.ends = [2, 3, 4]
self.axes = [0, 1, 2]
self.strides = [1, 1, 1]
self.infer_flags = [-1, -1, -1]
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
def test_check_output(self):
place = paddle.NPUPlace(0)
self.check_output_with_place(place)
@skip_check_grad_ci(
reason='''forward only, it doesn't need to call check_grad.''')
class TestStridedSliceOp_ends_Tensor(OpTest):
def setUp(self):
self.op_type = "strided_slice"
self.config()
self.inputs = {
'Input': self.input,
"EndsTensor": np.array(
self.ends, dtype="int32")
}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
#'ends': self.ends,
'strides': self.strides,
'infer_flags': self.infer_flags,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float64")
self.starts = [1, 0, 2]
self.ends = [2, 3, 4]
self.axes = [0, 1, 2]
self.strides = [1, 1, 1]
self.infer_flags = [-1, -1, -1]
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
def test_check_output(self):
place = paddle.NPUPlace(0)
self.check_output_with_place(place)
@skip_check_grad_ci(
reason='''forward only, it doesn't need to call check_grad.''')
class TestStridedSliceOp_listTensor_Tensor(OpTest):
def setUp(self):
self.config()
ends_tensor = []
for index, ele in enumerate(self.ends):
ends_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.op_type = "strided_slice"
self.inputs = {
'Input': self.input,
"StartsTensor": np.array(
self.starts, dtype="int32"),
"EndsTensorList": ends_tensor
}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
#'starts': self.starts,
#'ends': self.ends,
'strides': self.strides,
'infer_flags': self.infer_flags,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float64")
self.starts = [1, 0, 2]
self.ends = [2, 3, 4]
self.axes = [0, 1, 2]
self.strides = [1, 1, 1]
self.infer_flags = [-1, -1, -1]
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
def test_check_output(self):
place = paddle.NPUPlace(0)
self.check_output_with_place(place)
@skip_check_grad_ci(
reason='''forward only, it doesn't need to call check_grad.''')
class TestStridedSliceOp_strides_Tensor(OpTest):
def setUp(self):
self.op_type = "strided_slice"
self.config()
self.inputs = {
'Input': self.input,
"StridesTensor": np.array(
self.strides, dtype="int32")
}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends,
#'strides': self.strides,
'infer_flags': self.infer_flags,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float64")
self.starts = [1, -1, 2]
self.ends = [2, 0, 4]
self.axes = [0, 1, 2]
self.strides = [1, -1, 1]
self.infer_flags = [-1, -1, -1]
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
def test_check_output(self):
place = paddle.NPUPlace(0)
self.check_output_with_place(place)
# Test python API
class TestStridedSliceAPI(unittest.TestCase):
def test_1(self):
input = np.random.random([3, 4, 5, 6]).astype("float64")
minus_1 = fluid.layers.fill_constant([1], "int32", -1)
minus_3 = fluid.layers.fill_constant([1], "int32", -3)
starts = fluid.layers.data(
name='starts', shape=[3], dtype='int32', append_batch_size=False)
ends = fluid.layers.data(
name='ends', shape=[3], dtype='int32', append_batch_size=False)
strides = fluid.layers.data(
name='strides', shape=[3], dtype='int32', append_batch_size=False)
x = fluid.layers.data(
name="x",
shape=[3, 4, 5, 6],
append_batch_size=False,
dtype="float64")
out_1 = fluid.layers.strided_slice(
x,
axes=[0, 1, 2],
starts=[-3, 0, 2],
ends=[3, 100, -1],
strides=[1, 1, 1])
out_2 = fluid.layers.strided_slice(
x,
axes=[0, 1, 3],
starts=[minus_3, 0, 2],
ends=[3, 100, -1],
strides=[1, 1, 1])
out_3 = fluid.layers.strided_slice(
x,
axes=[0, 1, 3],
starts=[minus_3, 0, 2],
ends=[3, 100, minus_1],
strides=[1, 1, 1])
out_4 = fluid.layers.strided_slice(
x, axes=[0, 1, 2], starts=starts, ends=ends, strides=strides)
out_5 = x[-3:3, 0:100:2, -1:2:-1]
out_6 = x[minus_3:3:1, 0:100:2, :, minus_1:2:minus_1]
out_7 = x[minus_1, 0:100:2, :, -1:2:-1]
exe = fluid.Executor(place=paddle.NPUPlace(0))
res_1, res_2, res_3, res_4, res_5, res_6, res_7 = exe.run(
fluid.default_main_program(),
feed={
"x": input,
'starts': np.array([-3, 0, 2]).astype("int32"),
'ends': np.array([3, 2147483648, -1]).astype("int64"),
'strides': np.array([1, 1, 1]).astype("int32")
},
fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7])
assert np.array_equal(res_1, input[-3:3, 0:100, 2:-1, :])
assert np.array_equal(res_2, input[-3:3, 0:100, :, 2:-1])
assert np.array_equal(res_3, input[-3:3, 0:100, :, 2:-1])
assert np.array_equal(res_4, input[-3:3, 0:100, 2:-1, :])
assert np.array_equal(res_5, input[-3:3, 0:100:2, -1:2:-1, :])
assert np.array_equal(res_6, input[-3:3, 0:100:2, :, -1:2:-1])
assert np.array_equal(res_7, input[-1, 0:100:2, :, -1:2:-1])
def test_dygraph_op(self):
x = paddle.zeros(shape=[3, 4, 5, 6], dtype="float32")
axes = [1, 2, 3]
starts = [-3, 0, 2]
ends = [3, 2, 4]
strides_1 = [1, 1, 1]
sliced_1 = paddle.strided_slice(
x, axes=axes, starts=starts, ends=ends, strides=strides_1)
assert sliced_1.shape == (3, 2, 2, 2)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册