未验证 提交 5fe3da39 编写于 作者: L liym27 提交者: GitHub

[cherry-pick 2.0] Fix bug: In dynamic mode, if start or end is negetive,...

[cherry-pick 2.0] Fix bug: In dynamic mode, if start or end is negetive, __getitem__  return wrong result(#30003) (#30146)

1. when slice_item is a slice:
 1) the start of __getitem__ should be std::max(start, 0) if slice
 2) the start of __getitem__ should be std::min(end, dim)
2. when slice_item is an integer, it should be in [-dim_len, dim_len)
3. Fix error message to use accurate data
上级 f46ddc0e
...@@ -121,8 +121,11 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -121,8 +121,11 @@ class SliceOp : public framework::OperatorWithKernel {
start = std::max(start, 0); start = std::max(start, 0);
end = std::max(end, 0); end = std::max(end, 0);
end = std::min(end, dim_value); end = std::min(end, dim_value);
PADDLE_ENFORCE_GT(end, start, platform::errors::InvalidArgument( PADDLE_ENFORCE_GT(end, start,
"end should greater than start")); platform::errors::InvalidArgument(
"end should greater than start, but received "
"end = %d, start = %d.",
ends[i], starts[i]));
out_dims[axes[i]] = end - start; out_dims[axes[i]] = end - start;
} }
} }
......
...@@ -122,8 +122,8 @@ class SliceKernel : public framework::OpKernel<T> { ...@@ -122,8 +122,8 @@ class SliceKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_GT(end, start, PADDLE_ENFORCE_GT(end, start,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Attr(ends) should be greater than attr(starts) in " "Attr(ends) should be greater than attr(starts) in "
"slice op. But received ends = %d, starts = %d.", "slice op. But received end = %d, start = %d.",
end, start)); ends[0], starts[0]));
int64_t out_size = end - start; int64_t out_size = end - start;
if (out_is_tensor_array) { if (out_is_tensor_array) {
...@@ -181,8 +181,8 @@ class SliceKernel : public framework::OpKernel<T> { ...@@ -181,8 +181,8 @@ class SliceKernel : public framework::OpKernel<T> {
end, start, end, start,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Attr(ends) should be greater than attr(starts) in " "Attr(ends) should be greater than attr(starts) in "
"slice op. But received ends = %d, starts = %d.", "slice op. But received end = %d, start = %d.",
end, start)); ends[i], starts[i]));
out_dims[axes[i]] = end - start; out_dims[axes[i]] = end - start;
} }
} }
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include <pybind11/functional.h> #include <pybind11/functional.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <algorithm>
#include <memory> #include <memory>
#include <set> #include <set>
#include <string> #include <string>
...@@ -322,6 +323,7 @@ static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length, ...@@ -322,6 +323,7 @@ static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length,
std::string(Py_TYPE(r->start)->tp_name))); std::string(Py_TYPE(r->start)->tp_name)));
} }
if (*start < 0) *start += length; if (*start < 0) *start += length;
*start = std::max(*start, static_cast<Py_ssize_t>(0));
} }
if (r->stop == Py_None) { if (r->stop == Py_None) {
*stop = *step < 0 ? -1 : length; *stop = *step < 0 ? -1 : length;
...@@ -335,6 +337,7 @@ static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length, ...@@ -335,6 +337,7 @@ static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length,
std::string(Py_TYPE(r->stop)->tp_name))); std::string(Py_TYPE(r->stop)->tp_name)));
} }
if (*stop < 0) *stop += length; if (*stop < 0) *stop += length;
*stop = std::min(*stop, length);
} }
if (*stop > length) return -1; if (*stop > length) return -1;
if (*start >= length) return -1; if (*start >= length) return -1;
...@@ -380,7 +383,7 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, ...@@ -380,7 +383,7 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
int start = static_cast<int>(PyLong_AsLong(slice_item)); int start = static_cast<int>(PyLong_AsLong(slice_item));
auto s_t = start; auto s_t = start;
start = start < 0 ? start + dim_len : start; start = start < 0 ? start + dim_len : start;
if (start >= dim_len) { if (start >= dim_len || start < 0) {
std::string str_error_message = std::string str_error_message =
"The starting index " + std::to_string(s_t) + "The starting index " + std::to_string(s_t) +
" of slice is out of bounds in tensor " + std::to_string(dim) + " of slice is out of bounds in tensor " + std::to_string(dim) +
......
...@@ -413,10 +413,11 @@ class TestVarBase(unittest.TestCase): ...@@ -413,10 +413,11 @@ class TestVarBase(unittest.TestCase):
var13 = var[2:10, 2:, -2:-1] var13 = var[2:10, 2:, -2:-1]
var14 = var[1:-1, 0:2, ::-1] var14 = var[1:-1, 0:2, ::-1]
var15 = var[::-1, ::-1, ::-1] var15 = var[::-1, ::-1, ::-1]
var16 = var[-4:4]
vars = [ vars = [
var, var1, var2, var3, var4, var5, var6, var7, var8, var9, var10, var, var1, var2, var3, var4, var5, var6, var7, var8, var9, var10,
var11, var12, var13, var14, var15 var11, var12, var13, var14, var15, var16
] ]
local_out = [var.numpy() for var in vars] local_out = [var.numpy() for var in vars]
...@@ -444,6 +445,7 @@ class TestVarBase(unittest.TestCase): ...@@ -444,6 +445,7 @@ class TestVarBase(unittest.TestCase):
np.array_equal(local_out[14], tensor_array[1:-1, 0:2, ::-1])) np.array_equal(local_out[14], tensor_array[1:-1, 0:2, ::-1]))
self.assertTrue( self.assertTrue(
np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1])) np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1]))
self.assertTrue(np.array_equal(local_out[16], tensor_array[-4:4]))
def _test_for_var(self): def _test_for_var(self):
np_value = np.random.random((30, 100, 100)).astype('float32') np_value = np.random.random((30, 100, 100)).astype('float32')
...@@ -464,6 +466,9 @@ class TestVarBase(unittest.TestCase): ...@@ -464,6 +466,9 @@ class TestVarBase(unittest.TestCase):
with self.assertRaises(IndexError): with self.assertRaises(IndexError):
y = var[self.shape[0]] y = var[self.shape[0]]
with self.assertRaises(IndexError):
y = var[0 - self.shape[0] - 1]
def test_var_base_to_np(self): def test_var_base_to_np(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(self.array) var = fluid.dygraph.to_variable(self.array)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册