未验证 提交 dc1b6511 编写于 作者: R RedContritio 提交者: GitHub

support empty input for unique_consecutive (#49978)

上级 111075a3
...@@ -51,9 +51,11 @@ static void UniqueConsecutiveFlattenedTensor(const Context& context, ...@@ -51,9 +51,11 @@ static void UniqueConsecutiveFlattenedTensor(const Context& context,
} }
} }
int64_t output_size = p - out_vec.data() + 1; bool is_empty = in.numel() == 0;
int64_t output_size = is_empty ? 0 : (p - out_vec.data() + 1);
if (return_counts) { if (return_counts) {
*q = in.numel() - last; if (!is_empty) *q = in.numel() - last;
counts_vec.resize(output_size); counts_vec.resize(output_size);
} }
out_vec.resize(output_size); out_vec.resize(output_size);
......
...@@ -32,12 +32,14 @@ def reference_unique_consecutive(X, return_inverse=False, return_counts=False): ...@@ -32,12 +32,14 @@ def reference_unique_consecutive(X, return_inverse=False, return_counts=False):
return_counts(bool, optional): If True, also return the counts for each unique consecutive element. return_counts(bool, optional): If True, also return the counts for each unique consecutive element.
""" """
X = list(X) X = list(X)
is_empty = len(X) == 0
counts_vec = [1] * len(X) counts_vec = [1] * len(X)
i = 0 i = 0
counts = 1 counts = 1
last = 0 last = 0
inverse_vec = [0] * len(X) inverse_vec = [0] * len(X)
inverse_vec[last] = i if not is_empty:
inverse_vec[last] = i
cnt = 0 cnt = 0
while i < len(X) - 1: while i < len(X) - 1:
if X[i] == X[i + 1]: if X[i] == X[i + 1]:
...@@ -271,6 +273,40 @@ class TestUniqueConsecutiveCase2API(unittest.TestCase): ...@@ -271,6 +273,40 @@ class TestUniqueConsecutiveCase2API(unittest.TestCase):
) )
class TestUniqueConsecutiveEmptyInput(OpTest):
"""empty input"""
def config(self):
self.return_inverse = True
self.return_counts = True
self.python_api = paddle.unique_consecutive
def init_kernel_type(self):
self.dtype = "float32" if core.is_compiled_with_rocm() else "float64"
def setUp(self):
self.init_kernel_type()
self.config()
self.op_type = "unique_consecutive"
x = np.array([]).astype(self.dtype)
result = reference_unique_consecutive(
x, self.return_inverse, self.return_counts
)
out = reference_unique_consecutive(x)
out = np.array(out).astype(self.dtype)
self.inputs = {
'X': x,
}
self.python_out_sig = ["Out"]
self.attrs = {'dtype': int(core.VarDesc.VarType.INT32)}
self.outputs = {
'Out': out,
}
def test_check_output(self):
self.check_output(check_eager=True)
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册