random_engine.h 3.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <glog/logging.h>
#include <stdint.h>

#include <random>

namespace cinn {
namespace utils {

/**
26 27 28 29 30 31 32
 * LinearRandomEngine is a random number engine using linear congruence
 * algorithm. The transition function of state is: x(i + 1) = (multiplier * x(i)
 * + increment) mod modulus. Its interface and members are roughly the same as
 * std::linear_congruential_engine, which can be used for std::xxx_distribution.
 * The difference from std::linear_congruential_engine is that the
 * LinearRandomEngine does not own the random seed, but holds the pointer of the
 * random seed and transfers the state for other objects.
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
 */
class LinearRandomEngine {
 public:
  using StateType = int64_t;
  // the type name "resule_type" is needed by std::xxx_distribution
  using result_type = uint32_t;

  // The minimum possible value of random state
  static constexpr result_type min() { return 0; }
  // The maximum possible value of random state
  static constexpr result_type max() { return modulus - 1; }
  // The multiplier
  static constexpr StateType multiplier = 48271;
  // The increment
  static constexpr StateType increment = 0;
  // The modulus
  static constexpr StateType modulus = 2147483647;

  // Construct a linear random engine with a random state pointer
  LinearRandomEngine(StateType* state) : state_(state) {}

  // operator() is needed by std::xxx_distribution
  result_type operator()() { return Next(); }

  // Get a device random state
58 59 60
  static StateType GetDeviceRandomValue() {
    return (std::random_device()()) % modulus;
  }
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90

  // Normalize the random seed to the range of [1, modulus - 1]
  static StateType NormalizeState(StateType state) {
    if (state == -1) {
      state = GetDeviceRandomValue();
    } else {
      state %= modulus;
    }
    if (state == 0) {
      state = 1;
    }
    CHECK_GE(state, 0) << "Random seed must be greater than 0";

    return state;
  }

  // Fork a new state for another Random Generator from current state
  StateType ForkState() { return (Next() * 32767) % 1999999973; }

 private:
  // Move the state to the next and return the new state
  result_type Next() {
    *state_ = (increment + (*state_) * multiplier) % modulus;
    return *state_;
  }

 private:
  StateType* state_;
};

91 92 93 94
// Fork a new random state for another Random Generator, the original seed will
// be changed to next state.
inline LinearRandomEngine::StateType ForkRandomState(
    LinearRandomEngine::StateType* rand_seed) {
95 96 97 98
  return LinearRandomEngine(rand_seed).ForkState();
}

// Sample Integers from uniform distribution [min, max)
99 100 101
int SampleUniformInt(int min,
                     int max,
                     LinearRandomEngine::StateType* rand_seed);
102 103

// Sample Real Numbers from uniform distribution [min, max)
104 105 106
double SampleUniformDouble(double min,
                           double max,
                           LinearRandomEngine::StateType* rand_seed);
107 108 109

// Sample Integers from distribution of input weights
template <typename T>
110 111
int SampleDiscreteFromDistribution(const std::vector<T>& weights,
                                   LinearRandomEngine::StateType* rand_seed) {
112
  CHECK_GT(weights.size(), 0);
113 114 115 116 117 118 119
  LinearRandomEngine engine(rand_seed);
  std::discrete_distribution<int> dist(weights.begin(), weights.end());
  return dist(engine);
}

}  // namespace utils
}  // namespace cinn