未验证 提交 463e31f4 编写于 作者: W Wilber 提交者: GitHub

context add generator (#39475)

* context add generator

* update
上级 0790f949
......@@ -138,13 +138,13 @@ std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t seed) {
}
}
GeneratorState Generator::GetState() {
pten::Generator::GeneratorState Generator::GetState() {
std::lock_guard<std::mutex> lock(this->mu_);
state_.cpu_engine = *engine_;
return this->state_;
}
void Generator::SetState(const GeneratorState& state) {
void Generator::SetState(const pten::Generator::GeneratorState& state) {
std::lock_guard<std::mutex> lock(this->mu_);
this->state_ = state;
this->engine_ = std::make_shared<std::mt19937_64>(state.cpu_engine);
......
......@@ -25,6 +25,8 @@ limitations under the License. */
#include <typeinfo>
#include <utility>
#include "paddle/pten/core/generator.h"
namespace paddle {
namespace framework {
......@@ -34,14 +36,7 @@ static uint64_t GetRandomSeed() {
return ((((uint64_t)rd()) << 32) + rd()) & 0x1FFFFFFFFFFFFF;
}
struct GeneratorState {
int64_t device = -1;
uint64_t current_seed = 34342423252;
uint64_t thread_offset = 0;
std::mt19937_64 cpu_engine;
};
struct Generator {
struct Generator : public pten::Generator {
Generator() {
auto seed = GetRandomSeed();
std::seed_seq seq({seed});
......@@ -82,9 +77,9 @@ struct Generator {
Generator(const Generator& other) = delete;
// get random state
GeneratorState GetState();
pten::Generator::GeneratorState GetState();
// set random state
void SetState(const GeneratorState&);
void SetState(const pten::Generator::GeneratorState&);
// get current seed
uint64_t GetCurrentSeed();
// random a seed and get
......@@ -105,7 +100,7 @@ struct Generator {
uint64_t get_device_id() { return this->state_.device; }
private:
GeneratorState state_;
pten::Generator::GeneratorState state_;
std::shared_ptr<std::mt19937_64> engine_;
mutable std::mutex mu_;
......
......@@ -123,7 +123,7 @@ cc_library(init SRCS init.cc DEPS device_context custom_kernel)
# avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc DEPS simple_threadpool malloc xxhash ${STREAM_CALLBACK_DEPS}
place pten_place eigen3 stringpiece cpu_helper cpu_info framework_proto ${IPU_CTX_DEPS} ${GPU_CTX_DEPS} ${NPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}
${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS} ${MLU_CTX_DEPS} eigen3 cpu_context)
${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS} ${MLU_CTX_DEPS} eigen3 cpu_context generator)
if(WITH_XPU)
target_link_libraries(device_context xpu_context)
endif()
......
......@@ -28,6 +28,7 @@ limitations under the License. */
#endif
#include "glog/logging.h"
#include "paddle/fluid/framework/expect.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -160,11 +161,14 @@ inline void EmplaceDeviceContext(
.GetAllocator(p, cuda_ctx->stream())
.get());
cuda_ctx->PartialInitWithAllocator();
dev_ctx->SetGenerator(
framework::GetDefaultCUDAGenerator(p.GetDeviceId()).get());
#endif
} else {
dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance()
.GetAllocator(p)
.get());
dev_ctx->SetGenerator(framework::DefaultCPUGenerator().get());
}
dev_ctx->SetHostAllocator(
memory::allocation::AllocatorFacade::Instance()
......
......@@ -8,6 +8,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/generator.h"
#include <fcntl.h>
#ifdef _POSIX_C_SOURCE
......@@ -31,10 +32,11 @@ namespace paddle {
namespace pybind {
void BindGenerator(py::module* m_ptr) {
auto& m = *m_ptr;
py::class_<framework::GeneratorState,
std::shared_ptr<framework::GeneratorState>>(m, "GeneratorState")
py::class_<pten::Generator::GeneratorState,
std::shared_ptr<pten::Generator::GeneratorState>>(m,
"GeneratorState")
.def("current_seed",
[](std::shared_ptr<framework::GeneratorState>& self) {
[](std::shared_ptr<pten::Generator::GeneratorState>& self) {
return self->current_seed;
});
py::class_<std::mt19937_64>(m, "mt19937_64", "");
......
......@@ -114,10 +114,27 @@ struct DeviceContext::Impl {
return static_cast<T*>(HostAlloc(tensor, dtype, requested_size));
}
void SetGenerator(Generator* gen) {
PADDLE_ENFORCE_NOT_NULL(
gen,
pten::errors::InvalidArgument(
"Required generator shall not be nullptr, but received nullptr."));
generator_ = gen;
}
Generator* GetGenerator() const {
PADDLE_ENFORCE_NOT_NULL(
generator_,
pten::errors::InvalidArgument("Required generator_ shall not be "
"nullptr, but received nullptr."));
return generator_;
}
private:
const Allocator* device_allocator_{nullptr};
const Allocator* host_allocator_{nullptr};
const Allocator* zero_allocator_{nullptr};
Generator* generator_{nullptr};
};
DeviceContext::DeviceContext() { impl_ = std::make_unique<Impl>(); }
......@@ -201,4 +218,8 @@ DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex128)
#undef DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION
void DeviceContext::SetGenerator(Generator* gen) { impl_->SetGenerator(gen); }
Generator* DeviceContext::GetGenerator() const { return impl_->GetGenerator(); }
} // namespace pten
......@@ -16,11 +16,10 @@ limitations under the License. */
#include <memory>
// TODO(wilber): Do we need to use place in pten kernel?
#include "paddle/pten/common/place.h"
#include "paddle/pten/common/data_type.h"
#include "paddle/pten/common/place.h"
#include "paddle/pten/core/allocator.h"
#include "paddle/pten/core/generator.h"
namespace pten {
class TensorBase;
......@@ -112,13 +111,24 @@ class DeviceContext {
template <typename T>
T* HostAlloc(TensorBase* tensor, size_t requested_size = 0) const;
// TODO(wilber): Just for the convenience of migrating the code, it will be
// modified or removed later.
virtual const Place& GetPlace() const = 0;
// TODO(wilber): The fluid framework uses wait() in many places, how to delete
// this API interface.
virtual void Wait() const {}
/**
* @brief Set the generator for special op.
*
* @param Generator
*/
void SetGenerator(Generator*);
/**
* @brief Get the generator object.
*
* @return Generator
*/
Generator* GetGenerator() const;
private:
struct Impl;
std::unique_ptr<Impl> impl_;
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cstdint>
#include <memory>
#include <random>
namespace pten {
class Generator {
public:
struct GeneratorState {
int64_t device = -1;
uint64_t current_seed = 34342423252;
uint64_t thread_offset = 0;
std::mt19937_64 cpu_engine;
};
virtual ~Generator() = default;
// get random state
virtual GeneratorState GetState() = 0;
// set random state
virtual void SetState(const GeneratorState&) = 0;
// get current seed
virtual uint64_t GetCurrentSeed() = 0;
// random a seed and get
virtual uint64_t Seed() = 0;
// set seed
virtual void SetCurrentSeed(uint64_t seed) = 0;
// get cpu engine
virtual std::shared_ptr<std::mt19937_64> GetCPUEngine() = 0;
// set cpu engine
virtual void SetCPUEngine(std::shared_ptr<std::mt19937_64>) = 0;
virtual uint64_t Random64() = 0;
virtual std::pair<uint64_t, uint64_t> IncrementOffset(
uint64_t increament_offset) = 0;
// NOTE(zhiqiu): is_init_py_ is used to make generator be compatible with
// old seed, and it should be removed after all random-related operators
// and unittests upgrades to use generator.
virtual void SetIsInitPy(bool) = 0;
virtual bool GetIsInitPy() const = 0;
virtual uint64_t get_device_id() = 0;
};
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册