未验证 提交 1354652b 编写于 作者: N niuliling123 提交者: GitHub

Modified distribution kernel with Kernel Primitive API (#39563)

上级 a909bdf1
......@@ -28,6 +28,10 @@ limitations under the License. */
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/core/hostdevice.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/pten/kernels/primitive/kernel_primitives.h"
#endif
#if !defined(_WIN32)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#else
......@@ -91,6 +95,8 @@ struct normal_transform {
#if defined(__NVCC__) || defined(__HIPCC__)
namespace kps = pten::kps;
/*********************** Distribution Function *************************/
template <typename T>
struct uniform_distribution;
......@@ -176,25 +182,26 @@ template <typename T, typename DistOp, typename TransformOp>
__global__ void DistributionKernel(size_t size, uint64_t seed, uint64_t offset,
DistOp dist, TransformOp trans,
T *out_data) {
size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
int32_t returns_count = DistOp::kReturnsCount;
size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
static constexpr int kCount = DistOp::kReturnsCount;
#if defined(__NVCC__)
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, offset, &state);
curand_init(seed, idx + THREAD_ID_X, offset, &state);
using SType = curandStatePhilox4_32_10_t;
#else
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed, idx, offset, &state);
hiprand_init(seed, idx + THREAD_ID_X, offset, &state);
using SType = hiprandStatePhilox4_32_10_t;
#endif
size_t total_thread = gridDim.x * blockDim.x;
for (size_t i = idx; i < size; i += total_thread * returns_count) {
auto random_tuple = dist(&state);
for (size_t j = 0; j < returns_count; j++) {
size_t index = i + j * total_thread;
if (index < size) {
auto random = (&random_tuple.x)[j];
out_data[index] = static_cast<T>(trans(random));
}
}
size_t total_thread = GRID_NUM_X * BLOCK_NUM_X;
T args[kCount];
T result[kCount];
for (size_t i = idx; i < size; i += total_thread * kCount) {
kps::ElementwiseRandom<SType, T, kCount, 1, DistOp>(&args[0], dist, &state);
kps::ElementwiseUnary<T, T, kCount, 1, 1, TransformOp>(&result[0], &args[0],
trans);
kps::WriteData<T, T, kCount, 1, 1, true>(out_data + i, &result[0], size - i,
1, total_thread, 1);
}
}
......
......@@ -428,5 +428,58 @@ __device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) {
}
}
template <typename StateType,
typename OutT,
int ReturnsCount,
int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseRandom(OutT* out,
OpFunc compute,
StateType* state) {
auto random_tuple = compute(state);
#pragma unroll
for (int i = 0; i < ReturnsCount; i++) {
out[i] = static_cast<OutT>((&random_tuple.x)[i]);
}
}
// attention please set share_size = blockDim.x;
// data and b are the register pointer
#define shared_size 64
template <typename InT,
typename OutT,
int NX,
int NY,
int BlockSize,
class OpFunc>
__device__ __forceinline__ void Cumsum(OutT* out,
const InT* in,
OpFunc compute) {
__shared__ InT temp[shared_size * 2 + (shared_size * 2) / 32];
int tidx = threadIdx.x;
temp[tidx + tidx / 32] = in[0];
temp[shared_size + tidx + (shared_size + tidx) / 32] = in[1];
for (int stride = 1; stride <= blockDim.x; stride *= 2) {
__syncthreads();
int index = (tidx + 1) * 2 * stride - 1;
if (index < (blockDim.x * 2)) {
temp[index + index / 32] += temp[index - stride + (index - stride) / 32];
}
}
for (int stride = (blockDim.x * 2) / 4; stride > 0; stride /= 2) {
__syncthreads();
int index = (tidx + 1) * 2 * stride - 1;
if ((index + stride) < (blockDim.x * 2)) {
temp[index + stride + (stride + index) / 32] +=
temp[index + (index) / 32];
}
}
__syncthreads();
out[0] = static_cast<OutT>(temp[tidx + tidx / 32]);
out[1] =
static_cast<OutT>(temp[tidx + shared_size + (tidx + shared_size) / 32]);
}
} // namespace kps
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册