未验证 提交 ab786715 编写于 作者: S sprouteer 提交者: GitHub

fix unique_kernel support axis=-1 (#49385)

上级 da357615
...@@ -96,6 +96,7 @@ void UniqueRawKernel(const Context& context, ...@@ -96,6 +96,7 @@ void UniqueRawKernel(const Context& context,
return_counts)); return_counts));
} else { } else {
int axis_value = axis[0]; int axis_value = axis[0];
axis_value = (axis_value == -1) ? (x.dims().size() - 1) : axis_value;
phi::VisitDataTypeTiny( phi::VisitDataTypeTiny(
dtype, dtype,
phi::funcs::UniqueDimFunctor<Context, T>(context, phi::funcs::UniqueDimFunctor<Context, T>(context,
......
...@@ -563,6 +563,7 @@ void UniqueRawKernel(const Context& context, ...@@ -563,6 +563,7 @@ void UniqueRawKernel(const Context& context,
} else { } else {
// 'axis' is required. // 'axis' is required.
int axis_value = axis[0]; int axis_value = axis[0];
axis_value = (axis_value == -1) ? (x.dims().size() - 1) : axis_value;
phi::VisitDataTypeTiny(dtype, phi::VisitDataTypeTiny(dtype,
UniqueDimsCUDAFunctor<Context, T>(context, UniqueDimsCUDAFunctor<Context, T>(context,
x, x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册