提交 bb2507b2 编写于 作者: D dingminghui 提交者: jackzhang235

feat(graph): support measure hardware time

上级 ada6790e
...@@ -23,6 +23,12 @@ ...@@ -23,6 +23,12 @@
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
#include "lite/kernels/mlu/bridges/tensor.h" #include "lite/kernels/mlu/bridges/tensor.h"
#define PRINT_HW_TIME false
#if PRINT_HW_TIME
#include <mutex>
#endif
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace subgraph { namespace subgraph {
...@@ -32,7 +38,13 @@ namespace mlu { ...@@ -32,7 +38,13 @@ namespace mlu {
// to the MLU IR graph // to the MLU IR graph
class Graph { class Graph {
public: public:
Graph() { CNML_CALL(cnmlCreateFusionOp(&fusion_op_)); } Graph() {
CNML_CALL(cnmlCreateFusionOp(&fusion_op_));
#if PRINT_HW_TIME
CNRT_CALL(cnrtCreateNotifier(&notifier_start_));
CNRT_CALL(cnrtCreateNotifier(&notifier_end_));
#endif
}
~Graph() { ~Graph() {
FreeConstData(); FreeConstData();
...@@ -40,6 +52,16 @@ class Graph { ...@@ -40,6 +52,16 @@ class Graph {
for (auto op : ops_) { for (auto op : ops_) {
CNML_CALL(cnmlDestroyBaseOp(&op)); CNML_CALL(cnmlDestroyBaseOp(&op));
} }
#if PRINT_HW_TIME
CNRT_CALL(cnrtDestroyNotifier(&notifier_start_));
CNRT_CALL(cnrtDestroyNotifier(&notifier_end_));
double total_time = 0;
for (auto& f : time_log_) {
total_time += f;
}
std::cout << "cnml hardware time for " << time_log_.size()
<< " process:" << total_time / time_log_.size() << std::endl;
#endif
} }
// Data node // Data node
...@@ -90,6 +112,10 @@ class Graph { ...@@ -90,6 +112,10 @@ class Graph {
} }
void Compute(cnrtInvokeFuncParam_t forward_param, cnrtQueue_t que) { void Compute(cnrtInvokeFuncParam_t forward_param, cnrtQueue_t que) {
#if PRINT_HW_TIME
thread_local float hw_time;
CNRT_CALL(cnrtPlaceNotifier(notifier_start_, que));
#endif
CNML_CALL(cnmlComputeFusionOpForward_V3(fusion_op_, CNML_CALL(cnmlComputeFusionOpForward_V3(fusion_op_,
input_addrs_.data(), input_addrs_.data(),
input_addrs_.size(), input_addrs_.size(),
...@@ -97,7 +123,18 @@ class Graph { ...@@ -97,7 +123,18 @@ class Graph {
output_addrs_.size(), output_addrs_.size(),
&forward_param, &forward_param,
que)); que));
#if PRINT_HW_TIME
CNRT_CALL(cnrtPlaceNotifier(notifier_end_, que));
#endif
CNRT_CALL(cnrtSyncQueue(que)); CNRT_CALL(cnrtSyncQueue(que));
#if PRINT_HW_TIME
CNRT_CALL(cnrtNotifierDuration(notifier_start_, notifier_end_, &hw_time));
hw_time /= 1000.0f;
DLOG(INFO) << "cnml hardware time " << hw_time << "ms" << std::endl;
std::lock_guard<std::mutex> lk(time_mut_);
time_log_.push_back(hw_time);
#endif
} }
template <typename T> template <typename T>
...@@ -203,6 +240,11 @@ class Graph { ...@@ -203,6 +240,11 @@ class Graph {
std::vector<cnmlBaseOp_t> ops_; std::vector<cnmlBaseOp_t> ops_;
cnmlFusionOp_t fusion_op_; cnmlFusionOp_t fusion_op_;
std::vector<void*> const_data_storage_; std::vector<void*> const_data_storage_;
#if PRINT_HW_TIME
cnrtNotifier_t notifier_start_{}, notifier_end_{};
std::mutex time_mut_;
std::vector<float> time_log_;
#endif
}; };
} // namespace mlu } // namespace mlu
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册