diff --git a/paddle/fluid/framework/generator.cc b/paddle/fluid/framework/generator.cc index 2bd8ed900f10298deca35891625799f48ee9e4e2..b621eca35b893e95b825b3a5ae228ac125c07a72 100644 --- a/paddle/fluid/framework/generator.cc +++ b/paddle/fluid/framework/generator.cc @@ -24,7 +24,7 @@ limitations under the License. */ namespace paddle { namespace framework { -const std::shared_ptr& GetDefaultCUDAGenerator(int64_t device_id) { +const std::shared_ptr& DefaultCUDAGenerator(int64_t device_id) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) static int64_t num_cuda_devices = -1; @@ -58,8 +58,6 @@ const std::shared_ptr& GetDefaultCUDAGenerator(int64_t device_id) { const std::shared_ptr& DefaultCPUGenerator() { static auto default_cpu_generator = std::make_shared(GetRandomSeed()); - VLOG(4) << "initial seed: " << default_cpu_generator->GetCurrentSeed() - << ", cpu engine: " << default_cpu_generator->GetCPUEngine().get(); return default_cpu_generator; } @@ -100,19 +98,13 @@ const std::shared_ptr& GetRandomSeedGenerator( return iter->second; } -std::shared_ptr OpDefaultCPUEngine() { - static auto op_default_cpu_engine = std::make_shared(); - return op_default_cpu_engine; -} - -// NOTE(zhiqiu): there are 3 conditions: -// (1) op seed is not set and DefaultCPUGenerator is inited, use -// DefaultCPUGenerator -// (2) op seed is not set and DefaultCPUGenerator is not inited, use se -// OpDefaultCPUEngine() and set a radnom seed -// (3) op seed is set, use OpDefaultCPUEngine() and set the seed +// There are 3 conditions: +// (1) op seed is set, use op seed. +// (2) op seed is not set, global seed is set, use global seed. +// (3) op seed is not set, global seed is not set too, use random seed from +// RandomGenerator. std::shared_ptr GetCPURandomEngine(uint64_t seed) { - if (DefaultCPUGenerator()->GetIsInitPy() && seed == 0) { + if (seed == 0) { VLOG(4) << "Use random engine from generator"; return DefaultCPUGenerator()->GetCPUEngine(); } else { @@ -123,12 +115,6 @@ std::shared_ptr GetCPURandomEngine(uint64_t seed) { // // And we need to measure the determinacy of Generator in PE. auto engine = std::make_shared(); - if (seed == 0) { - seed = GetRandomSeed(); - VLOG(4) << "Use default random engine with random seed = " << seed; - } else { - VLOG(4) << "Use default random engine with fixed random seed = " << seed; - } static std::mutex mu_; { std::lock_guard lock(mu_); @@ -204,11 +190,5 @@ std::pair Generator::IncrementOffset( #endif } -void Generator::SetIsInitPy(bool is_init_py) { - this->is_init_py_ = is_init_py; - VLOG(4) << "SetIsInitPy:" << this->is_init_py_; -} -bool Generator::GetIsInitPy() const { return this->is_init_py_; } - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/generator.h b/paddle/fluid/framework/generator.h index 1c19234bf7d8061257c4e2c7aeab04080d6a2637..35efc1bee33d59b1b96d4d1fb895069326c9f124 100644 --- a/paddle/fluid/framework/generator.h +++ b/paddle/fluid/framework/generator.h @@ -59,7 +59,6 @@ struct Generator : public phi::Generator { this->engine_ = engine; VLOG(4) << "initial seed: " << this->state_.current_seed << ", cpu engine: " << &this->state_.cpu_engine; - this->is_init_py_ = true; // TODO(zhiqiu): remove it in future } Generator(uint64_t seed, uint64_t device_id) { std::seed_seq seq({seed}); @@ -71,7 +70,6 @@ struct Generator : public phi::Generator { this->engine_ = engine; VLOG(4) << "initial seed: " << this->state_.current_seed << ", cpu engine: " << &this->state_.cpu_engine; - this->is_init_py_ = false; // TODO(zhiqiu): remove it in future } Generator(const Generator& other) = delete; @@ -95,32 +93,21 @@ struct Generator : public phi::Generator { std::pair IncrementOffset(uint64_t increament_offset); - void SetIsInitPy(bool); - bool GetIsInitPy() const; uint64_t get_device_id() { return this->state_.device; } private: phi::Generator::GeneratorState state_; std::shared_ptr engine_; mutable std::mutex mu_; - - // NOTE(zhiqiu): is_init_py_ is used to make generator be compatible with - // old seed, and it should be removed after all random-related operators - // and unittests upgrades to use generator. - bool is_init_py_ = false; }; // The DefaultCPUGenerator is used in manual_seed() const std::shared_ptr& DefaultCPUGenerator(); -// If op seed is set or global is not set, the OpDefaultCPUEngine is used. -std::shared_ptr OpDefaultCPUEngine(); +const std::shared_ptr& DefaultCUDAGenerator(int64_t device_id = -1); std::shared_ptr GetCPURandomEngine(uint64_t); -const std::shared_ptr& GetDefaultCUDAGenerator( - int64_t device_id = -1); - const std::shared_ptr& SetRandomSeedGenerator( const std::string& name, uint64_t seed); diff --git a/paddle/fluid/operators/class_center_sample_op.cu b/paddle/fluid/operators/class_center_sample_op.cu index 1064c77cc00410b484140f25a6e325732f36842b..a23cf2815d8fe84ba30d6957c9e31aeb2a3e9040 100644 --- a/paddle/fluid/operators/class_center_sample_op.cu +++ b/paddle/fluid/operators/class_center_sample_op.cu @@ -416,14 +416,13 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel { 1) * vec_size; int device_id = ctx.GetPlace().GetDeviceId(); - auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); - if (gen_cuda->GetIsInitPy() && (!fix_seed)) { + auto gen_cuda = framework::DefaultCUDAGenerator(device_id); + if (!fix_seed) { auto seed_offset = gen_cuda->IncrementOffset(offset); seed_data = seed_offset.first; increment = seed_offset.second; } else { - std::random_device rnd; - seed_data = fix_seed ? seed + rank : rnd(); + seed_data = seed + rank; increment = offset; } RandomSampleClassCenter<< { int seed = ctx.Attr("seed"); if (!is_test) { - int device_id = ctx.GetPlace().GetDeviceId(); - auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); - if (gen_cuda->GetIsInitPy() && seed == 0) { - // If perform `manual_seed` in python and inner seed is not specified - // (equals 0), use global generator generated seed. + if (seed == 0) { + // If not specify seed, use global Generator to generate seed. + int device_id = ctx.GetPlace().GetDeviceId(); + auto gen_cuda = paddle::framework::DefaultCUDAGenerator(device_id); seed = static_cast(gen_cuda->Random64()); - } else if (seed == 0) { - // use random generated seed - std::random_device rd; - seed = rd(); - } // else use `ctx.Attr("seed")` specified seed + } + // else use `ctx.Attr("seed")` specified seed } bool has_seq_length = ctx.HasInput("SequenceLength"); diff --git a/paddle/fluid/operators/dirichlet_op.cu b/paddle/fluid/operators/dirichlet_op.cu index 63f9c7339bfc5bd46ae0b12a05395c010ac81244..ac6480a8fa1c629339e15851e3aa2a7691c8d6f9 100644 --- a/paddle/fluid/operators/dirichlet_op.cu +++ b/paddle/fluid/operators/dirichlet_op.cu @@ -77,7 +77,7 @@ struct DirichletSampler { // init state, seed & offset for all threads int device_id = ctx.GetPlace().GetDeviceId(); - auto p_gen = framework::GetDefaultCUDAGenerator(device_id); + auto p_gen = framework::DefaultCUDAGenerator(device_id); auto seed_and_offset = p_gen->IncrementOffset(10); // hard-coded offset auto seed = seed_and_offset.first; auto offset = seed_and_offset.second; diff --git a/paddle/fluid/operators/dropout_impl_util.h b/paddle/fluid/operators/dropout_impl_util.h index c62d45570ba291dc60120c393d21842cc6548c61..571a1c97c52e8c600963a09a16a47d0c23b5bd97 100644 --- a/paddle/fluid/operators/dropout_impl_util.h +++ b/paddle/fluid/operators/dropout_impl_util.h @@ -26,7 +26,7 @@ inline void GetSeedDataAndIncrement(const phi::GPUContext& dev_ctx, const int offset, uint64_t* seed_data, uint64_t* increment) { int device_id = dev_ctx.GetPlace().GetDeviceId(); - auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + auto gen_cuda = framework::DefaultCUDAGenerator(device_id); if (seed) { framework::Tensor seed_cpu_tensor; @@ -34,13 +34,12 @@ inline void GetSeedDataAndIncrement(const phi::GPUContext& dev_ctx, &seed_cpu_tensor); *seed_data = static_cast(seed_cpu_tensor.data()[0]); *increment = offset; - } else if (gen_cuda->GetIsInitPy() && (!is_fix_seed)) { + } else if (!is_fix_seed) { auto seed_offset = gen_cuda->IncrementOffset(offset); *seed_data = seed_offset.first; *increment = seed_offset.second; } else { - std::random_device rnd; - *seed_data = is_fix_seed ? seed_val : rnd(); + *seed_data = seed_val; *increment = offset; } } diff --git a/paddle/fluid/operators/gaussian_random_op.cu b/paddle/fluid/operators/gaussian_random_op.cu index 552649279e9118372faa56b931fe8196c31c03d3..deac932d59b800594cf6915bd67cb0aef2ffaca1 100644 --- a/paddle/fluid/operators/gaussian_random_op.cu +++ b/paddle/fluid/operators/gaussian_random_op.cu @@ -54,26 +54,21 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel { auto* tensor = context.Output("Out"); T* data = tensor->mutable_data(context.GetPlace()); unsigned int seed = static_cast(context.Attr("seed")); - bool seed_flag = false; - if (seed == 0) { - std::random_device rd; - seed = rd(); - seed_flag = true; - } T mean = static_cast(context.Attr("mean")); T std = static_cast(context.Attr("std")); int64_t size = tensor->numel(); int device_id = context.GetPlace().GetDeviceId(); - auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + auto gen_cuda = framework::DefaultCUDAGenerator(device_id); auto& dev_cxt = context.template device_context(); - if (gen_cuda->GetIsInitPy() && seed_flag) { + if (seed == 0) { + // use global Generator seed auto seed_offset = gen_cuda->IncrementOffset(1); - int64_t gen_offset = size * seed_offset.second; - auto func = GaussianGenerator(mean, std, seed_offset.first, - seed_offset.second); + uint64_t seed = seed_offset.first; + uint64_t offset = seed_offset.second; + auto func = GaussianGenerator(mean, std, seed, size * offset); phi::IndexKernel>(dev_cxt, tensor, func); } else { auto func = GaussianGenerator(mean, std, seed); diff --git a/paddle/fluid/operators/uniform_random_op.h b/paddle/fluid/operators/uniform_random_op.h index ae846f4cae6fba7314b2c046e01bfc69220349af..3e27402c86947974a247f89a23c2ea43a3eefd61 100644 --- a/paddle/fluid/operators/uniform_random_op.h +++ b/paddle/fluid/operators/uniform_random_op.h @@ -151,12 +151,6 @@ void UniformRandom(const framework::ExecutionContext& context, T* data = tensor->mutable_data(dev_cxt.GetPlace()); if (size <= 0) return; unsigned int seed = static_cast(context.Attr("seed")); - bool seed_flag = false; - if (seed == 0) { - std::random_device rd; - seed = rd(); - seed_flag = true; - } T min = static_cast(context.Attr("min")); T max = static_cast(context.Attr("max")); @@ -165,14 +159,15 @@ void UniformRandom(const framework::ExecutionContext& context, unsigned int diag_step = static_cast(context.Attr("diag_step")); T diag_val = static_cast(context.Attr("diag_val")); - int device_id = context.GetPlace().GetDeviceId(); - auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); - if (gen_cuda->GetIsInitPy() && seed_flag) { + + if (seed == 0) { + // Use global Generator seed using MT = typename details::MPTypeTrait::Type; phi::funcs::uniform_distribution dist; phi::funcs::uniform_real_transform trans(min, max); phi::funcs::distribution_and_transform(dev_cxt, tensor, dist, trans); } else { + // Use OP seed auto func = UniformGenerator(min, max, seed, diag_num, diag_step, diag_val); phi::IndexKernel>(dev_cxt, tensor, func); diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 904e4854ba6b45f55f7367490270b366b56caf62..0bf5ca7f8f52572943c711eb6b960aedb7b13acd 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -169,7 +169,7 @@ inline void EmplaceDeviceContext( cuda_ctx->PartialInitWithAllocator(); dev_ctx->SetGenerator( - framework::GetDefaultCUDAGenerator(p.GetDeviceId()).get()); + framework::DefaultCUDAGenerator(p.GetDeviceId()).get()); #endif } else { dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance() diff --git a/paddle/fluid/pybind/generator_py.cc b/paddle/fluid/pybind/generator_py.cc index 53379373d25186f0ce730171fd173e468acb15a8..6bb85da8c466fdc657a295d1c5cd66b7b0739812 100644 --- a/paddle/fluid/pybind/generator_py.cc +++ b/paddle/fluid/pybind/generator_py.cc @@ -55,13 +55,9 @@ void BindGenerator(py::module* m_ptr) { }) .def("seed", &framework::Generator::Seed) .def("initial_seed", &framework::Generator::GetCurrentSeed) - .def("random", &framework::Generator::Random64) - // .def("get_cpu_engine", &framework::Generator::GetCPUEngine) - // .def("set_cpu_engine", &framework::Generator::SetCPUEngine) - .def_property("_is_init_py", &framework::Generator::GetIsInitPy, - &framework::Generator::SetIsInitPy); + .def("random", &framework::Generator::Random64); m.def("default_cpu_generator", &framework::DefaultCPUGenerator); - m.def("default_cuda_generator", &framework::GetDefaultCUDAGenerator); + m.def("default_cuda_generator", &framework::DefaultCUDAGenerator); m.def("set_random_seed_generator", &framework::SetRandomSeedGenerator); m.def("get_random_seed_generator", &framework::GetRandomSeedGenerator); } diff --git a/paddle/phi/core/generator.h b/paddle/phi/core/generator.h index 29ea92cbe6d94234b34af492fb6d828305c1d8f3..3263b2a52573271357d2e5f8751a654e06629a87 100644 --- a/paddle/phi/core/generator.h +++ b/paddle/phi/core/generator.h @@ -49,12 +49,6 @@ class Generator { virtual std::pair IncrementOffset( uint64_t increament_offset) = 0; - // NOTE(zhiqiu): is_init_py_ is used to make generator be compatible with - // old seed, and it should be removed after all random-related operators - // and unittests upgrades to use generator. - virtual void SetIsInitPy(bool) = 0; - virtual bool GetIsInitPy() const = 0; - virtual uint64_t get_device_id() = 0; }; diff --git a/paddle/phi/kernels/gpu/gaussian_random_kernel.cu b/paddle/phi/kernels/gpu/gaussian_random_kernel.cu index 96ebc0353ef2453308d6b9e371b6b640e8ab7b28..b80634357d62f0bbefd4744acd1dcc755e2231d4 100644 --- a/paddle/phi/kernels/gpu/gaussian_random_kernel.cu +++ b/paddle/phi/kernels/gpu/gaussian_random_kernel.cu @@ -59,34 +59,20 @@ void GaussianRandomKernel(const Context& dev_ctx, int seed, DataType dtype, DenseTensor* out) { - auto tensor = out; - - bool seed_flag = false; + out->Resize(phi::make_ddim(shape.GetData())); + dev_ctx.template Alloc(out); if (seed == 0) { - std::random_device rd; - seed = rd(); - seed_flag = true; - } - - tensor->Resize(phi::make_ddim(shape.GetData())); - - T* data = dev_ctx.template Alloc(tensor); - - int64_t size = tensor->numel(); - - int device_id = dev_ctx.GetPlace().GetDeviceId(); - auto gen_cuda = paddle::framework::GetDefaultCUDAGenerator(device_id); - - if (gen_cuda->GetIsInitPy() && seed_flag) { + // use global Generator seed using MT = typename phi::dtype::MPTypeTrait::Type; funcs::normal_distribution dist; funcs::normal_transform trans(static_cast(mean), static_cast(std)); - funcs::distribution_and_transform(dev_ctx, tensor, dist, trans); + funcs::distribution_and_transform(dev_ctx, out, dist, trans); } else { + // use OP seed auto func = GaussianGenerator(static_cast(mean), static_cast(std), seed); - IndexKernel>(dev_ctx, tensor, func); + IndexKernel>(dev_ctx, out, func); } } diff --git a/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu b/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu index 6b1e58981baa0a4768057b5a1c072d4182dfc1fd..c0e557f09bcc9bceb0d9f46383428454f8441b38 100644 --- a/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu +++ b/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu @@ -27,12 +27,9 @@ namespace cub = hipcub; #endif -#include -#include -#include -#include #include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/tensor_util.h" +#include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -144,27 +141,21 @@ struct GumbleNoiseGenerator { DenseTensor random_tensor; int64_t size = size_to_axis * size_from_axis; random_tensor.Resize(make_ddim({size})); - auto* random_data = ctx.template Alloc(&random_tensor); - thrust::counting_iterator index_sequence_begin(0); + T* random_data = ctx.template Alloc(&random_tensor); // generate gumbel noise int device_id = ctx.GetPlace().GetDeviceId(); - auto gen_cuda = paddle::framework::GetDefaultCUDAGenerator(device_id); - if (gen_cuda->GetIsInitPy()) { - auto seed_offset = gen_cuda->IncrementOffset(1); - int64_t gen_offset = size * seed_offset.second; - thrust::transform( - index_sequence_begin, - index_sequence_begin + size, - thrust::device_ptr(random_data), - UniformCUDAGenerator(0.00001, 1, seed_offset.first, gen_offset)); - } else { - const unsigned int seed = std::random_device()(); - thrust::transform(index_sequence_begin, - index_sequence_begin + size, - thrust::device_ptr(random_data), - UniformCUDAGenerator(0.00001, 1, seed)); - } + auto gen_cuda = ctx.GetGenerator(); + + auto seed_offset = gen_cuda->IncrementOffset(1); + uint64_t seed = seed_offset.first; + uint64_t offset = seed_offset.second; + + thrust::counting_iterator index_sequence_begin(0); + thrust::transform(index_sequence_begin, + index_sequence_begin + size, + thrust::device_ptr(random_data), + UniformCUDAGenerator(0.00001, 1, seed, size * offset)); // add gumbel noise to X const int thread_size = 512; diff --git a/paddle/phi/kernels/gpu/rnn_kernel.cu.cc b/paddle/phi/kernels/gpu/rnn_kernel.cu.cc index d30b7ec34d43ccf0fe945362cda67f5ed7fa025d..f2ffe3c9d4fba50fed02249d44fe8b94dba49ff1 100644 --- a/paddle/phi/kernels/gpu/rnn_kernel.cu.cc +++ b/paddle/phi/kernels/gpu/rnn_kernel.cu.cc @@ -175,17 +175,13 @@ void RnnKernel(const Context &dev_ctx, mode)); if (!is_test) { - int device_id = dev_ctx.GetPlace().GetDeviceId(); - auto gen_cuda = paddle::framework::GetDefaultCUDAGenerator(device_id); - if (gen_cuda->GetIsInitPy() && seed == 0) { - // If perform `manual_seed` in python and inner seed is not specified - // (equals 0), use global generator generated seed. + if (seed == 0) { + // If not specify seed, use global Generator to generate seed. + int device_id = dev_ctx.GetPlace().GetDeviceId(); + auto gen_cuda = paddle::framework::DefaultCUDAGenerator(device_id); seed = static_cast(gen_cuda->Random64()); - } else if (seed == 0) { - // use random generated seed - std::random_device rd; - seed = rd(); - } // else use `ctx.Attr("seed")` specified seed + } + // else use `ctx.Attr("seed")` specified seed } const T *x_data = x.data(); diff --git a/paddle/phi/kernels/gpu/truncated_gaussian_random_kernel.cu b/paddle/phi/kernels/gpu/truncated_gaussian_random_kernel.cu index 5b6ae9d09bff207fc56baf958fe15a5d4e9c52d2..33ecb4d6eb544c68f07cfec14701c6a38de987b0 100644 --- a/paddle/phi/kernels/gpu/truncated_gaussian_random_kernel.cu +++ b/paddle/phi/kernels/gpu/truncated_gaussian_random_kernel.cu @@ -90,34 +90,25 @@ void TruncatedGaussianRandomKernel(const Context& dev_ctx, int seed, DataType dtype, DenseTensor* out) { - auto tensor = out; - - T* data = dev_ctx.template Alloc(tensor); - - bool seed_flag = false; - if (seed == 0) { - std::random_device rd; - seed = rd(); - seed_flag = true; - } + T* data = dev_ctx.template Alloc(out); thrust::counting_iterator index_sequence_begin(0); - int64_t size = tensor->numel(); + int64_t size = out->numel(); auto gen_cuda = dev_ctx.GetGenerator(); - - if (gen_cuda->GetIsInitPy() && seed_flag) { + if (seed == 0) { + // use global Generator seed auto seed_offset = gen_cuda->IncrementOffset(1); - int64_t gen_offset = size * seed_offset.second; - thrust::transform(index_sequence_begin, - index_sequence_begin + size, - thrust::device_ptr(data), - TruncatedNormalOffset(mean, - std, - std::numeric_limits::min(), - seed_offset.first, - gen_offset)); + uint64_t seed = seed_offset.first; + uint64_t offset = seed_offset.second; + thrust::transform( + index_sequence_begin, + index_sequence_begin + size, + thrust::device_ptr(data), + TruncatedNormalOffset( + mean, std, std::numeric_limits::min(), seed, size * offset)); } else { + // use OP seed thrust::transform( index_sequence_begin, index_sequence_begin + size, diff --git a/paddle/phi/kernels/gpu/uniform_random_kernel.cu b/paddle/phi/kernels/gpu/uniform_random_kernel.cu index a4aea10cfe762f203f326d69888becbf1ee3094e..68e61b7328971633ec3d933d0ab56e2491c9bb5e 100644 --- a/paddle/phi/kernels/gpu/uniform_random_kernel.cu +++ b/paddle/phi/kernels/gpu/uniform_random_kernel.cu @@ -65,22 +65,15 @@ void UniformRandomRawKernel(const Context& dev_ctx, float diag_val, DenseTensor* out) { out->Resize(phi::make_ddim(shape.GetData())); - T* data = dev_ctx.template Alloc(out); - auto size = out->numel(); - bool seed_flag = false; + dev_ctx.template Alloc(out); if (seed == 0) { - std::random_device rd; - seed = rd(); - seed_flag = true; - } - - auto generator = dev_ctx.GetGenerator(); - if (generator->GetIsInitPy() && seed_flag) { + // Use global Generator seed using MT = typename kps::details::MPTypeTrait::Type; funcs::uniform_distribution dist; funcs::uniform_real_transform trans(min, max); funcs::distribution_and_transform(dev_ctx, out, dist, trans); } else { + // Use OP seed auto func = UniformGenerator(min, max, seed, diag_num, diag_step, diag_val); IndexKernel>(dev_ctx, out, func); diff --git a/python/paddle/fluid/tests/unittests/test_cuda_random_seed.py b/python/paddle/fluid/tests/unittests/test_cuda_random_seed.py index 6033b809f218d8b5ecac070cbef350d3e52606aa..14a91b0c2c5fe970da2f86202518c92cfa593ead 100644 --- a/python/paddle/fluid/tests/unittests/test_cuda_random_seed.py +++ b/python/paddle/fluid/tests/unittests/test_cuda_random_seed.py @@ -25,6 +25,8 @@ import paddle import paddle.fluid.core as core +@unittest.skipIf(not core.is_compiled_with_cuda(), + "Only test cuda Random Generator") class TestGeneratorSeed(unittest.TestCase): """ Test cases for cpu generator seed. @@ -70,15 +72,13 @@ class TestGeneratorSeed(unittest.TestCase): """Test Generator seed.""" fluid.enable_dygraph() - paddle.seed(12312321111) - x = fluid.layers.gaussian_random([120], dtype="float32") - st1 = paddle.get_cuda_rng_state() - x1 = fluid.layers.gaussian_random([120], dtype="float32") - paddle.set_cuda_rng_state(st1) - x2 = fluid.layers.gaussian_random([120], dtype="float32") - paddle.seed(12312321111) - x3 = fluid.layers.gaussian_random([120], dtype="float32") - x_np = x.numpy() + st = paddle.get_cuda_rng_state() + x1 = paddle.randn([120], dtype="float32") + paddle.set_cuda_rng_state(st) + x2 = paddle.randn([120], dtype="float32") + paddle.set_cuda_rng_state(st) + x3 = paddle.randn([120], dtype="float32") + x1_np = x1.numpy() x2_np = x2.numpy() x3_np = x3.numpy() @@ -86,7 +86,7 @@ class TestGeneratorSeed(unittest.TestCase): if core.is_compiled_with_cuda(): print(">>>>>>> gaussian random dygraph >>>>>>>") self.assertTrue(np.allclose(x1_np, x2_np)) - self.assertTrue(np.allclose(x_np, x3_np)) + self.assertTrue(np.allclose(x2_np, x3_np)) def test_generator_randint_dygraph(self): """Test Generator seed.""" diff --git a/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py b/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py index dacb7a5b599579951f663dcf8f84d8354108132b..3621fd1b9d4457fc9159a4715d03f13c10e05117 100644 --- a/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py +++ b/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py @@ -629,7 +629,6 @@ class ModuleApiTest(unittest.TestCase): else: fluid.disable_dygraph() gen = paddle.seed(self._random_seed) - gen._is_init_py = False paddle.framework.random._manual_program_seed(self._random_seed) scope = fluid.core.Scope() with fluid.scope_guard(scope): diff --git a/python/paddle/fluid/tests/unittests/test_uniform_random_bf16_op.py b/python/paddle/fluid/tests/unittests/test_uniform_random_bf16_op.py index 2ba808a341e5eb02c7759f04cf44fff1e4365ece..5f4989f6c5dbdd30f60c14a370ee82933f378c96 100644 --- a/python/paddle/fluid/tests/unittests/test_uniform_random_bf16_op.py +++ b/python/paddle/fluid/tests/unittests/test_uniform_random_bf16_op.py @@ -178,7 +178,6 @@ class TestUniformRandomOpAPISeed(unittest.TestCase): def test_attr_tensor_API(self): _seed = 10 gen = paddle.seed(_seed) - gen._is_init_py = False startup_program = fluid.Program() train_program = fluid.Program() with fluid.program_guard(train_program, startup_program): diff --git a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py index 0b27c616230898a7cad98859d6d316c17b7799d5..0bca3c08f3d78f4c04e348833ec8e88a4a7fd7ab 100644 --- a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py @@ -370,7 +370,6 @@ class TestUniformRandomOp_API_seed(unittest.TestCase): def test_attr_tensor_API(self): _seed = 10 gen = paddle.seed(_seed) - gen._is_init_py = False startup_program = fluid.Program() train_program = fluid.Program() with fluid.program_guard(train_program, startup_program): diff --git a/python/paddle/framework/random.py b/python/paddle/framework/random.py index 147f6be39c5e01f52809e7259bc7d57aa7fab0eb..b58d36b8e7d502f7980727dd2c955b058c85ffec 100644 --- a/python/paddle/framework/random.py +++ b/python/paddle/framework/random.py @@ -44,10 +44,8 @@ def seed(seed): if core.is_compiled_with_cuda(): for i in range(core.get_cuda_device_count()): - core.default_cuda_generator(i)._is_init_py = True core.default_cuda_generator(i).manual_seed(seed) - core.default_cpu_generator()._is_init_py = True return core.default_cpu_generator().manual_seed(seed)