bernoulli_kernel.cu 4.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <thrust/random.h>
#include <thrust/transform.h>
17 18 19 20 21 22 23
#ifdef __NVCC__
#include <curand_kernel.h>
#endif
#ifdef __HIPCC__
#include <hiprand_kernel.h>
#endif

24 25
#include <algorithm>
#include <vector>
26

27
#include "paddle/phi/backends/gpu/gpu_context.h"
28
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
29 30 31
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/bernoulli_kernel.h"
32
#include "paddle/phi/kernels/funcs/distribution_helper.h"
33 34 35 36

// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/transform.h"

37 38
DECLARE_bool(use_curand);

39
namespace phi {
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62

template <typename T>
struct BernoulliCudaFunctor {
  unsigned int seed_;
  unsigned int offset_;
  __host__ __device__ BernoulliCudaFunctor(unsigned int seed,
                                           unsigned int offset)
      : seed_(seed), offset_(offset) {}

  __host__ __device__ T operator()(const unsigned int n, const T p) const {
    // NOTE(zhiqiu): currently, PADDLE_ENFORCE in cuda kernel may print several
    // lines of error messages if, and it should be refined.
    PADDLE_ENFORCE(p >= 0.0 && p <= 1.0,
                   "The probability should be >=0 and <= 1, but got %f",
                   p);
    thrust::minstd_rand rng;
    rng.seed(seed_);
    thrust::uniform_real_distribution<T> dist(0.0, 1.0);
    rng.discard(n + offset_);
    return static_cast<T>(dist(rng) < p);
  }
};

63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
// 'curand_uniform4/hiprand_uniform4' generate 4 random number each time
template <typename T>
__global__ void bernoulli_cuda_kernel(
    size_t size, uint64_t seed, uint64_t offset, const T* x_data, T* out_data) {
  size_t thread_idx =
      static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);

#if defined(__NVCC__)
  curandStatePhilox4_32_10_t state;
  curand_init(seed, thread_idx, offset, &state);
#else
  hiprandStatePhilox4_32_10_t state;
  hiprand_init(seed, thread_idx, offset, &state);
#endif

  size_t total_thread = gridDim.x * blockDim.x;
  for (size_t i = 4 * thread_idx; i < size; i += total_thread * 4) {
80
    funcs::uniform_distribution<float> dist;
81 82 83 84 85 86 87 88 89 90 91
    float4 rand = dist(&state);
#pragma unroll
    for (size_t j = 0; j < 4; j++) {
      size_t idx = i + j;
      if (idx < size) {
        out_data[idx] = static_cast<T>((&rand.x)[j] <= x_data[idx]);
      }
    }
  }
}

92 93 94 95
template <typename T, typename Context>
void BernoulliKernel(const Context& ctx,
                     const DenseTensor& x,
                     DenseTensor* out) {
96
  const T* x_data = x.data<T>();
97
  T* out_data = ctx.template Alloc<T>(out);
98
  auto numel = x.numel();
99 100

  auto gen_cuda = ctx.GetGenerator();
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125

  if (FLAGS_use_curand) {
    auto seed_offset = gen_cuda->IncrementOffset(12);
    uint64_t seed = seed_offset.first;
    uint64_t offset = seed_offset.second;

    auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 4);
    size_t grid_size = gpu_config.GetGridSize();
    size_t block_size = gpu_config.GetBlockSize();

    bernoulli_cuda_kernel<<<grid_size, block_size, 0, ctx.stream()>>>(
        numel, seed, offset, x_data, out_data);
  } else {
    auto seed_offset = gen_cuda->IncrementOffset(1);
    int64_t gen_offset = numel * seed_offset.second;
    paddle::platform::Transform<phi::GPUContext> trans;
    thrust::counting_iterator<int64_t> index_sequence_begin(0);
    trans(ctx,
          index_sequence_begin,
          index_sequence_begin + numel,
          x_data,
          out_data,
          BernoulliCudaFunctor<T>(static_cast<int64_t>(seed_offset.first),
                                  static_cast<int64_t>(gen_offset)));
  }
126 127
}

128
}  // namespace phi
129

130
PD_REGISTER_KERNEL(
131
    bernoulli, GPU, ALL_LAYOUT, phi::BernoulliKernel, float, double) {}