未验证 提交 c1a61fc0 编写于 作者: Z Zero Rains 提交者: GitHub

【PaddlePaddle Hackathon 4 No.33】为 Paddle 优化 Histogram op 在 GPU 上的计算性能 (#53112)

* create KernelMinMax to optimize the performance of histogram op in GPU

* change to block and warp wise operation

* remove the time in DtoH

* fix a bug
上级 22e96bde
...@@ -18,8 +18,7 @@ ...@@ -18,8 +18,7 @@
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace phi { namespace phi {
...@@ -46,8 +45,8 @@ template <typename T, typename IndexType> ...@@ -46,8 +45,8 @@ template <typename T, typename IndexType>
__global__ void KernelHistogram(const T* input, __global__ void KernelHistogram(const T* input,
const int total_elements, const int total_elements,
const int64_t nbins, const int64_t nbins,
const T min_value, const T* min_value,
const T max_value, const T* max_value,
int64_t* output) { int64_t* output) {
extern __shared__ int64_t buf_hist[]; extern __shared__ int64_t buf_hist[];
for (int i = threadIdx.x; i < nbins; i += blockDim.x) { for (int i = threadIdx.x; i < nbins; i += blockDim.x) {
...@@ -58,9 +57,9 @@ __global__ void KernelHistogram(const T* input, ...@@ -58,9 +57,9 @@ __global__ void KernelHistogram(const T* input,
CUDA_KERNEL_LOOP(input_index, total_elements) { CUDA_KERNEL_LOOP(input_index, total_elements) {
// const IndexType input_index = threadIdx.x + blockIdx.x * blockDim.x; // const IndexType input_index = threadIdx.x + blockIdx.x * blockDim.x;
const auto input_value = input[input_index]; const auto input_value = input[input_index];
if (input_value >= min_value && input_value <= max_value) { if (input_value >= *min_value && input_value <= *max_value) {
const IndexType output_index = const IndexType output_index =
GetBin<T, IndexType>(input_value, min_value, max_value, nbins); GetBin<T, IndexType>(input_value, *min_value, *max_value, nbins);
phi::CudaAtomicAdd(&buf_hist[output_index], 1); phi::CudaAtomicAdd(&buf_hist[output_index], 1);
} }
} }
...@@ -71,6 +70,60 @@ __global__ void KernelHistogram(const T* input, ...@@ -71,6 +70,60 @@ __global__ void KernelHistogram(const T* input,
} }
} }
template <typename T>
__global__ void KernelMinMax(const T* input,
const int numel,
const int block_num,
T* min_ptr,
T* max_ptr) {
int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
int64_t i = index;
T min_value = static_cast<T>(i < numel ? input[i] : input[0]);
T max_value = static_cast<T>(i < numel ? input[i] : input[0]);
for (; i < numel; i += blockDim.x * gridDim.x) {
T value = static_cast<T>(input[i]);
min_value = value < min_value ? value : min_value;
max_value = value > max_value ? value : max_value;
}
if (max_ptr && min_ptr) {
__syncthreads();
T block_min_value = phi::funcs::BlockReduceMin<T>(min_value, FINAL_MASK);
T block_max_value = phi::funcs::BlockReduceMax<T>(max_value, FINAL_MASK);
if (threadIdx.x == 0) {
min_ptr[blockIdx.x] = block_min_value;
max_ptr[blockIdx.x] = block_max_value;
}
}
__syncthreads();
if (index == 0) {
if (min_ptr && max_ptr) {
min_value = min_ptr[0];
max_value = max_ptr[0];
for (int64_t i = 1; i < block_num; i++) {
min_ptr[0] = min_ptr[i] < min_value ? min_ptr[i] : min_value;
max_ptr[0] = max_ptr[i] > max_value ? max_ptr[i] : max_value;
}
if (min_ptr[0] == max_ptr[0]) {
min_ptr[0] = min_ptr[0] - 1;
max_ptr[0] = max_ptr[0] + 1;
}
}
}
}
template <typename T>
__global__ void KernelMinMax(const T min_value,
const T max_value,
T* min_ptr,
T* max_ptr) {
if (threadIdx.x == 0 && blockIdx.x == 0) {
min_ptr[0] = min_value;
max_ptr[0] = max_value;
}
}
template <typename T, typename Context> template <typename T, typename Context>
void HistogramKernel(const Context& dev_ctx, void HistogramKernel(const Context& dev_ctx,
const DenseTensor& input, const DenseTensor& input,
...@@ -93,32 +146,20 @@ void HistogramKernel(const Context& dev_ctx, ...@@ -93,32 +146,20 @@ void HistogramKernel(const Context& dev_ctx,
T output_min = static_cast<T>(minval); T output_min = static_cast<T>(minval);
T output_max = static_cast<T>(maxval); T output_max = static_cast<T>(maxval);
DenseTensor min_max;
if (output_min == output_max) { int block_num = GET_BLOCKS(input_numel);
auto input_x = phi::EigenVector<T>::Flatten(input); min_max.Resize({2 * block_num});
auto* min_block_ptr = dev_ctx.template Alloc<T>(&min_max);
DenseTensor input_min_t, input_max_t; auto* max_block_ptr = min_block_ptr + block_num;
input_min_t.Resize({1});
input_max_t.Resize({1});
auto* input_min_data = dev_ctx.template Alloc<T>(&input_min_t);
auto* input_max_data = dev_ctx.template Alloc<T>(&input_max_t);
auto input_min_scala = phi::EigenScalar<T>::From(input_min_t);
auto input_max_scala = phi::EigenScalar<T>::From(input_max_t);
auto* place = dev_ctx.eigen_device();
input_min_scala.device(*place) = input_x.minimum();
input_max_scala.device(*place) = input_x.maximum();
DenseTensor input_min_cpu, input_max_cpu;
phi::Copy(dev_ctx, input_min_t, phi::CPUPlace(), true, &input_min_cpu);
phi::Copy(dev_ctx, input_max_t, phi::CPUPlace(), true, &input_max_cpu);
output_min = input_min_cpu.data<T>()[0];
output_max = input_max_cpu.data<T>()[0];
}
if (output_min == output_max) { if (output_min == output_max) {
output_min = output_min - 1; KernelMinMax<T><<<GET_BLOCKS(input_numel),
output_max = output_max + 1; PADDLE_CUDA_NUM_THREADS,
0,
dev_ctx.stream()>>>(
input_data, input_numel, block_num, min_block_ptr, max_block_ptr);
} else {
KernelMinMax<T><<<1, 1, 0, dev_ctx.stream()>>>(
output_min, output_max, min_block_ptr, max_block_ptr);
} }
PADDLE_ENFORCE_EQ((std::isinf(static_cast<float>(output_min)) || PADDLE_ENFORCE_EQ((std::isinf(static_cast<float>(output_min)) ||
...@@ -142,7 +183,7 @@ void HistogramKernel(const Context& dev_ctx, ...@@ -142,7 +183,7 @@ void HistogramKernel(const Context& dev_ctx,
PADDLE_CUDA_NUM_THREADS, PADDLE_CUDA_NUM_THREADS,
nbins * sizeof(int64_t), nbins * sizeof(int64_t),
stream>>>( stream>>>(
input_data, input_numel, nbins, output_min, output_max, out_data); input_data, input_numel, nbins, min_block_ptr, max_block_ptr, out_data);
} }
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册