random.cuh 2.7 KB
Newer Older
S
Siming Dai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
/*
 * Copyright (c) 2019-2022, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#pragma once

#ifdef __NVCC__
#include <cuda_runtime_api.h>  // NOLINT
#endif

class RandomNumGen {
 public:
  __host__ __device__ __forceinline__ RandomNumGen(int gid, unsigned long long seed) {
    next_random = seed + gid;
    next_random ^= next_random >> 33U;
    next_random *= 0xff51afd7ed558ccdUL;
    next_random ^= next_random >> 33U;
    next_random *= 0xc4ceb9fe1a85ec53UL;
    next_random ^= next_random >> 33U;
  }
  __host__ __device__ __forceinline__ ~RandomNumGen() = default;
  __host__ __device__ __forceinline__ void SetSeed(int seed) {
    next_random = seed;
    NextValue();
  }
  __host__ __device__ __forceinline__ unsigned long long SaveState() const {
    return next_random;
  }
  __host__ __device__ __forceinline__ void LoadState(unsigned long long state) {
    next_random = state;
  }
  __host__ __device__ __forceinline__ int Random() {
    int ret_value = (int) (next_random & 0x7fffffffULL);
    NextValue();
    return ret_value;
  }
  __host__ __device__ __forceinline__ int RandomMod(int mod) {
    return Random() % mod;
  }
  __host__ __device__ __forceinline__ int64_t Random64() {
    int64_t ret_value = (next_random & 0x7FFFFFFFFFFFFFFFLL);
    NextValue();
    return ret_value;
  }
  __host__ __device__ __forceinline__ int64_t RandomMod64(int64_t mod) {
    return Random64() % mod;
  }
  __host__ __device__ __forceinline__ float RandomUniformFloat(float max = 1.0f, float min = 0.0f) {
    int value = (int) (next_random & 0xffffff);
    auto ret_value = (float) value;
    ret_value /= 0xffffffL;
    ret_value *= (max - min);
    ret_value += min;
    NextValue();
    return ret_value;
  }
  __host__ __device__ __forceinline__ bool RandomBool(float true_prob) {
    float value = RandomUniformFloat();
    return value <= true_prob;
  }
  __host__ __device__ __forceinline__ void NextValue() {
    //next_random = next_random * (unsigned long long)0xc4ceb9fe1a85ec53UL + generator_id;
    //next_random = next_random * (unsigned long long)25214903917ULL + 11;
    next_random = next_random * (unsigned long long) 13173779397737131ULL + 1023456798976543201ULL;
  }

 private:
  unsigned long long next_random = 1;
};