未验证 提交 b47fb764 编写于 作者: T TeslaZhao 提交者: GitHub

Keep strided_slice op behavior consistent with slice op when starts input is...

Keep strided_slice op behavior consistent with slice op when starts input is less than -rank (#39066)
上级 fdab43b5
......@@ -121,6 +121,7 @@ static void StridedSliceFunctor(int64_t* starts, int64_t* ends,
// stride must not be zero
if (starts[axis_index] < 0) {
starts[axis_index] = starts[axis_index] + axis_size;
starts[axis_index] = std::max<int64_t>(starts[axis_index], 0);
}
if (ends[axis_index] < 0) {
if (!(ends[axis_index] == -1 &&
......@@ -139,11 +140,6 @@ static void StridedSliceFunctor(int64_t* starts, int64_t* ends,
}
}
if ((starts[axis_index] < 0) && (axis_size > 0)) {
starts[axis_index] += axis_size;
starts[axis_index] = std::max<int64_t>(starts[axis_index], 0);
}
if (strides[axis_index] < 0) {
reverse_axis[axis_index] = 1;
strides[axis_index] = -strides[axis_index];
......
......@@ -216,6 +216,16 @@ class TestStrideSliceOp13(TestStrideSliceOp):
self.infer_flags = [1, 1, 1, 1, 1]
class TestStrideSliceOp14(TestStrideSliceOp):
def initTestCase(self):
self.input = np.random.rand(4, 4, 4, 4)
self.axes = [1, 2, 3]
self.starts = [-5, 0, -7]
self.ends = [-1, 2, 4]
self.strides = [1, 1, 1]
self.infer_flags = [1, 1, 1]
class TestStrideSliceOpBool(TestStrideSliceOp):
def test_check_grad(self):
pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册