未验证 提交 23261ff4 编写于 作者: Y yaoxuefeng 提交者: GitHub

add cpu random Generator (#26013)

上级 69742bd9
......@@ -268,6 +268,7 @@ cc_test(op_compatible_info_test SRCS op_compatible_info_test.cc DEPS op_compatib
cc_library(save_load_util SRCS save_load_util DEPS tensor scope layer)
cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer)
cc_library(generator SRCS generator.cc)
# Get the current working branch
execute_process(
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <deque>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/generator.h"
namespace paddle {
namespace framework {
std::shared_ptr<Generator> Generator::gen_instance_ = NULL;
GeneratorState* Generator::GetState() {
std::lock_guard<std::mutex> lock(this->mutex);
return this->state_.get();
}
void Generator::SetState(GeneratorState* state_in) {
std::lock_guard<std::mutex> lock(this->mutex);
*this->state_ = *state_in;
}
uint64_t Generator::GetCurrentSeed() {
std::lock_guard<std::mutex> lock(this->mutex);
return this->state_->current_seed;
}
uint64_t Generator::Seed() {
std::lock_guard<std::mutex> lock(this->mutex);
uint64_t seed;
std::random_device de;
seed = ((((uint64_t)de()) << 32) + de()) & 0x1FFFFFFFFFFFFF;
this->state_->current_seed = seed;
std::seed_seq seq({seed});
this->state_->cpu_engine.seed(seq);
return this->state_->current_seed;
}
void Generator::SetCurrentSeed(uint64_t seed) {
std::lock_guard<std::mutex> lock(this->mutex);
this->state_->current_seed = uint64_t(seed);
std::seed_seq seq({seed});
this->state_->cpu_engine.seed(seq);
}
std::mt19937_64& Generator::GetCPUEngine() {
std::lock_guard<std::mutex> lock(this->mutex);
return this->state_->cpu_engine;
}
void Generator::SetCPUEngine(std::mt19937_64 engine) {
std::lock_guard<std::mutex> lock(this->mutex);
this->state_->cpu_engine = std::mt19937_64(engine);
}
uint64_t Generator::Random64() {
std::lock_guard<std::mutex> lock(this->mutex);
return this->state_->cpu_engine();
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdint.h>
#include <atomic>
#include <deque>
#include <iostream> // temp for debug
#include <memory>
#include <mutex> // NOLINT
#include <random>
#include <typeinfo>
#include <utility>
namespace paddle {
namespace framework {
struct GeneratorState {
int64_t device = -1;
uint64_t current_seed = 34342423252;
std::mt19937_64 cpu_engine;
};
struct Generator {
Generator() {
GeneratorState default_gen_state_cpu;
default_gen_state_cpu.device = -1;
default_gen_state_cpu.current_seed = 34342423252;
std::seed_seq seq({34342423252});
default_gen_state_cpu.cpu_engine = std::mt19937_64(seq);
this->state_ = std::make_shared<GeneratorState>(default_gen_state_cpu);
}
explicit Generator(GeneratorState state_in)
: state_{std::make_shared<GeneratorState>(state_in)} {}
Generator(const Generator& other)
: Generator(other, std::lock_guard<std::mutex>(other.mutex)) {}
// get random state
GeneratorState* GetState();
// set random state
void SetState(GeneratorState* state_in);
// get current seed
uint64_t GetCurrentSeed();
// random a seed and get
uint64_t Seed();
// set seed
void SetCurrentSeed(uint64_t seed);
// get cpu engine
std::mt19937_64& GetCPUEngine();
// set cpu engine
void SetCPUEngine(std::mt19937_64 engine);
uint64_t Random64();
bool is_init_py = false;
// CPU Generator singleton
static std::shared_ptr<Generator> GetInstance() {
if (NULL == gen_instance_) {
gen_instance_.reset(new paddle::framework::Generator());
}
return gen_instance_;
}
static std::shared_ptr<Generator> GetInstanceX() {
if (NULL == gen_instance_) {
gen_instance_.reset(new paddle::framework::Generator());
}
gen_instance_->is_init_py = true;
return gen_instance_;
}
private:
static std::shared_ptr<Generator> gen_instance_;
std::shared_ptr<GeneratorState> state_;
mutable std::mutex mutex;
Generator(const Generator& other, const std::lock_guard<std::mutex>&)
: state_(std::make_shared<GeneratorState>(*(other.state_))) {}
};
} // namespace framework
} // namespace paddle
......@@ -88,7 +88,9 @@ endif()
cc_library(common_infer_shape_functions SRCS common_infer_shape_functions.cc DEPS operator)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor device_memory_aligment)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows
lod_tensor maxouting unpooling pooling lod_rank_table context_project
sequence_pooling executor device_memory_aligment generator)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse)
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <random>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/fill_constant_op.h"
#ifdef PADDLE_WITH_MKLDNN
......@@ -30,13 +31,13 @@ class CPUGaussianRandomKernel : public framework::OpKernel<T> {
float mean = context.Attr<float>("mean");
float std = context.Attr<float>("std");
auto* tensor = context.Output<framework::Tensor>("Out");
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
std::normal_distribution<T> dist(mean, std);
const std::string op_type = "gaussian_random";
......
......@@ -14,6 +14,7 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/uniform_random_op.h"
......@@ -37,11 +38,11 @@ class CPURandintKernel : public framework::OpKernel<T> {
new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor);
}
}
auto* out = ctx.Output<framework::LoDTensor>("Out");
if (!new_shape.empty()) out->Resize(framework::make_ddim(new_shape));
T* data = out->mutable_data<T>(ctx.GetPlace());
int64_t size = out->numel();
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
......
......@@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/uniform_random_op.h"
#include <string>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
......@@ -55,19 +57,40 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
"supports SelectedRows and LoDTensor");
}
T *data = tensor->mutable_data<T>(ctx.GetPlace());
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
int64_t size = tensor->numel();
std::uniform_real_distribution<T> dist(
static_cast<T>(ctx.Attr<float>("min")),
static_cast<T>(ctx.Attr<float>("max")));
int64_t size = tensor->numel();
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine);
auto gen_ptr = framework::Generator::GetInstance();
if (gen_ptr->is_init_py) {
std::mt19937_64 &gen_engine = gen_ptr->GetCPUEngine();
// auto gen_engine = gen_ptr_->GetCPUEngine();
// std::uniform_real_distribution<T> dist(
// static_cast<T>(ctx.Attr<float>("min")),
// static_cast<T>(ctx.Attr<float>("max")));
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(gen_engine);
}
} else {
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
// std::uniform_real_distribution<T> dist(
// static_cast<T>(ctx.Attr<float>("min")),
// static_cast<T>(ctx.Attr<float>("max")));
// int64_t size = tensor->numel();
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine);
}
}
// std::mt19937_64 &engine = gen_ptr->GetCPUEngine();
// auto engine = gen_ptr_->GetCPUEngine();
unsigned int diag_num =
static_cast<unsigned int>(ctx.Attr<int>("diag_num"));
unsigned int diag_step =
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/uniform_random_op.h"
......@@ -87,9 +88,14 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
}
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
if (seed == 0) {
std::random_device rd;
seed = rd();
if (framework::Generator::GetInstance()->is_init_py) {
seed = static_cast<unsigned int>(
framework::Generator::GetInstance()->GetCurrentSeed());
} else {
if (seed == 0) {
std::random_device rd;
seed = rd();
}
}
T min = static_cast<T>(context.Attr<float>("min"));
T max = static_cast<T>(context.Attr<float>("max"));
......
......@@ -17,6 +17,7 @@
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
......
set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune
feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
gloo_wrapper infer_io_utils heter_wrapper)
gloo_wrapper infer_io_utils heter_wrapper generator)
if (WITH_NCCL)
set(PYBIND_DEPS ${PYBIND_DEPS} nccl_wrapper)
......@@ -37,7 +37,8 @@ set(PYBIND_SRCS
data_set_py.cc
imperative.cc
ir.cc
inference_api.cc)
inference_api.cc
generator_py.cc)
if (WITH_CRYPTO)
set(PYBIND_DEPS ${PYBIND_DEPS} paddle_crypto)
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif
#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/pybind/generator_py.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindGenerator(py::module* m) {
py::class_<framework::GeneratorState>(*m, "GeneratorState", "");
py::class_<std::mt19937_64>(*m, "mt19937_64", "");
py::class_<framework::Generator, std::shared_ptr<framework::Generator>>(
*m, "Generator")
.def(py::init([]() { return framework::Generator::GetInstanceX(); }),
py::return_value_policy::reference)
.def("get_state", &framework::Generator::GetState,
py::return_value_policy::move)
.def("set_state", &framework::Generator::SetState)
.def("manual_seed", &framework::Generator::SetCurrentSeed)
.def("seed", &framework::Generator::Seed)
.def("initial_seed", &framework::Generator::GetCurrentSeed)
.def("random", &framework::Generator::Random64)
.def("get_cpu_engine", &framework::Generator::GetCPUEngine,
py::return_value_policy::move)
.def("set_cpu_engine", &framework::Generator::SetCPUEngine);
} // end Generator
} // end namespace pybind
} // end namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindGenerator(py::module* m);
} // namespace pybind
} // namespace paddle
......@@ -64,6 +64,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/data_set_py.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/fleet_wrapper_py.h"
#include "paddle/fluid/pybind/generator_py.h"
#include "paddle/fluid/pybind/global_value_getter_setter.h"
#include "paddle/fluid/pybind/gloo_wrapper_py.h"
#include "paddle/fluid/pybind/heter_wrapper_py.h"
......@@ -2503,6 +2504,7 @@ All parameter, weight, gradient are variables in Paddle.
BindNode(&m);
BindInferenceApi(&m);
BindDataset(&m);
BindGenerator(&m);
#ifdef PADDLE_WITH_CRYPTO
BindCrypto(&m);
#endif
......
......@@ -89,6 +89,7 @@ from .dygraph.base import enable_dygraph, disable_dygraph
from .io import save, load, load_program_state, set_program_state
from .dygraph.checkpoint import save_dygraph, load_dygraph
from .dygraph.varbase_patch_methods import monkey_patch_varbase
from . import generator
Tensor = LoDTensor
enable_imperative = enable_dygraph
disable_imperative = disable_dygraph
......@@ -96,7 +97,7 @@ disable_imperative = disable_dygraph
__all__ = framework.__all__ + executor.__all__ + \
trainer_desc.__all__ + transpiler.__all__ + \
parallel_executor.__all__ + lod_tensor.__all__ + \
data_feed_desc.__all__ + compiler.__all__ + backward.__all__ + [
data_feed_desc.__all__ + compiler.__all__ + backward.__all__ + generator.__all__ + [
'io',
'initializer',
'embedding',
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This is definition of generator class, which is for managing the state of the algorithm that produces pseudo random numbers."""
from . import core
__all__ = ['Generator']
default_rng_seed_val = 34342423252
class Generator(object):
"""Generator class"""
def __init__(self, device="CPU"):
"""init"""
self.device = device
seed_in = default_rng_seed_val
if self.device == "CPU":
self.generator = core.Generator()
self.generator.manual_seed(seed_in)
else:
raise ValueError(
"generator class with device %s does not exist, currently only support generator with device 'CPU' "
% device)
def get_state(self):
return self.generator.get_state()
def set_state(self, state):
self.generator.set_state(state)
def manual_seed(self, seed):
self.generator.manual_seed(seed)
def seed(self):
return self.generator.seed()
def initial_seed(self):
return self.generator.initial_seed()
def random(self):
return self.generator.random()
def get_cpu_engine(self):
return self.generator.get_cpu_engine()
def set_cpu_engine(self, cpu_engine):
self.generator.set_cpu_engine(cpu_engine)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test cloud role maker."""
from __future__ import print_function
import os
import unittest
import paddle.fluid.generator as generator
import time # temp for debug
class TestGenerator(unittest.TestCase):
"""
Test cases for cpu generator.
"""
def test_basic_generator(self):
"""Test basic generator."""
gen = generator.Generator()
gen.manual_seed(123123143)
s = gen.initial_seed()
s = gen.seed()
st = gen.get_state()
gen.set_state(st)
gen.random()
gen.set_cpu_engine(gen.get_cpu_engine())
def test_basic_generator_error(self):
self.assertRaises(ValueError, generator.Generator, device="CUDA")
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test cloud role maker."""
from __future__ import print_function
import os
import unittest
import paddle.fluid.generator as generator
import time # temp for debug
import paddle.fluid as fluid
import numpy as np
import paddle
import paddle.fluid.core as core
class TestGeneratorSeed(unittest.TestCase):
"""
Test cases for cpu generator seed.
"""
def test_generator_uniform_random_dygraph(self):
"""Test Generator seed."""
gen = generator.Generator()
fluid.enable_dygraph()
gen.manual_seed(12312321111)
x = fluid.layers.uniform_random([10], dtype="float32", min=0.0, max=1.0)
st1 = gen.get_state()
x1 = fluid.layers.uniform_random(
[10], dtype="float32", min=0.0, max=1.0)
gen.set_state(st1)
x2 = fluid.layers.uniform_random(
[10], dtype="float32", min=0.0, max=1.0)
gen.manual_seed(12312321111)
x3 = fluid.layers.uniform_random(
[10], dtype="float32", min=0.0, max=1.0)
x_np = x.numpy()
x1_np = x1.numpy()
x2_np = x2.numpy()
x3_np = x3.numpy()
if not core.is_compiled_with_cuda():
self.assertTrue(np.allclose(x1_np, x2_np))
self.assertTrue(np.allclose(x_np, x3_np))
def test_generator_uniform_random_static(self):
fluid.disable_dygraph()
gen = generator.Generator()
gen.manual_seed(123123143)
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
# example 1:
# attr shape is a list which doesn't contain tensor Variable.
result_1 = fluid.layers.uniform_random(shape=[3, 4])
result_2 = fluid.layers.uniform_random(shape=[3, 4])
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program)
out1 = exe.run(train_program,
feed={},
fetch_list=[result_1, result_2])
#gen.set_state(cur_state)
gen.manual_seed(123123143)
out2 = exe.run(train_program,
feed={},
fetch_list=[result_1, result_2])
out1_res1 = np.array(out1[0])
out1_res2 = np.array(out1[1])
out2_res1 = np.array(out2[0])
out2_res2 = np.array(out2[1])
if not core.is_compiled_with_cuda():
self.assertTrue(np.allclose(out1_res1, out2_res1))
self.assertTrue(np.allclose(out1_res2, out2_res2))
self.assertTrue(not np.allclose(out1_res2, out1_res1))
def test_generator_randint_dygraph(self):
"""Test Generator seed."""
gen = generator.Generator()
fluid.enable_dygraph()
gen.manual_seed(12312321111)
x = paddle.randint(low=1)
st1 = gen.get_state()
x1 = paddle.randint(low=1)
gen.set_state(st1)
x2 = paddle.randint(low=1)
gen.manual_seed(12312321111)
x3 = paddle.randint(low=1)
x_np = x.numpy()
x1_np = x1.numpy()
x2_np = x2.numpy()
x3_np = x3.numpy()
if not core.is_compiled_with_cuda():
self.assertTrue(np.allclose(x1_np, x2_np))
self.assertTrue(np.allclose(x_np, x3_np))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册