未验证 提交 e788c7b5 编写于 作者: Z zyfncg 提交者: GitHub

Support zero value in dimension for slice (#37313)

* support zero dim for slice op

* support zero dim Tensor in set_value op

* polish some debug log
上级 de0cb386
......@@ -260,6 +260,9 @@ class SetValueKernel : public framework::OpKernel<T> {
starts_indices[axis_index] = starts[i];
ends_indices[axis_index] = ends[i];
strides_indices[axis_index] = steps[i];
if (starts[i] == ends[i]) { // slice is empty, data will not be changed
return;
}
}
out_e.stridedSlice(starts_indices, ends_indices, strides_indices)
......
......@@ -30,12 +30,20 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
std::vector<T>* infer_flags = nullptr) {
for (size_t i = 0; i < axes.size(); ++i) {
T axis = axes[i];
T dim_value = in_dims[axis];
PADDLE_ENFORCE_LT(
axis, in_dims.size(),
platform::errors::InvalidArgument(
"The axis value should be less than the rank of input, "
"but received axes[%d] = %d, rank of input is %d.",
i, axis, in_dims.size()));
if (dim_value > 0) {
if (infer_flags != nullptr && (*infer_flags)[i] == -1) {
continue;
}
T dim_value = in_dims[axis];
if (dim_value > 0) {
T step = steps == nullptr ? 1 : (*steps)[i];
PADDLE_ENFORCE_NE(
step, 0, platform::errors::InvalidArgument(
......@@ -51,7 +59,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
if (step > 0) {
start = std::min(start, dim_value);
end = std::max(end, static_cast<T>(0));
PADDLE_ENFORCE_GT(
PADDLE_ENFORCE_GE(
end, start,
platform::errors::InvalidArgument(
"When step > 0, end should be greater than start, but "
......@@ -63,7 +71,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
// "end is -1" means contain the 0-th element of this axis.
start = std::min(start, dim_value - 1);
end = std::max(end, static_cast<T>(-1));
PADDLE_ENFORCE_GT(
PADDLE_ENFORCE_GE(
start, end,
platform::errors::InvalidArgument(
"When step < 0, start should be greater than end, but "
......@@ -73,6 +81,9 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
(*starts)[i] = start;
(*ends)[i] = end;
} else if (dim_value == 0) {
(*starts)[i] = 0;
(*ends)[i] = 0;
}
}
}
......@@ -111,20 +122,22 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
const std::vector<T>& decrease_axes,
std::vector<T>* infer_flags = nullptr) {
framework::DDim decreased_dims(slice_dims);
std::vector<uint8_t> decrease_flag(slice_dims.size(), 0);
if (decrease_axes.size() > 0) {
for (size_t i = 0; i < decrease_axes.size(); ++i) {
T axis = decrease_axes[i];
decrease_flag[axis] = 1;
if (infer_flags && (*infer_flags)[i] != -1) {
PADDLE_ENFORCE_EQ(
decreased_dims[axis], 1,
platform::errors::InvalidArgument("decrease dim should be 1"));
PADDLE_ENFORCE_EQ(decreased_dims[axis], 1,
platform::errors::InvalidArgument(
"Decrease dim should be 1, but now received %d",
decreased_dims[axis]));
}
decreased_dims[axis] = 0;
}
std::vector<T> new_shape;
for (int i = 0; i < decreased_dims.size(); ++i) {
if (decreased_dims[i] != 0) {
if (decrease_flag[i] == 0) {
new_shape.push_back(decreased_dims[i]);
}
}
......
......@@ -77,10 +77,9 @@ static void StridedSliceOutDims(
end_index = end_index + 1;
}
bool zero_dim_condition =
((stride_index < 0 && (start_index <= end_index)) ||
(stride_index > 0 && (start_index >= end_index)));
PADDLE_ENFORCE_EQ(zero_dim_condition, false,
bool neg_dim_condition = ((stride_index < 0 && (start_index < end_index)) ||
(stride_index > 0 && (start_index > end_index)));
PADDLE_ENFORCE_EQ(neg_dim_condition, false,
platform::errors::InvalidArgument(
"The start index and end index are invalid for their "
"corresponding stride."));
......
......@@ -127,6 +127,14 @@ class TestSetValueItemSlice4(TestSetValueApi):
self.data[0:, 1:2, :] = self.value
class TestSetValueItemSlice5(TestSetValueApi):
def _call_setitem(self, x):
x[0:, 1:1, :] = self.value
def _get_answer(self):
self.data[0:, 1:1, :] = self.value
class TestSetValueItemSliceInWhile(TestSetValueApi):
def _call_setitem(self, x):
def cond(i, x):
......
......@@ -568,10 +568,12 @@ class TestVarBase(unittest.TestCase):
var14 = var[1:-1, 0:2, ::-1]
var15 = var[::-1, ::-1, ::-1]
var16 = var[-4:4]
var17 = var[:, 0, 0:0]
var18 = var[:, 1:1:2]
vars = [
var, var1, var2, var3, var4, var5, var6, var7, var8, var9, var10,
var11, var12, var13, var14, var15, var16
var11, var12, var13, var14, var15, var16, var17, var18
]
local_out = [var.numpy() for var in vars]
......@@ -600,6 +602,8 @@ class TestVarBase(unittest.TestCase):
self.assertTrue(
np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1]))
self.assertTrue(np.array_equal(local_out[16], tensor_array[-4:4]))
self.assertTrue(np.array_equal(local_out[17], tensor_array[:, 0, 0:0]))
self.assertTrue(np.array_equal(local_out[18], tensor_array[:, 1:1:2]))
def _test_slice_for_tensor_attr(self):
tensor_array = np.array(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册