From 1354652be009d09ab13837bda5ff538f1d0991ff Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Thu, 17 Feb 2022 09:43:42 +0800 Subject: [PATCH] Modified distribution kernel with Kernel Primitive API (#39563) --- paddle/fluid/operators/distribution_helper.h | 35 +++++++----- .../kernels/primitive/compute_primitives.h | 53 +++++++++++++++++++ 2 files changed, 74 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/distribution_helper.h b/paddle/fluid/operators/distribution_helper.h index a13ae57090..f3bce38e3a 100644 --- a/paddle/fluid/operators/distribution_helper.h +++ b/paddle/fluid/operators/distribution_helper.h @@ -28,6 +28,10 @@ limitations under the License. */ #include "paddle/fluid/platform/for_range.h" #include "paddle/pten/core/hostdevice.h" +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/pten/kernels/primitive/kernel_primitives.h" +#endif + #if !defined(_WIN32) #define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) #else @@ -91,6 +95,8 @@ struct normal_transform { #if defined(__NVCC__) || defined(__HIPCC__) +namespace kps = pten::kps; + /*********************** Distribution Function *************************/ template struct uniform_distribution; @@ -176,25 +182,26 @@ template __global__ void DistributionKernel(size_t size, uint64_t seed, uint64_t offset, DistOp dist, TransformOp trans, T *out_data) { - size_t idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - int32_t returns_count = DistOp::kReturnsCount; + size_t idx = static_cast(BLOCK_ID_X * BLOCK_NUM_X); + static constexpr int kCount = DistOp::kReturnsCount; #if defined(__NVCC__) curandStatePhilox4_32_10_t state; - curand_init(seed, idx, offset, &state); + curand_init(seed, idx + THREAD_ID_X, offset, &state); + using SType = curandStatePhilox4_32_10_t; #else hiprandStatePhilox4_32_10_t state; - hiprand_init(seed, idx, offset, &state); + hiprand_init(seed, idx + THREAD_ID_X, offset, &state); + using SType = hiprandStatePhilox4_32_10_t; #endif - size_t total_thread = gridDim.x * blockDim.x; - for (size_t i = idx; i < size; i += total_thread * returns_count) { - auto random_tuple = dist(&state); - for (size_t j = 0; j < returns_count; j++) { - size_t index = i + j * total_thread; - if (index < size) { - auto random = (&random_tuple.x)[j]; - out_data[index] = static_cast(trans(random)); - } - } + size_t total_thread = GRID_NUM_X * BLOCK_NUM_X; + T args[kCount]; + T result[kCount]; + for (size_t i = idx; i < size; i += total_thread * kCount) { + kps::ElementwiseRandom(&args[0], dist, &state); + kps::ElementwiseUnary(&result[0], &args[0], + trans); + kps::WriteData(out_data + i, &result[0], size - i, + 1, total_thread, 1); } } diff --git a/paddle/pten/kernels/primitive/compute_primitives.h b/paddle/pten/kernels/primitive/compute_primitives.h index a8ed081622..02a2f7baf7 100644 --- a/paddle/pten/kernels/primitive/compute_primitives.h +++ b/paddle/pten/kernels/primitive/compute_primitives.h @@ -428,5 +428,58 @@ __device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) { } } +template +__device__ __forceinline__ void ElementwiseRandom(OutT* out, + OpFunc compute, + StateType* state) { + auto random_tuple = compute(state); +#pragma unroll + for (int i = 0; i < ReturnsCount; i++) { + out[i] = static_cast((&random_tuple.x)[i]); + } +} + +// attention please set share_size = blockDim.x; +// data and b are the register pointer +#define shared_size 64 +template +__device__ __forceinline__ void Cumsum(OutT* out, + const InT* in, + OpFunc compute) { + __shared__ InT temp[shared_size * 2 + (shared_size * 2) / 32]; + int tidx = threadIdx.x; + temp[tidx + tidx / 32] = in[0]; + temp[shared_size + tidx + (shared_size + tidx) / 32] = in[1]; + for (int stride = 1; stride <= blockDim.x; stride *= 2) { + __syncthreads(); + int index = (tidx + 1) * 2 * stride - 1; + if (index < (blockDim.x * 2)) { + temp[index + index / 32] += temp[index - stride + (index - stride) / 32]; + } + } + for (int stride = (blockDim.x * 2) / 4; stride > 0; stride /= 2) { + __syncthreads(); + int index = (tidx + 1) * 2 * stride - 1; + if ((index + stride) < (blockDim.x * 2)) { + temp[index + stride + (stride + index) / 32] += + temp[index + (index) / 32]; + } + } + + __syncthreads(); + out[0] = static_cast(temp[tidx + tidx / 32]); + out[1] = + static_cast(temp[tidx + shared_size + (tidx + shared_size) / 32]); +} + } // namespace kps } // namespace pten -- GitLab