未验证 提交 06c86d4c 编写于 作者: 1 123malin 提交者: GitHub

test=develop, bug fix for index_select and roll op (#25251) (#25360)

上级 bddfa218
...@@ -52,6 +52,23 @@ void IndexSelectInner(const framework::ExecutionContext& context, ...@@ -52,6 +52,23 @@ void IndexSelectInner(const framework::ExecutionContext& context,
TensorToVector(index, context.device_context(), &index_vec); TensorToVector(index, context.device_context(), &index_vec);
std::vector<T> out_vec(output->numel()); std::vector<T> out_vec(output->numel());
for (int i = 0; i < index_size; i++) {
PADDLE_ENFORCE_GE(
index_vec[i], 0,
platform::errors::InvalidArgument(
"Variable value (index) of OP(index_select) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
input_dim[dim], index_vec[i]));
PADDLE_ENFORCE_LT(
index_vec[i], input_dim[dim],
platform::errors::InvalidArgument(
"Variable value (index) of OP(index_select) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
input_dim[dim], index_vec[i]));
}
VLOG(3) << "Index_Select_Debug; outer_nums: " << outer_nums VLOG(3) << "Index_Select_Debug; outer_nums: " << outer_nums
<< "; slice_size: " << slice_size << "; input_width: " << input_width << "; slice_size: " << slice_size << "; input_width: " << input_width
<< "; output_width: " << output_width << "; output_width: " << output_width
......
...@@ -49,7 +49,7 @@ class TestRollOp(OpTest): ...@@ -49,7 +49,7 @@ class TestRollOp(OpTest):
class TestRollOpCase2(TestRollOp): class TestRollOpCase2(TestRollOp):
def init_dtype_type(self): def init_dtype_type(self):
self.dtype = np.float32 self.dtype = np.float32
self.x_shape = (100, 100, 5) self.x_shape = (100, 10, 5)
self.shifts = [8, -1] self.shifts = [8, -1]
self.dims = [-1, -2] self.dims = [-1, -2]
...@@ -59,7 +59,7 @@ class TestRollAPI(unittest.TestCase): ...@@ -59,7 +59,7 @@ class TestRollAPI(unittest.TestCase):
self.data_x = np.array( self.data_x = np.array(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
def test_roll_api(self): def test_roll_op_api(self):
self.input_data() self.input_data()
# case 1: # case 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册