提交 53da5c79 编写于 作者: M Megvii Engine Team

feat(cg): add comp_seq_sync_device option

GitOrigin-RevId: c2199c59e9744352207e8cbd6d53616d968319bf
上级 e1c7b22f
......@@ -288,6 +288,7 @@ ComputingGraphHolder& get_computing_graph(std::shared_ptr<OpDef> compiled_op, Sm
cg_holder.graph->options().async_exec_level = 0;
cg_holder.graph->options().graph_opt_level = compiled_op->cast_final_safe<CompiledOp>().gopt_level;
cg_holder.graph->options().enable_var_mem_defragment = false;
cg_holder.graph->options().comp_seq_sync_device = false;
cg_holder.graph->set_device_memory_allocator(cg_holder.allocator);
// cg_holder.graph->options().graph_opt.jit = 2;
VarNodeArray input_vars;
......
......@@ -385,21 +385,27 @@ void ComputingGraphImpl::ComputingSequence::do_wait(bool explicit_user_wait) {
}
}
for (auto cn : m_used_comp_node) {
m_event_end.at(cn)->host_wait();
bool sync_device = m_owner_graph->options().comp_seq_sync_device;
if (sync_device) {
for (auto cn : m_used_comp_node) {
m_event_end.at(cn)->host_wait();
}
}
m_wait_finished = true;
#if MGB_NEED_MEGDNN_ASYNC_ERROR
// FIXME: It CAN NOT work well if more than one ComputingSequnces has been
// executed on the same compnode and got AsyncError concurrently, because
// only the first async error on each comp_node would be recorded.
for (auto&& cn : m_used_comp_node) {
auto error = cn.check_async_error();
if (error) {
static_cast<const OperatorNodeExcExtraInfo*>(error->extra_info())
->opr()
->owner_graph()
->record_async_error(std::move(error));
if (sync_device) {
for (auto&& cn : m_used_comp_node) {
auto error = cn.check_async_error();
if (error) {
static_cast<const OperatorNodeExcExtraInfo*>(error->extra_info())
->opr()
->owner_graph()
->record_async_error(std::move(error));
}
}
}
#endif
......
......@@ -520,6 +520,9 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
*/
bool no_force_inplace = false;
//! whether to sync comp_node when waiting computing sequence
bool comp_seq_sync_device = true;
//! add extra deps for the comp seq if a specific var is dependent
ThinHashMap<VarNode*, VarNodeArray> extra_vardeps;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册