未验证 提交 463e31f4 编写于 作者: W Wilber 提交者: GitHub

context add generator (#39475)

* context add generator

* update
上级 0790f949
...@@ -138,13 +138,13 @@ std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t seed) { ...@@ -138,13 +138,13 @@ std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t seed) {
} }
} }
GeneratorState Generator::GetState() { pten::Generator::GeneratorState Generator::GetState() {
std::lock_guard<std::mutex> lock(this->mu_); std::lock_guard<std::mutex> lock(this->mu_);
state_.cpu_engine = *engine_; state_.cpu_engine = *engine_;
return this->state_; return this->state_;
} }
void Generator::SetState(const GeneratorState& state) { void Generator::SetState(const pten::Generator::GeneratorState& state) {
std::lock_guard<std::mutex> lock(this->mu_); std::lock_guard<std::mutex> lock(this->mu_);
this->state_ = state; this->state_ = state;
this->engine_ = std::make_shared<std::mt19937_64>(state.cpu_engine); this->engine_ = std::make_shared<std::mt19937_64>(state.cpu_engine);
......
...@@ -25,6 +25,8 @@ limitations under the License. */ ...@@ -25,6 +25,8 @@ limitations under the License. */
#include <typeinfo> #include <typeinfo>
#include <utility> #include <utility>
#include "paddle/pten/core/generator.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -34,14 +36,7 @@ static uint64_t GetRandomSeed() { ...@@ -34,14 +36,7 @@ static uint64_t GetRandomSeed() {
return ((((uint64_t)rd()) << 32) + rd()) & 0x1FFFFFFFFFFFFF; return ((((uint64_t)rd()) << 32) + rd()) & 0x1FFFFFFFFFFFFF;
} }
struct GeneratorState { struct Generator : public pten::Generator {
int64_t device = -1;
uint64_t current_seed = 34342423252;
uint64_t thread_offset = 0;
std::mt19937_64 cpu_engine;
};
struct Generator {
Generator() { Generator() {
auto seed = GetRandomSeed(); auto seed = GetRandomSeed();
std::seed_seq seq({seed}); std::seed_seq seq({seed});
...@@ -82,9 +77,9 @@ struct Generator { ...@@ -82,9 +77,9 @@ struct Generator {
Generator(const Generator& other) = delete; Generator(const Generator& other) = delete;
// get random state // get random state
GeneratorState GetState(); pten::Generator::GeneratorState GetState();
// set random state // set random state
void SetState(const GeneratorState&); void SetState(const pten::Generator::GeneratorState&);
// get current seed // get current seed
uint64_t GetCurrentSeed(); uint64_t GetCurrentSeed();
// random a seed and get // random a seed and get
...@@ -105,7 +100,7 @@ struct Generator { ...@@ -105,7 +100,7 @@ struct Generator {
uint64_t get_device_id() { return this->state_.device; } uint64_t get_device_id() { return this->state_.device; }
private: private:
GeneratorState state_; pten::Generator::GeneratorState state_;
std::shared_ptr<std::mt19937_64> engine_; std::shared_ptr<std::mt19937_64> engine_;
mutable std::mutex mu_; mutable std::mutex mu_;
......
...@@ -123,7 +123,7 @@ cc_library(init SRCS init.cc DEPS device_context custom_kernel) ...@@ -123,7 +123,7 @@ cc_library(init SRCS init.cc DEPS device_context custom_kernel)
# avoiding cycle dependencies # avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc DEPS simple_threadpool malloc xxhash ${STREAM_CALLBACK_DEPS} cc_library(device_context SRCS device_context.cc DEPS simple_threadpool malloc xxhash ${STREAM_CALLBACK_DEPS}
place pten_place eigen3 stringpiece cpu_helper cpu_info framework_proto ${IPU_CTX_DEPS} ${GPU_CTX_DEPS} ${NPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} place pten_place eigen3 stringpiece cpu_helper cpu_info framework_proto ${IPU_CTX_DEPS} ${GPU_CTX_DEPS} ${NPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}
${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS} ${MLU_CTX_DEPS} eigen3 cpu_context) ${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS} ${MLU_CTX_DEPS} eigen3 cpu_context generator)
if(WITH_XPU) if(WITH_XPU)
target_link_libraries(device_context xpu_context) target_link_libraries(device_context xpu_context)
endif() endif()
......
...@@ -28,6 +28,7 @@ limitations under the License. */ ...@@ -28,6 +28,7 @@ limitations under the License. */
#endif #endif
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/expect.h" #include "paddle/fluid/framework/expect.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -160,11 +161,14 @@ inline void EmplaceDeviceContext( ...@@ -160,11 +161,14 @@ inline void EmplaceDeviceContext(
.GetAllocator(p, cuda_ctx->stream()) .GetAllocator(p, cuda_ctx->stream())
.get()); .get());
cuda_ctx->PartialInitWithAllocator(); cuda_ctx->PartialInitWithAllocator();
dev_ctx->SetGenerator(
framework::GetDefaultCUDAGenerator(p.GetDeviceId()).get());
#endif #endif
} else { } else {
dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance() dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance()
.GetAllocator(p) .GetAllocator(p)
.get()); .get());
dev_ctx->SetGenerator(framework::DefaultCPUGenerator().get());
} }
dev_ctx->SetHostAllocator( dev_ctx->SetHostAllocator(
memory::allocation::AllocatorFacade::Instance() memory::allocation::AllocatorFacade::Instance()
......
...@@ -8,6 +8,7 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -8,6 +8,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/pten/core/generator.h"
#include <fcntl.h> #include <fcntl.h>
#ifdef _POSIX_C_SOURCE #ifdef _POSIX_C_SOURCE
...@@ -31,10 +32,11 @@ namespace paddle { ...@@ -31,10 +32,11 @@ namespace paddle {
namespace pybind { namespace pybind {
void BindGenerator(py::module* m_ptr) { void BindGenerator(py::module* m_ptr) {
auto& m = *m_ptr; auto& m = *m_ptr;
py::class_<framework::GeneratorState, py::class_<pten::Generator::GeneratorState,
std::shared_ptr<framework::GeneratorState>>(m, "GeneratorState") std::shared_ptr<pten::Generator::GeneratorState>>(m,
"GeneratorState")
.def("current_seed", .def("current_seed",
[](std::shared_ptr<framework::GeneratorState>& self) { [](std::shared_ptr<pten::Generator::GeneratorState>& self) {
return self->current_seed; return self->current_seed;
}); });
py::class_<std::mt19937_64>(m, "mt19937_64", ""); py::class_<std::mt19937_64>(m, "mt19937_64", "");
......
...@@ -114,10 +114,27 @@ struct DeviceContext::Impl { ...@@ -114,10 +114,27 @@ struct DeviceContext::Impl {
return static_cast<T*>(HostAlloc(tensor, dtype, requested_size)); return static_cast<T*>(HostAlloc(tensor, dtype, requested_size));
} }
void SetGenerator(Generator* gen) {
PADDLE_ENFORCE_NOT_NULL(
gen,
pten::errors::InvalidArgument(
"Required generator shall not be nullptr, but received nullptr."));
generator_ = gen;
}
Generator* GetGenerator() const {
PADDLE_ENFORCE_NOT_NULL(
generator_,
pten::errors::InvalidArgument("Required generator_ shall not be "
"nullptr, but received nullptr."));
return generator_;
}
private: private:
const Allocator* device_allocator_{nullptr}; const Allocator* device_allocator_{nullptr};
const Allocator* host_allocator_{nullptr}; const Allocator* host_allocator_{nullptr};
const Allocator* zero_allocator_{nullptr}; const Allocator* zero_allocator_{nullptr};
Generator* generator_{nullptr};
}; };
DeviceContext::DeviceContext() { impl_ = std::make_unique<Impl>(); } DeviceContext::DeviceContext() { impl_ = std::make_unique<Impl>(); }
...@@ -201,4 +218,8 @@ DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex128) ...@@ -201,4 +218,8 @@ DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex128)
#undef DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION #undef DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION
void DeviceContext::SetGenerator(Generator* gen) { impl_->SetGenerator(gen); }
Generator* DeviceContext::GetGenerator() const { return impl_->GetGenerator(); }
} // namespace pten } // namespace pten
...@@ -16,11 +16,10 @@ limitations under the License. */ ...@@ -16,11 +16,10 @@ limitations under the License. */
#include <memory> #include <memory>
// TODO(wilber): Do we need to use place in pten kernel?
#include "paddle/pten/common/place.h"
#include "paddle/pten/common/data_type.h" #include "paddle/pten/common/data_type.h"
#include "paddle/pten/common/place.h"
#include "paddle/pten/core/allocator.h" #include "paddle/pten/core/allocator.h"
#include "paddle/pten/core/generator.h"
namespace pten { namespace pten {
class TensorBase; class TensorBase;
...@@ -112,13 +111,24 @@ class DeviceContext { ...@@ -112,13 +111,24 @@ class DeviceContext {
template <typename T> template <typename T>
T* HostAlloc(TensorBase* tensor, size_t requested_size = 0) const; T* HostAlloc(TensorBase* tensor, size_t requested_size = 0) const;
// TODO(wilber): Just for the convenience of migrating the code, it will be
// modified or removed later.
virtual const Place& GetPlace() const = 0; virtual const Place& GetPlace() const = 0;
// TODO(wilber): The fluid framework uses wait() in many places, how to delete // TODO(wilber): The fluid framework uses wait() in many places, how to delete
// this API interface. // this API interface.
virtual void Wait() const {} virtual void Wait() const {}
/**
* @brief Set the generator for special op.
*
* @param Generator
*/
void SetGenerator(Generator*);
/**
* @brief Get the generator object.
*
* @return Generator
*/
Generator* GetGenerator() const;
private: private:
struct Impl; struct Impl;
std::unique_ptr<Impl> impl_; std::unique_ptr<Impl> impl_;
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cstdint>
#include <memory>
#include <random>
namespace pten {
class Generator {
public:
struct GeneratorState {
int64_t device = -1;
uint64_t current_seed = 34342423252;
uint64_t thread_offset = 0;
std::mt19937_64 cpu_engine;
};
virtual ~Generator() = default;
// get random state
virtual GeneratorState GetState() = 0;
// set random state
virtual void SetState(const GeneratorState&) = 0;
// get current seed
virtual uint64_t GetCurrentSeed() = 0;
// random a seed and get
virtual uint64_t Seed() = 0;
// set seed
virtual void SetCurrentSeed(uint64_t seed) = 0;
// get cpu engine
virtual std::shared_ptr<std::mt19937_64> GetCPUEngine() = 0;
// set cpu engine
virtual void SetCPUEngine(std::shared_ptr<std::mt19937_64>) = 0;
virtual uint64_t Random64() = 0;
virtual std::pair<uint64_t, uint64_t> IncrementOffset(
uint64_t increament_offset) = 0;
// NOTE(zhiqiu): is_init_py_ is used to make generator be compatible with
// old seed, and it should be removed after all random-related operators
// and unittests upgrades to use generator.
virtual void SetIsInitPy(bool) = 0;
virtual bool GetIsInitPy() const = 0;
virtual uint64_t get_device_id() = 0;
};
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册