From 23261ff44ba46a507e72e2da7c83f7fede3486f7 Mon Sep 17 00:00:00 2001 From: yaoxuefeng Date: Tue, 18 Aug 2020 12:29:11 +0800 Subject: [PATCH] add cpu random Generator (#26013) --- paddle/fluid/framework/CMakeLists.txt | 1 + paddle/fluid/framework/generator.cc | 78 ++++++++++++ paddle/fluid/framework/generator.h | 96 ++++++++++++++ paddle/fluid/operators/CMakeLists.txt | 4 +- paddle/fluid/operators/gaussian_random_op.cc | 3 +- paddle/fluid/operators/randint_op.cc | 3 +- paddle/fluid/operators/uniform_random_op.cc | 41 ++++-- paddle/fluid/operators/uniform_random_op.cu | 12 +- paddle/fluid/operators/uniform_random_op.h | 1 + paddle/fluid/pybind/CMakeLists.txt | 5 +- paddle/fluid/pybind/generator_py.cc | 51 ++++++++ paddle/fluid/pybind/generator_py.h | 28 +++++ paddle/fluid/pybind/pybind.cc | 2 + python/paddle/fluid/__init__.py | 3 +- python/paddle/fluid/generator.py | 60 +++++++++ .../fluid/tests/unittests/test_generator.py | 44 +++++++ .../fluid/tests/unittests/test_random_seed.py | 119 ++++++++++++++++++ 17 files changed, 533 insertions(+), 18 deletions(-) create mode 100644 paddle/fluid/framework/generator.cc create mode 100644 paddle/fluid/framework/generator.h create mode 100644 paddle/fluid/pybind/generator_py.cc create mode 100644 paddle/fluid/pybind/generator_py.h create mode 100644 python/paddle/fluid/generator.py create mode 100644 python/paddle/fluid/tests/unittests/test_generator.py create mode 100644 python/paddle/fluid/tests/unittests/test_random_seed.py diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index d725bdffa01..8c49a5f6de7 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -268,6 +268,7 @@ cc_test(op_compatible_info_test SRCS op_compatible_info_test.cc DEPS op_compatib cc_library(save_load_util SRCS save_load_util DEPS tensor scope layer) cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer) +cc_library(generator SRCS generator.cc) # Get the current working branch execute_process( diff --git a/paddle/fluid/framework/generator.cc b/paddle/fluid/framework/generator.cc new file mode 100644 index 00000000000..d00e38784c2 --- /dev/null +++ b/paddle/fluid/framework/generator.cc @@ -0,0 +1,78 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/generator.h" + +namespace paddle { +namespace framework { + +std::shared_ptr Generator::gen_instance_ = NULL; + +GeneratorState* Generator::GetState() { + std::lock_guard lock(this->mutex); + return this->state_.get(); +} + +void Generator::SetState(GeneratorState* state_in) { + std::lock_guard lock(this->mutex); + *this->state_ = *state_in; +} + +uint64_t Generator::GetCurrentSeed() { + std::lock_guard lock(this->mutex); + return this->state_->current_seed; +} + +uint64_t Generator::Seed() { + std::lock_guard lock(this->mutex); + uint64_t seed; + std::random_device de; + seed = ((((uint64_t)de()) << 32) + de()) & 0x1FFFFFFFFFFFFF; + this->state_->current_seed = seed; + std::seed_seq seq({seed}); + this->state_->cpu_engine.seed(seq); + + return this->state_->current_seed; +} + +void Generator::SetCurrentSeed(uint64_t seed) { + std::lock_guard lock(this->mutex); + this->state_->current_seed = uint64_t(seed); + std::seed_seq seq({seed}); + this->state_->cpu_engine.seed(seq); +} + +std::mt19937_64& Generator::GetCPUEngine() { + std::lock_guard lock(this->mutex); + return this->state_->cpu_engine; +} + +void Generator::SetCPUEngine(std::mt19937_64 engine) { + std::lock_guard lock(this->mutex); + this->state_->cpu_engine = std::mt19937_64(engine); +} + +uint64_t Generator::Random64() { + std::lock_guard lock(this->mutex); + return this->state_->cpu_engine(); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/generator.h b/paddle/fluid/framework/generator.h new file mode 100644 index 00000000000..17870782ba7 --- /dev/null +++ b/paddle/fluid/framework/generator.h @@ -0,0 +1,96 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include // temp for debug +#include +#include // NOLINT +#include +#include +#include + +namespace paddle { +namespace framework { + +struct GeneratorState { + int64_t device = -1; + uint64_t current_seed = 34342423252; + std::mt19937_64 cpu_engine; +}; + +struct Generator { + Generator() { + GeneratorState default_gen_state_cpu; + default_gen_state_cpu.device = -1; + default_gen_state_cpu.current_seed = 34342423252; + std::seed_seq seq({34342423252}); + default_gen_state_cpu.cpu_engine = std::mt19937_64(seq); + this->state_ = std::make_shared(default_gen_state_cpu); + } + explicit Generator(GeneratorState state_in) + : state_{std::make_shared(state_in)} {} + Generator(const Generator& other) + : Generator(other, std::lock_guard(other.mutex)) {} + + // get random state + GeneratorState* GetState(); + // set random state + void SetState(GeneratorState* state_in); + // get current seed + uint64_t GetCurrentSeed(); + // random a seed and get + uint64_t Seed(); + + // set seed + void SetCurrentSeed(uint64_t seed); + // get cpu engine + std::mt19937_64& GetCPUEngine(); + // set cpu engine + void SetCPUEngine(std::mt19937_64 engine); + + uint64_t Random64(); + + bool is_init_py = false; + + // CPU Generator singleton + static std::shared_ptr GetInstance() { + if (NULL == gen_instance_) { + gen_instance_.reset(new paddle::framework::Generator()); + } + return gen_instance_; + } + + static std::shared_ptr GetInstanceX() { + if (NULL == gen_instance_) { + gen_instance_.reset(new paddle::framework::Generator()); + } + gen_instance_->is_init_py = true; + return gen_instance_; + } + + private: + static std::shared_ptr gen_instance_; + std::shared_ptr state_; + mutable std::mutex mutex; + + Generator(const Generator& other, const std::lock_guard&) + : state_(std::make_shared(*(other.state_))) {} +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index e74f363d886..48d1ec9461a 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -88,7 +88,9 @@ endif() cc_library(common_infer_shape_functions SRCS common_infer_shape_functions.cc DEPS operator) -set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor device_memory_aligment) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows +lod_tensor maxouting unpooling pooling lod_rank_table context_project +sequence_pooling executor device_memory_aligment generator) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse) diff --git a/paddle/fluid/operators/gaussian_random_op.cc b/paddle/fluid/operators/gaussian_random_op.cc index 253078751ce..898c063afdd 100644 --- a/paddle/fluid/operators/gaussian_random_op.cc +++ b/paddle/fluid/operators/gaussian_random_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include + #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/fill_constant_op.h" #ifdef PADDLE_WITH_MKLDNN @@ -30,13 +31,13 @@ class CPUGaussianRandomKernel : public framework::OpKernel { float mean = context.Attr("mean"); float std = context.Attr("std"); auto* tensor = context.Output("Out"); - unsigned int seed = static_cast(context.Attr("seed")); std::minstd_rand engine; if (seed == 0) { seed = std::random_device()(); } engine.seed(seed); + std::normal_distribution dist(mean, std); const std::string op_type = "gaussian_random"; diff --git a/paddle/fluid/operators/randint_op.cc b/paddle/fluid/operators/randint_op.cc index 9f6df3f32b7..11ce738e001 100644 --- a/paddle/fluid/operators/randint_op.cc +++ b/paddle/fluid/operators/randint_op.cc @@ -14,6 +14,7 @@ #include #include + #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/uniform_random_op.h" @@ -37,11 +38,11 @@ class CPURandintKernel : public framework::OpKernel { new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor); } } - auto* out = ctx.Output("Out"); if (!new_shape.empty()) out->Resize(framework::make_ddim(new_shape)); T* data = out->mutable_data(ctx.GetPlace()); int64_t size = out->numel(); + unsigned int seed = static_cast(ctx.Attr("seed")); std::minstd_rand engine; if (seed == 0) { diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index e0c56307639..a4487cde277 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/uniform_random_op.h" #include +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" + namespace paddle { namespace operators { @@ -55,19 +57,40 @@ class CPUUniformRandomKernel : public framework::OpKernel { "supports SelectedRows and LoDTensor"); } T *data = tensor->mutable_data(ctx.GetPlace()); - unsigned int seed = static_cast(ctx.Attr("seed")); - std::minstd_rand engine; - if (seed == 0) { - seed = std::random_device()(); - } - engine.seed(seed); + + int64_t size = tensor->numel(); std::uniform_real_distribution dist( static_cast(ctx.Attr("min")), static_cast(ctx.Attr("max"))); - int64_t size = tensor->numel(); - for (int64_t i = 0; i < size; ++i) { - data[i] = dist(engine); + auto gen_ptr = framework::Generator::GetInstance(); + if (gen_ptr->is_init_py) { + std::mt19937_64 &gen_engine = gen_ptr->GetCPUEngine(); + // auto gen_engine = gen_ptr_->GetCPUEngine(); + // std::uniform_real_distribution dist( + // static_cast(ctx.Attr("min")), + // static_cast(ctx.Attr("max"))); + + for (int64_t i = 0; i < size; ++i) { + data[i] = dist(gen_engine); + } + } else { + unsigned int seed = static_cast(ctx.Attr("seed")); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + // std::uniform_real_distribution dist( + // static_cast(ctx.Attr("min")), + // static_cast(ctx.Attr("max"))); + // int64_t size = tensor->numel(); + for (int64_t i = 0; i < size; ++i) { + data[i] = dist(engine); + } } + // std::mt19937_64 &engine = gen_ptr->GetCPUEngine(); + // auto engine = gen_ptr_->GetCPUEngine(); + unsigned int diag_num = static_cast(ctx.Attr("diag_num")); unsigned int diag_step = diff --git a/paddle/fluid/operators/uniform_random_op.cu b/paddle/fluid/operators/uniform_random_op.cu index 53c79cf672e..c024bb87b09 100644 --- a/paddle/fluid/operators/uniform_random_op.cu +++ b/paddle/fluid/operators/uniform_random_op.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include #include +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/uniform_random_op.h" @@ -87,9 +88,14 @@ class GPUUniformRandomKernel : public framework::OpKernel { } T* data = tensor->mutable_data(context.GetPlace()); unsigned int seed = static_cast(context.Attr("seed")); - if (seed == 0) { - std::random_device rd; - seed = rd(); + if (framework::Generator::GetInstance()->is_init_py) { + seed = static_cast( + framework::Generator::GetInstance()->GetCurrentSeed()); + } else { + if (seed == 0) { + std::random_device rd; + seed = rd(); + } } T min = static_cast(context.Attr("min")); T max = static_cast(context.Attr("max")); diff --git a/paddle/fluid/operators/uniform_random_op.h b/paddle/fluid/operators/uniform_random_op.h index 867b1044164..d263dd03dd0 100644 --- a/paddle/fluid/operators/uniform_random_op.h +++ b/paddle/fluid/operators/uniform_random_op.h @@ -17,6 +17,7 @@ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 318a45919af..ef19fcc5475 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,7 +1,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context - gloo_wrapper infer_io_utils heter_wrapper) + gloo_wrapper infer_io_utils heter_wrapper generator) if (WITH_NCCL) set(PYBIND_DEPS ${PYBIND_DEPS} nccl_wrapper) @@ -37,7 +37,8 @@ set(PYBIND_SRCS data_set_py.cc imperative.cc ir.cc - inference_api.cc) + inference_api.cc + generator_py.cc) if (WITH_CRYPTO) set(PYBIND_DEPS ${PYBIND_DEPS} paddle_crypto) diff --git a/paddle/fluid/pybind/generator_py.cc b/paddle/fluid/pybind/generator_py.cc new file mode 100644 index 00000000000..3bccd5fb2dd --- /dev/null +++ b/paddle/fluid/pybind/generator_py.cc @@ -0,0 +1,51 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include + +#ifdef _POSIX_C_SOURCE +#undef _POSIX_C_SOURCE +#endif + +#ifdef _XOPEN_SOURCE +#undef _XOPEN_SOURCE +#endif + +#include +#include +#include + +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/pybind/generator_py.h" + +namespace py = pybind11; + +namespace paddle { +namespace pybind { +void BindGenerator(py::module* m) { + py::class_(*m, "GeneratorState", ""); + py::class_(*m, "mt19937_64", ""); + py::class_>( + *m, "Generator") + .def(py::init([]() { return framework::Generator::GetInstanceX(); }), + py::return_value_policy::reference) + .def("get_state", &framework::Generator::GetState, + py::return_value_policy::move) + .def("set_state", &framework::Generator::SetState) + .def("manual_seed", &framework::Generator::SetCurrentSeed) + .def("seed", &framework::Generator::Seed) + .def("initial_seed", &framework::Generator::GetCurrentSeed) + .def("random", &framework::Generator::Random64) + .def("get_cpu_engine", &framework::Generator::GetCPUEngine, + py::return_value_policy::move) + .def("set_cpu_engine", &framework::Generator::SetCPUEngine); +} // end Generator +} // end namespace pybind +} // end namespace paddle diff --git a/paddle/fluid/pybind/generator_py.h b/paddle/fluid/pybind/generator_py.h new file mode 100644 index 00000000000..d37654c1ba2 --- /dev/null +++ b/paddle/fluid/pybind/generator_py.h @@ -0,0 +1,28 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; + +namespace paddle { +namespace pybind { + +void BindGenerator(py::module* m); + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index f426ca82966..635a81dff0d 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -64,6 +64,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/data_set_py.h" #include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/fleet_wrapper_py.h" +#include "paddle/fluid/pybind/generator_py.h" #include "paddle/fluid/pybind/global_value_getter_setter.h" #include "paddle/fluid/pybind/gloo_wrapper_py.h" #include "paddle/fluid/pybind/heter_wrapper_py.h" @@ -2503,6 +2504,7 @@ All parameter, weight, gradient are variables in Paddle. BindNode(&m); BindInferenceApi(&m); BindDataset(&m); + BindGenerator(&m); #ifdef PADDLE_WITH_CRYPTO BindCrypto(&m); #endif diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 88dd815d937..7e0d8c0de5b 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -89,6 +89,7 @@ from .dygraph.base import enable_dygraph, disable_dygraph from .io import save, load, load_program_state, set_program_state from .dygraph.checkpoint import save_dygraph, load_dygraph from .dygraph.varbase_patch_methods import monkey_patch_varbase +from . import generator Tensor = LoDTensor enable_imperative = enable_dygraph disable_imperative = disable_dygraph @@ -96,7 +97,7 @@ disable_imperative = disable_dygraph __all__ = framework.__all__ + executor.__all__ + \ trainer_desc.__all__ + transpiler.__all__ + \ parallel_executor.__all__ + lod_tensor.__all__ + \ - data_feed_desc.__all__ + compiler.__all__ + backward.__all__ + [ + data_feed_desc.__all__ + compiler.__all__ + backward.__all__ + generator.__all__ + [ 'io', 'initializer', 'embedding', diff --git a/python/paddle/fluid/generator.py b/python/paddle/fluid/generator.py new file mode 100644 index 00000000000..24262e3f566 --- /dev/null +++ b/python/paddle/fluid/generator.py @@ -0,0 +1,60 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This is definition of generator class, which is for managing the state of the algorithm that produces pseudo random numbers.""" + +from . import core + +__all__ = ['Generator'] + +default_rng_seed_val = 34342423252 + + +class Generator(object): + """Generator class""" + + def __init__(self, device="CPU"): + """init""" + self.device = device + seed_in = default_rng_seed_val + if self.device == "CPU": + self.generator = core.Generator() + self.generator.manual_seed(seed_in) + else: + raise ValueError( + "generator class with device %s does not exist, currently only support generator with device 'CPU' " + % device) + + def get_state(self): + return self.generator.get_state() + + def set_state(self, state): + self.generator.set_state(state) + + def manual_seed(self, seed): + self.generator.manual_seed(seed) + + def seed(self): + return self.generator.seed() + + def initial_seed(self): + return self.generator.initial_seed() + + def random(self): + return self.generator.random() + + def get_cpu_engine(self): + return self.generator.get_cpu_engine() + + def set_cpu_engine(self, cpu_engine): + self.generator.set_cpu_engine(cpu_engine) diff --git a/python/paddle/fluid/tests/unittests/test_generator.py b/python/paddle/fluid/tests/unittests/test_generator.py new file mode 100644 index 00000000000..6cc43d3d549 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_generator.py @@ -0,0 +1,44 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test cloud role maker.""" + +from __future__ import print_function +import os +import unittest +import paddle.fluid.generator as generator +import time # temp for debug + + +class TestGenerator(unittest.TestCase): + """ + Test cases for cpu generator. + """ + + def test_basic_generator(self): + """Test basic generator.""" + gen = generator.Generator() + gen.manual_seed(123123143) + s = gen.initial_seed() + s = gen.seed() + st = gen.get_state() + gen.set_state(st) + gen.random() + gen.set_cpu_engine(gen.get_cpu_engine()) + + def test_basic_generator_error(self): + self.assertRaises(ValueError, generator.Generator, device="CUDA") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_random_seed.py b/python/paddle/fluid/tests/unittests/test_random_seed.py new file mode 100644 index 00000000000..31120a73042 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_random_seed.py @@ -0,0 +1,119 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test cloud role maker.""" + +from __future__ import print_function +import os +import unittest +import paddle.fluid.generator as generator + +import time # temp for debug +import paddle.fluid as fluid +import numpy as np +import paddle +import paddle.fluid.core as core + + +class TestGeneratorSeed(unittest.TestCase): + """ + Test cases for cpu generator seed. + """ + + def test_generator_uniform_random_dygraph(self): + """Test Generator seed.""" + gen = generator.Generator() + + fluid.enable_dygraph() + + gen.manual_seed(12312321111) + x = fluid.layers.uniform_random([10], dtype="float32", min=0.0, max=1.0) + st1 = gen.get_state() + x1 = fluid.layers.uniform_random( + [10], dtype="float32", min=0.0, max=1.0) + gen.set_state(st1) + x2 = fluid.layers.uniform_random( + [10], dtype="float32", min=0.0, max=1.0) + gen.manual_seed(12312321111) + x3 = fluid.layers.uniform_random( + [10], dtype="float32", min=0.0, max=1.0) + x_np = x.numpy() + x1_np = x1.numpy() + x2_np = x2.numpy() + x3_np = x3.numpy() + + if not core.is_compiled_with_cuda(): + self.assertTrue(np.allclose(x1_np, x2_np)) + self.assertTrue(np.allclose(x_np, x3_np)) + + def test_generator_uniform_random_static(self): + + fluid.disable_dygraph() + + gen = generator.Generator() + gen.manual_seed(123123143) + + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + # example 1: + # attr shape is a list which doesn't contain tensor Variable. + result_1 = fluid.layers.uniform_random(shape=[3, 4]) + result_2 = fluid.layers.uniform_random(shape=[3, 4]) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(startup_program) + out1 = exe.run(train_program, + feed={}, + fetch_list=[result_1, result_2]) + #gen.set_state(cur_state) + gen.manual_seed(123123143) + out2 = exe.run(train_program, + feed={}, + fetch_list=[result_1, result_2]) + + out1_res1 = np.array(out1[0]) + out1_res2 = np.array(out1[1]) + out2_res1 = np.array(out2[0]) + out2_res2 = np.array(out2[1]) + + if not core.is_compiled_with_cuda(): + self.assertTrue(np.allclose(out1_res1, out2_res1)) + self.assertTrue(np.allclose(out1_res2, out2_res2)) + self.assertTrue(not np.allclose(out1_res2, out1_res1)) + + def test_generator_randint_dygraph(self): + """Test Generator seed.""" + gen = generator.Generator() + + fluid.enable_dygraph() + + gen.manual_seed(12312321111) + x = paddle.randint(low=1) + st1 = gen.get_state() + x1 = paddle.randint(low=1) + gen.set_state(st1) + x2 = paddle.randint(low=1) + gen.manual_seed(12312321111) + x3 = paddle.randint(low=1) + x_np = x.numpy() + x1_np = x1.numpy() + x2_np = x2.numpy() + x3_np = x3.numpy() + if not core.is_compiled_with_cuda(): + self.assertTrue(np.allclose(x1_np, x2_np)) + self.assertTrue(np.allclose(x_np, x3_np)) + + +if __name__ == "__main__": + unittest.main() -- GitLab