bernoulli_kernel.cu 2.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/phi/kernels/bernoulli_kernel.h"

17 18 19 20 21 22 23
#ifdef __NVCC__
#include <curand_kernel.h>
#endif
#ifdef __HIPCC__
#include <hiprand_kernel.h>
#endif

24 25
#include <algorithm>
#include <vector>
26

27
#include "paddle/phi/backends/gpu/gpu_context.h"
28
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
29 30
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
31
#include "paddle/phi/kernels/funcs/distribution_helper.h"
32

33
namespace phi {
34

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
// 'curand_uniform4/hiprand_uniform4' generate 4 random number each time
template <typename T>
__global__ void bernoulli_cuda_kernel(
    size_t size, uint64_t seed, uint64_t offset, const T* x_data, T* out_data) {
  size_t thread_idx =
      static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);

#if defined(__NVCC__)
  curandStatePhilox4_32_10_t state;
  curand_init(seed, thread_idx, offset, &state);
#else
  hiprandStatePhilox4_32_10_t state;
  hiprand_init(seed, thread_idx, offset, &state);
#endif

  size_t total_thread = gridDim.x * blockDim.x;
  for (size_t i = 4 * thread_idx; i < size; i += total_thread * 4) {
52
    funcs::uniform_distribution<float> dist;
53 54 55 56 57 58 59 60 61 62 63
    float4 rand = dist(&state);
#pragma unroll
    for (size_t j = 0; j < 4; j++) {
      size_t idx = i + j;
      if (idx < size) {
        out_data[idx] = static_cast<T>((&rand.x)[j] <= x_data[idx]);
      }
    }
  }
}

64 65 66 67
template <typename T, typename Context>
void BernoulliKernel(const Context& ctx,
                     const DenseTensor& x,
                     DenseTensor* out) {
68
  const T* x_data = x.data<T>();
69
  T* out_data = ctx.template Alloc<T>(out);
70
  auto numel = x.numel();
71 72

  auto gen_cuda = ctx.GetGenerator();
73

74 75 76
  auto seed_offset = gen_cuda->IncrementOffset(12);
  uint64_t seed = seed_offset.first;
  uint64_t offset = seed_offset.second;
77

78 79 80
  auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 4);
  size_t grid_size = gpu_config.GetGridSize();
  size_t block_size = gpu_config.GetBlockSize();
81

82 83
  bernoulli_cuda_kernel<<<grid_size, block_size, 0, ctx.stream()>>>(
      numel, seed, offset, x_data, out_data);
84 85
}

86
}  // namespace phi
87

88
PD_REGISTER_KERNEL(
89
    bernoulli, GPU, ALL_LAYOUT, phi::BernoulliKernel, float, double) {}