提交 88574565 编写于 作者: M Megvii Engine Team

fix(mgb/core): use thread local fix multi thread use same compnode with recorder enabled

GitOrigin-RevId: 7d3daa866c114f77c312783ed7431cbaaddecdee
上级 3246ee5e
......@@ -102,17 +102,23 @@ class CpuCompNode::SeqRecorderImpl final : public CompNodeSeqRecorder {
bool m_fake_exec = false, m_synchronized = false, m_stopped = false,
m_first_replay = true;
SeqRecorderImpl** const m_self_pointer;
std::mutex* const m_self_pointer_mtx;
std::vector<TaskElem> m_tasks;
ThreadPool* m_thread_pool = nullptr;
const CompNode m_record_compnode;
/*!
* \brief use to check the all ther recording tasks are its self CompNode
* related task, void hook other CompNode related task to the recorder.
*/
void check_the_same_comp_node(const CompNode& comp_node) const;
public:
SeqRecorderImpl(SeqRecorderImpl** self_pointer,
std::mutex* const self_pointer_mtx, ThreadPool* thread_pool)
SeqRecorderImpl(SeqRecorderImpl** self_pointer, ThreadPool* thread_pool,
const CompNode& comp_node)
: m_self_pointer{self_pointer},
m_self_pointer_mtx{self_pointer_mtx},
m_thread_pool{thread_pool} {
m_thread_pool{thread_pool},
m_record_compnode{comp_node} {
mgb_assert(!*m_self_pointer);
*m_self_pointer = this;
}
......@@ -123,23 +129,25 @@ public:
}
}
void enter_fake_exec() override {
void enter_fake_exec(const CompNode& comp_node) override {
check_the_same_comp_node(comp_node);
mgb_assert(!m_stopped && !m_fake_exec);
m_fake_exec = true;
}
void exit_fake_exec() override {
void exit_fake_exec(const CompNode& comp_node) override {
check_the_same_comp_node(comp_node);
mgb_assert(!m_stopped && m_fake_exec);
mgb_assert(m_tasks.empty());
m_fake_exec = false;
m_synchronized = false;
}
void stop() override {
void stop(const CompNode& comp_node = {}) override {
check_the_same_comp_node(comp_node);
mgb_assert(*m_self_pointer == this);
mgb_assert(!m_fake_exec);
*m_self_pointer = nullptr;
m_self_pointer_mtx->unlock();
m_stopped = true;
}
......@@ -175,25 +183,32 @@ public:
});
}
void on_alloc() {
void on_alloc(const CompNode& comp_node) {
check_the_same_comp_node(comp_node);
mgb_assert(m_fake_exec,
"alloc is disallowed during comp node seq recording");
}
void on_free() {
void on_free(const CompNode& comp_node) {
check_the_same_comp_node(comp_node);
mgb_assert(m_fake_exec,
"free is disallowed during comp node seq recording");
}
void on_sync() { m_synchronized = true; }
void on_sync(const CompNode& comp_node) {
check_the_same_comp_node(comp_node);
m_synchronized = true;
}
void dispatch(Task&& task) {
void dispatch(Task&& task, const CompNode& comp_node) {
mgb_assert(!m_synchronized,
"no more tasks should be dispatched after synchronization");
auto kern = [task](size_t, size_t) { task(); };
dispatch_allow_after_sync({std::move(kern), static_cast<size_t>(1_z)});
dispatch_allow_after_sync({std::move(kern), static_cast<size_t>(1_z)},
comp_node);
}
void dispatch_allow_after_sync(Task&& task) {
void dispatch_allow_after_sync(Task&& task, const CompNode& comp_node) {
check_the_same_comp_node(comp_node);
mgb_assert(!m_stopped,
"dispatch should not be called after recording is stopped");
if (!m_fake_exec) {
......@@ -201,151 +216,28 @@ public:
m_tasks.push_back({std::move(kern), static_cast<size_t>(1_z)});
}
}
void dispatch(TaskElem&& task_elem) {
void dispatch(TaskElem&& task_elem, const CompNode& comp_node) {
mgb_assert(!m_synchronized,
"no more tasks should be dispatched after synchronization");
dispatch_allow_after_sync(std::move(task_elem));
dispatch_allow_after_sync(std::move(task_elem), comp_node);
}
void dispatch_allow_after_sync(TaskElem&& task_elem) {
void dispatch_allow_after_sync(TaskElem&& task_elem,
const CompNode& comp_node) {
check_the_same_comp_node(comp_node);
mgb_assert(!m_stopped,
"dispatch should not be called after recording is stopped");
if (!m_fake_exec) {
m_tasks.push_back(task_elem);
}
}
size_t nr_threads() {
size_t nr_threads(const CompNode& comp_node) {
check_the_same_comp_node(comp_node);
return m_thread_pool ? m_thread_pool->nr_threads() : 1_z;
}
ThreadPool* get_thread_pool() { return m_thread_pool; }
};
//! implementation of CPUDispatcher that is passed to megdnn via megcore
class CpuCompNode::WorkerQueue::DispatcherImpl final: public CPUDispatcher {
std::atomic_size_t m_nr_task{0};
std::shared_ptr<WorkerQueue> m_queue;
SeqRecorderImpl** const m_cur_recorder;
public:
DispatcherImpl(const std::shared_ptr<WorkerQueue>& queue,
SeqRecorderImpl** recorder)
: m_queue{queue}, m_cur_recorder{recorder} {}
void dispatch(Task&& task) override {
if (*m_cur_recorder) {
(*m_cur_recorder)->dispatch(std::move(task));
} else {
m_nr_task.fetch_add(1, std::memory_order_relaxed);
auto kern = [task](size_t, size_t) { task(); };
m_queue->add_task({kern, static_cast<size_t>(1_z)});
}
}
void dispatch(MultiThreadingTask&& task, size_t parallelism) override {
if (*m_cur_recorder) {
(*m_cur_recorder)->dispatch({std::move(task), parallelism});
} else {
m_nr_task.fetch_add(1, std::memory_order_relaxed);
m_queue->add_task({std::move(task), parallelism});
}
}
void sync() override {
if (*m_cur_recorder) {
(*m_cur_recorder)->on_sync();
} else {
m_queue->wait_all_task_finish();
}
}
size_t nr_threads() override {
if (*m_cur_recorder) {
return (*m_cur_recorder)->nr_threads();
} else {
return m_queue->nr_threads();
}
}
size_t get_nr_dispatched_tasks() const override {
return m_nr_task;
}
void set_affinity(AffinityCallBack&& affinity_cb) override {
auto thread_pool = m_queue->get_thread_pool();
if(thread_pool){
thread_pool->set_affinity(affinity_cb);
} else {
auto affinity_run = [affinity_cb](size_t, size_t) {
affinity_cb(0);
};
m_queue->add_task({affinity_run, 1_z});
}
}
};
//! implementation of InplaceCPUDispatcher
class InplaceCPUDispatcher final : public CPUDispatcher {
std::atomic_size_t m_nr_task{0};
ThreadPool* m_thread_pool = nullptr;
CpuCompNode::SeqRecorderImpl** const m_cur_recorder;
public:
InplaceCPUDispatcher(CpuCompNode::SeqRecorderImpl** recorder,
ThreadPool* thread_pool = nullptr)
: m_thread_pool(thread_pool), m_cur_recorder(recorder) {}
void dispatch(Task&& task) override {
if (*m_cur_recorder) {
(*m_cur_recorder)->dispatch(std::move(task));
} else if (m_thread_pool) {
m_nr_task.fetch_add(1, std::memory_order_relaxed);
auto kern = [task](size_t, size_t) { task(); };
m_thread_pool->add_task({kern, static_cast<size_t>(1_z)});
}else {
m_nr_task.fetch_add(1, std::memory_order_relaxed);
task();
}
}
void dispatch(MultiThreadingTask&& task, size_t parallelism) override {
if (*m_cur_recorder) {
(*m_cur_recorder)->dispatch({std::move(task), parallelism});
} else if (m_thread_pool) {
m_nr_task.fetch_add(1, std::memory_order_relaxed);
m_thread_pool->add_task({task, parallelism});
}else{
m_nr_task.fetch_add(1, std::memory_order_relaxed);
for(size_t i=0; i<parallelism;i++){
task(i, 0);
}
}
}
size_t nr_threads() override {
return m_thread_pool ? m_thread_pool->nr_threads() : 1_z;
}
void sync() override {
if (*m_cur_recorder) {
(*m_cur_recorder)->on_sync();
} else if (m_thread_pool) {
m_thread_pool->deactive();
}
}
size_t get_nr_dispatched_tasks() const override { return m_nr_task; }
void set_affinity(AffinityCallBack&& affinity_cb) override {
if (*m_cur_recorder) {
(*m_cur_recorder)->get_thread_pool()->set_affinity(affinity_cb);
} else if (m_thread_pool) {
m_thread_pool->set_affinity(affinity_cb);
}else{
affinity_cb(0);
}
}
};
class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
......@@ -353,8 +245,7 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase {
class CompSeqRecEventImpl;
class CpuEventImpl;
SeqRecorderImpl* m_cur_recorder = nullptr;
std::mutex m_cur_recorder_mtx;
static thread_local SeqRecorderImpl* sm_cur_recorder;
std::shared_ptr<WorkerQueue> m_worker_queue;
Locator m_locator, m_locator_logical;
std::unique_ptr<ThreadPool> m_thread_pool;
......@@ -375,49 +266,10 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase {
public:
CompNodeImpl(const Locator& locator, const Locator& locator_logical,
const std::shared_ptr<WorkerQueue>& worker_queue)
: CpuDispatchableBase(static_free_device, static_free_host),
m_worker_queue{worker_queue},
m_locator(locator),
m_locator_logical(locator_logical) {
auto cn = make_comp_node_from_impl(this);
if (locator.type == DeviceType::MULTITHREAD) {
m_thread_pool = std::unique_ptr<ThreadPool>(new ThreadPool(
static_cast<size_t>(locator.nr_threads)));
mgb_assert(m_thread_pool, "ThradPool create failed");
}
if (locator.type == DeviceType::CPU) {
if(locator.device == Locator::DEVICE_CPU_DEFAULT){
sm_default_cpu_comp_node_ptr = this;
m_env.init_cpu({std::make_shared<InplaceCPUDispatcher>(
&m_cur_recorder)},
cn);
} else {
m_env.init_cpu(
{std::make_shared<WorkerQueue::DispatcherImpl>(
m_worker_queue, &m_cur_recorder)},
cn);
}
} else if (locator.type == DeviceType::MULTITHREAD) {
if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) {
m_env.init_cpu(
{std::make_shared<InplaceCPUDispatcher>(
&m_cur_recorder, m_thread_pool.get())},
cn);
} else {
m_worker_queue->attach_thread_pool(m_thread_pool.get());
m_env.init_cpu(
{std::make_shared<WorkerQueue::DispatcherImpl>(
m_worker_queue, &m_cur_recorder)},
cn);
}
}
}
const std::shared_ptr<WorkerQueue>& worker_queue);
~CompNodeImpl() {
if (m_cur_recorder) {
m_cur_recorder->stop();
if (sm_cur_recorder) {
sm_cur_recorder->stop();
}
if (m_worker_queue) {
// synchronize before fini
......@@ -462,17 +314,17 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase {
}
void* alloc_device(size_t size) override {
if (m_cur_recorder) {
m_cur_recorder->on_alloc();
if (sm_cur_recorder) {
sm_cur_recorder->on_alloc(this);
}
return mgb_aligned_alloc(size);
}
void free_device(void *ptr) {
if (m_cur_recorder || check_global_finalized("free_device()")) {
if (sm_cur_recorder || check_global_finalized("free_device()")) {
mgb_aligned_free(ptr);
if (m_cur_recorder) {
m_cur_recorder->on_free();
if (sm_cur_recorder) {
sm_cur_recorder->on_free(this);
}
return;
} else {
......@@ -557,8 +409,8 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase {
std::unique_ptr<Event> create_event(size_t flags) override;
void sync() override {
if (m_cur_recorder) {
m_cur_recorder->on_sync();
if (sm_cur_recorder) {
sm_cur_recorder->on_sync(this);
} else if (m_worker_queue) {
m_worker_queue->wait_all_task_finish();
}
......@@ -590,13 +442,12 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase {
std::unique_ptr<CompNodeSeqRecorder> create_seq_recorder(
cg::ComputingGraph*) override {
m_cur_recorder_mtx.lock();
return std::make_unique<SeqRecorderImpl>(
&m_cur_recorder, &m_cur_recorder_mtx, m_thread_pool.get());
return std::make_unique<SeqRecorderImpl>(&sm_cur_recorder,
m_thread_pool.get(), this);
}
//! current sequence recorder
SeqRecorderImpl* cur_recorder() const { return m_cur_recorder; }
SeqRecorderImpl* cur_recorder() const { return sm_cur_recorder; }
void add_callback(Task &&task) override {
if (!check_global_finalized("add_callback()")) {
......@@ -608,6 +459,179 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase {
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CpuCompNodeImpl);
CpuCompNodeImpl* CpuCompNodeImpl::sm_default_cpu_comp_node_ptr;
thread_local CpuCompNode::SeqRecorderImpl* CpuCompNodeImpl::sm_cur_recorder =
nullptr;
void CpuCompNode::SeqRecorderImpl::check_the_same_comp_node(
const CompNode& comp_node) const {
if (mgb_unlikely(comp_node.valid())) {
mgb_assert(m_record_compnode == comp_node,
"CompNode %s can't hook in CompNode %s when recording\n",
comp_node.locator().to_string().c_str(),
m_record_compnode.locator().to_string().c_str());
}
}
//! implementation of CPUDispatcher that is passed to megdnn via megcore
class CpuCompNode::WorkerQueue::DispatcherImpl final: public CPUDispatcher {
std::atomic_size_t m_nr_task{0};
std::shared_ptr<WorkerQueue> m_queue;
CpuCompNode::CompNodeImpl* const m_comp_node;
public:
DispatcherImpl(const std::shared_ptr<WorkerQueue>& queue,
CpuCompNode::CompNodeImpl* comp_node)
: m_queue{queue}, m_comp_node{comp_node} {}
void dispatch(Task&& task) override {
if (auto recorder = m_comp_node->cur_recorder()) {
recorder->dispatch(std::move(task), m_comp_node);
} else {
m_nr_task.fetch_add(1, std::memory_order_relaxed);
auto kern = [task](size_t, size_t) { task(); };
m_queue->add_task({kern, static_cast<size_t>(1_z)});
}
}
void dispatch(MultiThreadingTask&& task, size_t parallelism) override {
if (auto recorder = m_comp_node->cur_recorder()) {
recorder->dispatch({std::move(task), parallelism}, m_comp_node);
} else {
m_nr_task.fetch_add(1, std::memory_order_relaxed);
m_queue->add_task({std::move(task), parallelism});
}
}
void sync() override {
if (auto recorder = m_comp_node->cur_recorder()) {
recorder->on_sync(m_comp_node);
} else {
m_queue->wait_all_task_finish();
}
}
size_t nr_threads() override {
if (auto recorder = m_comp_node->cur_recorder()) {
return recorder->nr_threads(m_comp_node);
} else {
return m_queue->nr_threads();
}
}
size_t get_nr_dispatched_tasks() const override { return m_nr_task; }
void set_affinity(AffinityCallBack&& affinity_cb) override {
auto thread_pool = m_queue->get_thread_pool();
if (thread_pool) {
thread_pool->set_affinity(affinity_cb);
} else {
auto affinity_run = [affinity_cb](size_t, size_t) {
affinity_cb(0);
};
m_queue->add_task({affinity_run, 1_z});
}
}
};
//! implementation of InplaceCPUDispatcher
class InplaceCPUDispatcher final : public CPUDispatcher {
std::atomic_size_t m_nr_task{0};
ThreadPool* m_thread_pool = nullptr;
CpuCompNode::CompNodeImpl* const m_comp_node;
public:
InplaceCPUDispatcher(CpuCompNode::CompNodeImpl* comp_node,
ThreadPool* thread_pool = nullptr)
: m_thread_pool(thread_pool), m_comp_node(comp_node) {}
void dispatch(Task&& task) override {
if (auto recorder = m_comp_node->cur_recorder()) {
recorder->dispatch(std::move(task), m_comp_node);
} else if (m_thread_pool) {
m_nr_task.fetch_add(1, std::memory_order_relaxed);
auto kern = [task](size_t, size_t) { task(); };
m_thread_pool->add_task({kern, static_cast<size_t>(1_z)});
} else {
m_nr_task.fetch_add(1, std::memory_order_relaxed);
task();
}
}
void dispatch(MultiThreadingTask&& task, size_t parallelism) override {
if (auto recorder = m_comp_node->cur_recorder()) {
recorder->dispatch({std::move(task), parallelism}, m_comp_node);
} else if (m_thread_pool) {
m_nr_task.fetch_add(1, std::memory_order_relaxed);
m_thread_pool->add_task({task, parallelism});
}else{
m_nr_task.fetch_add(1, std::memory_order_relaxed);
for(size_t i=0; i<parallelism;i++){
task(i, 0);
}
}
}
size_t nr_threads() override {
return m_thread_pool ? m_thread_pool->nr_threads() : 1_z;
}
void sync() override {
if (auto recorder = m_comp_node->cur_recorder()) {
recorder->on_sync(m_comp_node);
} else if (m_thread_pool) {
m_thread_pool->deactive();
}
}
size_t get_nr_dispatched_tasks() const override { return m_nr_task; }
void set_affinity(AffinityCallBack&& affinity_cb) override {
if (auto recorder = m_comp_node->cur_recorder()) {
recorder->get_thread_pool()->set_affinity(affinity_cb);
} else if (m_thread_pool) {
m_thread_pool->set_affinity(affinity_cb);
}else{
affinity_cb(0);
}
}
};
CpuCompNode::CompNodeImpl::CompNodeImpl(
const Locator& locator, const Locator& locator_logical,
const std::shared_ptr<WorkerQueue>& worker_queue)
: CpuDispatchableBase(static_free_device, static_free_host),
m_worker_queue{worker_queue},
m_locator(locator),
m_locator_logical(locator_logical) {
auto cn = make_comp_node_from_impl(this);
if (locator.type == DeviceType::MULTITHREAD) {
m_thread_pool = std::unique_ptr<ThreadPool>(
new ThreadPool(static_cast<size_t>(locator.nr_threads)));
mgb_assert(m_thread_pool, "ThradPool create failed");
}
if (locator.type == DeviceType::CPU) {
if (locator.device == Locator::DEVICE_CPU_DEFAULT) {
sm_default_cpu_comp_node_ptr = this;
m_env.init_cpu({std::make_shared<InplaceCPUDispatcher>(this)}, cn);
} else {
m_env.init_cpu({std::make_shared<WorkerQueue::DispatcherImpl>(
m_worker_queue, this)},
cn);
}
} else if (locator.type == DeviceType::MULTITHREAD) {
if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) {
m_env.init_cpu({std::make_shared<InplaceCPUDispatcher>(
this, m_thread_pool.get())},
cn);
} else {
m_worker_queue->attach_thread_pool(m_thread_pool.get());
m_env.init_cpu({std::make_shared<WorkerQueue::DispatcherImpl>(
m_worker_queue, this)},
cn);
}
}
}
class CpuCompNodeImpl::CompSeqRecEventImpl final
: public CpuDispatchableBase::EventImpl {
......@@ -618,7 +642,7 @@ class CpuCompNodeImpl::CompSeqRecEventImpl final
incr_nr_req();
on_finish();
};
rec->dispatch_allow_after_sync(callback);
rec->dispatch_allow_after_sync(callback, m_comp_node_impl);
} else {
EventImpl::do_record();
}
......@@ -674,7 +698,7 @@ std::unique_ptr<CompNode::Event> CpuCompNodeImpl::create_event(size_t flags) {
if (m_worker_queue) {
m_worker_queue->check_exception();
}
if (m_cur_recorder) {
if (sm_cur_recorder) {
return std::make_unique<CompSeqRecEventImpl>(this, flags);
} else {
return std::make_unique<CpuEventImpl>(this, flags);
......
......@@ -78,14 +78,16 @@ class ComputingGraphImpl::ComputingSequence::ExecContext {
void warmup_for_fake_exec_with_recorder() {
// Rerun recorder to ensure that all internal caches stabilize
m_recorder->enter_fake_exec();
auto comp_node = *(m_comp_seq->m_used_comp_node.begin());
m_recorder->enter_fake_exec(comp_node);
m_comp_seq->m_exec_env.start_exec();
m_comp_seq->m_exec_env.wait_all();
m_recorder->exit_fake_exec();
m_recorder->exit_fake_exec(comp_node);
}
void stop_and_move_recorder() {
m_recorder->stop();
auto comp_node = *(m_comp_seq->m_used_comp_node.begin());
m_recorder->stop(comp_node);
if (m_fake_next_exec) {
m_owner_graph->options().fake_next_exec = false;
} else {
......@@ -439,17 +441,22 @@ void ComputingGraphImpl::ComputingSequence::on_first_exec() {
m_used_comp_node.insert(j->comp_node());
}
// we maintain a recorder because events may depend on whether recorder
// is enabled
auto recorder = check_enable_comp_node_seq_recorder();
auto&& options = m_owner_graph->options();
//! The recorder in comp_node is thread_local, so the create thread should
//! the same as the execute thread, so set the Synchronize mode
if (m_enable_comp_node_seq_recorder) {
m_exec_env.set_async_level(0);
} else {
m_exec_env.set_async_level(options.async_exec_level);
}
if (options.async_exec_level) {
for (auto i : m_used_comp_node)
m_exec_env.add_comp_node(i);
}
// we maintain a recorder because events may depend on whether recorder
// is enabled
auto recorder = check_enable_comp_node_seq_recorder();
// create events for timing and sync
for (auto&& i : m_used_comp_node) {
size_t flag = 0;
......
......@@ -32,39 +32,7 @@ namespace cg {
class ComputingGraph;
}
/*!
* \brief record computation operations on a computing node
*
* This is used for fast execution of an identical computation sequence where
* only input/output data differ.
*
* When this object is created from a comp node, recording starts immediately.
* Call stop() when computation finishes, and call replay() when it needs to be
* re-executed.
*
* Implementations should hold a global lock on the comp node until stop() is
* called.
*/
class CompNodeSeqRecorder {
public:
virtual ~CompNodeSeqRecorder() noexcept = default;
/*!
* \brief Enter fake-exec mode
*
* Memory allocation/free is only allowed in fake-exec mode, and kernels
* should not be actually recorded in this mode.
*
* This should be paired with exit_fake_exec()
*/
virtual void enter_fake_exec() = 0;
//! Exit fake-exec mode
virtual void exit_fake_exec() = 0;
virtual void stop() = 0;
virtual void replay() = 0;
};
class CompNodeSeqRecorder;
/*!
* \brief identifier for a memory node
......@@ -563,18 +531,55 @@ class CompNode {
//! is needed
ImplBase *m_impl = nullptr;
CompNode(ImplBase *impl):
m_impl{impl}
{}
friend class CompNodeEnv;
friend struct HashTrait<CompNode>;
friend class CompNodeImplHelper;
public:
CompNode(ImplBase* impl) : m_impl{impl} {}
};
MGB_DEF_ENUM_CLASS_BIT_OPR(CompNode::Flag)
/*!
* \brief record computation operations on a computing node
*
* This is used for fast execution of an identical computation sequence where
* only input/output data differ.
*
* When this object is created from a comp node, recording starts immediately.
* Call stop() when computation finishes, and call replay() when it needs to be
* re-executed.
*
* Implementations should consider thread safe in comp_node, in order to support
* multi threads reording in the same comp_node simultaneously, using thread
* local recorder in comp_node.
*
* Note. When recording is over, the recorder is independent with comp_node, so
* the task dispatched into recorder should not related to the comp_node
* methord, and the thread of recorder replay is the user thread.
*/
class CompNodeSeqRecorder {
public:
virtual ~CompNodeSeqRecorder() noexcept = default;
/*!
* \brief Enter fake-exec mode
*
* Memory allocation/free is only allowed in fake-exec mode, and kernels
* should not be actually recorded in this mode.
*
* This should be paired with exit_fake_exec()
*/
virtual void enter_fake_exec(const CompNode& comp_node) = 0;
//! Exit fake-exec mode
virtual void exit_fake_exec(const CompNode& comp_node) = 0;
virtual void stop(const CompNode& comp_node) = 0;
virtual void replay() = 0;
};
/*!
* \brief event associated with a CompNode node, used for cross-device
* synchronization
......
......@@ -471,6 +471,37 @@ void run<shape_dep_const_shape>(CompNode cn) {
MGB_ASSERT_TENSOR_EQ(y_expect, host_y);
}
//! single thread multi recorder run interleave
template <>
void run<multi_recorder_run>(CompNode cn) {
using ConvParam = opr::Convolution::Param;
ConvParam param;
param.sparse = ConvParam::Sparse::GROUP;
HostTensorGenerator<> gen;
std::vector<HostTensorND> host_z_v(2, HostTensorND());
std::vector<std::unique_ptr<mgb::cg::AsyncExecutable>> funcs;
auto host_x = gen({3, 4, 10, 8}, cn), host_y = gen({2, 3, 2, 3, 3}, cn);
auto gen_graph =
[&](int graph_id) -> std::unique_ptr<mgb::cg::AsyncExecutable> {
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
y = opr::Host2DeviceCopy::make(*graph, host_y),
z = opr::Convolution::make(x, y, param);
graph->options().comp_node_seq_record_level = 1;
return graph->compile({make_callback_copy(z, host_z_v[graph_id])});
};
funcs.push_back(gen_graph(0));
funcs.push_back(gen_graph(1));
for (int iter = 0; iter < 10; ++iter) {
host_x->copy_from_fixlayout(*gen(host_x->shape(), cn));
funcs[0]->execute();
funcs[1]->execute();
auto expect = eval_conv_cpu<opr::Convolution>(*host_x, *host_y, param);
MGB_ASSERT_TENSOR_NEAR(expect, host_z_v[0], 1e-3) << "iter " << iter;
MGB_ASSERT_TENSOR_NEAR(expect, host_z_v[1], 1e-3) << "iter " << iter;
}
}
template <>
void run<void>(CompNode) {}
......
......@@ -56,7 +56,7 @@ namespace seq_rec {
cb(dyn_elemwise_fake_exec) \
cb(level2) cb(level2_multi_holder) cb(level2_share_storage) \
cb(level2_exec_check) cb(sync_from_func) cb(cb_non_contig) \
cb(shape_dep_const_shape)
cb(shape_dep_const_shape) cb(multi_recorder_run)
// clang-format on
#define def_tags(name) \
......
......@@ -12,6 +12,7 @@
#include "megbrain/opr/io.h"
#include "megbrain/opr/utility.h"
#include "megbrain/system.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/test/helper.h"
......@@ -20,6 +21,37 @@
using namespace mgb;
namespace{
template <typename Opr>
HostTensorND eval_conv(const std::shared_ptr<HostTensorND>& src,
const std::shared_ptr<HostTensorND>& filter,
const typename Opr::Param& param = {}) {
auto graph = ComputingGraph::make();
graph->options().log_level = 0;
SymbolVar x = opr::Host2DeviceCopy::make(*graph, src);
SymbolVar y = opr::Host2DeviceCopy::make(*graph, filter);
SymbolVar z = Opr::make(x, y, param);
HostTensorND host_z;
auto func = graph->compile({make_callback_copy(z, host_z)});
func->execute();
host_z.sync();
return host_z;
}
template <typename Opr>
HostTensorND eval_conv_cpu(const HostTensorND& xv, const HostTensorND& fv,
const typename Opr::Param& param = {}) {
auto cn = CompNode::load("cpux");
auto src = std::make_shared<HostTensorND>(cn, xv.layout()),
filter = std::make_shared<HostTensorND>(cn, fv.layout());
memcpy(src->raw_ptr(), xv.raw_ptr(), xv.layout().span().dist_byte());
memcpy(filter->raw_ptr(), fv.raw_ptr(), fv.layout().span().dist_byte());
return eval_conv<Opr>(src, filter, param);
}
} // namespace
TEST(TestGraph, AsyncExecLevel) {
REQUIRE_GPU(1);
......@@ -165,4 +197,35 @@ TEST(TestGraph, ParallelRun) {
i.join();
}
TEST(TestGraph, MultiThreadRecorder) {
using ConvParam = opr::Convolution::Param;
ConvParam param;
param.sparse = ConvParam::Sparse::GROUP;
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpux");
auto host_x = gen({3, 4, 10, 8}, cn), host_y = gen({2, 3, 2, 3, 3}, cn);
auto worker = [&](int record_level) {
HostTensorND host_z;
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
y = opr::Host2DeviceCopy::make(*graph, host_y),
z = opr::Convolution::make(x, y, param);
graph->options().comp_node_seq_record_level = record_level;
graph->options().var_sanity_check_first_run = false;
auto func = graph->compile({make_callback_copy(z, host_z)});
for (int i = 0; i < 5; i++) {
func->execute();
}
auto expect = eval_conv_cpu<opr::Convolution>(*host_x, *host_y, param);
MGB_ASSERT_TENSOR_NEAR(expect, host_z, 1e-3);
};
std::vector<std::thread> workers;
for (size_t i = 0; i < 4; ++i)
workers.emplace_back(worker, i % 2);
for (auto&& i : workers)
i.join();
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册