未验证 提交 e788c7b5 编写于 作者: Z zyfncg 提交者: GitHub

Support zero value in dimension for slice (#37313)

* support zero dim for slice op

* support zero dim Tensor in set_value op

* polish some debug log
上级 de0cb386
...@@ -260,6 +260,9 @@ class SetValueKernel : public framework::OpKernel<T> { ...@@ -260,6 +260,9 @@ class SetValueKernel : public framework::OpKernel<T> {
starts_indices[axis_index] = starts[i]; starts_indices[axis_index] = starts[i];
ends_indices[axis_index] = ends[i]; ends_indices[axis_index] = ends[i];
strides_indices[axis_index] = steps[i]; strides_indices[axis_index] = steps[i];
if (starts[i] == ends[i]) { // slice is empty, data will not be changed
return;
}
} }
out_e.stridedSlice(starts_indices, ends_indices, strides_indices) out_e.stridedSlice(starts_indices, ends_indices, strides_indices)
......
...@@ -30,12 +30,20 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, ...@@ -30,12 +30,20 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
std::vector<T>* infer_flags = nullptr) { std::vector<T>* infer_flags = nullptr) {
for (size_t i = 0; i < axes.size(); ++i) { for (size_t i = 0; i < axes.size(); ++i) {
T axis = axes[i]; T axis = axes[i];
PADDLE_ENFORCE_LT(
axis, in_dims.size(),
platform::errors::InvalidArgument(
"The axis value should be less than the rank of input, "
"but received axes[%d] = %d, rank of input is %d.",
i, axis, in_dims.size()));
if (infer_flags != nullptr && (*infer_flags)[i] == -1) {
continue;
}
T dim_value = in_dims[axis]; T dim_value = in_dims[axis];
if (dim_value > 0) { if (dim_value > 0) {
if (infer_flags != nullptr && (*infer_flags)[i] == -1) {
continue;
}
T step = steps == nullptr ? 1 : (*steps)[i]; T step = steps == nullptr ? 1 : (*steps)[i];
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
step, 0, platform::errors::InvalidArgument( step, 0, platform::errors::InvalidArgument(
...@@ -51,7 +59,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, ...@@ -51,7 +59,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
if (step > 0) { if (step > 0) {
start = std::min(start, dim_value); start = std::min(start, dim_value);
end = std::max(end, static_cast<T>(0)); end = std::max(end, static_cast<T>(0));
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GE(
end, start, end, start,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"When step > 0, end should be greater than start, but " "When step > 0, end should be greater than start, but "
...@@ -63,7 +71,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, ...@@ -63,7 +71,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
// "end is -1" means contain the 0-th element of this axis. // "end is -1" means contain the 0-th element of this axis.
start = std::min(start, dim_value - 1); start = std::min(start, dim_value - 1);
end = std::max(end, static_cast<T>(-1)); end = std::max(end, static_cast<T>(-1));
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GE(
start, end, start, end,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"When step < 0, start should be greater than end, but " "When step < 0, start should be greater than end, but "
...@@ -73,6 +81,9 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, ...@@ -73,6 +81,9 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
(*starts)[i] = start; (*starts)[i] = start;
(*ends)[i] = end; (*ends)[i] = end;
} else if (dim_value == 0) {
(*starts)[i] = 0;
(*ends)[i] = 0;
} }
} }
} }
...@@ -111,20 +122,22 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims, ...@@ -111,20 +122,22 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
const std::vector<T>& decrease_axes, const std::vector<T>& decrease_axes,
std::vector<T>* infer_flags = nullptr) { std::vector<T>* infer_flags = nullptr) {
framework::DDim decreased_dims(slice_dims); framework::DDim decreased_dims(slice_dims);
std::vector<uint8_t> decrease_flag(slice_dims.size(), 0);
if (decrease_axes.size() > 0) { if (decrease_axes.size() > 0) {
for (size_t i = 0; i < decrease_axes.size(); ++i) { for (size_t i = 0; i < decrease_axes.size(); ++i) {
T axis = decrease_axes[i]; T axis = decrease_axes[i];
decrease_flag[axis] = 1;
if (infer_flags && (*infer_flags)[i] != -1) { if (infer_flags && (*infer_flags)[i] != -1) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(decreased_dims[axis], 1,
decreased_dims[axis], 1, platform::errors::InvalidArgument(
platform::errors::InvalidArgument("decrease dim should be 1")); "Decrease dim should be 1, but now received %d",
decreased_dims[axis]));
} }
decreased_dims[axis] = 0;
} }
std::vector<T> new_shape; std::vector<T> new_shape;
for (int i = 0; i < decreased_dims.size(); ++i) { for (int i = 0; i < decreased_dims.size(); ++i) {
if (decreased_dims[i] != 0) { if (decrease_flag[i] == 0) {
new_shape.push_back(decreased_dims[i]); new_shape.push_back(decreased_dims[i]);
} }
} }
......
...@@ -77,10 +77,9 @@ static void StridedSliceOutDims( ...@@ -77,10 +77,9 @@ static void StridedSliceOutDims(
end_index = end_index + 1; end_index = end_index + 1;
} }
bool zero_dim_condition = bool neg_dim_condition = ((stride_index < 0 && (start_index < end_index)) ||
((stride_index < 0 && (start_index <= end_index)) || (stride_index > 0 && (start_index > end_index)));
(stride_index > 0 && (start_index >= end_index))); PADDLE_ENFORCE_EQ(neg_dim_condition, false,
PADDLE_ENFORCE_EQ(zero_dim_condition, false,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The start index and end index are invalid for their " "The start index and end index are invalid for their "
"corresponding stride.")); "corresponding stride."));
......
...@@ -127,6 +127,14 @@ class TestSetValueItemSlice4(TestSetValueApi): ...@@ -127,6 +127,14 @@ class TestSetValueItemSlice4(TestSetValueApi):
self.data[0:, 1:2, :] = self.value self.data[0:, 1:2, :] = self.value
class TestSetValueItemSlice5(TestSetValueApi):
def _call_setitem(self, x):
x[0:, 1:1, :] = self.value
def _get_answer(self):
self.data[0:, 1:1, :] = self.value
class TestSetValueItemSliceInWhile(TestSetValueApi): class TestSetValueItemSliceInWhile(TestSetValueApi):
def _call_setitem(self, x): def _call_setitem(self, x):
def cond(i, x): def cond(i, x):
......
...@@ -568,10 +568,12 @@ class TestVarBase(unittest.TestCase): ...@@ -568,10 +568,12 @@ class TestVarBase(unittest.TestCase):
var14 = var[1:-1, 0:2, ::-1] var14 = var[1:-1, 0:2, ::-1]
var15 = var[::-1, ::-1, ::-1] var15 = var[::-1, ::-1, ::-1]
var16 = var[-4:4] var16 = var[-4:4]
var17 = var[:, 0, 0:0]
var18 = var[:, 1:1:2]
vars = [ vars = [
var, var1, var2, var3, var4, var5, var6, var7, var8, var9, var10, var, var1, var2, var3, var4, var5, var6, var7, var8, var9, var10,
var11, var12, var13, var14, var15, var16 var11, var12, var13, var14, var15, var16, var17, var18
] ]
local_out = [var.numpy() for var in vars] local_out = [var.numpy() for var in vars]
...@@ -600,6 +602,8 @@ class TestVarBase(unittest.TestCase): ...@@ -600,6 +602,8 @@ class TestVarBase(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1])) np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1]))
self.assertTrue(np.array_equal(local_out[16], tensor_array[-4:4])) self.assertTrue(np.array_equal(local_out[16], tensor_array[-4:4]))
self.assertTrue(np.array_equal(local_out[17], tensor_array[:, 0, 0:0]))
self.assertTrue(np.array_equal(local_out[18], tensor_array[:, 1:1:2]))
def _test_slice_for_tensor_attr(self): def _test_slice_for_tensor_attr(self):
tensor_array = np.array( tensor_array = np.array(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册