未验证 提交 d6d745d2 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] add flag to control timer (#39241)

上级 41a64351
......@@ -611,27 +611,43 @@ bool DistModel::Run(const std::vector<DistModelTensor> &input_data,
DistModelTimer timer;
timer.tic();
double feed_elapse;
double fleet_exe_elapse;
double fetch_elapse;
if (!FeedData(input_data, scope_.get())) {
LOG(ERROR) << "DistModel failed at feeding data.";
return false;
}
double feed_elapse = timer.toc();
VLOG(3) << "Finish loading data, cost " << feed_elapse << "ms.";
if (config_.enable_timer) {
feed_elapse = timer.toc();
LOG(INFO) << "Finish loading data, cost " << feed_elapse << "ms.";
} else {
VLOG(3) << "Finish loading data.";
}
fleet_exe->Run(carrier_id_);
double fleet_exe_elapse = timer.toc();
VLOG(3) << "Finish FleetExe running, cost " << fleet_exe_elapse - feed_elapse
<< "ms.";
if (config_.enable_timer) {
fleet_exe_elapse = timer.toc();
LOG(INFO) << "Finish FleetExe running, cost "
<< fleet_exe_elapse - feed_elapse << "ms.";
} else {
VLOG(3) << "Finish FleetExe running.";
}
if (!FetchResults(output_data, scope_.get())) {
LOG(ERROR) << "DistModel failed at fetching result.";
return false;
}
double fetch_elapse = timer.toc();
VLOG(3) << "Finish fetching data, cost " << fetch_elapse - fleet_exe_elapse
<< "ms.";
VLOG(3) << "DistModel finish inf, cost " << fetch_elapse << "ms";
if (config_.enable_timer) {
fetch_elapse = timer.toc();
LOG(INFO) << "Finish fetching data, cost "
<< fetch_elapse - fleet_exe_elapse << "ms.";
LOG(INFO) << "DistModel finish inf, cost " << fetch_elapse << "ms";
} else {
VLOG(3) << "Finish fetching data.";
VLOG(3) << "DistModel finish inf.";
}
return true;
}
......
......@@ -52,6 +52,7 @@ struct DistModelConfig {
int64_t mp_ring_id{-1};
int64_t pp_upstream_ring_id{-1};
int64_t pp_downstream_ring_id{-1};
bool enable_timer{false};
};
class DistModel {
......
......@@ -154,6 +154,7 @@ void BindFleetExecutor(py::module* m) {
.def_readwrite("mp_degree", &DistModelConfig::mp_degree)
.def_readwrite("pp_degree", &DistModelConfig::pp_degree)
.def_readwrite("mp_ring_id", &DistModelConfig::mp_ring_id)
.def_readwrite("enable_timer", &DistModelConfig::enable_timer)
.def_readwrite("pp_upstream_ring_id",
&DistModelConfig::pp_upstream_ring_id)
.def_readwrite("pp_downstream_ring_id",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册