diff --git a/paddle/phi/kernels/gpu/histogram_kernel.cu b/paddle/phi/kernels/gpu/histogram_kernel.cu index 111b13f11dd0ef462d0ba8cca73f14a4f5cbe46b..aa10aea35f867a1bfb8f0b2592ed43182ba380ff 100644 --- a/paddle/phi/kernels/gpu/histogram_kernel.cu +++ b/paddle/phi/kernels/gpu/histogram_kernel.cu @@ -18,8 +18,7 @@ #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" -#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/math_cuda_utils.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -46,8 +45,8 @@ template __global__ void KernelHistogram(const T* input, const int total_elements, const int64_t nbins, - const T min_value, - const T max_value, + const T* min_value, + const T* max_value, int64_t* output) { extern __shared__ int64_t buf_hist[]; for (int i = threadIdx.x; i < nbins; i += blockDim.x) { @@ -58,9 +57,9 @@ __global__ void KernelHistogram(const T* input, CUDA_KERNEL_LOOP(input_index, total_elements) { // const IndexType input_index = threadIdx.x + blockIdx.x * blockDim.x; const auto input_value = input[input_index]; - if (input_value >= min_value && input_value <= max_value) { + if (input_value >= *min_value && input_value <= *max_value) { const IndexType output_index = - GetBin(input_value, min_value, max_value, nbins); + GetBin(input_value, *min_value, *max_value, nbins); phi::CudaAtomicAdd(&buf_hist[output_index], 1); } } @@ -71,6 +70,60 @@ __global__ void KernelHistogram(const T* input, } } +template +__global__ void KernelMinMax(const T* input, + const int numel, + const int block_num, + T* min_ptr, + T* max_ptr) { + int64_t index = threadIdx.x + blockIdx.x * blockDim.x; + int64_t i = index; + T min_value = static_cast(i < numel ? input[i] : input[0]); + T max_value = static_cast(i < numel ? input[i] : input[0]); + + for (; i < numel; i += blockDim.x * gridDim.x) { + T value = static_cast(input[i]); + min_value = value < min_value ? value : min_value; + max_value = value > max_value ? value : max_value; + } + if (max_ptr && min_ptr) { + __syncthreads(); + T block_min_value = phi::funcs::BlockReduceMin(min_value, FINAL_MASK); + T block_max_value = phi::funcs::BlockReduceMax(max_value, FINAL_MASK); + + if (threadIdx.x == 0) { + min_ptr[blockIdx.x] = block_min_value; + max_ptr[blockIdx.x] = block_max_value; + } + } + __syncthreads(); + if (index == 0) { + if (min_ptr && max_ptr) { + min_value = min_ptr[0]; + max_value = max_ptr[0]; + for (int64_t i = 1; i < block_num; i++) { + min_ptr[0] = min_ptr[i] < min_value ? min_ptr[i] : min_value; + max_ptr[0] = max_ptr[i] > max_value ? max_ptr[i] : max_value; + } + if (min_ptr[0] == max_ptr[0]) { + min_ptr[0] = min_ptr[0] - 1; + max_ptr[0] = max_ptr[0] + 1; + } + } + } +} + +template +__global__ void KernelMinMax(const T min_value, + const T max_value, + T* min_ptr, + T* max_ptr) { + if (threadIdx.x == 0 && blockIdx.x == 0) { + min_ptr[0] = min_value; + max_ptr[0] = max_value; + } +} + template void HistogramKernel(const Context& dev_ctx, const DenseTensor& input, @@ -93,32 +146,20 @@ void HistogramKernel(const Context& dev_ctx, T output_min = static_cast(minval); T output_max = static_cast(maxval); - - if (output_min == output_max) { - auto input_x = phi::EigenVector::Flatten(input); - - DenseTensor input_min_t, input_max_t; - input_min_t.Resize({1}); - input_max_t.Resize({1}); - auto* input_min_data = dev_ctx.template Alloc(&input_min_t); - auto* input_max_data = dev_ctx.template Alloc(&input_max_t); - auto input_min_scala = phi::EigenScalar::From(input_min_t); - auto input_max_scala = phi::EigenScalar::From(input_max_t); - - auto* place = dev_ctx.eigen_device(); - input_min_scala.device(*place) = input_x.minimum(); - input_max_scala.device(*place) = input_x.maximum(); - - DenseTensor input_min_cpu, input_max_cpu; - phi::Copy(dev_ctx, input_min_t, phi::CPUPlace(), true, &input_min_cpu); - phi::Copy(dev_ctx, input_max_t, phi::CPUPlace(), true, &input_max_cpu); - - output_min = input_min_cpu.data()[0]; - output_max = input_max_cpu.data()[0]; - } + DenseTensor min_max; + int block_num = GET_BLOCKS(input_numel); + min_max.Resize({2 * block_num}); + auto* min_block_ptr = dev_ctx.template Alloc(&min_max); + auto* max_block_ptr = min_block_ptr + block_num; if (output_min == output_max) { - output_min = output_min - 1; - output_max = output_max + 1; + KernelMinMax<<>>( + input_data, input_numel, block_num, min_block_ptr, max_block_ptr); + } else { + KernelMinMax<<<1, 1, 0, dev_ctx.stream()>>>( + output_min, output_max, min_block_ptr, max_block_ptr); } PADDLE_ENFORCE_EQ((std::isinf(static_cast(output_min)) || @@ -142,7 +183,7 @@ void HistogramKernel(const Context& dev_ctx, PADDLE_CUDA_NUM_THREADS, nbins * sizeof(int64_t), stream>>>( - input_data, input_numel, nbins, output_min, output_max, out_data); + input_data, input_numel, nbins, min_block_ptr, max_block_ptr, out_data); } } // namespace phi