提交 fc8b501b 编写于 作者: M Megvii Engine Team

refactor(mgb/core): refactor cpu compnode so that default cpu has no ability to record

GitOrigin-RevId: 7de4771476e87d3ed11cffbad4c02741591ef9a2
上级 b7176069
...@@ -11,18 +11,18 @@ ...@@ -11,18 +11,18 @@
#include "./comp_node.h" #include "./comp_node.h"
#include "megbrain/common.h"
#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
#include "megbrain/system.h" #include "megbrain/system.h"
#include "megbrain/utils/arith_helper.h" #include "megbrain/utils/arith_helper.h"
#include "megbrain/utils/thread.h" #include "megbrain/utils/thread.h"
#include "megbrain/utils/timer.h"
#include "megbrain/utils/thread_pool.h" #include "megbrain/utils/thread_pool.h"
#include "megbrain/common.h" #include "megbrain/utils/timer.h"
#include <atomic>
#include <condition_variable> #include <condition_variable>
#include <cstdint> #include <cstdint>
#include <cstring> #include <cstring>
#include <atomic>
#include <stdlib.h> #include <stdlib.h>
#ifndef __APPLE__ #ifndef __APPLE__
...@@ -44,8 +44,6 @@ struct TaskElem { ...@@ -44,8 +44,6 @@ struct TaskElem {
}; };
} // anonymous namespace } // anonymous namespace
using CpuCompNodeImpl = CpuCompNode::CompNodeImpl;
void CpuCompNode::CpuDispatchableBase::add_callback(Task&& task) { void CpuCompNode::CpuDispatchableBase::add_callback(Task&& task) {
dispatch(std::move(task)); dispatch(std::move(task));
} }
...@@ -110,7 +108,15 @@ class CpuCompNode::SeqRecorderImpl final : public CompNodeSeqRecorder { ...@@ -110,7 +108,15 @@ class CpuCompNode::SeqRecorderImpl final : public CompNodeSeqRecorder {
* \brief use to check the all ther recording tasks are its self CompNode * \brief use to check the all ther recording tasks are its self CompNode
* related task, void hook other CompNode related task to the recorder. * related task, void hook other CompNode related task to the recorder.
*/ */
void check_the_same_comp_node(const CompNode& comp_node) const; void check_the_same_comp_node(const CompNode& comp_node) const {
if (mgb_unlikely(comp_node.valid())) {
mgb_assert(m_record_compnode == comp_node,
"CompNode %s can't hook in CompNode %s when recording\n",
comp_node.locator().to_string().c_str(),
m_record_compnode.locator().to_string().c_str());
}
}
public: public:
SeqRecorderImpl(SeqRecorderImpl** self_pointer, ThreadPool* thread_pool, SeqRecorderImpl(SeqRecorderImpl** self_pointer, ThreadPool* thread_pool,
const CompNode& comp_node) const CompNode& comp_node)
...@@ -127,13 +133,13 @@ public: ...@@ -127,13 +133,13 @@ public:
} }
} }
void enter_fake_exec(const CompNode& comp_node) override { void enter_fake_exec(const CompNode& comp_node) override {
check_the_same_comp_node(comp_node); check_the_same_comp_node(comp_node);
mgb_assert(!m_stopped && !m_fake_exec); mgb_assert(!m_stopped && !m_fake_exec);
m_fake_exec = true; m_fake_exec = true;
} }
void exit_fake_exec(const CompNode& comp_node) override { void exit_fake_exec(const CompNode& comp_node) override {
check_the_same_comp_node(comp_node); check_the_same_comp_node(comp_node);
mgb_assert(!m_stopped && m_fake_exec); mgb_assert(!m_stopped && m_fake_exec);
mgb_assert(m_tasks.empty()); mgb_assert(m_tasks.empty());
...@@ -165,9 +171,9 @@ public: ...@@ -165,9 +171,9 @@ public:
m_thread_pool->add_task(i); m_thread_pool->add_task(i);
} }
m_thread_pool->deactive(); m_thread_pool->deactive();
}else{ } else {
for (auto&& task : m_tasks) { for (auto&& task : m_tasks) {
for(size_t i=0; i<task.nr_parallelism;i++){ for (size_t i = 0; i < task.nr_parallelism; i++) {
task.task(i, 0); task.task(i, 0);
} }
} }
...@@ -236,273 +242,113 @@ public: ...@@ -236,273 +242,113 @@ public:
ThreadPool* get_thread_pool() { return m_thread_pool; } ThreadPool* get_thread_pool() { return m_thread_pool; }
}; };
class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase { using CompNodeBaseImpl = CpuCompNode::CompNodeBaseImpl;
MGB_DYN_TYPE_OBJ_FINAL_DECL; using CompNodeNoRecorderImpl = CpuCompNode::CompNodeNoRecorderImpl;
using CompNodeRecorderImpl = CpuCompNode::CompNodeRecorderImpl;
//! used during comp node seq rec
class CompSeqRecEventImpl;
class CpuEventImpl;
//! TODO: because the x-code bug, see
//! https://github.com/tensorflow/tensorflow/issues/18356
//! thread local is no support on IOS,
//! When update x-xode, this code should be deleted
#if !defined(IOS) && MGB_HAVE_THREAD
static thread_local SeqRecorderImpl* sm_cur_recorder;
#else
SeqRecorderImpl* sm_cur_recorder = nullptr;
#endif
std::shared_ptr<WorkerQueue> m_worker_queue; //! ==================== CompNodeBaseImpl ======================
class CpuCompNode::CompNodeBaseImpl : public CpuDispatchableBase {
protected:
Locator m_locator, m_locator_logical; Locator m_locator, m_locator_logical;
std::unique_ptr<ThreadPool> m_thread_pool;
//! ptr to default cpu, only used by check_global_finalized
static CpuCompNodeImpl *sm_default_cpu_comp_node_ptr;
//! return whether global finalized, and print warning in such case
inline bool check_global_finalized(const char* reason);
static void static_free_device(ImplBase* self, void* ptr) {
static_cast<CompNodeImpl*>(self)->free_device(ptr);
}
static void static_free_host(ImplBase* self, void* ptr) {
static_cast<CompNodeImpl*>(self)->free_host(ptr);
}
public: public:
CompNodeImpl(const Locator& locator, const Locator& locator_logical, CompNodeBaseImpl(const Locator& locator, const Locator& locator_logical,
const std::shared_ptr<WorkerQueue>& worker_queue); free_func_t fd, free_func_t fh)
~CompNodeImpl() { : CpuDispatchableBase(fd, fh),
if (sm_cur_recorder) { m_locator(locator),
sm_cur_recorder->stop(); m_locator_logical(locator_logical) {}
}
if (m_worker_queue) {
// synchronize before fini
m_worker_queue->wait_all_task_finish();
}
m_env.fini();
if (m_worker_queue) {
// wait for new kernels dispatched in fini() (like free_device())
m_worker_queue->wait_all_task_finish();
}
if (this == sm_default_cpu_comp_node_ptr) {
// This should only happen in global library .fini. We clear
// sm_default_cpu_comp_node_ptr so check_global_finalized() can
// work correctly
sm_default_cpu_comp_node_ptr = nullptr;
}
}
ThreadPool* get_thread_pool() const { return m_thread_pool.get(); } virtual ~CompNodeBaseImpl() {}
void* mgb_aligned_alloc(size_t size) { void* mgb_aligned_alloc(size_t size) {
auto alignment = get_mem_addr_alignment(); auto alignment = get_mem_addr_alignment();
#ifdef WIN32 #ifdef WIN32
return _aligned_malloc(size, alignment); return _aligned_malloc(size, alignment);
#elif defined(__ANDROID__) || defined(ANDROID) #elif defined(__ANDROID__) || defined(ANDROID)
return memalign(alignment, size); return memalign(alignment, size);
#else #else
void *ptr = nullptr; void* ptr = nullptr;
auto err = posix_memalign(&ptr, alignment, size); auto err = posix_memalign(&ptr, alignment, size);
mgb_assert(!err, "failed to malloc %zubytes with align %zu", mgb_assert(!err, "failed to malloc %zubytes with align %zu", size,
size, alignment); alignment);
return ptr; return ptr;
#endif #endif
} }
static void mgb_aligned_free(void* ptr) { static void mgb_aligned_free(void* ptr) {
#ifdef WIN32 #ifdef WIN32
_aligned_free(ptr); _aligned_free(ptr);
#else #else
::free(ptr); ::free(ptr);
#endif #endif
} }
void* alloc_device(size_t size) override {
if (sm_cur_recorder) {
sm_cur_recorder->on_alloc(this);
}
return mgb_aligned_alloc(size);
}
void free_device(void *ptr) {
if (sm_cur_recorder || check_global_finalized("free_device()")) {
mgb_aligned_free(ptr);
if (sm_cur_recorder) {
sm_cur_recorder->on_free(this);
}
return;
} else {
auto do_free = [ptr]() {
mgb_aligned_free(ptr);
};
m_env.cpu_env().dispatch(do_free);
}
}
void *alloc_host(size_t size) override {
if (m_worker_queue) {
m_worker_queue->check_exception();
}
return mgb_aligned_alloc(size);
}
void free_host(void *ptr) {
if (check_global_finalized("free_host()")) {
mgb_aligned_free(ptr);
return;
}
if (m_worker_queue) {
m_worker_queue->check_exception();
}
return mgb_aligned_free(ptr);
}
void copy_to_host(void *host_ptr,
const void *device_ptr, size_t size) override {
if (m_worker_queue) {
m_worker_queue->check_exception();
}
// use lambda capture to avoid memory allocation in std::bind
auto do_copy = [host_ptr, device_ptr, size]() {
std::memcpy(host_ptr, device_ptr, size);
};
m_env.cpu_env().dispatch(do_copy);
}
void copy_to_device(void *device_ptr,
const void *host_ptr, size_t size) override {
if (m_worker_queue) {
m_worker_queue->check_exception();
}
// use lambda capture to avoid memory allocation in std::bind
auto do_copy = [device_ptr, host_ptr, size]() {
std::memcpy(device_ptr, host_ptr, size);
};
m_env.cpu_env().dispatch(do_copy);
}
void peer_copy_to( void* alloc_device(size_t size) override { return mgb_aligned_alloc(size); }
Impl *dest_impl, void *dest,
const void *src, size_t size) override {
if (!dest_impl->same_type<CpuCompNode::CompNodeImpl>()) {
if (dest_impl->env().property().type == DeviceType::ATLAS) {
#if MGB_ATLAS
dest_impl->copy_to_device(dest, src, size);
return;
#else
mgb_throw(MegBrainError,
"Atlas comp_node used but "
"MGB_ATLAS not enabled");
#endif
} else if (dest_impl->env().property().type ==
DeviceType::CAMBRICON) {
#if MGB_CAMBRICON
dest_impl->copy_to_device(dest, src, size);
return;
#else
mgb_throw(MegBrainError,
"Cambricon comp_node used but "
"MGB_CAMBRICON not enabled");
#endif
} else { void* alloc_host(size_t size) override { return mgb_aligned_alloc(size); }
mgb_assert(locator().device == Locator::DEVICE_CPU_DEFAULT,
"currently only peer copy from default cpu comp "
"nodes "
"is implemented");
}
}
dest_impl->copy_to_device(dest, src, size);
}
size_t get_mem_addr_alignment() override { void copy_to_host(void* host_ptr, const void* device_ptr,
return m_env.property().mem_alignment; size_t size) override {
} // use lambda capture to avoid memory allocation in std::bind
auto do_copy = [host_ptr, device_ptr, size]() {
std::memcpy(host_ptr, device_ptr, size);
};
m_env.cpu_env().dispatch(do_copy);
}
std::unique_ptr<Event> create_event(size_t flags) override; void copy_to_device(void* device_ptr, const void* host_ptr,
size_t size) override {
// use lambda capture to avoid memory allocation in std::bind
auto do_copy = [device_ptr, host_ptr, size]() {
std::memcpy(device_ptr, host_ptr, size);
};
m_env.cpu_env().dispatch(do_copy);
}
void sync() override { void peer_copy_to(Impl* dest_impl, void* dest, const void* src,
if (sm_cur_recorder) { size_t size) override {
sm_cur_recorder->on_sync(this); dest_impl->copy_to_device(dest, src, size);
} else if (m_worker_queue) { }
m_worker_queue->wait_all_task_finish();
}
if (m_thread_pool) {
m_thread_pool->deactive();
}
}
void dispatch(Task &&task) override { size_t get_mem_addr_alignment() override {
m_env.cpu_env().dispatch(std::move(task)); return m_env.property().mem_alignment;
} }
MemNode mem_node() override { void dispatch(Task&& task) override {
// TODO: numa nodes m_env.cpu_env().dispatch(std::move(task));
return get_host_cpu_mem_node(); }
}
std::pair<size_t, size_t> get_mem_status_bytes() override { MemNode mem_node() override {
return sys::get_ram_status_bytes(); // TODO: numa nodes
} return get_host_cpu_mem_node();
}
Locator locator() override { std::pair<size_t, size_t> get_mem_status_bytes() override {
return m_locator; return sys::get_ram_status_bytes();
} }
Locator locator_logical() override { Locator locator() override { return m_locator; }
return m_locator_logical;
}
std::unique_ptr<CompNodeSeqRecorder> create_seq_recorder( Locator locator_logical() override { return m_locator_logical; }
cg::ComputingGraph*) override {
return std::make_unique<SeqRecorderImpl>(&sm_cur_recorder,
m_thread_pool.get(), this);
}
//! current sequence recorder of this thread void add_callback(Task&& task) override {
#if !defined(IOS) && MGB_HAVE_THREAD CpuDispatchableBase::add_callback(std::move(task));
static SeqRecorderImpl* cur_recorder() { return sm_cur_recorder; } }
#else
SeqRecorderImpl* cur_recorder() { return sm_cur_recorder; }
#endif
void add_callback(Task &&task) override { virtual SeqRecorderImpl* cur_recorder() const = 0;
if (!check_global_finalized("add_callback()")) {
CpuDispatchableBase::add_callback(std::move(task));
} else {
task();
}
}
}; };
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CpuCompNodeImpl);
CpuCompNodeImpl* CpuCompNodeImpl::sm_default_cpu_comp_node_ptr;
#if !defined(IOS) && MGB_HAVE_THREAD
thread_local CpuCompNode::SeqRecorderImpl* CpuCompNodeImpl::sm_cur_recorder =
nullptr;
#endif
void CpuCompNode::SeqRecorderImpl::check_the_same_comp_node(
const CompNode& comp_node) const {
if (mgb_unlikely(comp_node.valid())) {
mgb_assert(m_record_compnode == comp_node,
"CompNode %s can't hook in CompNode %s when recording\n",
comp_node.locator().to_string().c_str(),
m_record_compnode.locator().to_string().c_str());
}
}
//! implementation of CPUDispatcher that is passed to megdnn via megcore //! implementation of CPUDispatcher that is passed to megdnn via megcore
class CpuCompNode::WorkerQueue::DispatcherImpl final: public CPUDispatcher { class CpuCompNode::WorkerQueue::DispatcherImpl final : public CPUDispatcher {
std::atomic_size_t m_nr_task{0}; std::atomic_size_t m_nr_task{0};
std::shared_ptr<WorkerQueue> m_queue; std::shared_ptr<WorkerQueue> m_queue;
CpuCompNode::CompNodeImpl* const m_comp_node; //! DispatcherImpl only used by CompNodeRecorderImpl, but we still use
//! CompNodeBaseImpl* because of incomplete type error
CompNodeBaseImpl* const m_comp_node;
public: public:
DispatcherImpl(const std::shared_ptr<WorkerQueue>& queue, DispatcherImpl(const std::shared_ptr<WorkerQueue>& queue,
CpuCompNode::CompNodeImpl* comp_node) CompNodeBaseImpl* comp_node)
: m_queue{queue}, m_comp_node{comp_node} {} : m_queue{queue}, m_comp_node{comp_node} {}
void dispatch(Task&& task) override { void dispatch(Task&& task) override {
...@@ -559,10 +405,12 @@ public: ...@@ -559,10 +405,12 @@ public:
class InplaceCPUDispatcher final : public CPUDispatcher { class InplaceCPUDispatcher final : public CPUDispatcher {
std::atomic_size_t m_nr_task{0}; std::atomic_size_t m_nr_task{0};
ThreadPool* m_thread_pool = nullptr; ThreadPool* m_thread_pool = nullptr;
CpuCompNode::CompNodeImpl* const m_comp_node; //! InplaceCPUDispatcher may used by both type of compnodes, so
//! m_comp_node's type should be base class.
CompNodeBaseImpl* const m_comp_node;
public: public:
InplaceCPUDispatcher(CpuCompNode::CompNodeImpl* comp_node, InplaceCPUDispatcher(CompNodeBaseImpl* comp_node,
ThreadPool* thread_pool = nullptr) ThreadPool* thread_pool = nullptr)
: m_thread_pool(thread_pool), m_comp_node(comp_node) {} : m_thread_pool(thread_pool), m_comp_node(comp_node) {}
...@@ -585,9 +433,9 @@ public: ...@@ -585,9 +433,9 @@ public:
} else if (m_thread_pool) { } else if (m_thread_pool) {
m_nr_task.fetch_add(1, std::memory_order_relaxed); m_nr_task.fetch_add(1, std::memory_order_relaxed);
m_thread_pool->add_task({task, parallelism}); m_thread_pool->add_task({task, parallelism});
}else{ } else {
m_nr_task.fetch_add(1, std::memory_order_relaxed); m_nr_task.fetch_add(1, std::memory_order_relaxed);
for(size_t i=0; i<parallelism;i++){ for (size_t i = 0; i < parallelism; i++) {
task(i, 0); task(i, 0);
} }
} }
...@@ -612,143 +460,417 @@ public: ...@@ -612,143 +460,417 @@ public:
recorder->get_thread_pool()->set_affinity(affinity_cb); recorder->get_thread_pool()->set_affinity(affinity_cb);
} else if (m_thread_pool) { } else if (m_thread_pool) {
m_thread_pool->set_affinity(affinity_cb); m_thread_pool->set_affinity(affinity_cb);
}else{ } else {
affinity_cb(0); affinity_cb(0);
} }
} }
}; };
CpuCompNode::CompNodeImpl::CompNodeImpl( //! ==================== CompNodeNoRecorderImpl ======================
const Locator& locator, const Locator& locator_logical, /**
const std::shared_ptr<WorkerQueue>& worker_queue) * \note: CompNodeNoRecorderImpl will use most implements in base including:
: CpuDispatchableBase(static_free_device, static_free_host), * alloc_device, alloc_host, copy_to_host, copy_to_device, peer_copy_to,
m_worker_queue{worker_queue}, * add_callback ...
m_locator(locator), */
m_locator_logical(locator_logical) { class CpuCompNode::CompNodeNoRecorderImpl final : public CompNodeBaseImpl {
auto cn = make_comp_node_from_impl(this); MGB_DYN_TYPE_OBJ_FINAL_DECL;
if (locator.type == DeviceType::MULTITHREAD) {
m_thread_pool = std::unique_ptr<ThreadPool>( public:
new ThreadPool(static_cast<size_t>(locator.nr_threads))); //! ptr to default cpu, only used by check_global_finalized
mgb_assert(m_thread_pool, "ThradPool create failed"); static CompNodeNoRecorderImpl* sm_default_cpu_comp_node_ptr;
static void static_free_device(ImplBase* self, void* ptr) {
static_cast<CompNodeNoRecorderImpl*>(self)->free_device(ptr);
} }
if (locator.type == DeviceType::CPU) { static void static_free_host(ImplBase* self, void* ptr) {
if (locator.device == Locator::DEVICE_CPU_DEFAULT) { static_cast<CompNodeNoRecorderImpl*>(self)->free_host(ptr);
sm_default_cpu_comp_node_ptr = this; }
m_env.init_cpu({std::make_shared<InplaceCPUDispatcher>(this)}, cn); using CpuEventImpl = CpuDispatchableBase::EventImpl;
} else {
m_env.init_cpu({std::make_shared<WorkerQueue::DispatcherImpl>( CompNodeNoRecorderImpl(const Locator& locator,
m_worker_queue, this)}, const Locator& locator_logical)
cn); : CompNodeBaseImpl(locator, locator_logical, static_free_device,
} static_free_host) {
} else if (locator.type == DeviceType::MULTITHREAD) { mgb_assert(
if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) { locator.type == DeviceType::CPU &&
m_env.init_cpu({std::make_shared<InplaceCPUDispatcher>( locator.device == Locator::DEVICE_CPU_DEFAULT,
this, m_thread_pool.get())}, "CompNodeNoRecorder is only constructed On DEVICE_CPU_DEFAULT");
cn); auto cn = make_comp_node_from_impl(this);
} else { m_env.init_cpu({std::make_shared<InplaceCPUDispatcher>(this)}, cn);
m_worker_queue->attach_thread_pool(m_thread_pool.get()); sm_default_cpu_comp_node_ptr = this;
m_env.init_cpu({std::make_shared<WorkerQueue::DispatcherImpl>( }
m_worker_queue, this)},
cn); ~CompNodeNoRecorderImpl() {
m_env.fini();
sm_default_cpu_comp_node_ptr = nullptr;
}
//! return whether global finalized, and print warning in such case
bool check_global_finalized(const char* reason) {
MGB_MARK_USED_VAR(reason);
if (!sm_default_cpu_comp_node_ptr) {
static std::atomic_flag warn_printed = ATOMIC_FLAG_INIT;
if (!warn_printed.test_and_set()) {
mgb_log_debug(
"cpu comp node method called after global finalize: "
"reason=%s",
reason);
}
return true;
} }
return false;
} }
}
class CpuCompNodeImpl::CompSeqRecEventImpl final void free_device(void* ptr) {
: public CpuDispatchableBase::EventImpl { if (check_global_finalized("free_device()")) {
void do_record() override { CompNodeBaseImpl::mgb_aligned_free(ptr);
auto impl = static_cast<CpuCompNodeImpl*>(m_comp_node_impl); return;
if (auto rec = impl->cur_recorder()) {
auto callback = [this]() {
incr_nr_req();
on_finish();
};
rec->dispatch_allow_after_sync(callback, m_comp_node_impl);
} else { } else {
EventImpl::do_record(); auto do_free = [ptr]() { CompNodeBaseImpl::mgb_aligned_free(ptr); };
m_env.cpu_env().dispatch(do_free);
} }
} }
void do_device_wait_by(Impl*) override { void free_host(void* ptr) {
mgb_throw(MegBrainError, check_global_finalized("free_host()");
"device_wait() should not be called on events created during " return CompNodeBaseImpl::mgb_aligned_free(ptr);
"comp node seq recording");
} }
public: std::unique_ptr<Event> create_event(size_t flags) override {
using EventImpl::EventImpl; return std::make_unique<CpuEventImpl>(this, flags);
}
void sync() override {}
std::unique_ptr<CompNodeSeqRecorder> create_seq_recorder(
cg::ComputingGraph*) override {
mgb_assert(false, "default_cpu has no ability to record");
return nullptr;
}
SeqRecorderImpl* cur_recorder() const override { return nullptr; }
}; };
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompNodeNoRecorderImpl);
CompNodeNoRecorderImpl* CompNodeNoRecorderImpl::sm_default_cpu_comp_node_ptr =
nullptr;
//! ==================== CompNodeRecorderImpl ======================
class CpuCompNode::CompNodeRecorderImpl final : public CompNodeBaseImpl {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
std::unique_ptr<ThreadPool> m_thread_pool;
std::shared_ptr<WorkerQueue> m_worker_queue;
//! used during comp node seq rec
class CompSeqRecEventImpl final : public CpuDispatchableBase::EventImpl {
void do_record() override {
auto impl = static_cast<CompNodeRecorderImpl*>(m_comp_node_impl);
if (auto rec = impl->cur_recorder()) {
auto callback = [this]() {
incr_nr_req();
on_finish();
};
rec->dispatch_allow_after_sync(callback, m_comp_node_impl);
} else {
EventImpl::do_record();
}
}
void do_device_wait_by(Impl*) override {
mgb_throw(MegBrainError,
"device_wait() should not be called on events created "
"during "
"comp node seq recording");
}
class CpuCompNodeImpl::CpuEventImpl final public:
: public CpuDispatchableBase::EventImpl { using EventImpl::EventImpl;
};
class CpuEventImpl final : public CpuDispatchableBase::EventImpl {
#if MGB_HAVE_THREAD #if MGB_HAVE_THREAD
void host_wait_cv() override { void host_wait_cv() override {
CpuDispatchableBase::EventImpl::host_wait_cv(); CpuDispatchableBase::EventImpl::host_wait_cv();
auto thread_pool = static_cast<CpuCompNodeImpl*>(m_comp_node_impl) auto thread_pool =
->get_thread_pool(); static_cast<CompNodeRecorderImpl*>(m_comp_node_impl)
if (thread_pool) { ->get_thread_pool();
thread_pool->deactive(); if (thread_pool) {
thread_pool->deactive();
}
} }
}
#endif #endif
public:
using EventImpl::EventImpl;
};
//! TODO: because the x-code bug, see
//! https://github.com/tensorflow/tensorflow/issues/18356
//! thread local is no support on IOS,
//! When update x-xode, this code should be deleted
#if !defined(IOS) && MGB_HAVE_THREAD
static thread_local SeqRecorderImpl* sm_cur_recorder;
#else
SeqRecorderImpl* sm_cur_recorder = nullptr;
#endif
public: public:
using EventImpl::EventImpl; static void static_free_device(ImplBase* self, void* ptr) {
}; static_cast<CompNodeRecorderImpl*>(self)->free_device(ptr);
}
std::unique_ptr<CompNode::Event> CpuCompNodeImpl::create_event(size_t flags) { static void static_free_host(ImplBase* self, void* ptr) {
if (m_worker_queue) { static_cast<CompNodeRecorderImpl*>(self)->free_host(ptr);
m_worker_queue->check_exception();
} }
if (sm_cur_recorder) {
return std::make_unique<CompSeqRecEventImpl>(this, flags); CompNodeRecorderImpl(const Locator& locator, const Locator& locator_logical,
} else { const std::shared_ptr<WorkerQueue>& worker_queue)
return std::make_unique<CpuEventImpl>(this, flags); : CompNodeBaseImpl(locator, locator_logical, static_free_device,
static_free_host),
m_worker_queue(worker_queue) {
auto cn = make_comp_node_from_impl(this);
if (locator.type == DeviceType::MULTITHREAD) {
m_thread_pool = std::unique_ptr<ThreadPool>(
new ThreadPool(static_cast<size_t>(locator.nr_threads)));
mgb_assert(m_thread_pool, "ThradPool create failed");
}
if (locator.type == DeviceType::CPU) {
if (locator.device == Locator::DEVICE_CPU_DEFAULT) {
m_env.init_cpu({std::make_shared<InplaceCPUDispatcher>(this)},
cn);
} else {
m_env.init_cpu({std::make_shared<WorkerQueue::DispatcherImpl>(
m_worker_queue, this)},
cn);
}
} else if (locator.type == DeviceType::MULTITHREAD) {
if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) {
m_env.init_cpu({std::make_shared<InplaceCPUDispatcher>(
this, m_thread_pool.get())},
cn);
} else {
m_worker_queue->attach_thread_pool(m_thread_pool.get());
m_env.init_cpu({std::make_shared<WorkerQueue::DispatcherImpl>(
m_worker_queue, this)},
cn);
}
}
} }
}
~CompNodeRecorderImpl() {
if (sm_cur_recorder) {
sm_cur_recorder->stop();
}
if (m_worker_queue) {
// synchronize before fini
m_worker_queue->wait_all_task_finish();
}
m_env.fini();
if (m_worker_queue) {
// wait for new kernels dispatched in fini() (like free_device())
m_worker_queue->wait_all_task_finish();
}
}
ThreadPool* get_thread_pool() const { return m_thread_pool.get(); }
//! return whether global finalized, and print warning in such case
bool check_global_finalized(const char* reason) {
MGB_MARK_USED_VAR(reason);
if (!sm_pool) {
static std::atomic_flag warn_printed = ATOMIC_FLAG_INIT;
if (!warn_printed.test_and_set()) {
mgb_log_debug(
"cpu comp node method called after global finalize: "
"reason=%s",
reason);
}
return true;
}
return false;
}
void* alloc_device(size_t size) override {
if (sm_cur_recorder) {
sm_cur_recorder->on_alloc(this);
}
return CompNodeBaseImpl::alloc_device(size);
}
void free_device(void* ptr) {
if (sm_cur_recorder || check_global_finalized("free_device()")) {
CompNodeBaseImpl::mgb_aligned_free(ptr);
if (sm_cur_recorder) {
sm_cur_recorder->on_free(this);
}
return;
} else {
auto do_free = [ptr]() { CompNodeBaseImpl::mgb_aligned_free(ptr); };
m_env.cpu_env().dispatch(do_free);
}
}
void* alloc_host(size_t size) override {
if (m_worker_queue) {
m_worker_queue->check_exception();
}
return CompNodeBaseImpl::alloc_host(size);
}
void free_host(void* ptr) {
if (check_global_finalized("free_host()")) {
CompNodeBaseImpl::mgb_aligned_free(ptr);
return;
}
if (m_worker_queue) {
m_worker_queue->check_exception();
}
CompNodeBaseImpl::mgb_aligned_free(ptr);
}
void copy_to_host(void* host_ptr, const void* device_ptr,
size_t size) override {
if (m_worker_queue) {
m_worker_queue->check_exception();
}
CompNodeBaseImpl::copy_to_host(host_ptr, device_ptr, size);
}
void copy_to_device(void* device_ptr, const void* host_ptr,
size_t size) override {
if (m_worker_queue) {
m_worker_queue->check_exception();
}
CompNodeBaseImpl::copy_to_device(device_ptr, host_ptr, size);
}
void peer_copy_to(Impl* dest_impl, void* dest, const void* src,
size_t size) override {
//! copy to default_cpu
if (dest_impl->same_type<CpuCompNode::CompNodeNoRecorderImpl>()) {
CompNodeBaseImpl::peer_copy_to(dest_impl, dest, src, size);
return;
}
if (!dest_impl->same_type<CpuCompNode::CompNodeRecorderImpl>()) {
if (dest_impl->env().property().type == DeviceType::ATLAS) {
#if MGB_ATLAS
dest_impl->copy_to_device(dest, src, size);
return;
#else
mgb_throw(MegBrainError,
"Atlas comp_node used but "
"MGB_ATLAS not enabled");
#endif
} else if (dest_impl->env().property().type ==
DeviceType::CAMBRICON) {
#if MGB_CAMBRICON
dest_impl->copy_to_device(dest, src, size);
return;
#else
mgb_throw(MegBrainError,
"Cambricon comp_node used but "
"MGB_CAMBRICON not enabled");
#endif
}
else {
mgb_assert(locator().device == Locator::DEVICE_CPU_DEFAULT,
"currently only peer copy from default cpu comp "
"nodes "
"is implemented");
}
}
dest_impl->copy_to_device(dest, src, size);
}
std::unique_ptr<Event> create_event(size_t flags) override {
if (m_worker_queue) {
m_worker_queue->check_exception();
}
if (sm_cur_recorder) {
return std::make_unique<CompSeqRecEventImpl>(this, flags);
} else {
return std::make_unique<CpuEventImpl>(this, flags);
}
}
void sync() override {
if (sm_cur_recorder) {
sm_cur_recorder->on_sync(this);
} else if (m_worker_queue) {
m_worker_queue->wait_all_task_finish();
}
if (m_thread_pool) {
m_thread_pool->deactive();
}
}
std::unique_ptr<CompNodeSeqRecorder> create_seq_recorder(
cg::ComputingGraph*) override {
return std::make_unique<SeqRecorderImpl>(&sm_cur_recorder,
m_thread_pool.get(), this);
}
SeqRecorderImpl* cur_recorder() const override { return sm_cur_recorder; }
void add_callback(Task&& task) override {
if (!check_global_finalized("add_callback()")) {
CompNodeBaseImpl::add_callback(std::move(task));
} else {
task();
}
}
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompNodeRecorderImpl);
#if !defined(IOS) && MGB_HAVE_THREAD
thread_local CpuCompNode::SeqRecorderImpl*
CompNodeRecorderImpl::sm_cur_recorder = nullptr;
#endif
/* ======================== CpuCompNode ======================== */ /* ======================== CpuCompNode ======================== */
struct CpuCompNode::Pool { struct CpuCompNode::Pool {
static constexpr int MAX_NR_COMP_NODE = 1024; static constexpr int MAX_NR_COMP_NODE = 1024;
struct CpuCompNodeImplDeleter { struct CompNodeRecorderImplDeleter {
void operator ()(CpuCompNodeImpl *p) { void operator()(CompNodeRecorderImpl* p) { p->~CompNodeRecorderImpl(); }
p->~CpuCompNodeImpl();
}
}; };
std::recursive_mutex mtx; std::recursive_mutex mtx;
// use global memory pool to ensuare object memory accessible even after // use global memory pool to ensuare object memory accessible even after
// global finalize // global finalize
std::aligned_storage_t<sizeof(CpuCompNodeImpl), alignof(CpuCompNodeImpl)> std::aligned_storage_t<sizeof(CompNodeRecorderImpl),
impl_storage[MAX_NR_COMP_NODE]; alignof(CompNodeRecorderImpl)>
impl_storage[MAX_NR_COMP_NODE];
size_t nr_used_impl_storage = 0; size_t nr_used_impl_storage = 0;
std::unordered_map<CompNode::LocatorPairHashKey, std::unordered_map<
std::unique_ptr<CpuCompNodeImpl, CpuCompNodeImplDeleter>, CompNode::LocatorPairHashKey,
CompNode::LocatorPairHashKey::Hash> locator2impl; std::unique_ptr<CompNodeRecorderImpl, CompNodeRecorderImplDeleter>,
CompNode::LocatorPairHashKey::Hash>
locator2impl;
ThinHashMap<std::pair<int, int>, std::weak_ptr<WorkerQueue>> physical2queue; ThinHashMap<std::pair<int, int>, std::weak_ptr<WorkerQueue>> physical2queue;
std::unordered_map<CompNode::LocatorPairHashKey, std::unordered_map<
std::unique_ptr<CpuCompNodeImpl, CpuCompNodeImplDeleter>, CompNode::LocatorPairHashKey,
CompNode::LocatorPairHashKey::Hash> locator2impl_multi_thread; std::unique_ptr<CompNodeRecorderImpl, CompNodeRecorderImplDeleter>,
CompNode::LocatorPairHashKey::Hash>
locator2impl_multi_thread;
ThinHashMap<std::pair<int, int>, std::weak_ptr<WorkerQueue>> ThinHashMap<std::pair<int, int>, std::weak_ptr<WorkerQueue>>
physical2queue_multithead; physical2queue_multithead;
}; };
CpuCompNode::Pool* CpuCompNode::sm_pool; CpuCompNode::Pool* CpuCompNode::sm_pool;
Spinlock CpuCompNode::sm_pool_mtx; Spinlock CpuCompNode::sm_pool_mtx;
void CpuCompNode::foreach(thin_function<void(CompNode)> callback) { void CpuCompNode::foreach (thin_function<void(CompNode)> callback) {
if (!sm_pool) if (!sm_pool)
return; return;
for (size_t i = 0; ; ++ i) { for (size_t i = 0;; ++i) {
CompNode cur; CompNode cur;
{ {
MGB_LOCK_GUARD(sm_pool->mtx); MGB_LOCK_GUARD(sm_pool->mtx);
if (i >= sm_pool->nr_used_impl_storage) if (i >= sm_pool->nr_used_impl_storage)
return; return;
cur = make_comp_node_from_impl( cur = make_comp_node_from_impl(
reinterpret_cast<CpuCompNodeImpl*>( reinterpret_cast<CompNodeRecorderImpl*>(
&sm_pool->impl_storage[i])); &sm_pool->impl_storage[i]));
} }
callback(cur); callback(cur);
} }
...@@ -781,7 +903,7 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, ...@@ -781,7 +903,7 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator,
// use static storage so object can be safely accessed even after // use static storage so object can be safely accessed even after
// global finalize // global finalize
static std::aligned_storage_t<sizeof(Pool), alignof(Pool)> storage; static std::aligned_storage_t<sizeof(Pool), alignof(Pool)> storage;
sm_pool = new(&storage) Pool; sm_pool = new (&storage) Pool;
} }
} }
mgb_assert(locator.device >= 0 || mgb_assert(locator.device >= 0 ||
...@@ -800,23 +922,22 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, ...@@ -800,23 +922,22 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator,
locator_logical.type == CompNode::DeviceType::MULTITHREAD); locator_logical.type == CompNode::DeviceType::MULTITHREAD);
} }
if (locator.type == DeviceType::CPU) { if (locator.type == DeviceType::CPU) {
auto &&pqueue_weak = auto&& pqueue_weak =
sm_pool->physical2queue[{locator.device, locator.stream}]; sm_pool->physical2queue[{locator.device, locator.stream}];
auto pqueue = pqueue_weak.lock(); auto pqueue = pqueue_weak.lock();
if (!pqueue) { if (!pqueue) {
pqueue = std::make_shared<WorkerQueue>(locator); pqueue = std::make_shared<WorkerQueue>(locator);
pqueue_weak = pqueue; pqueue_weak = pqueue;
} }
auto&& pimpl = sm_pool->locator2impl[{locator, auto&& pimpl = sm_pool->locator2impl[{locator, locator_logical}];
locator_logical}];
if (!pimpl) { if (!pimpl) {
mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE,
"too many cpu comp nodes; max %d allowed", "too many cpu comp nodes; max %d allowed",
Pool::MAX_NR_COMP_NODE); Pool::MAX_NR_COMP_NODE);
pimpl.reset(new ( pimpl.reset(new (
&sm_pool->impl_storage[sm_pool->nr_used_impl_storage++]) &sm_pool->impl_storage[sm_pool->nr_used_impl_storage++])
CpuCompNodeImpl{locator, locator_logical, CompNodeRecorderImpl{locator, locator_logical,
pqueue}); pqueue});
} }
log_comp_node_created(locator, locator_logical); log_comp_node_created(locator, locator_logical);
return pimpl.get(); return pimpl.get();
...@@ -829,16 +950,16 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, ...@@ -829,16 +950,16 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator,
pqueue = std::make_shared<WorkerQueue>(locator); pqueue = std::make_shared<WorkerQueue>(locator);
pqueue_weak = pqueue; pqueue_weak = pqueue;
} }
auto&& pimpl = sm_pool->locator2impl_multi_thread[{ auto&& pimpl =
locator, locator_logical}]; sm_pool->locator2impl_multi_thread[{locator, locator_logical}];
if (!pimpl) { if (!pimpl) {
mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE,
"too many cpu multithread comp nodes; max %d allowed", "too many cpu multithread comp nodes; max %d allowed",
Pool::MAX_NR_COMP_NODE); Pool::MAX_NR_COMP_NODE);
pimpl.reset(new ( pimpl.reset(new (
&sm_pool->impl_storage[sm_pool->nr_used_impl_storage++]) &sm_pool->impl_storage[sm_pool->nr_used_impl_storage++])
CpuCompNodeImpl{locator, locator_logical, CompNodeRecorderImpl{locator, locator_logical,
pqueue}); pqueue});
} }
log_comp_node_created(locator, locator_logical); log_comp_node_created(locator, locator_logical);
return pimpl.get(); return pimpl.get();
...@@ -850,25 +971,12 @@ void CpuCompNode::sync_all() { ...@@ -850,25 +971,12 @@ void CpuCompNode::sync_all() {
return; return;
MGB_LOCK_GUARD(sm_pool->mtx); MGB_LOCK_GUARD(sm_pool->mtx);
for (auto &&i: sm_pool->locator2impl) for (auto&& i : sm_pool->locator2impl)
i.second->sync(); i.second->sync();
for (auto&& i : sm_pool->locator2impl_multi_thread) for (auto&& i : sm_pool->locator2impl_multi_thread)
i.second->sync(); i.second->sync();
} }
bool CpuCompNode::CompNodeImpl::check_global_finalized(const char* reason) {
MGB_MARK_USED_VAR(reason);
if (this != sm_default_cpu_comp_node_ptr && !sm_pool) {
static std::atomic_flag warn_printed = ATOMIC_FLAG_INIT;
if (!warn_printed.test_and_set()) {
mgb_log_debug("cpu comp node method called after global finalize: "
"reason=%s", reason);
}
return true;
}
return false;
}
/* ======================== CompNode methods ======================== */ /* ======================== CompNode methods ======================== */
// CompNode get by default_cpu() is different from the CompNode which is // CompNode get by default_cpu() is different from the CompNode which is
// produced by CompNode::load("cpu:default") // produced by CompNode::load("cpu:default")
...@@ -878,9 +986,7 @@ bool CpuCompNode::CompNodeImpl::check_global_finalized(const char* reason) { ...@@ -878,9 +986,7 @@ bool CpuCompNode::CompNodeImpl::check_global_finalized(const char* reason) {
// CpuCompNode::Pool // CpuCompNode::Pool
CompNode CompNode::default_cpu() { CompNode CompNode::default_cpu() {
static Locator locator{DeviceType::CPU, Locator::DEVICE_CPU_DEFAULT, {-1}}; static Locator locator{DeviceType::CPU, Locator::DEVICE_CPU_DEFAULT, {-1}};
static auto empty_queue = static CompNodeNoRecorderImpl impl{locator, locator};
std::make_shared<CpuCompNode::WorkerQueue>(locator);
static CpuCompNodeImpl impl{locator, locator, empty_queue};
return &impl; return &impl;
} }
...@@ -890,22 +996,20 @@ bool CompNode::enable_affinity_for_cpu(bool flag) { ...@@ -890,22 +996,20 @@ bool CompNode::enable_affinity_for_cpu(bool flag) {
return old; return old;
} }
/* ======================== EventImpl ======================== */ /* ======================== EventImpl ======================== */
double CpuCompNode::CpuDispatchableBase::EventImpl::do_elapsed_time_until( double CpuCompNode::CpuDispatchableBase::EventImpl::do_elapsed_time_until(
EventImplHelper &end) { EventImplHelper& end) {
auto &&f1 = static_cast<EventImpl&>(end).m_prev_finish_time; auto&& f1 = static_cast<EventImpl&>(end).m_prev_finish_time;
return m_prev_finish_time.time_until_secs(f1); return m_prev_finish_time.time_until_secs(f1);
} }
#if MGB_HAVE_THREAD #if MGB_HAVE_THREAD
void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by( void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(
Impl *cn_impl) { Impl* cn_impl) {
{ {
auto locator = m_comp_node_impl->locator(); auto locator = m_comp_node_impl->locator();
if (locator.device == Locator::DEVICE_CPU_DEFAULT && if (locator.device == Locator::DEVICE_CPU_DEFAULT &&
!static_cast<CpuCompNode::CompNodeImpl*>(m_comp_node_impl) !static_cast<CpuCompNode::CompNodeRecorderImpl*>(m_comp_node_impl)
->cur_recorder()) { ->cur_recorder()) {
auto v0 = m_record_nr_req.load(std::memory_order_relaxed), auto v0 = m_record_nr_req.load(std::memory_order_relaxed),
v1 = m_record_nr_finish.load(std::memory_order_relaxed); v1 = m_record_nr_finish.load(std::memory_order_relaxed);
...@@ -934,14 +1038,14 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by( ...@@ -934,14 +1038,14 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(
mgb_throw(MegBrainError, mgb_throw(MegBrainError,
"Atlas comp_node used but MGB_ATLAS not enabled"); "Atlas comp_node used but MGB_ATLAS not enabled");
#endif #endif
} else if (cn_impl->env().property().type == CompNode::DeviceType::CAMBRICON) { } else if (cn_impl->env().property().type ==
CompNode::DeviceType::CAMBRICON) {
#if MGB_CAMBRICON #if MGB_CAMBRICON
return m_comp_node_impl->sync(); return m_comp_node_impl->sync();
#else #else
mgb_throw(MegBrainError, mgb_throw(MegBrainError,
"Cambricon comp_node used but MGB_CAMBRICON not enabled"); "Cambricon comp_node used but MGB_CAMBRICON not enabled");
#endif #endif
} }
auto version = m_record_nr_req.load(std::memory_order_relaxed); auto version = m_record_nr_req.load(std::memory_order_relaxed);
...@@ -991,14 +1095,15 @@ bool CpuCompNode::CpuDispatchableBase::EventImpl::do_finished() { ...@@ -991,14 +1095,15 @@ bool CpuCompNode::CpuDispatchableBase::EventImpl::do_finished() {
} }
void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() { void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() {
for (size_t i = 0, it = SCQueueSynchronizer::get_default_max_spin() / 20; i < it; ++i) { for (size_t i = 0, it = SCQueueSynchronizer::get_default_max_spin() / 20;
i < it; ++i) {
if (finished()) { if (finished()) {
return; return;
} }
} }
m_dev_wait_nr_waiter.fetch_add(1, std::memory_order_release); m_dev_wait_nr_waiter.fetch_add(1, std::memory_order_release);
for (; ; ) { for (;;) {
std::unique_lock<std::mutex> lock{m_dev_wait_mtx}; std::unique_lock<std::mutex> lock{m_dev_wait_mtx};
if (finished()) { if (finished()) {
break; break;
...@@ -1011,23 +1116,23 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() { ...@@ -1011,23 +1116,23 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() {
CpuCompNode::CpuDispatchableBase::EventImpl::~EventImpl() noexcept { CpuCompNode::CpuDispatchableBase::EventImpl::~EventImpl() noexcept {
auto check_all_finished = [this]() { auto check_all_finished = [this]() {
return do_finished() && return do_finished() &&
!m_dev_wait_nr_waiter.load(std::memory_order_acquire); !m_dev_wait_nr_waiter.load(std::memory_order_acquire);
}; };
if (!check_all_finished()) { if (!check_all_finished()) {
mgb_log_debug("event %p has unfinished callbacks when destructed; " mgb_log_debug(
"waiting ...", this); "event %p has unfinished callbacks when destructed; "
"waiting ...",
this);
while (!check_all_finished()) { while (!check_all_finished()) {
std::this_thread::yield(); std::this_thread::yield();
} }
} }
} }
#else // MGB_HAVE_THREAD #else // MGB_HAVE_THREAD
void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() { void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() {}
}
void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(Impl*) { void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(Impl*) {}
}
void CpuCompNode::CpuDispatchableBase::EventImpl::do_record() { void CpuCompNode::CpuDispatchableBase::EventImpl::do_record() {
if (m_create_flags & Flags::NEED_TIMER) { if (m_create_flags & Flags::NEED_TIMER) {
...@@ -1035,8 +1140,7 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_record() { ...@@ -1035,8 +1140,7 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_record() {
} }
} }
void CpuCompNode::CpuDispatchableBase::EventImpl::on_finish() { void CpuCompNode::CpuDispatchableBase::EventImpl::on_finish() {}
}
bool CpuCompNode::CpuDispatchableBase::EventImpl::do_finished() { bool CpuCompNode::CpuDispatchableBase::EventImpl::do_finished() {
return true; return true;
...@@ -1046,5 +1150,4 @@ CpuCompNode::CpuDispatchableBase::EventImpl::~EventImpl() noexcept = default; ...@@ -1046,5 +1150,4 @@ CpuCompNode::CpuDispatchableBase::EventImpl::~EventImpl() noexcept = default;
#endif // MGB_HAVE_THREAD #endif // MGB_HAVE_THREAD
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -54,7 +54,9 @@ namespace mgb { ...@@ -54,7 +54,9 @@ namespace mgb {
void add_callback(Task&& task) override; void add_callback(Task&& task) override;
}; };
class CompNodeImpl; class CompNodeBaseImpl;
class CompNodeNoRecorderImpl;
class CompNodeRecorderImpl;
static void foreach(thin_function<void(CompNode)> callback); static void foreach(thin_function<void(CompNode)> callback);
static void finalize(); static void finalize();
......
...@@ -100,6 +100,26 @@ void run_comp_seq_rec_basic_level2(CompNode cn) { ...@@ -100,6 +100,26 @@ void run_comp_seq_rec_basic_level2(CompNode cn) {
MGB_ASSERT_TENSOR_NEAR(expect, host_z, 1e-3) << "iter " << iter; MGB_ASSERT_TENSOR_NEAR(expect, host_z, 1e-3) << "iter " << iter;
} }
ASSERT_EQ(executed.size(), 2u); ASSERT_EQ(executed.size(), 2u);
//! test default_cpu with record2
{
HostTensorND hz;
graph = ComputingGraph::make();
x = opr::Host2DeviceCopy::make(*graph, host_x);
y = opr::Host2DeviceCopy::make(*graph, host_y);
z = opr::ConvBias::make(x, y, param);
z = opr::GetVarShape::make(z);
graph->options().comp_node_seq_record_level = 2;
graph->options().var_sanity_check_first_run = false;
auto func = graph->compile({make_callback_copy(z, hz, true)});
ComputingGraph::assert_destroy(graph);
func->execute();
ASSERT_TRUE(hz.comp_node() == cn);
ASSERT_EQ(hz.ptr<int>()[0], 3);
ASSERT_EQ(hz.ptr<int>()[1], 6);
ASSERT_EQ(hz.ptr<int>()[2], 8);
ASSERT_EQ(hz.ptr<int>()[3], 6);
}
} }
void run_comp_seq_rec_dyn_elemwise(CompNode cn, bool fake_first) { void run_comp_seq_rec_dyn_elemwise(CompNode cn, bool fake_first) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册