提交 fc8b501b 编写于 作者: M Megvii Engine Team

refactor(mgb/core): refactor cpu compnode so that default cpu has no ability to record

GitOrigin-RevId: 7de4771476e87d3ed11cffbad4c02741591ef9a2
上级 b7176069
......@@ -11,18 +11,18 @@
#include "./comp_node.h"
#include "megbrain/common.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/system.h"
#include "megbrain/utils/arith_helper.h"
#include "megbrain/utils/thread.h"
#include "megbrain/utils/timer.h"
#include "megbrain/utils/thread_pool.h"
#include "megbrain/common.h"
#include "megbrain/utils/timer.h"
#include <atomic>
#include <condition_variable>
#include <cstdint>
#include <cstring>
#include <atomic>
#include <stdlib.h>
#ifndef __APPLE__
......@@ -44,8 +44,6 @@ struct TaskElem {
};
} // anonymous namespace
using CpuCompNodeImpl = CpuCompNode::CompNodeImpl;
void CpuCompNode::CpuDispatchableBase::add_callback(Task&& task) {
dispatch(std::move(task));
}
......@@ -110,7 +108,15 @@ class CpuCompNode::SeqRecorderImpl final : public CompNodeSeqRecorder {
* \brief use to check the all ther recording tasks are its self CompNode
* related task, void hook other CompNode related task to the recorder.
*/
void check_the_same_comp_node(const CompNode& comp_node) const;
void check_the_same_comp_node(const CompNode& comp_node) const {
if (mgb_unlikely(comp_node.valid())) {
mgb_assert(m_record_compnode == comp_node,
"CompNode %s can't hook in CompNode %s when recording\n",
comp_node.locator().to_string().c_str(),
m_record_compnode.locator().to_string().c_str());
}
}
public:
SeqRecorderImpl(SeqRecorderImpl** self_pointer, ThreadPool* thread_pool,
const CompNode& comp_node)
......@@ -127,13 +133,13 @@ public:
}
}
void enter_fake_exec(const CompNode& comp_node) override {
void enter_fake_exec(const CompNode& comp_node) override {
check_the_same_comp_node(comp_node);
mgb_assert(!m_stopped && !m_fake_exec);
m_fake_exec = true;
}
void exit_fake_exec(const CompNode& comp_node) override {
void exit_fake_exec(const CompNode& comp_node) override {
check_the_same_comp_node(comp_node);
mgb_assert(!m_stopped && m_fake_exec);
mgb_assert(m_tasks.empty());
......@@ -165,9 +171,9 @@ public:
m_thread_pool->add_task(i);
}
m_thread_pool->deactive();
}else{
} else {
for (auto&& task : m_tasks) {
for(size_t i=0; i<task.nr_parallelism;i++){
for (size_t i = 0; i < task.nr_parallelism; i++) {
task.task(i, 0);
}
}
......@@ -236,273 +242,113 @@ public:
ThreadPool* get_thread_pool() { return m_thread_pool; }
};
class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
//! used during comp node seq rec
class CompSeqRecEventImpl;
class CpuEventImpl;
//! TODO: because the x-code bug, see
//! https://github.com/tensorflow/tensorflow/issues/18356
//! thread local is no support on IOS,
//! When update x-xode, this code should be deleted
#if !defined(IOS) && MGB_HAVE_THREAD
static thread_local SeqRecorderImpl* sm_cur_recorder;
#else
SeqRecorderImpl* sm_cur_recorder = nullptr;
#endif
using CompNodeBaseImpl = CpuCompNode::CompNodeBaseImpl;
using CompNodeNoRecorderImpl = CpuCompNode::CompNodeNoRecorderImpl;
using CompNodeRecorderImpl = CpuCompNode::CompNodeRecorderImpl;
std::shared_ptr<WorkerQueue> m_worker_queue;
//! ==================== CompNodeBaseImpl ======================
class CpuCompNode::CompNodeBaseImpl : public CpuDispatchableBase {
protected:
Locator m_locator, m_locator_logical;
std::unique_ptr<ThreadPool> m_thread_pool;
//! ptr to default cpu, only used by check_global_finalized
static CpuCompNodeImpl *sm_default_cpu_comp_node_ptr;
//! return whether global finalized, and print warning in such case
inline bool check_global_finalized(const char* reason);
static void static_free_device(ImplBase* self, void* ptr) {
static_cast<CompNodeImpl*>(self)->free_device(ptr);
}
static void static_free_host(ImplBase* self, void* ptr) {
static_cast<CompNodeImpl*>(self)->free_host(ptr);
}
public:
CompNodeImpl(const Locator& locator, const Locator& locator_logical,
const std::shared_ptr<WorkerQueue>& worker_queue);
~CompNodeImpl() {
if (sm_cur_recorder) {
sm_cur_recorder->stop();
}
if (m_worker_queue) {
// synchronize before fini
m_worker_queue->wait_all_task_finish();
}
m_env.fini();
if (m_worker_queue) {
// wait for new kernels dispatched in fini() (like free_device())
m_worker_queue->wait_all_task_finish();
}
if (this == sm_default_cpu_comp_node_ptr) {
// This should only happen in global library .fini. We clear
// sm_default_cpu_comp_node_ptr so check_global_finalized() can
// work correctly
sm_default_cpu_comp_node_ptr = nullptr;
}
}
public:
CompNodeBaseImpl(const Locator& locator, const Locator& locator_logical,
free_func_t fd, free_func_t fh)
: CpuDispatchableBase(fd, fh),
m_locator(locator),
m_locator_logical(locator_logical) {}
ThreadPool* get_thread_pool() const { return m_thread_pool.get(); }
virtual ~CompNodeBaseImpl() {}
void* mgb_aligned_alloc(size_t size) {
auto alignment = get_mem_addr_alignment();
void* mgb_aligned_alloc(size_t size) {
auto alignment = get_mem_addr_alignment();
#ifdef WIN32
return _aligned_malloc(size, alignment);
return _aligned_malloc(size, alignment);
#elif defined(__ANDROID__) || defined(ANDROID)
return memalign(alignment, size);
return memalign(alignment, size);
#else
void *ptr = nullptr;
auto err = posix_memalign(&ptr, alignment, size);
mgb_assert(!err, "failed to malloc %zubytes with align %zu",
size, alignment);
return ptr;
void* ptr = nullptr;
auto err = posix_memalign(&ptr, alignment, size);
mgb_assert(!err, "failed to malloc %zubytes with align %zu", size,
alignment);
return ptr;
#endif
}
}
static void mgb_aligned_free(void* ptr) {
static void mgb_aligned_free(void* ptr) {
#ifdef WIN32
_aligned_free(ptr);
_aligned_free(ptr);
#else
::free(ptr);
::free(ptr);
#endif
}
void* alloc_device(size_t size) override {
if (sm_cur_recorder) {
sm_cur_recorder->on_alloc(this);
}
return mgb_aligned_alloc(size);
}
void free_device(void *ptr) {
if (sm_cur_recorder || check_global_finalized("free_device()")) {
mgb_aligned_free(ptr);
if (sm_cur_recorder) {
sm_cur_recorder->on_free(this);
}
return;
} else {
auto do_free = [ptr]() {
mgb_aligned_free(ptr);
};
m_env.cpu_env().dispatch(do_free);
}
}
void *alloc_host(size_t size) override {
if (m_worker_queue) {
m_worker_queue->check_exception();
}
return mgb_aligned_alloc(size);
}
void free_host(void *ptr) {
if (check_global_finalized("free_host()")) {
mgb_aligned_free(ptr);
return;
}
if (m_worker_queue) {
m_worker_queue->check_exception();
}
return mgb_aligned_free(ptr);
}
void copy_to_host(void *host_ptr,
const void *device_ptr, size_t size) override {
if (m_worker_queue) {
m_worker_queue->check_exception();
}
// use lambda capture to avoid memory allocation in std::bind
auto do_copy = [host_ptr, device_ptr, size]() {
std::memcpy(host_ptr, device_ptr, size);
};
m_env.cpu_env().dispatch(do_copy);
}
void copy_to_device(void *device_ptr,
const void *host_ptr, size_t size) override {
if (m_worker_queue) {
m_worker_queue->check_exception();
}
// use lambda capture to avoid memory allocation in std::bind
auto do_copy = [device_ptr, host_ptr, size]() {
std::memcpy(device_ptr, host_ptr, size);
};
m_env.cpu_env().dispatch(do_copy);
}
}
void peer_copy_to(
Impl *dest_impl, void *dest,
const void *src, size_t size) override {
if (!dest_impl->same_type<CpuCompNode::CompNodeImpl>()) {
if (dest_impl->env().property().type == DeviceType::ATLAS) {
#if MGB_ATLAS
dest_impl->copy_to_device(dest, src, size);
return;
#else
mgb_throw(MegBrainError,
"Atlas comp_node used but "
"MGB_ATLAS not enabled");
#endif
} else if (dest_impl->env().property().type ==
DeviceType::CAMBRICON) {
#if MGB_CAMBRICON
dest_impl->copy_to_device(dest, src, size);
return;
#else
mgb_throw(MegBrainError,
"Cambricon comp_node used but "
"MGB_CAMBRICON not enabled");
#endif
void* alloc_device(size_t size) override { return mgb_aligned_alloc(size); }
} else {
mgb_assert(locator().device == Locator::DEVICE_CPU_DEFAULT,
"currently only peer copy from default cpu comp "
"nodes "
"is implemented");
}
}
dest_impl->copy_to_device(dest, src, size);
}
void* alloc_host(size_t size) override { return mgb_aligned_alloc(size); }
size_t get_mem_addr_alignment() override {
return m_env.property().mem_alignment;
}
void copy_to_host(void* host_ptr, const void* device_ptr,
size_t size) override {
// use lambda capture to avoid memory allocation in std::bind
auto do_copy = [host_ptr, device_ptr, size]() {
std::memcpy(host_ptr, device_ptr, size);
};
m_env.cpu_env().dispatch(do_copy);
}
std::unique_ptr<Event> create_event(size_t flags) override;
void copy_to_device(void* device_ptr, const void* host_ptr,
size_t size) override {
// use lambda capture to avoid memory allocation in std::bind
auto do_copy = [device_ptr, host_ptr, size]() {
std::memcpy(device_ptr, host_ptr, size);
};
m_env.cpu_env().dispatch(do_copy);
}
void sync() override {
if (sm_cur_recorder) {
sm_cur_recorder->on_sync(this);
} else if (m_worker_queue) {
m_worker_queue->wait_all_task_finish();
}
if (m_thread_pool) {
m_thread_pool->deactive();
}
}
void peer_copy_to(Impl* dest_impl, void* dest, const void* src,
size_t size) override {
dest_impl->copy_to_device(dest, src, size);
}
void dispatch(Task &&task) override {
m_env.cpu_env().dispatch(std::move(task));
}
size_t get_mem_addr_alignment() override {
return m_env.property().mem_alignment;
}
MemNode mem_node() override {
// TODO: numa nodes
return get_host_cpu_mem_node();
}
void dispatch(Task&& task) override {
m_env.cpu_env().dispatch(std::move(task));
}
std::pair<size_t, size_t> get_mem_status_bytes() override {
return sys::get_ram_status_bytes();
}
MemNode mem_node() override {
// TODO: numa nodes
return get_host_cpu_mem_node();
}
Locator locator() override {
return m_locator;
}
std::pair<size_t, size_t> get_mem_status_bytes() override {
return sys::get_ram_status_bytes();
}
Locator locator_logical() override {
return m_locator_logical;
}
Locator locator() override { return m_locator; }
std::unique_ptr<CompNodeSeqRecorder> create_seq_recorder(
cg::ComputingGraph*) override {
return std::make_unique<SeqRecorderImpl>(&sm_cur_recorder,
m_thread_pool.get(), this);
}
Locator locator_logical() override { return m_locator_logical; }
//! current sequence recorder of this thread
#if !defined(IOS) && MGB_HAVE_THREAD
static SeqRecorderImpl* cur_recorder() { return sm_cur_recorder; }
#else
SeqRecorderImpl* cur_recorder() { return sm_cur_recorder; }
#endif
void add_callback(Task&& task) override {
CpuDispatchableBase::add_callback(std::move(task));
}
void add_callback(Task &&task) override {
if (!check_global_finalized("add_callback()")) {
CpuDispatchableBase::add_callback(std::move(task));
} else {
task();
}
}
virtual SeqRecorderImpl* cur_recorder() const = 0;
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CpuCompNodeImpl);
CpuCompNodeImpl* CpuCompNodeImpl::sm_default_cpu_comp_node_ptr;
#if !defined(IOS) && MGB_HAVE_THREAD
thread_local CpuCompNode::SeqRecorderImpl* CpuCompNodeImpl::sm_cur_recorder =
nullptr;
#endif
void CpuCompNode::SeqRecorderImpl::check_the_same_comp_node(
const CompNode& comp_node) const {
if (mgb_unlikely(comp_node.valid())) {
mgb_assert(m_record_compnode == comp_node,
"CompNode %s can't hook in CompNode %s when recording\n",
comp_node.locator().to_string().c_str(),
m_record_compnode.locator().to_string().c_str());
}
}
//! implementation of CPUDispatcher that is passed to megdnn via megcore
class CpuCompNode::WorkerQueue::DispatcherImpl final: public CPUDispatcher {
class CpuCompNode::WorkerQueue::DispatcherImpl final : public CPUDispatcher {
std::atomic_size_t m_nr_task{0};
std::shared_ptr<WorkerQueue> m_queue;
CpuCompNode::CompNodeImpl* const m_comp_node;
//! DispatcherImpl only used by CompNodeRecorderImpl, but we still use
//! CompNodeBaseImpl* because of incomplete type error
CompNodeBaseImpl* const m_comp_node;
public:
DispatcherImpl(const std::shared_ptr<WorkerQueue>& queue,
CpuCompNode::CompNodeImpl* comp_node)
CompNodeBaseImpl* comp_node)
: m_queue{queue}, m_comp_node{comp_node} {}
void dispatch(Task&& task) override {
......@@ -559,10 +405,12 @@ public:
class InplaceCPUDispatcher final : public CPUDispatcher {
std::atomic_size_t m_nr_task{0};
ThreadPool* m_thread_pool = nullptr;
CpuCompNode::CompNodeImpl* const m_comp_node;
//! InplaceCPUDispatcher may used by both type of compnodes, so
//! m_comp_node's type should be base class.
CompNodeBaseImpl* const m_comp_node;
public:
InplaceCPUDispatcher(CpuCompNode::CompNodeImpl* comp_node,
InplaceCPUDispatcher(CompNodeBaseImpl* comp_node,
ThreadPool* thread_pool = nullptr)
: m_thread_pool(thread_pool), m_comp_node(comp_node) {}
......@@ -585,9 +433,9 @@ public:
} else if (m_thread_pool) {
m_nr_task.fetch_add(1, std::memory_order_relaxed);
m_thread_pool->add_task({task, parallelism});
}else{
} else {
m_nr_task.fetch_add(1, std::memory_order_relaxed);
for(size_t i=0; i<parallelism;i++){
for (size_t i = 0; i < parallelism; i++) {
task(i, 0);
}
}
......@@ -612,143 +460,417 @@ public:
recorder->get_thread_pool()->set_affinity(affinity_cb);
} else if (m_thread_pool) {
m_thread_pool->set_affinity(affinity_cb);
}else{
} else {
affinity_cb(0);
}
}
};
CpuCompNode::CompNodeImpl::CompNodeImpl(
const Locator& locator, const Locator& locator_logical,
const std::shared_ptr<WorkerQueue>& worker_queue)
: CpuDispatchableBase(static_free_device, static_free_host),
m_worker_queue{worker_queue},
m_locator(locator),
m_locator_logical(locator_logical) {
auto cn = make_comp_node_from_impl(this);
if (locator.type == DeviceType::MULTITHREAD) {
m_thread_pool = std::unique_ptr<ThreadPool>(
new ThreadPool(static_cast<size_t>(locator.nr_threads)));
mgb_assert(m_thread_pool, "ThradPool create failed");
//! ==================== CompNodeNoRecorderImpl ======================
/**
* \note: CompNodeNoRecorderImpl will use most implements in base including:
* alloc_device, alloc_host, copy_to_host, copy_to_device, peer_copy_to,
* add_callback ...
*/
class CpuCompNode::CompNodeNoRecorderImpl final : public CompNodeBaseImpl {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
//! ptr to default cpu, only used by check_global_finalized
static CompNodeNoRecorderImpl* sm_default_cpu_comp_node_ptr;
static void static_free_device(ImplBase* self, void* ptr) {
static_cast<CompNodeNoRecorderImpl*>(self)->free_device(ptr);
}
if (locator.type == DeviceType::CPU) {
if (locator.device == Locator::DEVICE_CPU_DEFAULT) {
sm_default_cpu_comp_node_ptr = this;
m_env.init_cpu({std::make_shared<InplaceCPUDispatcher>(this)}, cn);
} else {
m_env.init_cpu({std::make_shared<WorkerQueue::DispatcherImpl>(
m_worker_queue, this)},
cn);
}
} else if (locator.type == DeviceType::MULTITHREAD) {
if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) {
m_env.init_cpu({std::make_shared<InplaceCPUDispatcher>(
this, m_thread_pool.get())},
cn);
} else {
m_worker_queue->attach_thread_pool(m_thread_pool.get());
m_env.init_cpu({std::make_shared<WorkerQueue::DispatcherImpl>(
m_worker_queue, this)},
cn);
static void static_free_host(ImplBase* self, void* ptr) {
static_cast<CompNodeNoRecorderImpl*>(self)->free_host(ptr);
}
using CpuEventImpl = CpuDispatchableBase::EventImpl;
CompNodeNoRecorderImpl(const Locator& locator,
const Locator& locator_logical)
: CompNodeBaseImpl(locator, locator_logical, static_free_device,
static_free_host) {
mgb_assert(
locator.type == DeviceType::CPU &&
locator.device == Locator::DEVICE_CPU_DEFAULT,
"CompNodeNoRecorder is only constructed On DEVICE_CPU_DEFAULT");
auto cn = make_comp_node_from_impl(this);
m_env.init_cpu({std::make_shared<InplaceCPUDispatcher>(this)}, cn);
sm_default_cpu_comp_node_ptr = this;
}
~CompNodeNoRecorderImpl() {
m_env.fini();
sm_default_cpu_comp_node_ptr = nullptr;
}
//! return whether global finalized, and print warning in such case
bool check_global_finalized(const char* reason) {
MGB_MARK_USED_VAR(reason);
if (!sm_default_cpu_comp_node_ptr) {
static std::atomic_flag warn_printed = ATOMIC_FLAG_INIT;
if (!warn_printed.test_and_set()) {
mgb_log_debug(
"cpu comp node method called after global finalize: "
"reason=%s",
reason);
}
return true;
}
return false;
}
}
class CpuCompNodeImpl::CompSeqRecEventImpl final
: public CpuDispatchableBase::EventImpl {
void do_record() override {
auto impl = static_cast<CpuCompNodeImpl*>(m_comp_node_impl);
if (auto rec = impl->cur_recorder()) {
auto callback = [this]() {
incr_nr_req();
on_finish();
};
rec->dispatch_allow_after_sync(callback, m_comp_node_impl);
void free_device(void* ptr) {
if (check_global_finalized("free_device()")) {
CompNodeBaseImpl::mgb_aligned_free(ptr);
return;
} else {
EventImpl::do_record();
auto do_free = [ptr]() { CompNodeBaseImpl::mgb_aligned_free(ptr); };
m_env.cpu_env().dispatch(do_free);
}
}
void do_device_wait_by(Impl*) override {
mgb_throw(MegBrainError,
"device_wait() should not be called on events created during "
"comp node seq recording");
void free_host(void* ptr) {
check_global_finalized("free_host()");
return CompNodeBaseImpl::mgb_aligned_free(ptr);
}
public:
using EventImpl::EventImpl;
std::unique_ptr<Event> create_event(size_t flags) override {
return std::make_unique<CpuEventImpl>(this, flags);
}
void sync() override {}
std::unique_ptr<CompNodeSeqRecorder> create_seq_recorder(
cg::ComputingGraph*) override {
mgb_assert(false, "default_cpu has no ability to record");
return nullptr;
}
SeqRecorderImpl* cur_recorder() const override { return nullptr; }
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompNodeNoRecorderImpl);
CompNodeNoRecorderImpl* CompNodeNoRecorderImpl::sm_default_cpu_comp_node_ptr =
nullptr;
//! ==================== CompNodeRecorderImpl ======================
class CpuCompNode::CompNodeRecorderImpl final : public CompNodeBaseImpl {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
std::unique_ptr<ThreadPool> m_thread_pool;
std::shared_ptr<WorkerQueue> m_worker_queue;
//! used during comp node seq rec
class CompSeqRecEventImpl final : public CpuDispatchableBase::EventImpl {
void do_record() override {
auto impl = static_cast<CompNodeRecorderImpl*>(m_comp_node_impl);
if (auto rec = impl->cur_recorder()) {
auto callback = [this]() {
incr_nr_req();
on_finish();
};
rec->dispatch_allow_after_sync(callback, m_comp_node_impl);
} else {
EventImpl::do_record();
}
}
void do_device_wait_by(Impl*) override {
mgb_throw(MegBrainError,
"device_wait() should not be called on events created "
"during "
"comp node seq recording");
}
class CpuCompNodeImpl::CpuEventImpl final
: public CpuDispatchableBase::EventImpl {
public:
using EventImpl::EventImpl;
};
class CpuEventImpl final : public CpuDispatchableBase::EventImpl {
#if MGB_HAVE_THREAD
void host_wait_cv() override {
CpuDispatchableBase::EventImpl::host_wait_cv();
auto thread_pool = static_cast<CpuCompNodeImpl*>(m_comp_node_impl)
->get_thread_pool();
if (thread_pool) {
thread_pool->deactive();
void host_wait_cv() override {
CpuDispatchableBase::EventImpl::host_wait_cv();
auto thread_pool =
static_cast<CompNodeRecorderImpl*>(m_comp_node_impl)
->get_thread_pool();
if (thread_pool) {
thread_pool->deactive();
}
}
}
#endif
public:
using EventImpl::EventImpl;
};
//! TODO: because the x-code bug, see
//! https://github.com/tensorflow/tensorflow/issues/18356
//! thread local is no support on IOS,
//! When update x-xode, this code should be deleted
#if !defined(IOS) && MGB_HAVE_THREAD
static thread_local SeqRecorderImpl* sm_cur_recorder;
#else
SeqRecorderImpl* sm_cur_recorder = nullptr;
#endif
public:
using EventImpl::EventImpl;
};
static void static_free_device(ImplBase* self, void* ptr) {
static_cast<CompNodeRecorderImpl*>(self)->free_device(ptr);
}
std::unique_ptr<CompNode::Event> CpuCompNodeImpl::create_event(size_t flags) {
if (m_worker_queue) {
m_worker_queue->check_exception();
static void static_free_host(ImplBase* self, void* ptr) {
static_cast<CompNodeRecorderImpl*>(self)->free_host(ptr);
}
if (sm_cur_recorder) {
return std::make_unique<CompSeqRecEventImpl>(this, flags);
} else {
return std::make_unique<CpuEventImpl>(this, flags);
CompNodeRecorderImpl(const Locator& locator, const Locator& locator_logical,
const std::shared_ptr<WorkerQueue>& worker_queue)
: CompNodeBaseImpl(locator, locator_logical, static_free_device,
static_free_host),
m_worker_queue(worker_queue) {
auto cn = make_comp_node_from_impl(this);
if (locator.type == DeviceType::MULTITHREAD) {
m_thread_pool = std::unique_ptr<ThreadPool>(
new ThreadPool(static_cast<size_t>(locator.nr_threads)));
mgb_assert(m_thread_pool, "ThradPool create failed");
}
if (locator.type == DeviceType::CPU) {
if (locator.device == Locator::DEVICE_CPU_DEFAULT) {
m_env.init_cpu({std::make_shared<InplaceCPUDispatcher>(this)},
cn);
} else {
m_env.init_cpu({std::make_shared<WorkerQueue::DispatcherImpl>(
m_worker_queue, this)},
cn);
}
} else if (locator.type == DeviceType::MULTITHREAD) {
if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) {
m_env.init_cpu({std::make_shared<InplaceCPUDispatcher>(
this, m_thread_pool.get())},
cn);
} else {
m_worker_queue->attach_thread_pool(m_thread_pool.get());
m_env.init_cpu({std::make_shared<WorkerQueue::DispatcherImpl>(
m_worker_queue, this)},
cn);
}
}
}
}
~CompNodeRecorderImpl() {
if (sm_cur_recorder) {
sm_cur_recorder->stop();
}
if (m_worker_queue) {
// synchronize before fini
m_worker_queue->wait_all_task_finish();
}
m_env.fini();
if (m_worker_queue) {
// wait for new kernels dispatched in fini() (like free_device())
m_worker_queue->wait_all_task_finish();
}
}
ThreadPool* get_thread_pool() const { return m_thread_pool.get(); }
//! return whether global finalized, and print warning in such case
bool check_global_finalized(const char* reason) {
MGB_MARK_USED_VAR(reason);
if (!sm_pool) {
static std::atomic_flag warn_printed = ATOMIC_FLAG_INIT;
if (!warn_printed.test_and_set()) {
mgb_log_debug(
"cpu comp node method called after global finalize: "
"reason=%s",
reason);
}
return true;
}
return false;
}
void* alloc_device(size_t size) override {
if (sm_cur_recorder) {
sm_cur_recorder->on_alloc(this);
}
return CompNodeBaseImpl::alloc_device(size);
}
void free_device(void* ptr) {
if (sm_cur_recorder || check_global_finalized("free_device()")) {
CompNodeBaseImpl::mgb_aligned_free(ptr);
if (sm_cur_recorder) {
sm_cur_recorder->on_free(this);
}
return;
} else {
auto do_free = [ptr]() { CompNodeBaseImpl::mgb_aligned_free(ptr); };
m_env.cpu_env().dispatch(do_free);
}
}
void* alloc_host(size_t size) override {
if (m_worker_queue) {
m_worker_queue->check_exception();
}
return CompNodeBaseImpl::alloc_host(size);
}
void free_host(void* ptr) {
if (check_global_finalized("free_host()")) {
CompNodeBaseImpl::mgb_aligned_free(ptr);
return;
}
if (m_worker_queue) {
m_worker_queue->check_exception();
}
CompNodeBaseImpl::mgb_aligned_free(ptr);
}
void copy_to_host(void* host_ptr, const void* device_ptr,
size_t size) override {
if (m_worker_queue) {
m_worker_queue->check_exception();
}
CompNodeBaseImpl::copy_to_host(host_ptr, device_ptr, size);
}
void copy_to_device(void* device_ptr, const void* host_ptr,
size_t size) override {
if (m_worker_queue) {
m_worker_queue->check_exception();
}
CompNodeBaseImpl::copy_to_device(device_ptr, host_ptr, size);
}
void peer_copy_to(Impl* dest_impl, void* dest, const void* src,
size_t size) override {
//! copy to default_cpu
if (dest_impl->same_type<CpuCompNode::CompNodeNoRecorderImpl>()) {
CompNodeBaseImpl::peer_copy_to(dest_impl, dest, src, size);
return;
}
if (!dest_impl->same_type<CpuCompNode::CompNodeRecorderImpl>()) {
if (dest_impl->env().property().type == DeviceType::ATLAS) {
#if MGB_ATLAS
dest_impl->copy_to_device(dest, src, size);
return;
#else
mgb_throw(MegBrainError,
"Atlas comp_node used but "
"MGB_ATLAS not enabled");
#endif
} else if (dest_impl->env().property().type ==
DeviceType::CAMBRICON) {
#if MGB_CAMBRICON
dest_impl->copy_to_device(dest, src, size);
return;
#else
mgb_throw(MegBrainError,
"Cambricon comp_node used but "
"MGB_CAMBRICON not enabled");
#endif
}
else {
mgb_assert(locator().device == Locator::DEVICE_CPU_DEFAULT,
"currently only peer copy from default cpu comp "
"nodes "
"is implemented");
}
}
dest_impl->copy_to_device(dest, src, size);
}
std::unique_ptr<Event> create_event(size_t flags) override {
if (m_worker_queue) {
m_worker_queue->check_exception();
}
if (sm_cur_recorder) {
return std::make_unique<CompSeqRecEventImpl>(this, flags);
} else {
return std::make_unique<CpuEventImpl>(this, flags);
}
}
void sync() override {
if (sm_cur_recorder) {
sm_cur_recorder->on_sync(this);
} else if (m_worker_queue) {
m_worker_queue->wait_all_task_finish();
}
if (m_thread_pool) {
m_thread_pool->deactive();
}
}
std::unique_ptr<CompNodeSeqRecorder> create_seq_recorder(
cg::ComputingGraph*) override {
return std::make_unique<SeqRecorderImpl>(&sm_cur_recorder,
m_thread_pool.get(), this);
}
SeqRecorderImpl* cur_recorder() const override { return sm_cur_recorder; }
void add_callback(Task&& task) override {
if (!check_global_finalized("add_callback()")) {
CompNodeBaseImpl::add_callback(std::move(task));
} else {
task();
}
}
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompNodeRecorderImpl);
#if !defined(IOS) && MGB_HAVE_THREAD
thread_local CpuCompNode::SeqRecorderImpl*
CompNodeRecorderImpl::sm_cur_recorder = nullptr;
#endif
/* ======================== CpuCompNode ======================== */
struct CpuCompNode::Pool {
static constexpr int MAX_NR_COMP_NODE = 1024;
struct CpuCompNodeImplDeleter {
void operator ()(CpuCompNodeImpl *p) {
p->~CpuCompNodeImpl();
}
struct CompNodeRecorderImplDeleter {
void operator()(CompNodeRecorderImpl* p) { p->~CompNodeRecorderImpl(); }
};
std::recursive_mutex mtx;
// use global memory pool to ensuare object memory accessible even after
// global finalize
std::aligned_storage_t<sizeof(CpuCompNodeImpl), alignof(CpuCompNodeImpl)>
impl_storage[MAX_NR_COMP_NODE];
std::aligned_storage_t<sizeof(CompNodeRecorderImpl),
alignof(CompNodeRecorderImpl)>
impl_storage[MAX_NR_COMP_NODE];
size_t nr_used_impl_storage = 0;
std::unordered_map<CompNode::LocatorPairHashKey,
std::unique_ptr<CpuCompNodeImpl, CpuCompNodeImplDeleter>,
CompNode::LocatorPairHashKey::Hash> locator2impl;
std::unordered_map<
CompNode::LocatorPairHashKey,
std::unique_ptr<CompNodeRecorderImpl, CompNodeRecorderImplDeleter>,
CompNode::LocatorPairHashKey::Hash>
locator2impl;
ThinHashMap<std::pair<int, int>, std::weak_ptr<WorkerQueue>> physical2queue;
std::unordered_map<CompNode::LocatorPairHashKey,
std::unique_ptr<CpuCompNodeImpl, CpuCompNodeImplDeleter>,
CompNode::LocatorPairHashKey::Hash> locator2impl_multi_thread;
std::unordered_map<
CompNode::LocatorPairHashKey,
std::unique_ptr<CompNodeRecorderImpl, CompNodeRecorderImplDeleter>,
CompNode::LocatorPairHashKey::Hash>
locator2impl_multi_thread;
ThinHashMap<std::pair<int, int>, std::weak_ptr<WorkerQueue>>
physical2queue_multithead;
};
CpuCompNode::Pool* CpuCompNode::sm_pool;
Spinlock CpuCompNode::sm_pool_mtx;
void CpuCompNode::foreach(thin_function<void(CompNode)> callback) {
void CpuCompNode::foreach (thin_function<void(CompNode)> callback) {
if (!sm_pool)
return;
for (size_t i = 0; ; ++ i) {
for (size_t i = 0;; ++i) {
CompNode cur;
{
MGB_LOCK_GUARD(sm_pool->mtx);
if (i >= sm_pool->nr_used_impl_storage)
return;
cur = make_comp_node_from_impl(
reinterpret_cast<CpuCompNodeImpl*>(
&sm_pool->impl_storage[i]));
reinterpret_cast<CompNodeRecorderImpl*>(
&sm_pool->impl_storage[i]));
}
callback(cur);
}
......@@ -781,7 +903,7 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator,
// use static storage so object can be safely accessed even after
// global finalize
static std::aligned_storage_t<sizeof(Pool), alignof(Pool)> storage;
sm_pool = new(&storage) Pool;
sm_pool = new (&storage) Pool;
}
}
mgb_assert(locator.device >= 0 ||
......@@ -800,23 +922,22 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator,
locator_logical.type == CompNode::DeviceType::MULTITHREAD);
}
if (locator.type == DeviceType::CPU) {
auto &&pqueue_weak =
sm_pool->physical2queue[{locator.device, locator.stream}];
auto&& pqueue_weak =
sm_pool->physical2queue[{locator.device, locator.stream}];
auto pqueue = pqueue_weak.lock();
if (!pqueue) {
pqueue = std::make_shared<WorkerQueue>(locator);
pqueue_weak = pqueue;
}
auto&& pimpl = sm_pool->locator2impl[{locator,
locator_logical}];
auto&& pimpl = sm_pool->locator2impl[{locator, locator_logical}];
if (!pimpl) {
mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE,
"too many cpu comp nodes; max %d allowed",
Pool::MAX_NR_COMP_NODE);
pimpl.reset(new (
&sm_pool->impl_storage[sm_pool->nr_used_impl_storage++])
CpuCompNodeImpl{locator, locator_logical,
pqueue});
CompNodeRecorderImpl{locator, locator_logical,
pqueue});
}
log_comp_node_created(locator, locator_logical);
return pimpl.get();
......@@ -829,16 +950,16 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator,
pqueue = std::make_shared<WorkerQueue>(locator);
pqueue_weak = pqueue;
}
auto&& pimpl = sm_pool->locator2impl_multi_thread[{
locator, locator_logical}];
auto&& pimpl =
sm_pool->locator2impl_multi_thread[{locator, locator_logical}];
if (!pimpl) {
mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE,
"too many cpu multithread comp nodes; max %d allowed",
Pool::MAX_NR_COMP_NODE);
pimpl.reset(new (
&sm_pool->impl_storage[sm_pool->nr_used_impl_storage++])
CpuCompNodeImpl{locator, locator_logical,
pqueue});
CompNodeRecorderImpl{locator, locator_logical,
pqueue});
}
log_comp_node_created(locator, locator_logical);
return pimpl.get();
......@@ -850,25 +971,12 @@ void CpuCompNode::sync_all() {
return;
MGB_LOCK_GUARD(sm_pool->mtx);
for (auto &&i: sm_pool->locator2impl)
for (auto&& i : sm_pool->locator2impl)
i.second->sync();
for (auto&& i : sm_pool->locator2impl_multi_thread)
i.second->sync();
}
bool CpuCompNode::CompNodeImpl::check_global_finalized(const char* reason) {
MGB_MARK_USED_VAR(reason);
if (this != sm_default_cpu_comp_node_ptr && !sm_pool) {
static std::atomic_flag warn_printed = ATOMIC_FLAG_INIT;
if (!warn_printed.test_and_set()) {
mgb_log_debug("cpu comp node method called after global finalize: "
"reason=%s", reason);
}
return true;
}
return false;
}
/* ======================== CompNode methods ======================== */
// CompNode get by default_cpu() is different from the CompNode which is
// produced by CompNode::load("cpu:default")
......@@ -878,9 +986,7 @@ bool CpuCompNode::CompNodeImpl::check_global_finalized(const char* reason) {
// CpuCompNode::Pool
CompNode CompNode::default_cpu() {
static Locator locator{DeviceType::CPU, Locator::DEVICE_CPU_DEFAULT, {-1}};
static auto empty_queue =
std::make_shared<CpuCompNode::WorkerQueue>(locator);
static CpuCompNodeImpl impl{locator, locator, empty_queue};
static CompNodeNoRecorderImpl impl{locator, locator};
return &impl;
}
......@@ -890,22 +996,20 @@ bool CompNode::enable_affinity_for_cpu(bool flag) {
return old;
}
/* ======================== EventImpl ======================== */
double CpuCompNode::CpuDispatchableBase::EventImpl::do_elapsed_time_until(
EventImplHelper &end) {
auto &&f1 = static_cast<EventImpl&>(end).m_prev_finish_time;
EventImplHelper& end) {
auto&& f1 = static_cast<EventImpl&>(end).m_prev_finish_time;
return m_prev_finish_time.time_until_secs(f1);
}
#if MGB_HAVE_THREAD
void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(
Impl *cn_impl) {
Impl* cn_impl) {
{
auto locator = m_comp_node_impl->locator();
if (locator.device == Locator::DEVICE_CPU_DEFAULT &&
!static_cast<CpuCompNode::CompNodeImpl*>(m_comp_node_impl)
!static_cast<CpuCompNode::CompNodeRecorderImpl*>(m_comp_node_impl)
->cur_recorder()) {
auto v0 = m_record_nr_req.load(std::memory_order_relaxed),
v1 = m_record_nr_finish.load(std::memory_order_relaxed);
......@@ -934,14 +1038,14 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(
mgb_throw(MegBrainError,
"Atlas comp_node used but MGB_ATLAS not enabled");
#endif
} else if (cn_impl->env().property().type == CompNode::DeviceType::CAMBRICON) {
} else if (cn_impl->env().property().type ==
CompNode::DeviceType::CAMBRICON) {
#if MGB_CAMBRICON
return m_comp_node_impl->sync();
#else
mgb_throw(MegBrainError,
"Cambricon comp_node used but MGB_CAMBRICON not enabled");
#endif
}
auto version = m_record_nr_req.load(std::memory_order_relaxed);
......@@ -991,14 +1095,15 @@ bool CpuCompNode::CpuDispatchableBase::EventImpl::do_finished() {
}
void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() {
for (size_t i = 0, it = SCQueueSynchronizer::get_default_max_spin() / 20; i < it; ++i) {
for (size_t i = 0, it = SCQueueSynchronizer::get_default_max_spin() / 20;
i < it; ++i) {
if (finished()) {
return;
}
}
m_dev_wait_nr_waiter.fetch_add(1, std::memory_order_release);
for (; ; ) {
for (;;) {
std::unique_lock<std::mutex> lock{m_dev_wait_mtx};
if (finished()) {
break;
......@@ -1011,23 +1116,23 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() {
CpuCompNode::CpuDispatchableBase::EventImpl::~EventImpl() noexcept {
auto check_all_finished = [this]() {
return do_finished() &&
!m_dev_wait_nr_waiter.load(std::memory_order_acquire);
!m_dev_wait_nr_waiter.load(std::memory_order_acquire);
};
if (!check_all_finished()) {
mgb_log_debug("event %p has unfinished callbacks when destructed; "
"waiting ...", this);
mgb_log_debug(
"event %p has unfinished callbacks when destructed; "
"waiting ...",
this);
while (!check_all_finished()) {
std::this_thread::yield();
}
}
}
#else // MGB_HAVE_THREAD
#else // MGB_HAVE_THREAD
void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() {
}
void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() {}
void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(Impl*) {
}
void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(Impl*) {}
void CpuCompNode::CpuDispatchableBase::EventImpl::do_record() {
if (m_create_flags & Flags::NEED_TIMER) {
......@@ -1035,8 +1140,7 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_record() {
}
}
void CpuCompNode::CpuDispatchableBase::EventImpl::on_finish() {
}
void CpuCompNode::CpuDispatchableBase::EventImpl::on_finish() {}
bool CpuCompNode::CpuDispatchableBase::EventImpl::do_finished() {
return true;
......@@ -1046,5 +1150,4 @@ CpuCompNode::CpuDispatchableBase::EventImpl::~EventImpl() noexcept = default;
#endif // MGB_HAVE_THREAD
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -54,7 +54,9 @@ namespace mgb {
void add_callback(Task&& task) override;
};
class CompNodeImpl;
class CompNodeBaseImpl;
class CompNodeNoRecorderImpl;
class CompNodeRecorderImpl;
static void foreach(thin_function<void(CompNode)> callback);
static void finalize();
......
......@@ -100,6 +100,26 @@ void run_comp_seq_rec_basic_level2(CompNode cn) {
MGB_ASSERT_TENSOR_NEAR(expect, host_z, 1e-3) << "iter " << iter;
}
ASSERT_EQ(executed.size(), 2u);
//! test default_cpu with record2
{
HostTensorND hz;
graph = ComputingGraph::make();
x = opr::Host2DeviceCopy::make(*graph, host_x);
y = opr::Host2DeviceCopy::make(*graph, host_y);
z = opr::ConvBias::make(x, y, param);
z = opr::GetVarShape::make(z);
graph->options().comp_node_seq_record_level = 2;
graph->options().var_sanity_check_first_run = false;
auto func = graph->compile({make_callback_copy(z, hz, true)});
ComputingGraph::assert_destroy(graph);
func->execute();
ASSERT_TRUE(hz.comp_node() == cn);
ASSERT_EQ(hz.ptr<int>()[0], 3);
ASSERT_EQ(hz.ptr<int>()[1], 6);
ASSERT_EQ(hz.ptr<int>()[2], 8);
ASSERT_EQ(hz.ptr<int>()[3], 6);
}
}
void run_comp_seq_rec_dyn_elemwise(CompNode cn, bool fake_first) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册