提交 53da5c79 编写于 作者: M Megvii Engine Team

feat(cg): add comp_seq_sync_device option

GitOrigin-RevId: c2199c59e9744352207e8cbd6d53616d968319bf
上级 e1c7b22f
...@@ -288,6 +288,7 @@ ComputingGraphHolder& get_computing_graph(std::shared_ptr<OpDef> compiled_op, Sm ...@@ -288,6 +288,7 @@ ComputingGraphHolder& get_computing_graph(std::shared_ptr<OpDef> compiled_op, Sm
cg_holder.graph->options().async_exec_level = 0; cg_holder.graph->options().async_exec_level = 0;
cg_holder.graph->options().graph_opt_level = compiled_op->cast_final_safe<CompiledOp>().gopt_level; cg_holder.graph->options().graph_opt_level = compiled_op->cast_final_safe<CompiledOp>().gopt_level;
cg_holder.graph->options().enable_var_mem_defragment = false; cg_holder.graph->options().enable_var_mem_defragment = false;
cg_holder.graph->options().comp_seq_sync_device = false;
cg_holder.graph->set_device_memory_allocator(cg_holder.allocator); cg_holder.graph->set_device_memory_allocator(cg_holder.allocator);
// cg_holder.graph->options().graph_opt.jit = 2; // cg_holder.graph->options().graph_opt.jit = 2;
VarNodeArray input_vars; VarNodeArray input_vars;
......
...@@ -385,21 +385,27 @@ void ComputingGraphImpl::ComputingSequence::do_wait(bool explicit_user_wait) { ...@@ -385,21 +385,27 @@ void ComputingGraphImpl::ComputingSequence::do_wait(bool explicit_user_wait) {
} }
} }
for (auto cn : m_used_comp_node) { bool sync_device = m_owner_graph->options().comp_seq_sync_device;
m_event_end.at(cn)->host_wait();
if (sync_device) {
for (auto cn : m_used_comp_node) {
m_event_end.at(cn)->host_wait();
}
} }
m_wait_finished = true; m_wait_finished = true;
#if MGB_NEED_MEGDNN_ASYNC_ERROR #if MGB_NEED_MEGDNN_ASYNC_ERROR
// FIXME: It CAN NOT work well if more than one ComputingSequnces has been // FIXME: It CAN NOT work well if more than one ComputingSequnces has been
// executed on the same compnode and got AsyncError concurrently, because // executed on the same compnode and got AsyncError concurrently, because
// only the first async error on each comp_node would be recorded. // only the first async error on each comp_node would be recorded.
for (auto&& cn : m_used_comp_node) { if (sync_device) {
auto error = cn.check_async_error(); for (auto&& cn : m_used_comp_node) {
if (error) { auto error = cn.check_async_error();
static_cast<const OperatorNodeExcExtraInfo*>(error->extra_info()) if (error) {
->opr() static_cast<const OperatorNodeExcExtraInfo*>(error->extra_info())
->owner_graph() ->opr()
->record_async_error(std::move(error)); ->owner_graph()
->record_async_error(std::move(error));
}
} }
} }
#endif #endif
......
...@@ -520,6 +520,9 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, ...@@ -520,6 +520,9 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
*/ */
bool no_force_inplace = false; bool no_force_inplace = false;
//! whether to sync comp_node when waiting computing sequence
bool comp_seq_sync_device = true;
//! add extra deps for the comp seq if a specific var is dependent //! add extra deps for the comp seq if a specific var is dependent
ThinHashMap<VarNode*, VarNodeArray> extra_vardeps; ThinHashMap<VarNode*, VarNodeArray> extra_vardeps;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册