From ab786715fb55cefeec20f7acefeeba896ef7c5e1 Mon Sep 17 00:00:00 2001 From: sprouteer <89541335+sprouteer@users.noreply.github.com> Date: Wed, 28 Dec 2022 16:27:38 +0800 Subject: [PATCH] fix unique_kernel support axis=-1 (#49385) --- paddle/phi/kernels/cpu/unique_kernel.cc | 1 + paddle/phi/kernels/gpu/unique_kernel.cu | 1 + 2 files changed, 2 insertions(+) diff --git a/paddle/phi/kernels/cpu/unique_kernel.cc b/paddle/phi/kernels/cpu/unique_kernel.cc index 834f05f73e2..15c19b24444 100644 --- a/paddle/phi/kernels/cpu/unique_kernel.cc +++ b/paddle/phi/kernels/cpu/unique_kernel.cc @@ -96,6 +96,7 @@ void UniqueRawKernel(const Context& context, return_counts)); } else { int axis_value = axis[0]; + axis_value = (axis_value == -1) ? (x.dims().size() - 1) : axis_value; phi::VisitDataTypeTiny( dtype, phi::funcs::UniqueDimFunctor(context, diff --git a/paddle/phi/kernels/gpu/unique_kernel.cu b/paddle/phi/kernels/gpu/unique_kernel.cu index 9938d949c59..316fe1fae71 100644 --- a/paddle/phi/kernels/gpu/unique_kernel.cu +++ b/paddle/phi/kernels/gpu/unique_kernel.cu @@ -563,6 +563,7 @@ void UniqueRawKernel(const Context& context, } else { // 'axis' is required. int axis_value = axis[0]; + axis_value = (axis_value == -1) ? (x.dims().size() - 1) : axis_value; phi::VisitDataTypeTiny(dtype, UniqueDimsCUDAFunctor(context, x, -- GitLab