未验证 提交 1f0f5d3c 编写于 作者: W WeiXin 提交者: GitHub

supplement the function of slice. (#34172)

* supplement the function of slice

* edit unittest

* strided_slice_op support .

* polish error message.

* polish error message.

* polish code.

* polish unittest.

* polish code.

* polish code

* polish error message.
上级 c79fa1c3
......@@ -31,7 +31,13 @@ class StridedSliceOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "StridedSlice");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "StridedSlice");
auto input_var_type = ctx->GetInputsVarType("Input")[0];
if (input_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
if (ctx->IsRuntime()) {
// shape is determined by Runtime.
return;
}
}
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_LT(
in_dims.size(), 7,
......@@ -154,6 +160,27 @@ class StridedSliceOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto *in_var = ctx.InputVar("Input");
auto is_in_var_array = in_var->IsType<framework::LoDTensorArray>();
if (is_in_var_array) {
auto &tensor_array = in_var->Get<framework::LoDTensorArray>();
for (auto &tensor : tensor_array) {
if (!platform::is_cuda_pinned_place(tensor.place())) {
PADDLE_ENFORCE_EQ(
platform::is_same_place(tensor.place(),
ctx.device_context().GetPlace()),
true,
platform::errors::InvalidArgument(
"Place of context is %s. Place of input tensor is %s. They "
"are should be same, but reveived different place.",
string::to_string(ctx.device_context().GetPlace()),
string::to_string(tensor.place())));
}
}
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
// NOTE: cuda pinned tensor need to copy its data to target place
auto in_tensor = ctx.Input<Tensor>("Input");
if (platform::is_cuda_pinned_place(in_tensor->place())) {
......@@ -179,6 +206,14 @@ class StridedSliceOp : public framework::OperatorWithKernel {
}
};
class StridedSliceOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SetOutputType("Out", ctx->GetInputType("Input"));
ctx->SetOutputDataType("Out", ctx->GetInputDataType("Input"));
}
};
class StridedSliceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -259,6 +294,13 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "StridedSliceGrad");
auto input_var_type = ctx->GetInputsVarType("Input")[0];
if (input_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
if (ctx->IsRuntime()) {
// shape is determined by Runtime
return;
}
}
auto x_dims = ctx->GetInputDim("Input");
auto x_grad_name = framework::GradVarName("Input");
if (ctx->HasOutput(x_grad_name)) {
......@@ -308,6 +350,16 @@ class StridedSliceOpGradMaker : public framework::SingleGradOpMaker<T> {
bind->SetType("strided_slice_grad");
}
};
class StridedSliceGradOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SetOutputType(framework::GradVarName("Input"),
ctx->GetInputType(framework::GradVarName("Out")));
ctx->SetOutputDataType(
framework::GradVarName("Input"),
ctx->GetInputDataType(framework::GradVarName("Out")));
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer,
"Input");
......@@ -318,9 +370,12 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer,
namespace ops = paddle::operators;
REGISTER_OPERATOR(strided_slice, ops::StridedSliceOp, ops::StridedSliceOpMaker,
ops::StridedSliceOpGradMaker<paddle::framework::OpDesc>,
ops::StridedSliceOpGradMaker<paddle::imperative::OpBase>);
ops::StridedSliceOpGradMaker<paddle::imperative::OpBase>,
ops::StridedSliceOpVarTypeInference);
REGISTER_OPERATOR(strided_slice_grad, ops::StridedSliceOpGrad,
ops::StridedSliceOpGradNoNeedBufferVarsInferer);
ops::StridedSliceOpGradNoNeedBufferVarsInferer,
ops::StridedSliceGradOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(
strided_slice,
......
......@@ -127,6 +127,9 @@ static void StridedSliceFunctor(int64_t* starts, int64_t* ends,
if (!(ends[axis_index] == -1 &&
strides[axis_index] < 0)) { // skip None stop condition
ends[axis_index] = ends[axis_index] + axis_size;
if (ends[axis_index] < 0) {
ends[axis_index] = 0;
}
}
}
if (decrease_axis_affect) {
......@@ -136,14 +139,19 @@ static void StridedSliceFunctor(int64_t* starts, int64_t* ends,
ends[axis_index] = starts[axis_index] + 1;
}
}
if ((starts[axis_index] < 0) && (axis_size > 0)) {
starts[axis_index] += axis_size;
starts[axis_index] = std::max<int64_t>(starts[axis_index], 0);
}
if (strides[axis_index] < 0) {
reverse_axis[axis_index] = 1;
strides[axis_index] = -strides[axis_index];
if (starts[axis_index] > ends[axis_index]) {
// swap the reverse
auto end_dim = dims[axis_index] - 1 < starts[axis_index]
? dims[axis_index] - 1
: starts[axis_index];
auto end_dim = axis_size - 1 < starts[axis_index] ? axis_size - 1
: starts[axis_index];
auto offset = (end_dim - ends[axis_index]) % strides[axis_index];
offset = offset == 0 ? strides[axis_index] : offset;
......@@ -162,7 +170,11 @@ template <typename DeviceContext, typename T>
class StridedSliceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int rank = ctx.Input<framework::Tensor>("Input")->dims().size();
const Variable* input_var = ctx.InputVar("Input");
bool is_tensor_array = input_var->IsType<LoDTensorArray>();
int rank = is_tensor_array
? 1
: ctx.Input<framework::Tensor>("Input")->dims().size();
switch (rank) {
case 1:
StridedSliceCompute<1>(ctx);
......@@ -190,9 +202,17 @@ class StridedSliceKernel : public framework::OpKernel<T> {
void StridedSliceCompute(const framework::ExecutionContext& context) const {
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto in = context.Input<framework::Tensor>("Input");
auto out = context.Output<framework::Tensor>("Out");
auto in_dims = in->dims();
framework::DDim in_dims;
auto* input_var = context.InputVar("Input");
bool is_input_var_array = input_var->IsType<LoDTensorArray>();
if (is_input_var_array) {
const int64_t size = input_var->Get<framework::LoDTensorArray>().size();
in_dims = framework::make_ddim({size});
} else {
in_dims = context.Input<framework::Tensor>("Input")->dims();
}
auto starts_int = context.Attr<std::vector<int>>("starts");
auto ends_int = context.Attr<std::vector<int>>("ends");
......@@ -295,29 +315,97 @@ class StridedSliceKernel : public framework::OpKernel<T> {
}
}
out->Resize(out_dims);
out->mutable_data<T>(context.GetPlace());
auto in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*in);
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*out, out_dims);
if (need_reverse) {
framework::Tensor tmp;
tmp.mutable_data<T>(out_dims, context.GetPlace());
auto tmp_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(tmp);
tmp_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
out_t.device(place) = tmp_t.reverse(reverse_axis);
if (is_input_var_array) {
PADDLE_ENFORCE_EQ(
starts_indices.size(), 1,
platform::errors::InvalidArgument(
"When the input of 'strided_slice_op' is `TensorArray`, the "
"dimension of start index should be 1, but received %d.",
starts_indices.size()));
PADDLE_ENFORCE_EQ(
ends_indices.size(), 1,
platform::errors::InvalidArgument(
"When the input of 'strided_slice_op' is `TensorArray`, the "
"dimension of end index should be 1, but received %d.",
ends_indices.size()));
PADDLE_ENFORCE_EQ(
strides_indices.size(), 1,
platform::errors::InvalidArgument(
"When the input of 'strided_slice_op' is `TensorArray`, the "
"dimension of stride should be 1, but received %d.",
strides_indices.size()));
auto* output_var = context.OutputVar("Out");
PADDLE_ENFORCE_EQ(
output_var->IsType<LoDTensorArray>(), true,
platform::errors::InvalidArgument(
"When the input of `strided_slice_op` is `TensorArray`. The "
"output is excepted `TensorArray` , but received %s.",
framework::ToTypeName(output_var->Type())));
PADDLE_ENFORCE_EQ(
out_dims_origin.size(), 1,
platform::errors::InvalidArgument(
"When the input of 'strided_slice_op' is `TensorArray`, the "
"dimension of Output should be 1, but received %d",
out_dims_origin.size()));
auto& in_array = input_var->Get<framework::LoDTensorArray>();
auto* out_array = context.Output<framework::LoDTensorArray>("Out");
out_array->resize(out_dims_origin[0]);
size_t const in_array_size = in_array.size();
for (size_t i = 0; i < out_array->size(); i++) {
size_t in_offset =
(starts_indices[0] % in_array_size) + i * strides_indices[0];
int64_t out_offset = i;
if (need_reverse) {
out_offset = out_array->size() - i - 1;
}
auto& in_tensor = in_array.at(in_offset);
PADDLE_ENFORCE_GT(
in_tensor.memory_size(), 0,
platform::errors::PreconditionNotMet(
"The input LoDTensorArray Input[%d] holds no memory.",
in_offset));
auto* out_tensor = &out_array->at(out_offset);
out_tensor->set_lod(in_tensor.lod());
TensorCopy(in_tensor, context.GetPlace(), out_tensor);
}
} else {
out_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
}
auto in = context.Input<framework::Tensor>("Input");
auto out = context.Output<framework::Tensor>("Out");
out->Resize(out_dims);
out->mutable_data<T>(context.GetPlace());
auto in_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(*in);
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(*out, out_dims);
if (need_reverse) {
framework::Tensor tmp;
tmp.mutable_data<T>(out_dims, context.GetPlace());
auto tmp_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(tmp);
tmp_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
out_t.device(place) = tmp_t.reverse(reverse_axis);
} else {
out_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
}
if (decrease_axis.size() > 0) {
out->Resize(out_dims_origin);
if (decrease_axis.size() > 0) {
out->Resize(out_dims_origin);
}
}
}
};
......@@ -326,7 +414,11 @@ template <typename DeviceContext, typename T>
class StridedSliceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
size_t rank = ctx.Input<framework::Tensor>("Input")->dims().size();
const Variable* input_var = ctx.InputVar("Input");
bool is_tensor_array = input_var->IsType<LoDTensorArray>();
int rank = is_tensor_array
? 1
: ctx.Input<framework::Tensor>("Input")->dims().size();
switch (rank) {
case 1:
StridedSliceGradCompute<1>(ctx);
......@@ -355,17 +447,27 @@ class StridedSliceGradKernel : public framework::OpKernel<T> {
const framework::ExecutionContext& context) const {
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto* d_input =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_out =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
d_out->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, d_out, static_cast<T>(0));
auto out_dims = d_out->dims();
auto in_dims = d_input->dims();
framework::DDim out_dims;
auto* out_var = context.OutputVar(framework::GradVarName("Input"));
bool is_out_var_array = out_var->IsType<LoDTensorArray>();
if (is_out_var_array) {
// Note(weixin):Since the shape of `framework::GradVarName("Input")` of
// StridedSliceGrad cannot be calculated by
// `framework::GradVarName("Output")`, the dim of "Input" is used to
// calculate the output shape. when set it to inplace OP, there may be
// some problems.
const int64_t size =
context.Input<framework::LoDTensorArray>("Input")->size();
out_dims = framework::make_ddim({size});
} else {
out_dims =
context.Output<framework::Tensor>(framework::GradVarName("Input"))
->dims();
}
auto starts_int = context.Attr<std::vector<int>>("starts");
auto ends_int = context.Attr<std::vector<int>>("ends");
......@@ -438,25 +540,121 @@ class StridedSliceGradKernel : public framework::OpKernel<T> {
break;
}
}
auto in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_input);
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_out, out_dims);
if (need_reverse) {
framework::Tensor reverse_input;
reverse_input.mutable_data<T>(in_dims, context.GetPlace());
auto reverse_in_t =
framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(reverse_input);
reverse_in_t.device(place) = in_t.reverse(reverse_axis);
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(place) = reverse_in_t;
if (is_out_var_array) {
PADDLE_ENFORCE_EQ(
starts_indices.size(), 1,
platform::errors::InvalidArgument(
"When the input of 'strided_slice_grad_op' is `TensorArray`, the "
"dimension of start index should be 1, but received %d.",
starts_indices.size()));
PADDLE_ENFORCE_EQ(
ends_indices.size(), 1,
platform::errors::InvalidArgument(
"When the input of 'strided_slice_op' is `TensorArray`, the "
"dimension of end index should be 1, but received %d.",
ends_indices.size()));
PADDLE_ENFORCE_EQ(
strides_indices.size(), 1,
platform::errors::InvalidArgument(
"When the input of 'strided_slice_grad_op' is `TensorArray`, the "
"dimension of stride should be 1, but received %d.",
strides_indices.size()));
auto* d_input_var = context.InputVar(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(
d_input_var->IsType<LoDTensorArray>(), true,
platform::errors::InvalidArgument(
"When the output of `strided_slice_grad_op` is "
"`TensorArray`, the input is excepted `TensorArray` , "
"but received %s.",
framework::ToTypeName(d_input_var->Type())));
PADDLE_ENFORCE_EQ(
out_dims.size(), 1,
platform::errors::InvalidArgument(
"When the output of `strided_slice_grad_op` is `TensorArray`, "
"the dimension of output should be 1, but received %d.",
out_dims.size()));
auto& d_in_array = d_input_var->Get<framework::LoDTensorArray>();
auto* d_out_array = context.Output<framework::LoDTensorArray>(
framework::GradVarName("Input"));
d_out_array->resize(out_dims[0]);
auto const d_out_array_size = d_out_array->size();
auto* input_tensor_array =
context.Input<framework::LoDTensorArray>("Input");
for (size_t j = 0; j < d_out_array_size; j++) {
auto& dim = input_tensor_array->at(j).dims();
auto* d_out_tensor = &d_out_array->at(j);
int64_t sub = j - starts_indices[0];
int64_t in_offset = sub / strides_indices[0];
if (need_reverse) {
in_offset = d_in_array.size() - in_offset - 1;
}
if ((sub % strides_indices[0] == 0) && (0 <= in_offset) &&
(static_cast<size_t>(in_offset) < d_in_array.size())) {
auto& in_tensor = d_in_array.at(in_offset);
PADDLE_ENFORCE_GT(
in_tensor.memory_size(), 0,
platform::errors::PreconditionNotMet(
"The input LoDTensorArray Input[%d] holds no memory.",
in_offset));
d_out_tensor->set_lod(in_tensor.lod());
TensorCopy(in_tensor, context.GetPlace(), d_out_tensor);
} else {
d_out_tensor->Resize(dim);
if (!d_out_tensor->IsInitialized()) {
d_out_tensor->mutable_data<T>(context.GetPlace());
}
math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, d_out_tensor, static_cast<T>(0));
}
}
} else {
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(place) = in_t;
auto* d_input =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_out =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
d_out->mutable_data<T>(context.GetPlace());
math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, d_out, static_cast<T>(0));
auto in_dims = d_input->dims();
auto in_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(*d_input);
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(*d_out, out_dims);
if (need_reverse) {
framework::Tensor reverse_input;
reverse_input.mutable_data<T>(in_dims, context.GetPlace());
auto reverse_in_t =
framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(reverse_input);
reverse_in_t.device(place) = in_t.reverse(reverse_axis);
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(place) = reverse_in_t;
} else {
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(place) = in_t;
}
}
}
};
......
......@@ -18,6 +18,7 @@ import unittest
import numpy as np
import paddle
from paddle.static import InputSpec
SEED = 2020
np.random.seed(SEED)
......@@ -176,6 +177,46 @@ class TestSetValueWithLayerAndSave(unittest.TestCase):
output_spec=None)
class TestSliceSupplementSpecialCase(unittest.TestCase):
# unittest for slice index which abs(step)>0. eg: x[::2]
def test_static_slice_step(self):
paddle.enable_static()
array = np.arange(4**3).reshape((4, 4, 4)).astype('int64')
x = paddle.static.data(name='x', shape=[4, 4, 4], dtype='int64')
z1 = x[::2]
z2 = x[::-2]
place = paddle.CPUPlace()
prog = paddle.static.default_main_program()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
out = exe.run(prog, feed={'x': array}, fetch_list=[z1, z2])
self.assertTrue(np.array_equal(out[0], array[::2]))
self.assertTrue(np.array_equal(out[1], array[::-2]))
def test_static_slice_step_dygraph2static(self):
paddle.disable_static()
array = np.arange(4**2 * 5).reshape((5, 4, 4)).astype('int64')
inps = paddle.to_tensor(array)
def func(inps):
return inps[::2], inps[::-2]
origin_result = func(inps)
sfunc = paddle.jit.to_static(
func, input_spec=[InputSpec(shape=[None, 4, 4])])
static_result = sfunc(inps)
self.assertTrue(
np.array_equal(origin_result[0].numpy(), static_result[0].numpy()))
self.assertTrue(
np.array_equal(origin_result[1].numpy(), static_result[1].numpy()))
class TestPaddleStridedSlice(unittest.TestCase):
def test_compare_paddle_strided_slice_with_numpy(self):
paddle.disable_static()
......@@ -202,6 +243,20 @@ class TestPaddleStridedSlice(unittest.TestCase):
np.array_equal(sl.numpy(), array[s2[0]:e2[0]:stride2[0], s2[1]:e2[
1]:stride2[1]]))
array = np.arange(6 * 7 * 8).reshape((6, 7, 8))
pt = paddle.to_tensor(array)
s2 = [7, -1]
e2 = [2, -5]
stride2 = [-2, -3]
sl = paddle.strided_slice(
pt, axes=[0, 2], starts=s2, ends=e2, strides=stride2)
array_slice = array[s2[0]:e2[0]:stride2[0], ::, s2[1]:e2[1]:stride2[1]]
self.assertTrue(
np.array_equal(sl.numpy(), array_slice),
msg="paddle.strided_slice:\n {} \n numpy slice:\n{}".format(
sl.numpy(), array_slice))
if __name__ == '__main__':
unittest.main()
......@@ -588,5 +588,331 @@ class TestStridedSliceAPI(unittest.TestCase):
self.assertFalse(y.place.is_cuda_pinned_place())
class ArrayLayer(paddle.nn.Layer):
def __init__(self, input_size=224, output_size=10, array_size=1):
super(ArrayLayer, self).__init__()
self.input_size = input_size
self.output_size = output_size
self.array_size = array_size
for i in range(self.array_size):
setattr(self,
self.create_name(i),
paddle.nn.Linear(input_size, output_size))
def create_name(self, index):
return 'linear_' + str(index)
def forward(self, inps):
array = []
for i in range(self.array_size):
linear = getattr(self, self.create_name(i))
array.append(linear(inps))
tensor_array = self.create_tensor_array(array)
tensor_array = self.array_slice(tensor_array)
array1 = paddle.concat(tensor_array)
array2 = paddle.concat(tensor_array[::-1])
return array1 + array2 * array2
def get_all_grads(self, param_name='weight'):
grads = []
for i in range(self.array_size):
linear = getattr(self, self.create_name(i))
param = getattr(linear, param_name)
g = param.grad
if g is not None:
g = g.numpy()
grads.append(g)
return grads
def clear_all_grad(self):
param_names = ['weight', 'bias']
for i in range(self.array_size):
linear = getattr(self, self.create_name(i))
for p in param_names:
param = getattr(linear, p)
param.clear_gradient()
def array_slice(self, array):
return array
def create_tensor_array(self, tensors):
tensor_array = None
for i, tensor in enumerate(tensors):
index = paddle.full(shape=[1], dtype='int64', fill_value=i)
if tensor_array is None:
tensor_array = paddle.tensor.array_write(tensor, i=index)
else:
paddle.tensor.array_write(tensor, i=index, array=tensor_array)
return tensor_array
class TestStridedSliceTensorArray(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def grad_equal(self, g1, g2):
if g1 is None:
g1 = np.zeros_like(g2)
if g2 is None:
g2 = np.zeros_like(g1)
return np.array_equal(g1, g2)
def is_grads_equal(self, g1, g2):
for i, g in enumerate(g1):
self.assertTrue(
self.grad_equal(g, g2[i]),
msg="gradient_1:\n{} \ngradient_2:\n{}".format(g, g2))
def is_grads_equal_zeros(self, grads):
for g in grads:
self.assertTrue(
self.grad_equal(np.zeros_like(g), g),
msg="The gradient should be zeros, but received \n{}".format(g))
def create_case(self, net):
inps1 = paddle.randn([1, net.input_size], dtype='float32')
inps2 = inps1.detach().clone()
l1 = net(inps1)
s1 = l1.numpy()
l1.sum().backward()
grads_dy = net.get_all_grads()
net.clear_all_grad()
grads_zeros = net.get_all_grads()
self.is_grads_equal_zeros(grads_zeros)
func = paddle.jit.to_static(net.forward)
l2 = func(inps2)
s2 = l2.numpy()
l2.sum().backward()
grads_static = net.get_all_grads()
net.clear_all_grad()
# compare result of dygraph and static
self.is_grads_equal(grads_static, grads_dy)
self.assertTrue(
np.array_equal(s1, s2),
msg="dygraph graph result:\n{} \nstatic dygraph result:\n{}".format(
l1.numpy(), l2.numpy()))
def test_strided_slice_tensor_array_cuda_pinned_place(self):
if paddle.device.is_compiled_with_cuda():
with paddle.fluid.dygraph.guard():
class Simple(paddle.nn.Layer):
def __init__(self):
super(Simple, self).__init__()
def forward(self, inps):
tensor_array = None
for i, tensor in enumerate(inps):
index = paddle.full(
shape=[1], dtype='int64', fill_value=i)
if tensor_array is None:
tensor_array = paddle.tensor.array_write(
tensor, i=index)
else:
paddle.tensor.array_write(
tensor, i=index, array=tensor_array)
array1 = paddle.concat(tensor_array)
array2 = paddle.concat(tensor_array[::-1])
return array1 + array2 * array2
net = Simple()
func = paddle.jit.to_static(net.forward)
inps1 = paddle.to_tensor(
np.random.randn(2, 10),
place=paddle.CUDAPinnedPlace(),
stop_gradient=False)
inps2 = paddle.to_tensor(
np.random.randn(2, 10),
place=paddle.CUDAPinnedPlace(),
stop_gradient=False)
self.assertTrue(inps1.place.is_cuda_pinned_place())
self.assertTrue(inps2.place.is_cuda_pinned_place())
result = func([inps1, inps2])
self.assertFalse(result.place.is_cuda_pinned_place())
def test_strided_slice_tensor_array(self):
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[::-1]
self.create_case(Net(array_size=10))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[::-2]
self.create_case(Net(input_size=112, array_size=11))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[::-3]
self.create_case(Net(input_size=112, array_size=9))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[1::-4]
self.create_case(Net(input_size=112, array_size=9))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[:7:-4]
self.create_case(Net(input_size=112, array_size=9))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[8:0:-4]
self.create_case(Net(input_size=112, array_size=9))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[8:1:-4]
self.create_case(Net(input_size=112, array_size=9))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[::2]
self.create_case(Net(input_size=112, array_size=11))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[::3]
self.create_case(Net(input_size=112, array_size=9))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[1::4]
self.create_case(Net(input_size=112, array_size=9))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[:8:4]
self.create_case(Net(input_size=112, array_size=9))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[1:8:4]
self.create_case(Net(input_size=112, array_size=9))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[8:10:4]
self.create_case(Net(input_size=112, array_size=13))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[3:10:4]
self.create_case(Net(input_size=112, array_size=13))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[2:10:4]
self.create_case(Net(input_size=112, array_size=13))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[3:10:3]
self.create_case(Net(input_size=112, array_size=13))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[3:15:3]
self.create_case(Net(input_size=112, array_size=13))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[0:15:3]
self.create_case(Net(input_size=112, array_size=13))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[-1:-5:-3]
self.create_case(Net(input_size=112, array_size=13))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[-1:-6:-3]
self.create_case(Net(input_size=112, array_size=13))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[-3:-6:-3]
self.create_case(Net(input_size=112, array_size=13))
self.create_case(Net(input_size=112, array_size=13))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[-5:-1:3]
self.create_case(Net(input_size=112, array_size=13))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[-6:-1:3]
self.create_case(Net(input_size=112, array_size=13))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[-6:-3:3]
self.create_case(Net(input_size=112, array_size=13))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[0::3]
self.create_case(Net(input_size=112, array_size=13))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[-60:20:3]
self.create_case(Net(input_size=112, array_size=13))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[-3:-60:-3]
self.create_case(Net(input_size=112, array_size=13))
class Net(ArrayLayer):
def array_slice(self, tensors):
return tensors[-1:-60:-3]
if __name__ == "__main__":
unittest.main()
......@@ -144,13 +144,10 @@ def _getitem_impl_(var, item):
step = 1 if step is None else step
if start is None and end is None:
assert (step == -1)
reverse_axes.append(dim)
continue
start = 0 if start is None else start
end = MAX_INTEGER if end is None else end
if start is None:
start = 0 if step > 0 else MAX_INTEGER
if end is None:
end = MAX_INTEGER if step > 0 else -1
elif isinstance(slice_item, list):
is_bool_list = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册