未验证 提交 9ce8cfcf 编写于 作者: R RedContritio 提交者: GitHub

Fix UFA非法地址访问(UFA illegal address access) of case4: paddle.unbind (#49995)

* add axis check for unbind

* add axis range check for unbind

* update unittest and axis validation for unbind

* add unittest invalid axis for unbind

* restore axis extract for unbind
上级 7f1a1570
......@@ -4253,7 +4253,20 @@ void UnbindInferMeta(const MetaTensor& x,
std::vector<MetaTensor*> outs) {
auto in_dims = x.dims();
std::vector<int> out_dim;
PADDLE_ENFORCE_GE(
axis,
-in_dims.size(),
phi::errors::InvalidArgument(
"axis must be in range(%d, %d).", -in_dims.size(), in_dims.size()));
PADDLE_ENFORCE_LT(
axis,
in_dims.size(),
phi::errors::InvalidArgument(
"axis must be in range(%d, %d).", -in_dims.size(), in_dims.size()));
axis = axis < 0 ? in_dims.size() + axis : axis;
for (int i = 0; i < in_dims.size(); ++i) {
if (i != axis) out_dim.push_back(in_dims[i]);
}
......
......@@ -25,6 +25,7 @@ from paddle.fluid import Program, program_guard
class TestUnbind(unittest.TestCase):
def test_unbind(self):
paddle.enable_static()
x_1 = fluid.data(shape=[2, 3], dtype='float32', name='x_1')
[out_0, out_1] = tensor.unbind(input=x_1, axis=0)
......@@ -59,6 +60,7 @@ class TestUnbind(unittest.TestCase):
class TestLayersUnbind(unittest.TestCase):
def test_layers_unbind(self):
paddle.enable_static()
x_1 = fluid.data(shape=[2, 3], dtype='float32', name='x_1')
[out_0, out_1] = paddle.unbind(input=x_1, axis=0)
......@@ -214,6 +216,11 @@ class TestUnbindAxisError(unittest.TestCase):
self.assertRaises(TypeError, test_table_Variable)
def test_invalid_axis():
tensor.unbind(input=x, axis=2)
self.assertRaises(ValueError, test_invalid_axis)
if __name__ == '__main__':
unittest.main()
......@@ -2755,14 +2755,19 @@ def unbind(input, axis=0):
# x2.shape [3, 5]
# x3.shape [3, 5]
"""
if not isinstance(axis, (int)):
raise TypeError(
"The type of 'axis' must be int, but received %s." % (type(axis))
)
if axis not in range(-input.ndim, input.ndim):
raise ValueError(
f'The axis must in range({-input.ndim}, {input.ndim}).'
)
if in_dygraph_mode():
return _C_ops.unbind(input, axis)
else:
if not isinstance(axis, (int)):
raise TypeError(
"The type of 'axis' must be int, but received %s."
% (type(axis))
)
if isinstance(axis, np.generic):
axis = np.asscalar(axis)
input_shape = input.shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册