未验证 提交 1354652b 编写于 作者: N niuliling123 提交者: GitHub

Modified distribution kernel with Kernel Primitive API (#39563)

上级 a909bdf1
...@@ -28,6 +28,10 @@ limitations under the License. */ ...@@ -28,6 +28,10 @@ limitations under the License. */
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/core/hostdevice.h" #include "paddle/pten/core/hostdevice.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/pten/kernels/primitive/kernel_primitives.h"
#endif
#if !defined(_WIN32) #if !defined(_WIN32)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0) #define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#else #else
...@@ -91,6 +95,8 @@ struct normal_transform { ...@@ -91,6 +95,8 @@ struct normal_transform {
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
namespace kps = pten::kps;
/*********************** Distribution Function *************************/ /*********************** Distribution Function *************************/
template <typename T> template <typename T>
struct uniform_distribution; struct uniform_distribution;
...@@ -176,25 +182,26 @@ template <typename T, typename DistOp, typename TransformOp> ...@@ -176,25 +182,26 @@ template <typename T, typename DistOp, typename TransformOp>
__global__ void DistributionKernel(size_t size, uint64_t seed, uint64_t offset, __global__ void DistributionKernel(size_t size, uint64_t seed, uint64_t offset,
DistOp dist, TransformOp trans, DistOp dist, TransformOp trans,
T *out_data) { T *out_data) {
size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x); size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
int32_t returns_count = DistOp::kReturnsCount; static constexpr int kCount = DistOp::kReturnsCount;
#if defined(__NVCC__) #if defined(__NVCC__)
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, idx, offset, &state); curand_init(seed, idx + THREAD_ID_X, offset, &state);
using SType = curandStatePhilox4_32_10_t;
#else #else
hiprandStatePhilox4_32_10_t state; hiprandStatePhilox4_32_10_t state;
hiprand_init(seed, idx, offset, &state); hiprand_init(seed, idx + THREAD_ID_X, offset, &state);
using SType = hiprandStatePhilox4_32_10_t;
#endif #endif
size_t total_thread = gridDim.x * blockDim.x; size_t total_thread = GRID_NUM_X * BLOCK_NUM_X;
for (size_t i = idx; i < size; i += total_thread * returns_count) { T args[kCount];
auto random_tuple = dist(&state); T result[kCount];
for (size_t j = 0; j < returns_count; j++) { for (size_t i = idx; i < size; i += total_thread * kCount) {
size_t index = i + j * total_thread; kps::ElementwiseRandom<SType, T, kCount, 1, DistOp>(&args[0], dist, &state);
if (index < size) { kps::ElementwiseUnary<T, T, kCount, 1, 1, TransformOp>(&result[0], &args[0],
auto random = (&random_tuple.x)[j]; trans);
out_data[index] = static_cast<T>(trans(random)); kps::WriteData<T, T, kCount, 1, 1, true>(out_data + i, &result[0], size - i,
} 1, total_thread, 1);
}
} }
} }
......
...@@ -428,5 +428,58 @@ __device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) { ...@@ -428,5 +428,58 @@ __device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) {
} }
} }
template <typename StateType,
typename OutT,
int ReturnsCount,
int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseRandom(OutT* out,
OpFunc compute,
StateType* state) {
auto random_tuple = compute(state);
#pragma unroll
for (int i = 0; i < ReturnsCount; i++) {
out[i] = static_cast<OutT>((&random_tuple.x)[i]);
}
}
// attention please set share_size = blockDim.x;
// data and b are the register pointer
#define shared_size 64
template <typename InT,
typename OutT,
int NX,
int NY,
int BlockSize,
class OpFunc>
__device__ __forceinline__ void Cumsum(OutT* out,
const InT* in,
OpFunc compute) {
__shared__ InT temp[shared_size * 2 + (shared_size * 2) / 32];
int tidx = threadIdx.x;
temp[tidx + tidx / 32] = in[0];
temp[shared_size + tidx + (shared_size + tidx) / 32] = in[1];
for (int stride = 1; stride <= blockDim.x; stride *= 2) {
__syncthreads();
int index = (tidx + 1) * 2 * stride - 1;
if (index < (blockDim.x * 2)) {
temp[index + index / 32] += temp[index - stride + (index - stride) / 32];
}
}
for (int stride = (blockDim.x * 2) / 4; stride > 0; stride /= 2) {
__syncthreads();
int index = (tidx + 1) * 2 * stride - 1;
if ((index + stride) < (blockDim.x * 2)) {
temp[index + stride + (stride + index) / 32] +=
temp[index + (index) / 32];
}
}
__syncthreads();
out[0] = static_cast<OutT>(temp[tidx + tidx / 32]);
out[1] =
static_cast<OutT>(temp[tidx + shared_size + (tidx + shared_size) / 32]);
}
} // namespace kps } // namespace kps
} // namespace pten } // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册