未验证 提交 cefd0fb5 编写于 作者: H Hongyu Liu 提交者: GitHub

Fix slice op shape=-1 bug (#18107)

* fix slice op bug; test=develop

* fix variabel test bug; test=develop

* remove slice while true; test=develop
上级 b3cbc5be
...@@ -39,21 +39,49 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -39,21 +39,49 @@ class SliceOp : public framework::OperatorWithKernel {
auto axes = ctx->Attrs().Get<std::vector<int>>("axes"); auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
auto starts = ctx->Attrs().Get<std::vector<int>>("starts"); auto starts = ctx->Attrs().Get<std::vector<int>>("starts");
auto ends = ctx->Attrs().Get<std::vector<int>>("ends"); auto ends = ctx->Attrs().Get<std::vector<int>>("ends");
auto decrease_axis = ctx->Attrs().Get<std::vector<int>>("decrease_axis");
PADDLE_ENFORCE_EQ(starts.size(), ends.size()); PADDLE_ENFORCE_EQ(starts.size(), ends.size());
PADDLE_ENFORCE_EQ(starts.size(), axes.size()); PADDLE_ENFORCE_EQ(starts.size(), axes.size());
int dim_value, start, end; int dim_value, start, end;
for (size_t i = 0; i < axes.size(); ++i) { for (size_t i = 0; i < axes.size(); ++i) {
dim_value = out_dims[axes[i]]; dim_value = out_dims[axes[i]];
start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i]; if (dim_value > 0) {
end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i]; start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
start = std::max(start, 0); end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
end = std::max(end, 0); start = std::max(start, 0);
start = std::min(start, dim_value); end = std::max(end, 0);
end = std::min(end, dim_value); // start = std::min(start, dim_value);
start = std::min(start, end); end = std::min(end, dim_value);
out_dims[axes[i]] = end - start; // start = std::min(start, end);
PADDLE_ENFORCE_GT(end, start, "end should greater than start");
out_dims[axes[i]] = end - start;
}
} }
// generate new shape
if (decrease_axis.size() > 0) {
std::vector<int> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(out_dims[decrease_axis[i]], 1,
"decrease dim should be 1");
}
out_dims[decrease_axis[i]] = 0;
}
for (int i = 0; i < out_dims.size(); ++i) {
if (out_dims[i] != 0) {
new_out_shape.push_back(out_dims[i]);
}
}
if (new_out_shape.size() == 0) {
new_out_shape.push_back(1);
}
out_dims = framework::make_ddim(new_out_shape);
}
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
if (axes[0] != 0) { if (axes[0] != 0) {
ctx->ShareLoD("Input", /*->*/ "Out"); ctx->ShareLoD("Input", /*->*/ "Out");
...@@ -84,7 +112,8 @@ class SliceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -84,7 +112,8 @@ class SliceOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"ends", "ends",
"(list<int>) Starting indices of corresponding axis in `axes`."); "(list<int>) Starting indices of corresponding axis in `axes`.");
AddAttr<std::vector<int>>("decrease_axis", "(list<int>) decrease_axis")
.SetDefault({});
AddComment(R"DOC( AddComment(R"DOC(
Slice Operator. Slice Operator.
......
...@@ -55,17 +55,45 @@ class SliceKernel : public framework::OpKernel<T> { ...@@ -55,17 +55,45 @@ class SliceKernel : public framework::OpKernel<T> {
*context.template device_context<DeviceContext>().eigen_device(); *context.template device_context<DeviceContext>().eigen_device();
auto in = context.Input<framework::Tensor>("Input"); auto in = context.Input<framework::Tensor>("Input");
auto out = context.Output<framework::Tensor>("Out"); auto out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto out_dims = out->dims(); auto out_dims = out->dims();
auto in_dims = in->dims(); auto in_dims = in->dims();
// resize out_dims
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis");
if (decrease_axis.size() > 0) {
if (decrease_axis.size() == (size_t)in_dims.size()) {
std::vector<int> vec_origin_out_shape(decrease_axis.size(), 1);
out->Resize(framework::make_ddim(vec_origin_out_shape));
} else {
std::vector<int> vec_origin_out_shape(
out_dims.size() + decrease_axis.size(), -1);
for (size_t i = 0; i < decrease_axis.size(); ++i) {
vec_origin_out_shape[decrease_axis[i]] = 1;
}
int index = 0;
for (size_t i = 0; i < vec_origin_out_shape.size(); ++i) {
if (vec_origin_out_shape[i] == -1) {
vec_origin_out_shape[i] = out_dims[index];
++index;
}
}
out->Resize(framework::make_ddim(vec_origin_out_shape));
}
}
out->mutable_data<T>(context.GetPlace());
auto axes = context.Attr<std::vector<int>>("axes"); auto axes = context.Attr<std::vector<int>>("axes");
auto starts = context.Attr<std::vector<int>>("starts"); auto starts = context.Attr<std::vector<int>>("starts");
auto new_out_dims = out->dims();
auto offsets = Eigen::array<int, D>(); auto offsets = Eigen::array<int, D>();
auto extents = Eigen::array<int, D>(); auto extents = Eigen::array<int, D>();
for (size_t i = 0; i < D; ++i) { for (size_t i = 0; i < D; ++i) {
offsets[i] = 0; offsets[i] = 0;
extents[i] = out_dims[i]; extents[i] = new_out_dims[i];
} }
int start; int start;
for (size_t i = 0; i < axes.size(); ++i) { for (size_t i = 0; i < axes.size(); ++i) {
...@@ -81,8 +109,10 @@ class SliceKernel : public framework::OpKernel<T> { ...@@ -81,8 +109,10 @@ class SliceKernel : public framework::OpKernel<T> {
*in); *in);
auto out_t = auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From( framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*out); *out, new_out_dims);
out_t.device(place) = in_t.slice(offsets, extents); out_t.device(place) = in_t.slice(offsets, extents);
out->Resize(out_dims);
} }
}; };
...@@ -90,9 +120,7 @@ template <typename DeviceContext, typename T> ...@@ -90,9 +120,7 @@ template <typename DeviceContext, typename T>
class SliceGradKernel : public framework::OpKernel<T> { class SliceGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
size_t rank = ctx.Input<framework::Tensor>(framework::GradVarName("Out")) size_t rank = ctx.Input<framework::Tensor>("Input")->dims().size();
->dims()
.size();
switch (rank) { switch (rank) {
case 1: case 1:
SliceCompute<1>(ctx); SliceCompute<1>(ctx);
...@@ -130,6 +158,32 @@ class SliceGradKernel : public framework::OpKernel<T> { ...@@ -130,6 +158,32 @@ class SliceGradKernel : public framework::OpKernel<T> {
auto axes = context.Attr<std::vector<int>>("axes"); auto axes = context.Attr<std::vector<int>>("axes");
auto starts = context.Attr<std::vector<int>>("starts"); auto starts = context.Attr<std::vector<int>>("starts");
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis");
if (decrease_axis.size() > 0) {
if (decrease_axis.size() == (size_t)in_dims.size()) {
// all dims decrease
std::vector<int> vec_origin_out_shape(decrease_axis.size(), 1);
out_dims = framework::make_ddim(vec_origin_out_shape);
} else {
std::vector<int> vec_origin_out_shape(
out_dims.size() + decrease_axis.size(), -1);
for (size_t i = 0; i < decrease_axis.size(); ++i) {
vec_origin_out_shape[decrease_axis[i]] = 1;
}
int index = 0;
for (size_t i = 0; i < vec_origin_out_shape.size(); ++i) {
if (vec_origin_out_shape[i] == -1) {
vec_origin_out_shape[i] = out_dims[index];
++index;
}
}
out_dims = framework::make_ddim(vec_origin_out_shape);
}
}
auto offsets = Eigen::array<int, D>(); auto offsets = Eigen::array<int, D>();
auto extents = Eigen::array<int, D>(); auto extents = Eigen::array<int, D>();
for (size_t i = 0; i < D; ++i) { for (size_t i = 0; i < D; ++i) {
...@@ -155,7 +209,7 @@ class SliceGradKernel : public framework::OpKernel<T> { ...@@ -155,7 +209,7 @@ class SliceGradKernel : public framework::OpKernel<T> {
*d_input); *d_input);
auto d_out_t = auto d_out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From( framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_out); *d_out, out_dims);
d_in_t.device(place) = d_out_t.pad(paddings, 0); d_in_t.device(place) = d_out_t.pad(paddings, 0);
} }
}; };
......
...@@ -822,35 +822,84 @@ class Variable(object): ...@@ -822,35 +822,84 @@ class Variable(object):
Returns: Returns:
Sliced variable Sliced variable
""" """
new_var = None
if isinstance(item, tuple):
if len(item) > len(self.shape):
raise IndexError("Too many indexes")
fixedSize = True
for i in range(len(self.shape)):
if self.shape[i] == -1:
fixedSize = False
break
newitem = self._reconstructSliceinfo(item) or item if not isinstance(item, tuple):
if fixedSize: item = [item]
check, info = self._detectContinuesSlice(newitem)
if check: decrease_axis = []
starts = info[0] slice_axis = []
ends = info[1] slice_start = []
axes = [i for i in range(len(starts))] slice_end = []
return self._sliceVar(axes, starts, ends) reverse_axis = []
else:
new_var = self for dim, slice_item in enumerate(item):
for index, o in enumerate(newitem): if isinstance(slice_item, slice):
new_var = new_var._sliceAndConcatVar(o, index) start = slice_item.start
end = slice_item.stop
step = slice_item.step if slice_item.step else 1
assert (step == 1 or step == -1)
if step == -1:
reverse_axis.append(dim)
assert (start is None and end is None)
if start is None and end is None:
continue
if start is None:
start = 0
if end is None:
end = 10000000
slice_axis.append(dim)
slice_start.append(start)
slice_end.append(end)
else: else:
new_var = self # int
for index, o in enumerate(newitem): decrease_axis.append(dim)
new_var = new_var._sliceAndConcatVar(o, index) slice_axis.append(dim)
else: slice_start.append(slice_item)
new_var = self._sliceAndConcatVar(item, 0) slice_end.append(slice_item + 1
return new_var if slice_item != -1 else 10000000)
out = self
if len(slice_axis) > 0:
# append slice_op here
slice_out_var = self.block.create_var(
name=unique_name.generate_with_ignorable_key(self.name +
"_slice"),
dtype=self.dtype)
self.block.append_op(
type="slice",
inputs={'Input': [out]},
outputs={'Out': [slice_out_var]},
attrs={
'axes': slice_axis,
'starts': slice_start,
'ends': slice_end,
'decrease_axis': decrease_axis
})
out = slice_out_var
if len(reverse_axis) > 0:
reverse_out_var = self.block.create_var(
name=unique_name.generate_with_ignorable_key(self.name +
"_slice_reverse"),
dtype=self.dtype)
self.block.append_op(
type="reverse",
inputs={'X': out},
outputs={'Out': [reverse_out_var]},
attrs={'axis': reverse_axis})
out = reverse_out_var
return out
def get_all_op_protos(): def get_all_op_protos():
......
...@@ -46,6 +46,146 @@ class TestSliceOp(OpTest): ...@@ -46,6 +46,146 @@ class TestSliceOp(OpTest):
self.check_grad(['Input'], 'Out', max_relative_error=0.006) self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestSliceOp_decs_dim(OpTest):
def setUp(self):
self.op_type = "slice"
self.config()
self.inputs = {'Input': self.input}
self.outputs = {'Out': self.out}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends,
'decrease_axis': self.decrease_axis,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [1, 0, 2]
self.ends = [2, 3, 4]
self.axes = [0, 1, 2]
self.decrease_axis = [0]
self.out = self.input[1, 0:3, 2:4, :]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestSliceOp_decs_dim_2(OpTest):
def setUp(self):
self.op_type = "slice"
self.config()
self.inputs = {'Input': self.input}
self.outputs = {'Out': self.out}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends,
'decrease_axis': self.decrease_axis,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [1, 0, 2]
self.ends = [2, 1, 4]
self.axes = [0, 1, 2]
self.decrease_axis = [0, 1]
self.out = self.input[1, 0, 2:4, :]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestSliceOp_decs_dim_3(OpTest):
def setUp(self):
self.op_type = "slice"
self.config()
self.inputs = {'Input': self.input}
self.outputs = {'Out': self.out}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends,
'decrease_axis': self.decrease_axis,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [-1, 0, 2]
self.ends = [1000000, 1, 4]
self.axes = [0, 1, 2]
self.decrease_axis = [0, 1]
self.out = self.input[-1, 0, 2:4, :]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestSliceOp_decs_dim_5(OpTest):
def setUp(self):
self.op_type = "slice"
self.config()
self.inputs = {'Input': self.input}
self.outputs = {'Out': self.out}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends,
'decrease_axis': self.decrease_axis,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [-1]
self.ends = [1000000]
self.axes = [3]
self.decrease_axis = [3]
self.out = self.input[:, :, :, -1]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestSliceOp_decs_dim_6(OpTest):
def setUp(self):
self.op_type = "slice"
self.config()
self.inputs = {'Input': self.input}
self.outputs = {'Out': self.out}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends,
'decrease_axis': self.decrease_axis,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [0, 1, 2, 3]
self.ends = [1, 2, 3, 4]
self.axes = [0, 1, 2, 3]
self.decrease_axis = [0, 1, 2, 3]
self.out = self.input[0, 1, 2, 3:4]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestCase1(TestSliceOp): class TestCase1(TestSliceOp):
def config(self): def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32") self.input = np.random.random([3, 4, 5, 6]).astype("float32")
......
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest import unittest
from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, in_dygraph_mode from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, in_dygraph_mode
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core import paddle.fluid.core as core
import numpy as np import numpy as np
...@@ -67,19 +68,24 @@ class TestVariable(unittest.TestCase): ...@@ -67,19 +68,24 @@ class TestVariable(unittest.TestCase):
for i in range(3): for i in range(3):
nw = w[i] nw = w[i]
self.assertEqual((1, 100, 100), nw.shape) self.assertEqual((100, 100), nw.shape)
nw = w[:] nw = w[:]
self.assertEqual((784, 100, 100), nw.shape) self.assertEqual((784, 100, 100), nw.shape)
nw = w[:, :, ...] nw = w[:, :]
self.assertEqual((784, 100, 100), nw.shape) self.assertEqual((784, 100, 100), nw.shape)
nw = w[::2, ::2, :] nw = w[:, :, -1]
self.assertEqual((392, 50, 100), nw.shape) self.assertEqual((784, 100), nw.shape)
nw = w[::-2, ::-2, :] nw = w[1, 1, 1]
self.assertEqual((392, 50, 100), nw.shape)
self.assertEqual(len(nw.shape), 1)
self.assertEqual(nw.shape[0], 1)
nw = w[:, :, :-1]
self.assertEqual((784, 100, 99), nw.shape)
self.assertEqual(0, nw.lod_level) self.assertEqual(0, nw.lod_level)
...@@ -94,18 +100,23 @@ class TestVariable(unittest.TestCase): ...@@ -94,18 +100,23 @@ class TestVariable(unittest.TestCase):
var1 = var[0, 1, 1] var1 = var[0, 1, 1]
var2 = var[1:] var2 = var[1:]
var3 = var[0:1] var3 = var[0:1]
var4 = var[..., ] var4 = var[::-1]
var5 = var[2::-2] var5 = var[1, 1:, 1:]
var6 = var[1, 1:, 1:]
var7 = var[1, ..., 1:]
var8 = var[1, ...]
var_reshape = fluid.layers.reshape(var, [3, -1, 3]) var_reshape = fluid.layers.reshape(var, [3, -1, 3])
var9 = var_reshape[1, ..., 2] var6 = var_reshape[:, :, -1]
var10 = var_reshape[:, :, -1] var7 = var[:, :, :-1]
var8 = var[:1, :1, :1]
var9 = var[:-1, :-1, :-1]
var10 = var[::-1, :1, :-1]
var11 = var[:-1, ::-1, -1:]
var12 = var[1:2, 2:, ::-1]
var13 = var[2:10, 2:, -2:-1]
var14 = var[1:-1, 0:2, ::-1]
var15 = var[::-1, ::-1, ::-1]
x = fluid.layers.data(name='x', shape=[13], dtype='float32') x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y = fluid.layers.fc(input=x, size=1, act=None) y = fluid.layers.fc(input=x, size=1, act=None)
var11 = y[:, 0] y_1 = y[:, 0]
feeder = fluid.DataFeeder(place=place, feed_list=[x]) feeder = fluid.DataFeeder(place=place, feed_list=[x])
data = [] data = []
data.append((np.random.randint(10, size=[13]).astype('float32'))) data.append((np.random.randint(10, size=[13]).astype('float32')))
...@@ -115,28 +126,38 @@ class TestVariable(unittest.TestCase): ...@@ -115,28 +126,38 @@ class TestVariable(unittest.TestCase):
feed=feeder.feed([data]), feed=feeder.feed([data]),
fetch_list=[ fetch_list=[
var, var1, var2, var3, var4, var5, var6, var, var1, var2, var3, var4, var5, var6,
var7, var8, var9, var10, var11 var7, var8, var9, var10, var11, var12,
var13, var14, var15
]) ])
self.assertTrue((np.array(local_out[1]) == np.array(tensor_array[ self.assertTrue(
0, 1, 1])).all()) np.array_equal(local_out[1], tensor_array[0, 1, 1:2]))
self.assertTrue((np.array(local_out[2]) == np.array(tensor_array[ self.assertTrue(np.array_equal(local_out[2], tensor_array[1:]))
1:])).all()) self.assertTrue(np.array_equal(local_out[3], tensor_array[0:1]))
self.assertTrue((np.array(local_out[3]) == np.array(tensor_array[ self.assertTrue(np.array_equal(local_out[4], tensor_array[::-1]))
0:1])).all()) self.assertTrue(
self.assertTrue((np.array(local_out[4]) == np.array( np.array_equal(local_out[5], tensor_array[1, 1:, 1:]))
tensor_array[..., ])).all()) self.assertTrue(
self.assertTrue((np.array(local_out[5]) == np.array(tensor_array[ np.array_equal(local_out[6],
2::-2])).all()) tensor_array.reshape((3, -1, 3))[:, :, -1]))
self.assertTrue((np.array(local_out[6]) == np.array(tensor_array[ self.assertTrue(
1, 1:, 1:])).all()) np.array_equal(local_out[7], tensor_array[:, :, :-1]))
self.assertTrue((np.array(local_out[7]) == np.array(tensor_array[ self.assertTrue(
1, ..., 1:])).all()) np.array_equal(local_out[8], tensor_array[:1, :1, :1]))
self.assertTrue((np.array(local_out[8]) == np.array(tensor_array[ self.assertTrue(
1, ...])).all()) np.array_equal(local_out[9], tensor_array[:-1, :-1, :-1]))
self.assertEqual(local_out[9].shape, (1, 3, 1)) self.assertTrue(
self.assertEqual(local_out[10].shape, (3, 3, 1)) np.array_equal(local_out[10], tensor_array[::-1, :1, :-1]))
self.assertEqual(local_out[11].shape, (1, 1)) self.assertTrue(
np.array_equal(local_out[11], tensor_array[:-1, ::-1, -1:]))
self.assertTrue(
np.array_equal(local_out[12], tensor_array[1:2, 2:, ::-1]))
self.assertTrue(
np.array_equal(local_out[13], tensor_array[2:10, 2:, -2:-1]))
self.assertTrue(
np.array_equal(local_out[14], tensor_array[1:-1, 0:2, ::-1]))
self.assertTrue(
np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1]))
def test_slice(self): def test_slice(self):
place = fluid.CPUPlace() place = fluid.CPUPlace()
...@@ -148,12 +169,10 @@ class TestVariable(unittest.TestCase): ...@@ -148,12 +169,10 @@ class TestVariable(unittest.TestCase):
def _tostring(self): def _tostring(self):
b = default_main_program().current_block() b = default_main_program().current_block()
w = b.create_var(dtype="float64", lod_level=0) w = b.create_var(dtype="float64", lod_level=0)
print(w)
self.assertTrue(isinstance(str(w), str)) self.assertTrue(isinstance(str(w), str))
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
wc = b.create_var(dtype="int", lod_level=0) wc = b.create_var(dtype="int", lod_level=0)
print(wc)
self.assertTrue(isinstance(str(wc), str)) self.assertTrue(isinstance(str(wc), str))
def test_tostring(self): def test_tostring(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册