generator.cc 6.4 KB
Newer Older
Y
yaoxuefeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

L
Leo Chen 已提交
15 16 17
#include "paddle/fluid/framework/generator.h"

#include <glog/logging.h>
18

Y
yaoxuefeng 已提交
19 20
#include <memory>
#include <utility>
Y
yaoxuefeng 已提交
21

22
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
Y
yaoxuefeng 已提交
23
#include "paddle/fluid/platform/enforce.h"
Y
yaoxuefeng 已提交
24 25 26 27

namespace paddle {
namespace framework {

28
const std::shared_ptr<Generator>& DefaultCUDAGenerator(int64_t device_id) {
29
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Y
yaoxuefeng 已提交
30 31 32 33 34 35 36

  static int64_t num_cuda_devices = -1;
  static std::once_flag num_devices_init_flag;
  static std::deque<std::once_flag> cuda_device_flags;
  static std::vector<std::shared_ptr<Generator>> default_cuda_generators;

  std::call_once(num_devices_init_flag, []() {
37
    num_cuda_devices = paddle::platform::GetGPUDeviceCount();
Y
yaoxuefeng 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
    cuda_device_flags.resize(num_cuda_devices);
    default_cuda_generators.resize(num_cuda_devices);
  });
  if (device_id < 0) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "cuda device id shoule be greater than 0"));
  }

  std::call_once(cuda_device_flags[device_id], [device_id]() {
    default_cuda_generators[device_id] =
        std::make_shared<Generator>(GetRandomSeed(), device_id);
    VLOG(4) << "initial seed: "
            << default_cuda_generators[device_id]->GetCurrentSeed();
  });
  return default_cuda_generators[device_id];
#else
  PADDLE_THROW(platform::errors::PermissionDenied(
      "getDefaultCUDAGenerator only support in CUDA place"));
#endif
}

L
Leo Chen 已提交
59 60 61 62 63 64
const std::shared_ptr<Generator>& DefaultCPUGenerator() {
  static auto default_cpu_generator =
      std::make_shared<Generator>(GetRandomSeed());
  return default_cpu_generator;
}

65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
using RNGMap = std::unordered_map<std::string, std::shared_ptr<Generator>>;

static RNGMap& GetRandomSeedGeneratorMap() {
  static auto random_seed_generator_map = RNGMap();
  return random_seed_generator_map;
}

const std::shared_ptr<Generator>& SetRandomSeedGenerator(
    const std::string& name, uint64_t seed) {
  auto& rng_map = GetRandomSeedGeneratorMap();
  auto iter = rng_map.find(name);
  PADDLE_ENFORCE_EQ(iter == rng_map.end(), true,
                    platform::errors::AlreadyExists(
                        "%s RandomSeedGenerator is already exist", name));

  auto generator = std::make_shared<Generator>(seed);
  bool emplace_success = rng_map.emplace(name, generator).second;
  PADDLE_ENFORCE_EQ(
      emplace_success, true,
      platform::errors::PermissionDenied(
          "SetRandomSeedGenerator cannot emplace %s RandomSeedGenerator",
          name));
  return rng_map[name];
}

const std::shared_ptr<Generator>& GetRandomSeedGenerator(
    const std::string& name) {
  auto& rng_map = GetRandomSeedGeneratorMap();
  auto iter = rng_map.find(name);
  PADDLE_ENFORCE_EQ(iter != rng_map.end(), true,
                    platform::errors::NotFound(
                        "%s RandomSeedGenerator is not found, please "
                        "use `set_random_seed_generator` to set rng first",
                        name));
  return iter->second;
}

102 103 104 105 106
// There are 3 conditions:
// (1) op seed is set, use op seed.
// (2) op seed is not set, global seed is set, use global seed.
// (3) op seed is not set, global seed is not set too, use random seed from
// RandomGenerator.
L
Leo Chen 已提交
107
std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t seed) {
108
  if (seed == 0) {
L
Leo Chen 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
    VLOG(4) << "Use random engine from generator";
    return DefaultCPUGenerator()->GetCPUEngine();
  } else {
    // NOTE(zhiqiu): creating an engine instance everytime instead of using
    // OpDefaultCPUEngine(), this is the legacy behavior of random operators.
    // The benefit is that when runing PE with fixed-seed in multiple thrads,
    // each thread has their own engine, and doesn't affect each other.
    //
    // And we need to measure the determinacy of Generator in PE.
    auto engine = std::make_shared<std::mt19937_64>();
    static std::mutex mu_;
    {
      std::lock_guard<std::mutex> lock(mu_);
      engine->seed(seed);
    }
    return engine;
  }
}
Y
yaoxuefeng 已提交
127

128
phi::Generator::GeneratorState Generator::GetState() {
L
Leo Chen 已提交
129 130 131
  std::lock_guard<std::mutex> lock(this->mu_);
  state_.cpu_engine = *engine_;
  return this->state_;
Y
yaoxuefeng 已提交
132 133
}

134
void Generator::SetState(const phi::Generator::GeneratorState& state) {
L
Leo Chen 已提交
135 136 137
  std::lock_guard<std::mutex> lock(this->mu_);
  this->state_ = state;
  this->engine_ = std::make_shared<std::mt19937_64>(state.cpu_engine);
Y
yaoxuefeng 已提交
138 139 140
}

uint64_t Generator::GetCurrentSeed() {
L
Leo Chen 已提交
141 142
  std::lock_guard<std::mutex> lock(this->mu_);
  return this->state_.current_seed;
Y
yaoxuefeng 已提交
143 144 145
}

uint64_t Generator::Seed() {
L
Leo Chen 已提交
146
  std::lock_guard<std::mutex> lock(this->mu_);
Y
yaoxuefeng 已提交
147 148 149
  uint64_t seed;
  std::random_device de;
  seed = ((((uint64_t)de()) << 32) + de()) & 0x1FFFFFFFFFFFFF;
L
Leo Chen 已提交
150
  this->state_.current_seed = seed;
Y
yaoxuefeng 已提交
151
  std::seed_seq seq({seed});
L
Leo Chen 已提交
152
  this->engine_->seed(seq);
Y
yaoxuefeng 已提交
153

L
Leo Chen 已提交
154
  return this->state_.current_seed;
Y
yaoxuefeng 已提交
155 156 157
}

void Generator::SetCurrentSeed(uint64_t seed) {
L
Leo Chen 已提交
158 159
  std::lock_guard<std::mutex> lock(this->mu_);
  this->state_.current_seed = seed;
Y
yaoxuefeng 已提交
160
  this->state_.thread_offset = 0;
Y
yaoxuefeng 已提交
161
  std::seed_seq seq({seed});
L
Leo Chen 已提交
162
  this->engine_->seed(seq);
Y
yaoxuefeng 已提交
163 164
}

L
Leo Chen 已提交
165 166 167
std::shared_ptr<std::mt19937_64> Generator::GetCPUEngine() {
  std::lock_guard<std::mutex> lock(this->mu_);
  return this->engine_;
Y
yaoxuefeng 已提交
168 169
}

L
Leo Chen 已提交
170 171 172
void Generator::SetCPUEngine(std::shared_ptr<std::mt19937_64> engine) {
  std::lock_guard<std::mutex> lock(this->mu_);
  this->engine_ = engine;
Y
yaoxuefeng 已提交
173 174 175
}

uint64_t Generator::Random64() {
L
Leo Chen 已提交
176 177 178 179 180
  std::lock_guard<std::mutex> lock(this->mu_);
  auto engine = this->engine_;
  return (*engine)();
}

Y
yaoxuefeng 已提交
181 182
std::pair<uint64_t, uint64_t> Generator::IncrementOffset(
    uint64_t increament_offset) {
183
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Y
yaoxuefeng 已提交
184
  std::lock_guard<std::mutex> lock(this->mu_);
185
  uint64_t cur_offset = this->state_.thread_offset;
Y
yaoxuefeng 已提交
186
  this->state_.thread_offset += increament_offset;
187
  return std::make_pair(this->state_.current_seed, cur_offset);
Y
yaoxuefeng 已提交
188 189 190 191 192 193
#else
  PADDLE_THROW(platform::errors::PermissionDenied(
      "Increment Offset only support in CUDA place"));
#endif
}

Y
yaoxuefeng 已提交
194 195
}  // namespace framework
}  // namespace paddle