diff --git a/paddle/fluid/framework/generator.cc b/paddle/fluid/framework/generator.cc index a020bda8231670f4940caa62cc57fc2cf3abe757..df52d69724749c9e4a6b19d97932773023e187ce 100644 --- a/paddle/fluid/framework/generator.cc +++ b/paddle/fluid/framework/generator.cc @@ -138,13 +138,13 @@ std::shared_ptr GetCPURandomEngine(uint64_t seed) { } } -GeneratorState Generator::GetState() { +pten::Generator::GeneratorState Generator::GetState() { std::lock_guard lock(this->mu_); state_.cpu_engine = *engine_; return this->state_; } -void Generator::SetState(const GeneratorState& state) { +void Generator::SetState(const pten::Generator::GeneratorState& state) { std::lock_guard lock(this->mu_); this->state_ = state; this->engine_ = std::make_shared(state.cpu_engine); diff --git a/paddle/fluid/framework/generator.h b/paddle/fluid/framework/generator.h index d0a5b4443e3f49648d1e137a581f407f3c06fc40..19ee1d0191605b72ab56826484d1d1b55b7a445b 100644 --- a/paddle/fluid/framework/generator.h +++ b/paddle/fluid/framework/generator.h @@ -25,6 +25,8 @@ limitations under the License. */ #include #include +#include "paddle/pten/core/generator.h" + namespace paddle { namespace framework { @@ -34,14 +36,7 @@ static uint64_t GetRandomSeed() { return ((((uint64_t)rd()) << 32) + rd()) & 0x1FFFFFFFFFFFFF; } -struct GeneratorState { - int64_t device = -1; - uint64_t current_seed = 34342423252; - uint64_t thread_offset = 0; - std::mt19937_64 cpu_engine; -}; - -struct Generator { +struct Generator : public pten::Generator { Generator() { auto seed = GetRandomSeed(); std::seed_seq seq({seed}); @@ -82,9 +77,9 @@ struct Generator { Generator(const Generator& other) = delete; // get random state - GeneratorState GetState(); + pten::Generator::GeneratorState GetState(); // set random state - void SetState(const GeneratorState&); + void SetState(const pten::Generator::GeneratorState&); // get current seed uint64_t GetCurrentSeed(); // random a seed and get @@ -105,7 +100,7 @@ struct Generator { uint64_t get_device_id() { return this->state_.device; } private: - GeneratorState state_; + pten::Generator::GeneratorState state_; std::shared_ptr engine_; mutable std::mutex mu_; diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 07d3e419582931333e46856df2ba7a61127217ca..be02bac1aa0ef7462e15f9471a84f79a6007cfb5 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -123,7 +123,7 @@ cc_library(init SRCS init.cc DEPS device_context custom_kernel) # avoiding cycle dependencies cc_library(device_context SRCS device_context.cc DEPS simple_threadpool malloc xxhash ${STREAM_CALLBACK_DEPS} place pten_place eigen3 stringpiece cpu_helper cpu_info framework_proto ${IPU_CTX_DEPS} ${GPU_CTX_DEPS} ${NPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} - ${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS} ${MLU_CTX_DEPS} eigen3 cpu_context) + ${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS} ${MLU_CTX_DEPS} eigen3 cpu_context generator) if(WITH_XPU) target_link_libraries(device_context xpu_context) endif() diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 966dcf7770da8c30a0bb95b4cbae6c9e31836e90..a0a853a2f059745b281d3651d39baf061edf1053 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -28,6 +28,7 @@ limitations under the License. */ #endif #include "glog/logging.h" #include "paddle/fluid/framework/expect.h" +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/platform/profiler.h" @@ -160,11 +161,14 @@ inline void EmplaceDeviceContext( .GetAllocator(p, cuda_ctx->stream()) .get()); cuda_ctx->PartialInitWithAllocator(); + dev_ctx->SetGenerator( + framework::GetDefaultCUDAGenerator(p.GetDeviceId()).get()); #endif } else { dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance() .GetAllocator(p) .get()); + dev_ctx->SetGenerator(framework::DefaultCPUGenerator().get()); } dev_ctx->SetHostAllocator( memory::allocation::AllocatorFacade::Instance() diff --git a/paddle/fluid/pybind/generator_py.cc b/paddle/fluid/pybind/generator_py.cc index fa924ce65812575f2f6dd04a28b1179b2ac47027..af643b683ab85fa86fabb5b2d91281a835ea6490 100644 --- a/paddle/fluid/pybind/generator_py.cc +++ b/paddle/fluid/pybind/generator_py.cc @@ -8,6 +8,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/pten/core/generator.h" #include #ifdef _POSIX_C_SOURCE @@ -31,10 +32,11 @@ namespace paddle { namespace pybind { void BindGenerator(py::module* m_ptr) { auto& m = *m_ptr; - py::class_>(m, "GeneratorState") + py::class_>(m, + "GeneratorState") .def("current_seed", - [](std::shared_ptr& self) { + [](std::shared_ptr& self) { return self->current_seed; }); py::class_(m, "mt19937_64", ""); diff --git a/paddle/pten/core/device_context.cc b/paddle/pten/core/device_context.cc index 2a11b1bef9dbcf952fcf9323067c1c28ba099a5d..70d71b5c767eae79ae9036a1c2b119ba0f053f62 100644 --- a/paddle/pten/core/device_context.cc +++ b/paddle/pten/core/device_context.cc @@ -114,10 +114,27 @@ struct DeviceContext::Impl { return static_cast(HostAlloc(tensor, dtype, requested_size)); } + void SetGenerator(Generator* gen) { + PADDLE_ENFORCE_NOT_NULL( + gen, + pten::errors::InvalidArgument( + "Required generator shall not be nullptr, but received nullptr.")); + generator_ = gen; + } + + Generator* GetGenerator() const { + PADDLE_ENFORCE_NOT_NULL( + generator_, + pten::errors::InvalidArgument("Required generator_ shall not be " + "nullptr, but received nullptr.")); + return generator_; + } + private: const Allocator* device_allocator_{nullptr}; const Allocator* host_allocator_{nullptr}; const Allocator* zero_allocator_{nullptr}; + Generator* generator_{nullptr}; }; DeviceContext::DeviceContext() { impl_ = std::make_unique(); } @@ -201,4 +218,8 @@ DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex128) #undef DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION +void DeviceContext::SetGenerator(Generator* gen) { impl_->SetGenerator(gen); } + +Generator* DeviceContext::GetGenerator() const { return impl_->GetGenerator(); } + } // namespace pten diff --git a/paddle/pten/core/device_context.h b/paddle/pten/core/device_context.h index 68c16dc3a196482e026382d7f89ec46d437da6e6..d627f19b55dbcb137b66d4e04b061129c6ba7d44 100644 --- a/paddle/pten/core/device_context.h +++ b/paddle/pten/core/device_context.h @@ -16,11 +16,10 @@ limitations under the License. */ #include -// TODO(wilber): Do we need to use place in pten kernel? -#include "paddle/pten/common/place.h" - #include "paddle/pten/common/data_type.h" +#include "paddle/pten/common/place.h" #include "paddle/pten/core/allocator.h" +#include "paddle/pten/core/generator.h" namespace pten { class TensorBase; @@ -112,13 +111,24 @@ class DeviceContext { template T* HostAlloc(TensorBase* tensor, size_t requested_size = 0) const; - // TODO(wilber): Just for the convenience of migrating the code, it will be - // modified or removed later. virtual const Place& GetPlace() const = 0; // TODO(wilber): The fluid framework uses wait() in many places, how to delete // this API interface. virtual void Wait() const {} + /** + * @brief Set the generator for special op. + * + * @param Generator + */ + void SetGenerator(Generator*); + /** + * @brief Get the generator object. + * + * @return Generator + */ + Generator* GetGenerator() const; + private: struct Impl; std::unique_ptr impl_; diff --git a/paddle/pten/core/generator.h b/paddle/pten/core/generator.h new file mode 100644 index 0000000000000000000000000000000000000000..eb06d69e36e3ae419c031196812dd9c3a048ff31 --- /dev/null +++ b/paddle/pten/core/generator.h @@ -0,0 +1,61 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +namespace pten { + +class Generator { + public: + struct GeneratorState { + int64_t device = -1; + uint64_t current_seed = 34342423252; + uint64_t thread_offset = 0; + std::mt19937_64 cpu_engine; + }; + + virtual ~Generator() = default; + + // get random state + virtual GeneratorState GetState() = 0; + // set random state + virtual void SetState(const GeneratorState&) = 0; + // get current seed + virtual uint64_t GetCurrentSeed() = 0; + // random a seed and get + virtual uint64_t Seed() = 0; + // set seed + virtual void SetCurrentSeed(uint64_t seed) = 0; + // get cpu engine + virtual std::shared_ptr GetCPUEngine() = 0; + // set cpu engine + virtual void SetCPUEngine(std::shared_ptr) = 0; + virtual uint64_t Random64() = 0; + virtual std::pair IncrementOffset( + uint64_t increament_offset) = 0; + + // NOTE(zhiqiu): is_init_py_ is used to make generator be compatible with + // old seed, and it should be removed after all random-related operators + // and unittests upgrades to use generator. + virtual void SetIsInitPy(bool) = 0; + virtual bool GetIsInitPy() const = 0; + + virtual uint64_t get_device_id() = 0; +}; + +} // namespace pten