未验证 提交 0687bcd6 编写于 作者: W wangchaochaohu 提交者: GitHub

Refine getitem of Variable (#20729)

* add support for __get_item__ of Variable test=develop
上级 72d1d72c
......@@ -42,6 +42,7 @@ class StridedSliceOp : public framework::OperatorWithKernel {
auto strides = ctx->Attrs().Get<std::vector<int>>("strides");
auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
auto infer_flags = ctx->Attrs().Get<std::vector<int>>("infer_flags");
auto decrease_axis = ctx->Attrs().Get<std::vector<int>>("decrease_axis");
auto starts_size = starts.size();
auto ends_size = ends.size();
......@@ -90,10 +91,32 @@ class StridedSliceOp : public framework::OperatorWithKernel {
std::vector<int> out_dims_vector(in_dims.size(), -1);
if (!tensor_input) {
StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims,
out_dims_vector.data(), axes.size(), true);
decrease_axis, out_dims_vector.data(), axes.size(),
true);
}
framework::DDim out_dims(framework::make_ddim(out_dims_vector));
// generate new shape
if (decrease_axis.size() > 0) {
std::vector<int> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
if (ctx->IsRuntime() && infer_flags[i] != -1) {
PADDLE_ENFORCE_EQ(out_dims[decrease_axis[i]], 1,
"decrease dim should be 1");
}
out_dims[decrease_axis[i]] = 0;
}
for (int i = 0; i < out_dims.size(); ++i) {
if (out_dims[i] != 0) {
new_out_shape.push_back(out_dims[i]);
}
}
if (new_out_shape.size() == 0) {
new_out_shape.push_back(1);
}
out_dims = framework::make_ddim(new_out_shape);
}
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("Input", /*->*/ "Out");
}
......@@ -177,6 +200,8 @@ class StridedSliceOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int>>(
"infer_flags", "(list<int>) Flags of inferring dims in attributes.")
.SetDefault({});
AddAttr<std::vector<int>>("decrease_axis", "(list<int>) decrease_axis")
.SetDefault({});
AddComment(R"DOC(
Strided Slice Operator.
Instead of calling this op directly most users will want to use the
......@@ -212,10 +237,12 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "StartsTensor" || var_name == "EndsTensor") {
if (var_name == "StartsTensor" || var_name == "EndsTensor" ||
var_name == "StridesTensor") {
return expected_kernel_type;
}
if (var_name == "StartsTensorList" || var_name == "EndsTensorList") {
if (var_name == "StartsTensorList" || var_name == "EndsTensorList" ||
var_name == "StridesTensorList") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
......
......@@ -27,22 +27,34 @@ static void StridedSliceOutDims(
const std::vector<int>& starts, const std::vector<int>& ends,
const std::vector<int>& strides, const std::vector<int>& axes,
const std::vector<int>& infer_flags, const framework::DDim in_dims,
int* out_dims_vector, const size_t size, bool infer_shape) {
const std::vector<int>& decrease_axis, int* out_dims_vector,
const size_t size, bool infer_shape) {
for (int i = 0; i < in_dims.size(); i++) {
out_dims_vector[i] = in_dims[i];
}
int stride_index, start_index, end_index;
for (size_t i = 0; i < size; i++) {
int axes_index = axes[i];
start_index = starts[i];
end_index = ends[i];
stride_index = strides[i];
bool decrease_axis_affect = false;
if (start_index == -1 && end_index == 0 && infer_flags[i] == -1) {
auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]);
if (ret != decrease_axis.end()) {
decrease_axis_affect = true;
}
}
if (decrease_axis_affect) {
out_dims_vector[axes_index] = 1;
continue;
}
if (infer_shape && infer_flags[i] == -1) {
out_dims_vector[axes_index] = -1;
continue;
}
PADDLE_ENFORCE_NE(strides[i], 0, "stride must not to be zero");
start_index = starts[i];
end_index = ends[i];
stride_index = strides[i];
PADDLE_ENFORCE_NE(stride_index, 0, "stride must not to be zero");
int axis_size = in_dims[axes_index];
if (axis_size < 0) {
continue;
......@@ -77,6 +89,8 @@ static void StridedSliceOutDims(
static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes,
int* reverse_axis, const framework::DDim dims,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
const size_t size) {
for (size_t axis = 0; axis < size; axis++) {
int axis_size = dims[axes[axis]];
......@@ -86,6 +100,15 @@ static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes,
ends[axis_index] = 1;
strides[axis_index] = 1;
}
bool decrease_axis_affect = false;
if (starts[axis_index] == -1 && ends[axis_index] == 0 &&
infer_flags[axis_index] == -1) {
auto ret = std::find(decrease_axis.begin(), decrease_axis.end(),
axes[axis_index]);
if (ret != decrease_axis.end()) {
decrease_axis_affect = true;
}
}
// stride must not be zero
if (starts[axis_index] < 0) {
starts[axis_index] = starts[axis_index] + axis_size;
......@@ -94,6 +117,13 @@ static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes,
if (ends[axis_index] < 0) {
ends[axis_index] = ends[axis_index] + axis_size;
}
if (decrease_axis_affect) {
if (strides[axis_index] < 0) {
ends[axis_index] = starts[axis_index] - 1;
} else {
ends[axis_index] = starts[axis_index] + 1;
}
}
if (strides[axis_index] < 0) {
reverse_axis[axis_index] = 1;
strides[axis_index] = -strides[axis_index];
......@@ -151,6 +181,7 @@ class StridedSliceKernel : public framework::OpKernel<T> {
auto strides = context.Attr<std::vector<int>>("strides");
auto axes = context.Attr<std::vector<int>>("axes");
auto infer_flags = context.Attr<std::vector<int>>("infer_flags");
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis");
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
......@@ -187,12 +218,14 @@ class StridedSliceKernel : public framework::OpKernel<T> {
std::vector<int> out_dims_vector(in_dims.size(), -1);
StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims,
out_dims_vector.data(), axes.size(), false);
decrease_axis, out_dims_vector.data(), axes.size(),
false);
framework::DDim out_dims(framework::make_ddim(out_dims_vector));
std::vector<int> reverse_vector(starts.size(), 0);
StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(),
reverse_vector.data(), in_dims, starts.size());
reverse_vector.data(), in_dims, infer_flags,
decrease_axis, starts.size());
for (size_t axis = 0; axis < D; axis++) {
starts_indices[axis] = 0;
......@@ -209,8 +242,28 @@ class StridedSliceKernel : public framework::OpKernel<T> {
}
framework::Tensor tmp;
tmp.mutable_data<T>(out_dims, context.GetPlace());
auto out_dims_origin = out_dims;
if (decrease_axis.size() > 0) {
std::vector<int> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
PADDLE_ENFORCE_EQ(out_dims[decrease_axis[i]], 1,
"decrease dim should be 1");
out_dims_origin[decrease_axis[i]] = 0;
}
for (int i = 0; i < out_dims_origin.size(); ++i) {
if (out_dims_origin[i] != 0) {
new_out_shape.push_back(out_dims_origin[i]);
}
}
if (new_out_shape.size() == 0) {
new_out_shape.push_back(1);
}
out_dims_origin = framework::make_ddim(new_out_shape);
}
tmp.mutable_data<T>(out_dims, context.GetPlace());
out->Resize(out_dims);
out->mutable_data<T>(context.GetPlace());
auto in_t =
......@@ -225,6 +278,10 @@ class StridedSliceKernel : public framework::OpKernel<T> {
tmp_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
out_t.device(place) = tmp_t.reverse(reverse_axis);
if (decrease_axis.size() > 0) {
out->Resize(out_dims_origin);
}
}
};
......@@ -276,6 +333,8 @@ class StridedSliceGradKernel : public framework::OpKernel<T> {
auto ends = context.Attr<std::vector<int>>("ends");
auto strides = context.Attr<std::vector<int>>("strides");
auto axes = context.Attr<std::vector<int>>("axes");
auto infer_flags = context.Attr<std::vector<int>>("infer_flags");
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis");
auto list_new_ends_tensor =
context.MultiInput<framework::Tensor>("EndsTensorList");
......@@ -313,7 +372,8 @@ class StridedSliceGradKernel : public framework::OpKernel<T> {
std::vector<int> reverse_vector(starts.size(), 0);
StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(),
reverse_vector.data(), out_dims, starts.size());
reverse_vector.data(), out_dims, infer_flags,
decrease_axis, starts.size());
for (size_t axis = 0; axis < D; axis++) {
starts_indices[axis] = 0;
......
......@@ -1414,9 +1414,11 @@ class Variable(object):
slice_axis = []
slice_start = []
slice_end = []
slice_step = []
use_strided_slice = False
reverse_axis = []
def fill_constant(shape, dtype, value, force_cpu=False, out=None):
def fill_constant(shape, value, force_cpu=False, out=None):
self.block.append_op(
type='fill_constant',
inputs={},
......@@ -1425,7 +1427,7 @@ class Variable(object):
'shape': shape,
'dtype': out.dtype,
'value': float(value),
'force_cpu': force_cpu or force_init_on_cpu()
'force_cpu': force_cpu
},
stop_gradient=True)
out.stop_gradient = True
......@@ -1435,15 +1437,17 @@ class Variable(object):
if isinstance(slice_item, slice):
start = slice_item.start
end = slice_item.stop
step = slice_item.step if slice_item.step else 1
step = slice_item.step
assert (step == 1 or step == -1)
if start is None and end is None and step is None:
continue
if step == -1:
reverse_axis.append(dim)
assert (start is None and end is None)
if step is None:
step = 1
if start is None and end is None:
assert (step == -1)
reverse_axis.append(dim)
continue
if start is None:
......@@ -1452,16 +1456,21 @@ class Variable(object):
if end is None:
end = 10000000
if step != 1:
use_strided_slice = True
slice_axis.append(dim)
slice_start.append(start)
slice_end.append(end)
slice_step.append(step)
else:
decrease_axis.append(dim)
slice_axis.append(dim)
slice_start.append(slice_item)
slice_step.append(1)
if isinstance(slice_item, Variable):
temp_1 = self.block.create_var(dtype='int32')
fill_constant([1], 'int32', 1, force_cpu=True, out=temp_1)
fill_constant([1], 1, force_cpu=True, out=temp_1)
temp_end = self.block.create_var(dtype='int32')
self.block.append_op(
type='elementwise_add',
......@@ -1489,8 +1498,7 @@ class Variable(object):
else:
assert (isinstance(dim, int))
temp_out = self.block.create_var(dtype='int32')
fill_constant(
[1], 'int32', dim, force_cpu=True, out=temp_out)
fill_constant([1], dim, force_cpu=True, out=temp_out)
new_list_tensor.append(temp_out)
return new_list_tensor
......@@ -1501,8 +1509,9 @@ class Variable(object):
'ends': [],
'decrease_axis': decrease_axis
}
if (use_strided_slice == True):
attrs['strides'] = []
infer_flags = list(1 for i in range(len(slice_axis)))
# starts
if not contain_var(slice_start):
attrs['starts'] = slice_start
......@@ -1525,11 +1534,23 @@ class Variable(object):
infer_flags[i] = -1
else:
attrs['ends'].append(dim)
# strides
if use_strided_slice == True:
if not contain_var(slice_step):
attrs['strides'] = slice_step
else:
inputs['StridesTensorList'] = get_new_list_tensor(slice_step)
for i, dim in enumerate(slice_step):
if isinstance(dim, Variable):
attrs['strides'].append(-1)
infer_flags[i] = -1
else:
attrs['strides'].append(dim)
# infer_flags
attrs['infer_flags'] = infer_flags
out = self
if len(slice_axis) > 0:
if use_strided_slice == False and len(slice_axis) > 0:
# append slice_op here
slice_out_var = self.block.create_var(
name=unique_name.generate_with_ignorable_key(self.name +
......@@ -1543,6 +1564,18 @@ class Variable(object):
attrs=attrs)
out = slice_out_var
elif use_strided_slice == True and len(slice_axis) > 0:
strided_slice_out_var = self.block.create_var(
name=unique_name.generate_with_ignorable_key(self.name +
"_strided_slice"),
dtype=self.dtype)
self.block.append_op(
type="strided_slice",
inputs=inputs,
outputs={'Out': [strided_slice_out_var]},
attrs=attrs)
out = strided_slice_out_var
if len(reverse_axis) > 0:
reverse_out_var = self.block.create_var(
......
......@@ -438,7 +438,7 @@ class TestStridedSliceOp_strides_Tensor(OpTest):
# Test python API
class TestSliceAPI(OpTest):
class TestStridedSliceAPI(OpTest):
def test_1(self):
input = np.random.random([3, 4, 5, 6]).astype("float32")
minus_1 = fluid.layers.fill_constant([1], "int32", -1)
......@@ -455,7 +455,6 @@ class TestSliceAPI(OpTest):
shape=[3, 4, 5, 6],
append_batch_size=False,
dtype="float32")
out_1 = fluid.layers.strided_slice(
x,
axes=[0, 1, 2],
......@@ -477,9 +476,9 @@ class TestSliceAPI(OpTest):
out_4 = fluid.layers.strided_slice(
x, axes=[0, 1, 2], starts=starts, ends=ends, strides=strides)
out_5 = x[-3:3, 0:100, 2:-1]
out_6 = x[minus_3:3, 0:100, :, 2:-1]
out_7 = x[minus_1, 0:100, :, 2:minus_1]
out_5 = x[-3:3, 0:100:2, -1:2:-1]
out_6 = x[minus_3:3:1, 0:100:2, :, minus_1:2:minus_1]
out_7 = x[minus_1, 0:100:2, :, -1:2:-1]
exe = fluid.Executor(place=fluid.CPUPlace())
res_1, res_2, res_3, res_4, res_5, res_6, res_7 = exe.run(
......@@ -491,14 +490,13 @@ class TestSliceAPI(OpTest):
'strides': np.array([1, 1, 1]).astype("int32")
},
fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7])
assert np.array_equal(res_1, input[-3:3, 0:100, 2:-1, :])
assert np.array_equal(res_2, input[-3:3, 0:100, :, 2:-1])
assert np.array_equal(res_3, input[-3:3, 0:100, :, 2:-1])
assert np.array_equal(res_4, input[-3:3, 0:100, 2:-1, :])
assert np.array_equal(res_5, input[-3:3, 0:100, 2:-1, :])
assert np.array_equal(res_6, input[-3:3, 0:100, :, 2:-1])
assert np.array_equal(res_7, input[-1, 0:100, :, 2:-1])
assert np.array_equal(res_5, input[-3:3, 0:100:2, -1:2:-1, :])
assert np.array_equal(res_6, input[-3:3, 0:100:2, :, -1:2:-1])
assert np.array_equal(res_7, input[-1, 0:100:2, :, -1:2:-1])
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册