distribution_helper.h 8.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#ifdef __NVCC__
#include <curand_kernel.h>
#endif
#ifdef __HIPCC__
#include <hiprand_kernel.h>
#endif

24
#include "paddle/phi/backends/gpu/gpu_context.h"
25
#include "paddle/phi/backends/gpu/gpu_info.h"
26
#include "paddle/phi/common/amp_type_traits.h"
27 28
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/generator.h"
29
#include "paddle/phi/core/hostdevice.h"
30 31

#if defined(__NVCC__) || defined(__HIPCC__)
32
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
33 34 35 36 37 38 39 40 41 42 43
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
#endif

#if !defined(_WIN32)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#else
// there is no equivalent intrinsics in msvc.
#define UNLIKELY(condition) (condition)
#endif

namespace phi {
44
namespace funcs {
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67

/********************* Transformation Function **********************/
template <typename T>
struct exponential_transform {
  explicit exponential_transform(T lambda) : lambda_(lambda) {}

  HOSTDEVICE inline T operator()(T val) const {
#if defined(__NVCC__) || defined(__HIPCC__)
    if (std::is_same<T, double>::value) {
      return static_cast<T>(-1.0) / lambda_ * log(val);
    } else {
      return static_cast<T>(-1.0) / lambda_ * __logf(val);
    }
#else
    return static_cast<T>(-1.0) / lambda_ * std::log(static_cast<T>(1.0) - val);
#endif
  }

 private:
  T lambda_;
};

template <typename T>
68 69 70
struct uniform_real_transform {
  explicit uniform_real_transform(T min, T max)
      : range_(max - min), min_(min) {}
71 72 73 74 75 76 77 78 79 80 81 82 83 84

  HOSTDEVICE inline T operator()(T val) const {
    if (UNLIKELY(val == static_cast<T>(1.0))) {
      return min_;
    } else {
      return val * range_ + min_;
    }
  }

 private:
  T range_;
  T min_;
};

85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
template <typename T, typename R>
struct uniform_int_transform {
  explicit uniform_int_transform(int min, int max) {
    range_ = static_cast<uint32_t>(max - min);
    min_ = min;
  }

  HOSTDEVICE inline T operator()(R rand) const {
    return static_cast<T>(static_cast<int>(rand % range_) + min_);
  }

 private:
  uint32_t range_;
  int min_;
};

101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
template <typename T>
struct normal_transform {
  explicit normal_transform(T mean, T std) : mean_(mean), std_(std) {}

  HOSTDEVICE inline T operator()(T val) const { return val * std_ + mean_; }

 private:
  T mean_;
  T std_;
};

#if defined(__NVCC__) || defined(__HIPCC__)

namespace kps = phi::kps;

/*********************** Distribution Function *************************/
template <typename T>
struct uniform_distribution;

template <typename T>
struct normal_distribution;

#if defined(__NVCC__)
template <>
struct uniform_distribution<float> {
  __device__ inline float4 operator()(curandStatePhilox4_32_10_t *state) const {
    return curand_uniform4(state);
  }
  static constexpr int kReturnsCount = 4;
};

template <>
struct uniform_distribution<double> {
  __device__ inline double2 operator()(
      curandStatePhilox4_32_10_t *state) const {
    return curand_uniform2_double(state);
  }
  static constexpr int kReturnsCount = 2;
};

141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
template <>
struct uniform_distribution<uint32_t> {
  __device__ inline uint4 operator()(curandStatePhilox4_32_10_t *state) const {
    return curand4(state);
  }
  static constexpr int kReturnsCount = 4;
};

template <>
struct uniform_distribution<uint64_t> {
  __device__ inline ulonglong2 operator()(
      curandStatePhilox4_32_10_t *state) const {
    ulonglong2 result;
    uint4 rand = curand4(state);
    result.x = (uint64_t)rand.x << 32 | rand.y;
    result.y = (uint64_t)rand.z << 32 | rand.w;
    return result;
  }
  static constexpr int kReturnsCount = 2;
};

162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
template <>
struct normal_distribution<float> {
  __device__ inline float4 operator()(curandStatePhilox4_32_10_t *state) const {
    return curand_normal4(state);
  }
  static constexpr int kReturnsCount = 4;
};

template <>
struct normal_distribution<double> {
  __device__ inline double2 operator()(
      curandStatePhilox4_32_10_t *state) const {
    return curand_normal2_double(state);
  }
  static constexpr int kReturnsCount = 2;
};

#else
template <>
struct uniform_distribution<float> {
  __device__ inline float4 operator()(
      hiprandStatePhilox4_32_10_t *state) const {
    return hiprand_uniform4(state);
  }
  static constexpr int kReturnsCount = 4;
};

template <>
struct uniform_distribution<double> {
  __device__ inline double2 operator()(
      hiprandStatePhilox4_32_10_t *state) const {
    return hiprand_uniform2_double(state);
  }
  static constexpr int kReturnsCount = 2;
};

198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
template <>
struct uniform_distribution<uint32_t> {
  __device__ inline uint4 operator()(hiprandStatePhilox4_32_10_t *state) const {
    return hiprand4(state);
  }
  static constexpr int kReturnsCount = 4;
};

template <>
struct uniform_distribution<uint64_t> {
  __device__ inline ulonglong2 operator()(
      hiprandStatePhilox4_32_10_t *state) const {
    ulonglong2 result;
    uint4 rand = hiprand4(state);
    result.x = (uint64_t)rand.x << 32 | rand.y;
    result.y = (uint64_t)rand.z << 32 | rand.w;
    return result;
  }
  static constexpr int kReturnsCount = 2;
};

219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
template <>
struct normal_distribution<float> {
  __device__ inline float4 operator()(
      hiprandStatePhilox4_32_10_t *state) const {
    return hiprand_normal4(state);
  }
  static constexpr int kReturnsCount = 4;
};

template <>
struct normal_distribution<double> {
  __device__ inline double2 operator()(
      hiprandStatePhilox4_32_10_t *state) const {
    return hiprand_normal2_double(state);
  }
  static constexpr int kReturnsCount = 2;
};
#endif

/******** Launch GPU function of distribution and transformation *********/
template <typename T, typename DistOp, typename TransformOp>
__global__ void DistributionKernel(size_t size,
                                   uint64_t seed,
                                   uint64_t offset,
                                   DistOp dist,
                                   TransformOp trans,
                                   T *out_data,
                                   size_t stride) {
  size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
  static constexpr int kCount = DistOp::kReturnsCount;
#if defined(__NVCC__)
  curandStatePhilox4_32_10_t state;
  curand_init(seed, idx + THREAD_ID_X, offset, &state);
  using SType = curandStatePhilox4_32_10_t;
#else
  hiprandStatePhilox4_32_10_t state;
  hiprand_init(seed, idx + THREAD_ID_X, offset, &state);
  using SType = hiprandStatePhilox4_32_10_t;
#endif
  size_t total_thread = GRID_NUM_X * BLOCK_NUM_X;
259 260
  using MT = typename phi::dtype::MPTypeTrait<T>::Type;
  MT args[kCount];
261 262
  T result[kCount];
  for (size_t i = idx; i < size; i += total_thread * kCount) {
263 264 265
    kps::ElementwiseRandom<SType, MT, kCount, 1, DistOp>(
        &args[0], dist, &state);
    kps::ElementwiseUnary<MT, T, kCount, 1, 1, TransformOp>(
266 267 268 269 270 271 272 273
        &result[0], &args[0], trans);
    kps::WriteData<T, T, kCount, 1, 1, true>(
        out_data + i, &result[0], size - i, 1, stride, 1);
    __syncthreads();
  }
}

template <typename T, typename DistOp, typename TransformOp>
274
void distribution_and_transform(const GPUContext &ctx,
275 276 277
                                DenseTensor *out,
                                DistOp dist,
                                TransformOp trans) {
278
  T *out_data = ctx.template Alloc<T>(out);
279
  auto size = out->numel();
280 281
  if (size == 0) return;
  auto gen_cuda = ctx.GetGenerator();
282 283 284

  size_t block_size = 256;
  size_t expect_grid_size = (size + block_size - 1) / block_size;
285 286 287 288

  int64_t device_id = ctx.GetPlace().GetDeviceId();
  const auto &prop = phi::backends::gpu::GetDeviceProperties(device_id);

289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
  size_t max_grid_size = (prop.maxThreadsPerMultiProcessor / block_size) *
                         prop.multiProcessorCount;
  size_t grid_size =
      expect_grid_size > max_grid_size ? max_grid_size : expect_grid_size;

  size_t total_thread = block_size * grid_size;
  size_t curand4_loop_times =
      (size + 4 * total_thread - 1) / (4 * total_thread);
  // 'increment' shoulde be multiple of 4
  uint64_t increment = curand4_loop_times * 4;

  auto seed_offset = gen_cuda->IncrementOffset(increment);
  uint64_t seed = seed_offset.first;
  uint64_t offset = seed_offset.second;

304 305 306
  DistributionKernel<T,
                     DistOp,
                     TransformOp><<<grid_size, block_size, 0, ctx.stream()>>>(
307 308 309 310
      size, seed, offset, dist, trans, out_data, total_thread);
}

#endif
311 312

}  // namespace funcs
313
}  // namespace phi