diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index c3ed10d7a52d7d9c13bc3a9819354a7975b4c446..6e448ace963b16dba3858b240d2b9b765fb9137d 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -339,16 +339,17 @@ void ChannelImpl::dispatch_kernel( auto& state = get_channel_state(); auto& options = state.options; + auto name = op->trait()->make_name(*op); + auto _ = StackManager::Guard{name, &state.stack_manager}; + auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); MGB_RECORD_EVENT(ShapeInferEvent, validated); SmallVector output_infos; output_infos.reserve(output_descs.size()); - uint64_t apply_id = Profiler::next_id(); outputs->reserve(output_descs.size()); - for (int i = 0; i < output_descs.size(); ++i) { auto&& desc = output_descs[i]; auto info = alloc(); @@ -361,31 +362,28 @@ void ChannelImpl::dispatch_kernel( output_infos.push_back(info); outputs->push_back(reinterpret_cast(info)); } - auto op_info_getter = [op] { - std::unordered_map op_info; - auto props = OpDef::props(*op); - for (auto&& [key, value] : props) { - op_info[key] = value; - } - return op_info; - }; + ApplyOp cmd{ + Profiler::next_id(), std::move(op), std::move(input_infos), + std::move(output_infos), validated}; if (Profiler::is_profiling()) { - auto name = op->trait()->make_name(*op); - auto _ = StackManager::Guard{name, &state.stack_manager}; + auto op_info_getter = [op = cmd.op] { + std::unordered_map op_info; + auto props = OpDef::props(*op); + for (auto&& [key, value] : props) { + op_info[key] = value; + } + return op_info; + }; MGB_RECORD_EVENT( - OpDispatchEvent, apply_id, name, op_info_getter, - tinfo_to_tid(std::move(input_infos)), - tinfo_to_tid(std::move(output_infos)), state.stack_manager.dump()); + OpDispatchEvent, cmd.id, name, op_info_getter, tinfo_to_tid(cmd.inputs), + tinfo_to_tid(cmd.outputs), state.stack_manager.dump()); m_worker.add_task( - {Profiler::next_id(), - ApplyOp{apply_id, std::move(op), std::move(input_infos), - std::move(output_infos), validated}, + {Profiler::next_id(), std::move(cmd), get_channel_state().stack_manager.dump()}); } else { m_worker.add_task({ Profiler::next_id(), - ApplyOp{apply_id, std::move(op), std::move(input_infos), - std::move(output_infos), validated}, + std::move(cmd), }); } if (!validated && options.async_level == 1) {