未验证 提交 7e940b84 编写于 作者: Z zyfncg 提交者: GitHub

[cherry-pick] Fix bug of strided_slice and slice (#43388, #43443) (#43432)

* fix bug of strided_slice (#43388)

* fix stride_slice bug

* fix bug

* fix bug of infer shape for slice (#43443)
上级 53a7d38b
...@@ -101,7 +101,11 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -101,7 +101,11 @@ class SliceOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The size of ends must be equal to the size of axes.")); "The size of ends must be equal to the size of axes."));
} }
for (auto &axis : axes) {
if (axis < 0) {
axis = std::max(0, axis + in_dims.size());
}
}
phi::funcs::CheckAndUpdateSliceAttrs<int>(in_dims, axes, &starts, &ends, phi::funcs::CheckAndUpdateSliceAttrs<int>(in_dims, axes, &starts, &ends,
nullptr, &infer_flags); nullptr, &infer_flags);
......
...@@ -74,10 +74,14 @@ static void StridedSliceOutDims(const std::vector<int64_t>& starts, ...@@ -74,10 +74,14 @@ static void StridedSliceOutDims(const std::vector<int64_t>& starts,
if (start_index < 0) { if (start_index < 0) {
start_index = start_index + axis_size; start_index = start_index + axis_size;
start_index = std::max<int64_t>(start_index, 0);
} }
if (end_index < 0) { if (end_index < 0) {
if (!(end_index == -1 && stride_index < 0)) { // skip None stop condition if (!(end_index == -1 && stride_index < 0)) { // skip None stop condition
end_index = end_index + axis_size; end_index = end_index + axis_size;
if (end_index < 0) {
end_index = 0;
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册