diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 9d5c0cc7048f7db539c090d28c6184ac6d72d75a..bb5e2e1369a8478b500572106f9d11dff12e0189 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -272,7 +272,7 @@ cc_test(op_compatible_info_test SRCS op_compatible_info_test.cc DEPS op_compatib cc_library(save_load_util SRCS save_load_util DEPS tensor scope layer) cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer) -cc_library(generator SRCS generator.cc) +cc_library(generator SRCS generator.cc DEPS enforce place) # Get the current working branch execute_process( diff --git a/paddle/fluid/framework/generator.cc b/paddle/fluid/framework/generator.cc index 9bde9e20b19a0b14ce4489b91d9ab3d5273f7f9a..d51e97d98e902a87cd2a44d2019e93e8dfc30fc8 100644 --- a/paddle/fluid/framework/generator.cc +++ b/paddle/fluid/framework/generator.cc @@ -21,10 +21,46 @@ limitations under the License. */ #include #include #include +#include + +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/gpu_info.h" +#include "paddle/fluid/platform/place.h" namespace paddle { namespace framework { +const std::shared_ptr& GetDefaultCUDAGenerator(int64_t device_id) { +#ifdef PADDLE_WITH_CUDA + + static int64_t num_cuda_devices = -1; + static std::once_flag num_devices_init_flag; + static std::deque cuda_device_flags; + static std::vector> default_cuda_generators; + + std::call_once(num_devices_init_flag, []() { + num_cuda_devices = paddle::platform::GetCUDADeviceCount(); + cuda_device_flags.resize(num_cuda_devices); + default_cuda_generators.resize(num_cuda_devices); + }); + if (device_id < 0) { + PADDLE_THROW(platform::errors::InvalidArgument( + "cuda device id shoule be greater than 0")); + } + + std::call_once(cuda_device_flags[device_id], [device_id]() { + default_cuda_generators[device_id] = + std::make_shared(GetRandomSeed(), device_id); + VLOG(4) << "initial seed: " + << default_cuda_generators[device_id]->GetCurrentSeed(); + }); + return default_cuda_generators[device_id]; +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "getDefaultCUDAGenerator only support in CUDA place")); +#endif +} + const std::shared_ptr& DefaultCPUGenerator() { static auto default_cpu_generator = std::make_shared(GetRandomSeed()); @@ -103,6 +139,7 @@ uint64_t Generator::Seed() { void Generator::SetCurrentSeed(uint64_t seed) { std::lock_guard lock(this->mu_); this->state_.current_seed = seed; + this->state_.thread_offset = 0; std::seed_seq seq({seed}); this->engine_->seed(seq); } @@ -123,6 +160,22 @@ uint64_t Generator::Random64() { return (*engine)(); } +std::pair Generator::IncrementOffset( + uint64_t increament_offset) { + uint64_t cur_offset = this->state_.thread_offset; +#ifdef PADDLE_WITH_CUDA + std::lock_guard lock(this->mu_); + + this->state_.thread_offset += increament_offset; + +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Increment Offset only support in CUDA place")); +#endif + return std::make_pair(static_cast(this->state_.current_seed), + cur_offset); +} + void Generator::SetIsInitPy(bool is_init_py) { this->is_init_py_ = is_init_py; VLOG(4) << "SetIsInitPy:" << this->is_init_py_; diff --git a/paddle/fluid/framework/generator.h b/paddle/fluid/framework/generator.h index 82b35f7ad550e770e8d10457ddf6cdf8e6fbd709..a279c2e4e1458293b6579b7b7cb2111e440e5d5e 100644 --- a/paddle/fluid/framework/generator.h +++ b/paddle/fluid/framework/generator.h @@ -38,6 +38,7 @@ static uint64_t GetRandomSeed() { struct GeneratorState { int64_t device = -1; uint64_t current_seed = 34342423252; + uint64_t thread_offset = 0; std::mt19937_64 cpu_engine; }; @@ -49,6 +50,7 @@ struct Generator { this->state_.cpu_engine = *engine; this->state_.device = -1; this->state_.current_seed = seed; + this->state_.thread_offset = 0; this->engine_ = engine; VLOG(4) << "initial seed: " << this->state_.current_seed << ", cpu engine: " << &this->state_.cpu_engine; @@ -59,11 +61,25 @@ struct Generator { this->state_.cpu_engine = *engine; this->state_.device = -1; this->state_.current_seed = seed; + this->state_.thread_offset = 0; this->engine_ = engine; VLOG(4) << "initial seed: " << this->state_.current_seed << ", cpu engine: " << &this->state_.cpu_engine; this->is_init_py_ = true; // TODO(zhiqiu): remove it in future } + Generator(uint64_t seed, uint64_t device_id) { + std::seed_seq seq({seed}); + auto engine = std::make_shared(seq); + this->state_.cpu_engine = *engine; + this->state_.device = device_id; + this->state_.current_seed = seed; + this->state_.thread_offset = 0; + this->engine_ = engine; + VLOG(4) << "initial seed: " << this->state_.current_seed + << ", cpu engine: " << &this->state_.cpu_engine; + this->is_init_py_ = false; // TODO(zhiqiu): remove it in future + } + Generator(const Generator& other) = delete; // get random state @@ -83,8 +99,11 @@ struct Generator { uint64_t Random64(); + std::pair IncrementOffset(uint64_t increament_offset); + void SetIsInitPy(bool); bool GetIsInitPy() const; + uint64_t get_device_id() { return this->state_.device; } private: GeneratorState state_; @@ -105,5 +124,8 @@ std::shared_ptr OpDefaultCPUEngine(); std::shared_ptr GetCPURandomEngine(uint64_t); +const std::shared_ptr& GetDefaultCUDAGenerator( + int64_t device_id = -1); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/bernoulli_op.cu b/paddle/fluid/operators/bernoulli_op.cu index f665d2dd0e991847de2ad35bf6b18741fb3a6e26..6565f5a9a2176972e9e5085c6646097e8349f259 100644 --- a/paddle/fluid/operators/bernoulli_op.cu +++ b/paddle/fluid/operators/bernoulli_op.cu @@ -16,7 +16,6 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/bernoulli_op.h" diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index 4d5e4c4f600314d307125f9b2031026b6aa94f10..49ad67bbca353acc4a79c9e8912d7ae5a70c0021 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -96,6 +96,42 @@ __global__ void RandomGeneratorWithSeed(const size_t n, const int* seed, } } +template +__global__ void RandomGeneratorWithGenerator(const size_t n, uint64_t seed, + const float dropout_prob, + const T* src, MaskType* mask_data, + T* dst, bool is_upscale_in_train, + uint64_t increment) { + curandStatePhilox4_32_10_t state; + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int step_size = 0; + + MaskType mask; + T dest; + for (; idx < n; idx += blockDim.x * gridDim.x) { + T s = src[idx]; + if (step_size == 0) { + curand_init(seed, idx, increment, &state); + step_size = blockDim.x * gridDim.x; + } else { + curand_init(seed, idx, increment, &state); + } + if (curand_uniform(&state) < dropout_prob) { + mask = 0; + dest = 0; + } else { + mask = 1; + if (is_upscale_in_train) { + dest = s / static_cast(1.0f - dropout_prob); + } else { + dest = s; + } + } + mask_data[idx] = mask; + dst[idx] = dest; + } +} + // It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. @@ -150,6 +186,17 @@ class GPUDropoutKernel : public framework::OpKernel { context.Attr("fix_seed") ? context.Attr("seed") : rnd(); } + int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()) + .GetDeviceId(); + auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + if (gen_cuda->GetIsInitPy() && (!context.Attr("fix_seed"))) { + auto seed_offset = gen_cuda->IncrementOffset(1); + RandomGeneratorWithGenerator<<>>( + size, seed_offset.first, dropout_prob, x_data, mask_data, y_data, + upscale_in_train, seed_offset.second); + return; + } + RandomGenerator<<>>( size, seed_data, dropout_prob, x_data, mask_data, y_data, upscale_in_train); diff --git a/paddle/fluid/operators/gaussian_random_op.cu b/paddle/fluid/operators/gaussian_random_op.cu index c144481f8dedc9317f7657a22ce82e56022d5b89..eca42ac581ab982d514dd5eb10ea297a0283478e 100644 --- a/paddle/fluid/operators/gaussian_random_op.cu +++ b/paddle/fluid/operators/gaussian_random_op.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include #include +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/fill_constant_op.h" @@ -24,15 +25,20 @@ template struct GaussianGenerator { T mean_, std_; unsigned int seed_; + unsigned int offset_ = 0; __host__ __device__ GaussianGenerator(T mean, T std, int seed) : mean_(mean), std_(std), seed_(seed) {} + __host__ __device__ GaussianGenerator(T mean, T std, int seed, int offset) + : mean_(mean), std_(std), seed_(seed), offset_(offset) {} + __host__ __device__ T operator()(const unsigned int n) const { thrust::minstd_rand rng; rng.seed(seed_); thrust::normal_distribution dist(mean_, std_); - rng.discard(n); + unsigned int new_n = n + offset_; + rng.discard(new_n); return dist(rng); } }; @@ -43,9 +49,11 @@ class GPUGaussianRandomKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* tensor = context.Output("Out"); unsigned int seed = static_cast(context.Attr("seed")); + bool seed_flag = false; if (seed == 0) { std::random_device rd; seed = rd(); + seed_flag = true; } T mean = static_cast(context.Attr("mean")); T std = static_cast(context.Attr("std")); @@ -56,9 +64,27 @@ class GPUGaussianRandomKernel : public framework::OpKernel { T* data = tensor->mutable_data(context.GetPlace()); int64_t size = tensor->numel(); - thrust::transform(index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(data), - GaussianGenerator(mean, std, seed)); + + int device_id = + BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId(); + auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + + if (gen_cuda->GetIsInitPy() && seed_flag) { + auto seed_offset = gen_cuda->IncrementOffset(1); + int offset_step = 100; + // NOTE(xuefeng): Currently, we let offset step fixed to avoid + // unexpected results which may cause ut fail. + // we will fix this in future. + int gen_offset = offset_step * seed_offset.second; + thrust::transform( + index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(data), + GaussianGenerator(mean, std, seed_offset.first, gen_offset)); + } else { + thrust::transform(index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(data), + GaussianGenerator(mean, std, seed)); + } } }; @@ -69,17 +95,37 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel { auto* tensor = context.Output("Out"); T* data = tensor->mutable_data(context.GetPlace()); unsigned int seed = static_cast(context.Attr("seed")); + bool seed_flag = false; if (seed == 0) { std::random_device rd; seed = rd(); + seed_flag = true; } T mean = static_cast(context.Attr("mean")); T std = static_cast(context.Attr("std")); thrust::counting_iterator index_sequence_begin(0); int64_t size = tensor->numel(); - thrust::transform(index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(data), - GaussianGenerator(mean, std, seed)); + + int device_id = + BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId(); + auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + + if (gen_cuda->GetIsInitPy() && seed_flag) { + auto seed_offset = gen_cuda->IncrementOffset(1); + int offset_step = 100; + // NOTE(xuefeng): Currently, we let offset step fixed to avoid + // unexpected results which may cause ut fail. + // we will fix this in future. + int gen_offset = offset_step * seed_offset.second; + thrust::transform(index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(data), + GaussianGenerator(mean, std, seed_offset.first, + seed_offset.second)); + } else { + thrust::transform(index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(data), + GaussianGenerator(mean, std, seed)); + } } }; } // namespace operators diff --git a/paddle/fluid/operators/randint_op.cu b/paddle/fluid/operators/randint_op.cu index a07a92621e6b3726be518df6abcec58257a91489..40e390b0b87246bbaa8474262df8ba5576297385 100644 --- a/paddle/fluid/operators/randint_op.cu +++ b/paddle/fluid/operators/randint_op.cu @@ -13,6 +13,7 @@ // limitations under the License. #include #include +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/uniform_random_op.h" @@ -49,15 +50,23 @@ class GPURandintKernel : public framework::OpKernel { int64_t size = out->numel(); unsigned int seed = static_cast(context.Attr("seed")); + + /* std::minstd_rand engine; if (seed == 0) { std::random_device rd; seed = rd(); } engine.seed(seed); + */ + std::uniform_int_distribution<> dist(context.Attr("low"), context.Attr("high") - 1); - for (int64_t i = 0; i < size; ++i) data[i] = dist(engine); + auto engine = framework::GetCPURandomEngine(seed); + + for (int64_t i = 0; i < size; ++i) { + data[i] = dist(*engine); + } if (platform::is_gpu_place(context.GetPlace())) { // Copy tensor to out diff --git a/paddle/fluid/operators/truncated_gaussian_random_op.cu b/paddle/fluid/operators/truncated_gaussian_random_op.cu index 5a3510babe4d57b9e80f0e7898df98033834ca15..ef1e40b46d0be14a2a98fbfd0167a4eaa816d577 100644 --- a/paddle/fluid/operators/truncated_gaussian_random_op.cu +++ b/paddle/fluid/operators/truncated_gaussian_random_op.cu @@ -15,6 +15,7 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" @@ -46,6 +47,37 @@ struct TruncatedNormal { } }; +template +struct TruncatedNormalOffset { + T mean, std; + T a_normal_cdf; + T b_normal_cdf; + unsigned int seed; + T numeric_min; + int offset_; + + __host__ __device__ TruncatedNormalOffset(T mean, T std, T numeric_min, + int seed, int offset) + : mean(mean), + std(std), + seed(seed), + numeric_min(numeric_min), + offset_(offset) { + a_normal_cdf = (1.0 + erff(-2.0 / sqrtf(2.0))) / 2.0; + b_normal_cdf = (1.0 + erff(2.0 / sqrtf(2.0))) / 2.0; + } + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed); + thrust::uniform_real_distribution dist(numeric_min, 1); + rng.discard(n); + T value = dist(rng); + auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value; + return std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean; + } +}; + template class GPUTruncatedGaussianRandomKernel : public framework::OpKernel { public: @@ -54,14 +86,35 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel { T* data = tensor->mutable_data(context.GetPlace()); unsigned int seed = static_cast(context.Attr("seed")); + bool seed_flag = false; if (seed == 0) { std::random_device rd; seed = rd(); + seed_flag = true; } T mean = static_cast(context.Attr("mean")); T std = static_cast(context.Attr("std")); thrust::counting_iterator index_sequence_begin(0); int64_t size = tensor->numel(); + + int device_id = + BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId(); + auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + + if (gen_cuda->GetIsInitPy() && seed_flag) { + auto seed_offset = gen_cuda->IncrementOffset(1); + int offset_step = 100; + // NOTE(xuefeng): Currently, we let offset step fixed to avoid + // unexpected results which may cause ut fail. + // we will fix this in future. + int gen_offset = offset_step * seed_offset.second; + thrust::transform( + index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(data), + TruncatedNormalOffset(mean, std, std::numeric_limits::min(), + seed_offset.first, seed_offset.second)); + } + thrust::transform( index_sequence_begin, index_sequence_begin + size, thrust::device_ptr(data), diff --git a/paddle/fluid/operators/uniform_random_op.cu b/paddle/fluid/operators/uniform_random_op.cu index 4df1e0ffeb97564803f452114d52ab03d0464f8a..43a25a098b0c1d4fc9f6724d22bf8a98ee3fd745 100644 --- a/paddle/fluid/operators/uniform_random_op.cu +++ b/paddle/fluid/operators/uniform_random_op.cu @@ -51,6 +51,39 @@ struct UniformGenerator { } }; +template +struct UniformGeneratorOffset { + T min_, max_; + unsigned int seed_; + T diag_val_; + unsigned int diag_num_; + unsigned int diag_step_; + int offset_; + __host__ __device__ UniformGeneratorOffset(T min, T max, int seed, + int diag_num, int diag_step, + T diag_val, int offset) + : min_(min), + max_(max), + seed_(seed), + diag_num_(diag_num), + diag_step_(diag_step), + diag_val_(diag_val), + offset_(offset) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::uniform_real_distribution dist(min_, max_); + rng.discard(n + offset_); + T out = dist(rng); + unsigned int remainder = n % (diag_step_ + 1); + if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) { + out = diag_val_; + } + return out; + } +}; + // It seems that Eigen::Tensor::random in GPU will SEGFAULT. // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. @@ -89,10 +122,11 @@ class GPUUniformRandomKernel : public framework::OpKernel { } T* data = tensor->mutable_data(context.GetPlace()); unsigned int seed = static_cast(context.Attr("seed")); - + bool seed_flag = false; if (seed == 0) { std::random_device rd; seed = rd(); + seed_flag = true; } T min = static_cast(context.Attr("min")); @@ -104,10 +138,27 @@ class GPUUniformRandomKernel : public framework::OpKernel { T diag_val = static_cast(context.Attr("diag_val")); thrust::counting_iterator index_sequence_begin(0); int64_t size = tensor->numel(); - thrust::transform( - index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(data), - UniformGenerator(min, max, seed, diag_num, diag_step, diag_val)); + int device_id = + BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId(); + auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + if (gen_cuda->GetIsInitPy() && seed_flag) { + auto seed_offset = gen_cuda->IncrementOffset(1); + int offset_step = 100; + // NOTE(xuefeng): Currently, we let offset step fixed to avoid + // unexpected results which may cause ut fail. + // we will fix this in future. + int gen_offset = offset_step * seed_offset.second; + thrust::transform( + index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(data), + UniformGeneratorOffset(min, max, seed_offset.first, diag_num, + diag_step, diag_val, gen_offset)); + } else { + thrust::transform( + index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(data), + UniformGenerator(min, max, seed, diag_num, diag_step, diag_val)); + } } }; diff --git a/paddle/fluid/pybind/generator_py.cc b/paddle/fluid/pybind/generator_py.cc index 90b7f501052530a306ba22ea6a244f0ef8fad563..67121e24089f7c6c5b8de985da89039eca85f094 100644 --- a/paddle/fluid/pybind/generator_py.cc +++ b/paddle/fluid/pybind/generator_py.cc @@ -59,6 +59,7 @@ void BindGenerator(py::module* m_ptr) { .def_property("_is_init_py", &framework::Generator::GetIsInitPy, &framework::Generator::SetIsInitPy); m.def("default_cpu_generator", &framework::DefaultCPUGenerator); -} // end Generator -} // end namespace pybind + m.def("default_cuda_generator", &framework::GetDefaultCUDAGenerator); +} +} // namespace pybind } // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 46097316257307bd2211c9807594979682c8dfb6..d5793eb424ab794e3e8af8ef2312aac927c272e5 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -217,6 +217,8 @@ from .tensor.search import index_select #DEFINE_ALIAS from .tensor.search import nonzero #DEFINE_ALIAS from .tensor.search import sort #DEFINE_ALIAS from .framework.random import manual_seed #DEFINE_ALIAS +from .framework.random import get_cuda_rng_state #DEFINE_ALIAS +from .framework.random import set_cuda_rng_state #DEFINE_ALIAS from .framework import Variable #DEFINE_ALIAS from .framework import ParamAttr #DEFINE_ALIAS from .framework import create_global_var #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_cuda_random_seed.py b/python/paddle/fluid/tests/unittests/test_cuda_random_seed.py new file mode 100644 index 0000000000000000000000000000000000000000..0c2520038a82a0b9427b2cbe1d4010a1bc8e040c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cuda_random_seed.py @@ -0,0 +1,163 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test cloud role maker.""" + +from __future__ import print_function +import os +import unittest +import paddle.fluid.generator as generator + +import time # temp for debug +import paddle.fluid as fluid +import numpy as np +import paddle +import paddle.fluid.core as core + + +class TestGeneratorSeed(unittest.TestCase): + """ + Test cases for cpu generator seed. + """ + + def test_gen_dropout_dygraph(self): + gen = paddle.manual_seed(12343) + + fluid.enable_dygraph() + + gen.manual_seed(111111111) + st = paddle.get_cuda_rng_state() + + x = fluid.layers.uniform_random( + [2, 10], dtype="float32", min=0.0, max=1.0) + x_again = fluid.layers.uniform_random( + [2, 10], dtype="float32", min=0.0, max=1.0) + x_third = fluid.layers.uniform_random( + [2, 10], dtype="float32", min=0.0, max=1.0) + print("x: {}".format(x.numpy())) + print("x_again: {}".format(x_again.numpy())) + x = x + x_again + x_third + y = fluid.layers.dropout(x, 0.5) + + paddle.set_cuda_rng_state(st) + + x1 = fluid.layers.uniform_random( + [2, 10], dtype="float32", min=0.0, max=1.0) + x1_again = fluid.layers.uniform_random( + [2, 10], dtype="float32", min=0.0, max=1.0) + x1_third = fluid.layers.uniform_random( + [2, 10], dtype="float32", min=0.0, max=1.0) + x1 = x1 + x1_again + x1_third + y1 = fluid.layers.dropout(x1, 0.5) + y_np = y.numpy() + y1_np = y1.numpy() + + if core.is_compiled_with_cuda(): + print(">>>>>>> dropout dygraph >>>>>>>") + self.assertTrue(np.allclose(y_np, y1_np)) + + def test_generator_gaussian_random_dygraph(self): + """Test Generator seed.""" + fluid.enable_dygraph() + + paddle.manual_seed(12312321111) + x = fluid.layers.gaussian_random([120], dtype="float32") + st1 = paddle.get_cuda_rng_state() + x1 = fluid.layers.gaussian_random([120], dtype="float32") + paddle.set_cuda_rng_state(st1) + x2 = fluid.layers.gaussian_random([120], dtype="float32") + paddle.manual_seed(12312321111) + x3 = fluid.layers.gaussian_random([120], dtype="float32") + x_np = x.numpy() + x1_np = x1.numpy() + x2_np = x2.numpy() + x3_np = x3.numpy() + + if core.is_compiled_with_cuda(): + print(">>>>>>> gaussian random dygraph >>>>>>>") + self.assertTrue(np.allclose(x1_np, x2_np)) + self.assertTrue(np.allclose(x_np, x3_np)) + + def test_generator_randint_dygraph(self): + """Test Generator seed.""" + + fluid.enable_dygraph() + + gen = paddle.manual_seed(12312321111) + x = paddle.randint(low=10, shape=[10], dtype="int32") + st1 = gen.get_state() + x1 = paddle.randint(low=10, shape=[10], dtype="int32") + gen.set_state(st1) + x2 = paddle.randint(low=10, shape=[10], dtype="int32") + paddle.manual_seed(12312321111) + x3 = paddle.randint(low=10, shape=[10], dtype="int32") + x_np = x.numpy() + x1_np = x1.numpy() + x2_np = x2.numpy() + x3_np = x3.numpy() + + if core.is_compiled_with_cuda(): + print(">>>>>>> randint dygraph >>>>>>>") + self.assertTrue(np.allclose(x1_np, x2_np)) + self.assertTrue(np.allclose(x_np, x3_np)) + + def test_gen_TruncatedNormal_initializer(self): + fluid.disable_dygraph() + + gen = paddle.manual_seed(123123143) + cur_state = paddle.get_cuda_rng_state() + + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + # example 1: + # attr shape is a list which doesn't contain tensor Variable. + x = fluid.layers.uniform_random(shape=[2, 10]) + result_1 = fluid.layers.fc( + input=x, + size=10, + param_attr=fluid.initializer.TruncatedNormal( + loc=0.0, scale=2.0)) + result_2 = fluid.layers.fc( + input=x, + size=10, + param_attr=fluid.initializer.TruncatedNormal( + loc=0.0, scale=2.0)) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(startup_program) + out1 = exe.run(train_program, + feed={}, + fetch_list=[result_1, result_2]) + + paddle.manual_seed(123123143) + with fluid.program_guard(train_program, startup_program): + exe.run(startup_program) + out2 = exe.run(train_program, + feed={}, + fetch_list=[result_1, result_2]) + + out1_res1 = np.array(out1[0]) + out1_res2 = np.array(out1[1]) + out2_res1 = np.array(out2[0]) + out2_res2 = np.array(out2[1]) + + if core.is_compiled_with_cuda(): + print(">>>>>>> truncated normal static >>>>>>>") + self.assertTrue(np.allclose(out1_res1, out2_res1)) + self.assertTrue(np.allclose(out1_res2, out2_res2)) + self.assertTrue(not np.allclose(out1_res2, out1_res1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/framework/random.py b/python/paddle/framework/random.py index 2555d24464112ed8446d863dc8e65cfa37680b36..ba2cf603d4a69f118320e40f1f953cb4c5fcfb39 100644 --- a/python/paddle/framework/random.py +++ b/python/paddle/framework/random.py @@ -16,7 +16,7 @@ import paddle.fluid as fluid from paddle.fluid import core -__all__ = ['manual_seed'] +__all__ = ['manual_seed', 'get_cuda_rng_state', 'set_cuda_rng_state'] def manual_seed(seed): @@ -42,10 +42,69 @@ def manual_seed(seed): seed = int(seed) + if core.is_compiled_with_cuda(): + for i in range(core.get_cuda_device_count()): + core.default_cuda_generator(i)._is_init_py = True + core.default_cuda_generator(i).manual_seed(seed) + core.default_cpu_generator()._is_init_py = True return core.default_cpu_generator().manual_seed(seed) +def get_cuda_rng_state(): + """ + + Get random state of cuda generators. + + Args: + None + + Returns: + GeneratorState: object. + + Examples: + .. code-block:: python + + import paddle + sts = paddle.get_cuda_rng_state() + + """ + state_list = [] + if core.is_compiled_with_cuda(): + for i in range(core.get_cuda_device_count()): + state_list.append(core.default_cuda_generator(i).get_state()) + + return state_list + + +def set_cuda_rng_state(state_list): + """ + + Sets generator state for all cuda generators + + Args: + state_list(list): The cuda states to set back to cuda generators. state_list is obtained from get_cuda_rng_state(). + + Returns: + None + + Examples: + .. code-block:: python + + import paddle + sts = paddle.get_cuda_rng_state() + paddle.set_cuda_rng_state(sts) + + """ + if core.is_compiled_with_cuda(): + if not len(state_list) == core.get_cuda_device_count(): + raise ValueError( + "Length of cuda state list shoule be equal to the cuda device count" + ) + for i in range(core.get_cuda_device_count()): + core.default_cuda_generator(i).set_state(state_list[i]) + + def _manual_program_seed(seed): """ Sets global seed for generating random numbers.