未验证 提交 dbfdefa7 编写于 作者: R RedContritio 提交者: GitHub

Fix 堆栈溢出 (stack overflow) of case10: paddle.unique (#49981)

* add axis check in UniqueRawInferMeta

* add unittest for negative axis

* simplify check for unique
上级 82edc65b
...@@ -4648,6 +4648,7 @@ void UniqueRawInferMeta(const MetaTensor& x, ...@@ -4648,6 +4648,7 @@ void UniqueRawInferMeta(const MetaTensor& x,
if (axis_value < 0) { if (axis_value < 0) {
axis_value += x.dims().size(); axis_value += x.dims().size();
} }
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
axis_value, axis_value,
x.dims().size(), x.dims().size(),
...@@ -4655,6 +4656,14 @@ void UniqueRawInferMeta(const MetaTensor& x, ...@@ -4655,6 +4656,14 @@ void UniqueRawInferMeta(const MetaTensor& x,
"the dimension size(%d) of x.", "the dimension size(%d) of x.",
axis_value, axis_value,
x.dims().size())); x.dims().size()));
PADDLE_ENFORCE_GE(
axis_value,
0,
phi::errors::InvalidArgument(
"The axis(%d) + rank(x) (%d) should be greater than or equal to 0.",
axis_value,
-x.dims().size()));
auto out_dims = x.dims(); auto out_dims = x.dims();
out_dims[axis_value] = -1; out_dims[axis_value] = -1;
out->set_dims(out_dims); out->set_dims(out_dims);
......
...@@ -190,6 +190,32 @@ class TestUniqueOpAxisNone(TestUniqueOp): ...@@ -190,6 +190,32 @@ class TestUniqueOpAxisNone(TestUniqueOp):
} }
class TestUniqueOpAxisNeg(TestUniqueOp):
def init_config(self):
self.inputs = {'X': np.random.random((6, 1, 8)).astype('float64')}
unique, indices, inverse, counts = np.unique(
self.inputs['X'],
return_index=True,
return_inverse=True,
return_counts=True,
axis=-1,
)
self.attrs = {
'dtype': int(core.VarDesc.VarType.INT32),
"return_index": True,
"return_inverse": True,
"return_counts": True,
"axis": [-1],
"is_sorted": True,
}
self.outputs = {
'Out': unique,
'Indices': indices,
"Index": inverse,
"Counts": counts,
}
class TestUniqueOpAxis1(TestUniqueOp): class TestUniqueOpAxis1(TestUniqueOp):
def init_config(self): def init_config(self):
self.inputs = {'X': np.random.random((3, 8, 8)).astype('float64')} self.inputs = {'X': np.random.random((3, 8, 8)).astype('float64')}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册