提交 8fddd808 编写于 作者: M Megvii Engine Team

fix(profiler): respect record_device option

GitOrigin-RevId: 7c9a8cfba773061218fa3cda98fcc305333e7945
上级 59d59766
......@@ -724,14 +724,14 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) {
// Before execute
for (auto&& [device, kernel_id]: kernels) {
MGB_RECORD_EVENT(KernelLaunchEvent, apply_id, kernel_id, device);
MGB_RECORD_EVENT(RecordDeviceEvent, Timer::record_device(device));
MGB_RECORD_EVENT_IF((Profiler::get_option("profile_device", 0)), RecordDeviceEvent, Timer::record_device(device));
}
// Apply op
// Here std::move is REQUIRED for removing duplicated references.
auto outputs = apply_on_physical_tensor(apply_on_physical_tensor, *cmd.op, inputs);
// After execute
for (auto&& [device, kernel_id]: kernels) {
MGB_RECORD_EVENT(RecordDeviceEvent, Timer::record_device(device));
MGB_RECORD_EVENT_IF((Profiler::get_option("profile_device", 0)), RecordDeviceEvent, Timer::record_device(device));
MGB_RECORD_EVENT(KernelLaunchFinishEvent, apply_id, kernel_id, device);
}
// End profiling operator
......@@ -1009,9 +1009,9 @@ void ChannelImpl::process_one_task(Command& icmd) {
using T = std::decay_t<decltype(cmd)>;
if constexpr (std::is_same_v<T, Put>) {
MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Put);
MGB_RECORD_EVENT(RecordDeviceEvent, Timer::record_device(cmd.value.comp_node()));
MGB_RECORD_EVENT_IF((Profiler::get_option("profile_device", 0)), RecordDeviceEvent, Timer::record_device(cmd.value.comp_node()));
auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value) : Tensor::make(cmd.value);
MGB_RECORD_EVENT(RecordDeviceEvent, Timer::record_device(cmd.value.comp_node()));
MGB_RECORD_EVENT_IF((Profiler::get_option("profile_device", 0)), RecordDeviceEvent, Timer::record_device(cmd.value.comp_node()));
produce_tensor(cmd.dest, std::move(value));
MGB_RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Put);
sample_on_device(cmd.dest->desc.comp_node, false);
......@@ -1136,7 +1136,7 @@ void ChannelImpl::process_one_task(Command& icmd) {
if (Profiler::get_option("sample_rate", 0)) {
sample_on_device(device, true);
}
MGB_RECORD_EVENT(RecordDeviceEvent, Timer::record_device(device));
MGB_RECORD_EVENT_IF((Profiler::get_option("profile_device", 0)), RecordDeviceEvent, Timer::record_device(device));
});
MGB_RECORD_EVENT(StartProfileFinishEvent);
} else if constexpr (std::is_same_v<T, StopProfile>) {
......
......@@ -302,6 +302,8 @@ struct ChromeTimelineEventVisitor: EventVisitor<ChromeTimelineEventVisitor> {
} else if constexpr (std::is_same_v<TEvent, TensorGetPropEvent>) {
new_host_event("TensorGetProp", 'X')
.dur(0).args(current_tensor->detail(current->time));
} else if constexpr (std::is_same_v<TEvent, TensorWaitPropEvent>) {
new_host_event("TensorWaitProp", 'B');
} else if constexpr (std::is_same_v<TEvent, TensorWaitPropFinishEvent>) {
new_host_event(pid_str, 'f')
.id(event.tensor_id)
......
......@@ -26,7 +26,7 @@ ProfilerPlugin::ProfilerPlugin(cg::ComputingGraph* graph): PluginBase(graph) {
// reset
mgb_assert(!event.graph->options().imperative_proxy_graph);
CompNode::foreach([](CompNode device){
Profiler::record<RecordDeviceEvent>(Timer::record_device(device));
MGB_RECORD_EVENT_IF((Profiler::get_option("profile_device", 0)), RecordDeviceEvent, Timer::record_device(device));
});
if (m_opr_dict.empty() && m_var_dict.empty()) {
init_seq(event.exec);
......@@ -47,22 +47,21 @@ ProfilerPlugin::ProfilerPlugin(cg::ComputingGraph* graph): PluginBase(graph) {
Profiler::record<ScopeEvent>("DispatchOprs");
event.exec->iter_opr_seq([this](OperatorNodeBase* opr) -> bool{
auto& opr_info = get_opr_info(opr);
SmallVector<uint64_t> inputs;
for (auto input: opr->input()) {
inputs.push_back(get_var_info(input).id);
}
SmallVector<uint64_t> outputs;
for (auto output: opr->output()) {
outputs.push_back(get_var_info(output).id);
auto& var_id = get_var_info(output).id;
var_id = Profiler::next_id();
Profiler::record<TensorDeclareEvent>(var_id, output->name());
}
auto opr_name = opr->dyn_typeinfo()->name;
auto copy_params = [params = opr_info.params] { return *params; };
SmallVector<uint64_t> inputs, outputs;
for (auto input: opr->input()) {
inputs.push_back(get_var_info(input).id);
}
for (auto output: opr->output()) {
auto& var_id = get_var_info(output).id;
var_id = Profiler::next_id();
Profiler::record<TensorDeclareEvent>(var_id, output->name());
outputs.push_back(get_var_info(output).id);
}
Profiler::record<OpDispatchEvent>(opr_info.id, opr_name, copy_params, inputs, outputs);
Profiler::record<OpDispatchEvent>(opr_info.id = Profiler::next_id(), opr_name, copy_params, inputs, outputs);
return true;
});
Profiler::record<ScopeFinishEvent>("DispatchOprs");
......@@ -128,12 +127,12 @@ ProfilerPlugin::ProfilerPlugin(cg::ComputingGraph* graph): PluginBase(graph) {
auto on_before_kern = [this](BeforeKernel const& event) {
OperatorNodeBase* opr = event.opr;
Profiler::record<KernelLaunchEvent>(get_opr_info(opr).id, get_opr_info(opr).id, event.comp_node);
Profiler::record<RecordDeviceEvent>(Timer::record_device(event.comp_node));
MGB_RECORD_EVENT_IF((Profiler::get_option("profile_device", 0)), RecordDeviceEvent, Timer::record_device(event.comp_node));
};
auto on_after_kern = [this](AfterKernel const& event) {
OperatorNodeBase* opr = event.opr;
Profiler::record<RecordDeviceEvent>(Timer::record_device(event.comp_node));
Profiler::record<KernelLaunchEvent>(get_opr_info(opr).id, get_opr_info(opr).id, event.comp_node);
MGB_RECORD_EVENT_IF((Profiler::get_option("profile_device", 0)), RecordDeviceEvent, Timer::record_device(event.comp_node));
Profiler::record<KernelLaunchFinishEvent>(get_opr_info(opr).id, get_opr_info(opr).id, event.comp_node);
};
auto on_graph_compile = [this](const CompSeqOrderDetermined&) {
m_opr_dict.clear();
......@@ -182,7 +181,6 @@ void ProfilerPlugin::init_seq(cg::AsyncExecutable *comp_seq) {
ProfilerPlugin::OprInfo& ProfilerPlugin::register_opr(cg::OperatorNodeBase *opr) {
OprInfo info;
info.id = Profiler::next_id();
auto params = std::make_shared<std::unordered_map<std::string, std::string>>();
auto params_json = opr->to_json();
for (auto&& [k, v]: params_json->cast_final<json::Object>().get_impl()) {
......
......@@ -233,5 +233,10 @@ public:
mgb::imperative::Profiler::record<type>(type{__VA_ARGS__}); \
} \
#define MGB_RECORD_EVENT_IF(expr, type, ...) \
if (mgb::imperative::Profiler::is_profiling() && (expr)) { \
mgb::imperative::Profiler::record<type>(type{__VA_ARGS__}); \
} \
} // namespace imperative
} // namespace mgb
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册