From 06c86d4c23db42cb60b8decfa7752ee471fdddf1 Mon Sep 17 00:00:00 2001 From: 123malin Date: Fri, 3 Jul 2020 17:38:19 +0800 Subject: [PATCH] test=develop, bug fix for index_select and roll op (#25251) (#25360) --- paddle/fluid/operators/index_select_op.h | 17 +++++++++++++++++ .../fluid/tests/unittests/test_roll_op.py | 4 ++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/index_select_op.h b/paddle/fluid/operators/index_select_op.h index 96ec18d9a0d..70714b7f3a0 100644 --- a/paddle/fluid/operators/index_select_op.h +++ b/paddle/fluid/operators/index_select_op.h @@ -52,6 +52,23 @@ void IndexSelectInner(const framework::ExecutionContext& context, TensorToVector(index, context.device_context(), &index_vec); std::vector out_vec(output->numel()); + for (int i = 0; i < index_size; i++) { + PADDLE_ENFORCE_GE( + index_vec[i], 0, + platform::errors::InvalidArgument( + "Variable value (index) of OP(index_select) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + input_dim[dim], index_vec[i])); + PADDLE_ENFORCE_LT( + index_vec[i], input_dim[dim], + platform::errors::InvalidArgument( + "Variable value (index) of OP(index_select) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + input_dim[dim], index_vec[i])); + } + VLOG(3) << "Index_Select_Debug; outer_nums: " << outer_nums << "; slice_size: " << slice_size << "; input_width: " << input_width << "; output_width: " << output_width diff --git a/python/paddle/fluid/tests/unittests/test_roll_op.py b/python/paddle/fluid/tests/unittests/test_roll_op.py index d03b4e273cd..a6afc64aa18 100644 --- a/python/paddle/fluid/tests/unittests/test_roll_op.py +++ b/python/paddle/fluid/tests/unittests/test_roll_op.py @@ -49,7 +49,7 @@ class TestRollOp(OpTest): class TestRollOpCase2(TestRollOp): def init_dtype_type(self): self.dtype = np.float32 - self.x_shape = (100, 100, 5) + self.x_shape = (100, 10, 5) self.shifts = [8, -1] self.dims = [-1, -2] @@ -59,7 +59,7 @@ class TestRollAPI(unittest.TestCase): self.data_x = np.array( [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) - def test_roll_api(self): + def test_roll_op_api(self): self.input_data() # case 1: -- GitLab