未验证 提交 7f3e6ca5 编写于 作者: Y yaoxuefeng 提交者: GitHub

add cuda generator (#26786)

上级 c4846196
......@@ -272,7 +272,7 @@ cc_test(op_compatible_info_test SRCS op_compatible_info_test.cc DEPS op_compatib
cc_library(save_load_util SRCS save_load_util DEPS tensor scope layer)
cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer)
cc_library(generator SRCS generator.cc)
cc_library(generator SRCS generator.cc DEPS enforce place)
# Get the current working branch
execute_process(
......
......@@ -21,10 +21,46 @@ limitations under the License. */
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
const std::shared_ptr<Generator>& GetDefaultCUDAGenerator(int64_t device_id) {
#ifdef PADDLE_WITH_CUDA
static int64_t num_cuda_devices = -1;
static std::once_flag num_devices_init_flag;
static std::deque<std::once_flag> cuda_device_flags;
static std::vector<std::shared_ptr<Generator>> default_cuda_generators;
std::call_once(num_devices_init_flag, []() {
num_cuda_devices = paddle::platform::GetCUDADeviceCount();
cuda_device_flags.resize(num_cuda_devices);
default_cuda_generators.resize(num_cuda_devices);
});
if (device_id < 0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"cuda device id shoule be greater than 0"));
}
std::call_once(cuda_device_flags[device_id], [device_id]() {
default_cuda_generators[device_id] =
std::make_shared<Generator>(GetRandomSeed(), device_id);
VLOG(4) << "initial seed: "
<< default_cuda_generators[device_id]->GetCurrentSeed();
});
return default_cuda_generators[device_id];
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"getDefaultCUDAGenerator only support in CUDA place"));
#endif
}
const std::shared_ptr<Generator>& DefaultCPUGenerator() {
static auto default_cpu_generator =
std::make_shared<Generator>(GetRandomSeed());
......@@ -103,6 +139,7 @@ uint64_t Generator::Seed() {
void Generator::SetCurrentSeed(uint64_t seed) {
std::lock_guard<std::mutex> lock(this->mu_);
this->state_.current_seed = seed;
this->state_.thread_offset = 0;
std::seed_seq seq({seed});
this->engine_->seed(seq);
}
......@@ -123,6 +160,22 @@ uint64_t Generator::Random64() {
return (*engine)();
}
std::pair<uint64_t, uint64_t> Generator::IncrementOffset(
uint64_t increament_offset) {
uint64_t cur_offset = this->state_.thread_offset;
#ifdef PADDLE_WITH_CUDA
std::lock_guard<std::mutex> lock(this->mu_);
this->state_.thread_offset += increament_offset;
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Increment Offset only support in CUDA place"));
#endif
return std::make_pair(static_cast<int>(this->state_.current_seed),
cur_offset);
}
void Generator::SetIsInitPy(bool is_init_py) {
this->is_init_py_ = is_init_py;
VLOG(4) << "SetIsInitPy:" << this->is_init_py_;
......
......@@ -38,6 +38,7 @@ static uint64_t GetRandomSeed() {
struct GeneratorState {
int64_t device = -1;
uint64_t current_seed = 34342423252;
uint64_t thread_offset = 0;
std::mt19937_64 cpu_engine;
};
......@@ -49,6 +50,7 @@ struct Generator {
this->state_.cpu_engine = *engine;
this->state_.device = -1;
this->state_.current_seed = seed;
this->state_.thread_offset = 0;
this->engine_ = engine;
VLOG(4) << "initial seed: " << this->state_.current_seed
<< ", cpu engine: " << &this->state_.cpu_engine;
......@@ -59,11 +61,25 @@ struct Generator {
this->state_.cpu_engine = *engine;
this->state_.device = -1;
this->state_.current_seed = seed;
this->state_.thread_offset = 0;
this->engine_ = engine;
VLOG(4) << "initial seed: " << this->state_.current_seed
<< ", cpu engine: " << &this->state_.cpu_engine;
this->is_init_py_ = true; // TODO(zhiqiu): remove it in future
}
Generator(uint64_t seed, uint64_t device_id) {
std::seed_seq seq({seed});
auto engine = std::make_shared<std::mt19937_64>(seq);
this->state_.cpu_engine = *engine;
this->state_.device = device_id;
this->state_.current_seed = seed;
this->state_.thread_offset = 0;
this->engine_ = engine;
VLOG(4) << "initial seed: " << this->state_.current_seed
<< ", cpu engine: " << &this->state_.cpu_engine;
this->is_init_py_ = false; // TODO(zhiqiu): remove it in future
}
Generator(const Generator& other) = delete;
// get random state
......@@ -83,8 +99,11 @@ struct Generator {
uint64_t Random64();
std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t increament_offset);
void SetIsInitPy(bool);
bool GetIsInitPy() const;
uint64_t get_device_id() { return this->state_.device; }
private:
GeneratorState state_;
......@@ -105,5 +124,8 @@ std::shared_ptr<std::mt19937_64> OpDefaultCPUEngine();
std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t);
const std::shared_ptr<Generator>& GetDefaultCUDAGenerator(
int64_t device_id = -1);
} // namespace framework
} // namespace paddle
......@@ -16,7 +16,6 @@ limitations under the License. */
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/bernoulli_op.h"
......
......@@ -96,6 +96,42 @@ __global__ void RandomGeneratorWithSeed(const size_t n, const int* seed,
}
}
template <typename T, typename MaskType>
__global__ void RandomGeneratorWithGenerator(const size_t n, uint64_t seed,
const float dropout_prob,
const T* src, MaskType* mask_data,
T* dst, bool is_upscale_in_train,
uint64_t increment) {
curandStatePhilox4_32_10_t state;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int step_size = 0;
MaskType mask;
T dest;
for (; idx < n; idx += blockDim.x * gridDim.x) {
T s = src[idx];
if (step_size == 0) {
curand_init(seed, idx, increment, &state);
step_size = blockDim.x * gridDim.x;
} else {
curand_init(seed, idx, increment, &state);
}
if (curand_uniform(&state) < dropout_prob) {
mask = 0;
dest = 0;
} else {
mask = 1;
if (is_upscale_in_train) {
dest = s / static_cast<T>(1.0f - dropout_prob);
} else {
dest = s;
}
}
mask_data[idx] = mask;
dst[idx] = dest;
}
}
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
......@@ -150,6 +186,17 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
}
int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace())
.GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && (!context.Attr<bool>("fix_seed"))) {
auto seed_offset = gen_cuda->IncrementOffset(1);
RandomGeneratorWithGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
size, seed_offset.first, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, seed_offset.second);
return;
}
RandomGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train);
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fill_constant_op.h"
......@@ -24,15 +25,20 @@ template <typename T>
struct GaussianGenerator {
T mean_, std_;
unsigned int seed_;
unsigned int offset_ = 0;
__host__ __device__ GaussianGenerator(T mean, T std, int seed)
: mean_(mean), std_(std), seed_(seed) {}
__host__ __device__ GaussianGenerator(T mean, T std, int seed, int offset)
: mean_(mean), std_(std), seed_(seed), offset_(offset) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::normal_distribution<T> dist(mean_, std_);
rng.discard(n);
unsigned int new_n = n + offset_;
rng.discard(new_n);
return dist(rng);
}
};
......@@ -43,9 +49,11 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = true;
}
T mean = static_cast<T>(context.Attr<float>("mean"));
T std = static_cast<T>(context.Attr<float>("std"));
......@@ -56,9 +64,27 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
T* data = tensor->mutable_data<T>(context.GetPlace());
int64_t size = tensor->numel();
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed));
int device_id =
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) {
auto seed_offset = gen_cuda->IncrementOffset(1);
int offset_step = 100;
// NOTE(xuefeng): Currently, we let offset step fixed to avoid
// unexpected results which may cause ut fail.
// we will fix this in future.
int gen_offset = offset_step * seed_offset.second;
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed_offset.first, gen_offset));
} else {
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed));
}
}
};
......@@ -69,17 +95,37 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = true;
}
T mean = static_cast<T>(context.Attr<float>("mean"));
T std = static_cast<T>(context.Attr<float>("std"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int64_t size = tensor->numel();
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed));
int device_id =
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) {
auto seed_offset = gen_cuda->IncrementOffset(1);
int offset_step = 100;
// NOTE(xuefeng): Currently, we let offset step fixed to avoid
// unexpected results which may cause ut fail.
// we will fix this in future.
int gen_offset = offset_step * seed_offset.second;
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed_offset.first,
seed_offset.second));
} else {
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed));
}
}
};
} // namespace operators
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/uniform_random_op.h"
......@@ -49,15 +50,23 @@ class GPURandintKernel : public framework::OpKernel<T> {
int64_t size = out->numel();
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
/*
std::minstd_rand engine;
if (seed == 0) {
std::random_device rd;
seed = rd();
}
engine.seed(seed);
*/
std::uniform_int_distribution<> dist(context.Attr<int>("low"),
context.Attr<int>("high") - 1);
for (int64_t i = 0; i < size; ++i) data[i] = dist(engine);
auto engine = framework::GetCPURandomEngine(seed);
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(*engine);
}
if (platform::is_gpu_place(context.GetPlace())) {
// Copy tensor to out
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <thrust/random.h>
#include <thrust/transform.h>
#include <limits>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
......@@ -46,6 +47,37 @@ struct TruncatedNormal {
}
};
template <typename T>
struct TruncatedNormalOffset {
T mean, std;
T a_normal_cdf;
T b_normal_cdf;
unsigned int seed;
T numeric_min;
int offset_;
__host__ __device__ TruncatedNormalOffset(T mean, T std, T numeric_min,
int seed, int offset)
: mean(mean),
std(std),
seed(seed),
numeric_min(numeric_min),
offset_(offset) {
a_normal_cdf = (1.0 + erff(-2.0 / sqrtf(2.0))) / 2.0;
b_normal_cdf = (1.0 + erff(2.0 / sqrtf(2.0))) / 2.0;
}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed);
thrust::uniform_real_distribution<T> dist(numeric_min, 1);
rng.discard(n);
T value = dist(rng);
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
return std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean;
}
};
template <typename T>
class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
public:
......@@ -54,14 +86,35 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = true;
}
T mean = static_cast<T>(context.Attr<float>("mean"));
T std = static_cast<T>(context.Attr<float>("std"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int64_t size = tensor->numel();
int device_id =
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) {
auto seed_offset = gen_cuda->IncrementOffset(1);
int offset_step = 100;
// NOTE(xuefeng): Currently, we let offset step fixed to avoid
// unexpected results which may cause ut fail.
// we will fix this in future.
int gen_offset = offset_step * seed_offset.second;
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
TruncatedNormalOffset<T>(mean, std, std::numeric_limits<T>::min(),
seed_offset.first, seed_offset.second));
}
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
......
......@@ -51,6 +51,39 @@ struct UniformGenerator {
}
};
template <typename T>
struct UniformGeneratorOffset {
T min_, max_;
unsigned int seed_;
T diag_val_;
unsigned int diag_num_;
unsigned int diag_step_;
int offset_;
__host__ __device__ UniformGeneratorOffset(T min, T max, int seed,
int diag_num, int diag_step,
T diag_val, int offset)
: min_(min),
max_(max),
seed_(seed),
diag_num_(diag_num),
diag_step_(diag_step),
diag_val_(diag_val),
offset_(offset) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n + offset_);
T out = dist(rng);
unsigned int remainder = n % (diag_step_ + 1);
if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
out = diag_val_;
}
return out;
}
};
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
......@@ -89,10 +122,11 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
}
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = true;
}
T min = static_cast<T>(context.Attr<float>("min"));
......@@ -104,10 +138,27 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
T diag_val = static_cast<T>(context.Attr<float>("diag_val"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int64_t size = tensor->numel();
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val));
int device_id =
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) {
auto seed_offset = gen_cuda->IncrementOffset(1);
int offset_step = 100;
// NOTE(xuefeng): Currently, we let offset step fixed to avoid
// unexpected results which may cause ut fail.
// we will fix this in future.
int gen_offset = offset_step * seed_offset.second;
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
UniformGeneratorOffset<T>(min, max, seed_offset.first, diag_num,
diag_step, diag_val, gen_offset));
} else {
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val));
}
}
};
......
......@@ -59,6 +59,7 @@ void BindGenerator(py::module* m_ptr) {
.def_property("_is_init_py", &framework::Generator::GetIsInitPy,
&framework::Generator::SetIsInitPy);
m.def("default_cpu_generator", &framework::DefaultCPUGenerator);
} // end Generator
} // end namespace pybind
m.def("default_cuda_generator", &framework::GetDefaultCUDAGenerator);
}
} // namespace pybind
} // namespace paddle
......@@ -217,6 +217,8 @@ from .tensor.search import index_select #DEFINE_ALIAS
from .tensor.search import nonzero #DEFINE_ALIAS
from .tensor.search import sort #DEFINE_ALIAS
from .framework.random import manual_seed #DEFINE_ALIAS
from .framework.random import get_cuda_rng_state #DEFINE_ALIAS
from .framework.random import set_cuda_rng_state #DEFINE_ALIAS
from .framework import Variable #DEFINE_ALIAS
from .framework import ParamAttr #DEFINE_ALIAS
from .framework import create_global_var #DEFINE_ALIAS
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test cloud role maker."""
from __future__ import print_function
import os
import unittest
import paddle.fluid.generator as generator
import time # temp for debug
import paddle.fluid as fluid
import numpy as np
import paddle
import paddle.fluid.core as core
class TestGeneratorSeed(unittest.TestCase):
"""
Test cases for cpu generator seed.
"""
def test_gen_dropout_dygraph(self):
gen = paddle.manual_seed(12343)
fluid.enable_dygraph()
gen.manual_seed(111111111)
st = paddle.get_cuda_rng_state()
x = fluid.layers.uniform_random(
[2, 10], dtype="float32", min=0.0, max=1.0)
x_again = fluid.layers.uniform_random(
[2, 10], dtype="float32", min=0.0, max=1.0)
x_third = fluid.layers.uniform_random(
[2, 10], dtype="float32", min=0.0, max=1.0)
print("x: {}".format(x.numpy()))
print("x_again: {}".format(x_again.numpy()))
x = x + x_again + x_third
y = fluid.layers.dropout(x, 0.5)
paddle.set_cuda_rng_state(st)
x1 = fluid.layers.uniform_random(
[2, 10], dtype="float32", min=0.0, max=1.0)
x1_again = fluid.layers.uniform_random(
[2, 10], dtype="float32", min=0.0, max=1.0)
x1_third = fluid.layers.uniform_random(
[2, 10], dtype="float32", min=0.0, max=1.0)
x1 = x1 + x1_again + x1_third
y1 = fluid.layers.dropout(x1, 0.5)
y_np = y.numpy()
y1_np = y1.numpy()
if core.is_compiled_with_cuda():
print(">>>>>>> dropout dygraph >>>>>>>")
self.assertTrue(np.allclose(y_np, y1_np))
def test_generator_gaussian_random_dygraph(self):
"""Test Generator seed."""
fluid.enable_dygraph()
paddle.manual_seed(12312321111)
x = fluid.layers.gaussian_random([120], dtype="float32")
st1 = paddle.get_cuda_rng_state()
x1 = fluid.layers.gaussian_random([120], dtype="float32")
paddle.set_cuda_rng_state(st1)
x2 = fluid.layers.gaussian_random([120], dtype="float32")
paddle.manual_seed(12312321111)
x3 = fluid.layers.gaussian_random([120], dtype="float32")
x_np = x.numpy()
x1_np = x1.numpy()
x2_np = x2.numpy()
x3_np = x3.numpy()
if core.is_compiled_with_cuda():
print(">>>>>>> gaussian random dygraph >>>>>>>")
self.assertTrue(np.allclose(x1_np, x2_np))
self.assertTrue(np.allclose(x_np, x3_np))
def test_generator_randint_dygraph(self):
"""Test Generator seed."""
fluid.enable_dygraph()
gen = paddle.manual_seed(12312321111)
x = paddle.randint(low=10, shape=[10], dtype="int32")
st1 = gen.get_state()
x1 = paddle.randint(low=10, shape=[10], dtype="int32")
gen.set_state(st1)
x2 = paddle.randint(low=10, shape=[10], dtype="int32")
paddle.manual_seed(12312321111)
x3 = paddle.randint(low=10, shape=[10], dtype="int32")
x_np = x.numpy()
x1_np = x1.numpy()
x2_np = x2.numpy()
x3_np = x3.numpy()
if core.is_compiled_with_cuda():
print(">>>>>>> randint dygraph >>>>>>>")
self.assertTrue(np.allclose(x1_np, x2_np))
self.assertTrue(np.allclose(x_np, x3_np))
def test_gen_TruncatedNormal_initializer(self):
fluid.disable_dygraph()
gen = paddle.manual_seed(123123143)
cur_state = paddle.get_cuda_rng_state()
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
# example 1:
# attr shape is a list which doesn't contain tensor Variable.
x = fluid.layers.uniform_random(shape=[2, 10])
result_1 = fluid.layers.fc(
input=x,
size=10,
param_attr=fluid.initializer.TruncatedNormal(
loc=0.0, scale=2.0))
result_2 = fluid.layers.fc(
input=x,
size=10,
param_attr=fluid.initializer.TruncatedNormal(
loc=0.0, scale=2.0))
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program)
out1 = exe.run(train_program,
feed={},
fetch_list=[result_1, result_2])
paddle.manual_seed(123123143)
with fluid.program_guard(train_program, startup_program):
exe.run(startup_program)
out2 = exe.run(train_program,
feed={},
fetch_list=[result_1, result_2])
out1_res1 = np.array(out1[0])
out1_res2 = np.array(out1[1])
out2_res1 = np.array(out2[0])
out2_res2 = np.array(out2[1])
if core.is_compiled_with_cuda():
print(">>>>>>> truncated normal static >>>>>>>")
self.assertTrue(np.allclose(out1_res1, out2_res1))
self.assertTrue(np.allclose(out1_res2, out2_res2))
self.assertTrue(not np.allclose(out1_res2, out1_res1))
if __name__ == "__main__":
unittest.main()
......@@ -16,7 +16,7 @@
import paddle.fluid as fluid
from paddle.fluid import core
__all__ = ['manual_seed']
__all__ = ['manual_seed', 'get_cuda_rng_state', 'set_cuda_rng_state']
def manual_seed(seed):
......@@ -42,10 +42,69 @@ def manual_seed(seed):
seed = int(seed)
if core.is_compiled_with_cuda():
for i in range(core.get_cuda_device_count()):
core.default_cuda_generator(i)._is_init_py = True
core.default_cuda_generator(i).manual_seed(seed)
core.default_cpu_generator()._is_init_py = True
return core.default_cpu_generator().manual_seed(seed)
def get_cuda_rng_state():
"""
Get random state of cuda generators.
Args:
None
Returns:
GeneratorState: object.
Examples:
.. code-block:: python
import paddle
sts = paddle.get_cuda_rng_state()
"""
state_list = []
if core.is_compiled_with_cuda():
for i in range(core.get_cuda_device_count()):
state_list.append(core.default_cuda_generator(i).get_state())
return state_list
def set_cuda_rng_state(state_list):
"""
Sets generator state for all cuda generators
Args:
state_list(list): The cuda states to set back to cuda generators. state_list is obtained from get_cuda_rng_state().
Returns:
None
Examples:
.. code-block:: python
import paddle
sts = paddle.get_cuda_rng_state()
paddle.set_cuda_rng_state(sts)
"""
if core.is_compiled_with_cuda():
if not len(state_list) == core.get_cuda_device_count():
raise ValueError(
"Length of cuda state list shoule be equal to the cuda device count"
)
for i in range(core.get_cuda_device_count()):
core.default_cuda_generator(i).set_state(state_list[i])
def _manual_program_seed(seed):
"""
Sets global seed for generating random numbers.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册