提交 ccc79090 编写于 作者: D dingminghui 提交者: jackzhang235

refactor: reduce duplicated code and fix tensor dump error

上级 4bb98d71
...@@ -150,6 +150,23 @@ class Graph { ...@@ -150,6 +150,23 @@ class Graph {
CNML_CALL(cnmlCompileFusionOp_V2(fusion_op_)); CNML_CALL(cnmlCompileFusionOp_V2(fusion_op_));
} }
#define MEASURE_HWTIME_START(que) \
do { \
CNRT_CALL(cnrtPlaceNotifier(notifier_start_, que)); \
} while (0)
#define MEASURE_HWTIME_END(que) \
do { \
thread_local float hw_time; \
CNRT_CALL(cnrtPlaceNotifier(notifier_end_, que)); \
CNRT_CALL(cnrtSyncQueue(que)); \
CNRT_CALL(cnrtNotifierDuration(notifier_start_, notifier_end_, &hw_time)); \
hw_time /= 1000.0f; \
DLOG(INFO) << "cnml hardware time " << hw_time << "ms" << std::endl; \
std::lock_guard<std::mutex> lk(time_mut_); \
time_log_.push_back(hw_time); \
} while (0)
void Compute(cnrtInvokeFuncParam_t forward_param, cnrtQueue_t que) { void Compute(cnrtInvokeFuncParam_t forward_param, cnrtQueue_t que) {
input_addrs_.resize(input_tensors_.size()); input_addrs_.resize(input_tensors_.size());
output_addrs_.resize(output_tensors_.size()); output_addrs_.resize(output_tensors_.size());
...@@ -161,8 +178,7 @@ class Graph { ...@@ -161,8 +178,7 @@ class Graph {
} }
#if PRINT_HW_TIME #if PRINT_HW_TIME
thread_local float hw_time; MEASURE_HWTIME_START(que);
CNRT_CALL(cnrtPlaceNotifier(notifier_start_, que));
#endif #endif
CNML_CALL(cnmlComputeFusionOpForward_V3(fusion_op_, CNML_CALL(cnmlComputeFusionOpForward_V3(fusion_op_,
input_addrs_.data(), input_addrs_.data(),
...@@ -172,18 +188,11 @@ class Graph { ...@@ -172,18 +188,11 @@ class Graph {
&forward_param, &forward_param,
que)); que));
#if PRINT_HW_TIME #if PRINT_HW_TIME
CNRT_CALL(cnrtPlaceNotifier(notifier_end_, que)); MEASURE_HWTIME_END(que);
CNRT_CALL(cnrtSyncQueue(que));
CNRT_CALL(cnrtNotifierDuration(notifier_start_, notifier_end_, &hw_time));
hw_time /= 1000.0f;
DLOG(INFO) << "cnml hardware time " << hw_time << "ms" << std::endl;
std::lock_guard<std::mutex> lk(time_mut_);
time_log_.push_back(hw_time);
#endif #endif
} }
void Compute(cnrtInvokeFuncParam_t forward_param, void Compute(cnrtQueue_t que,
cnrtQueue_t que,
const std::vector<std::shared_ptr<MLUTensor>>& in, const std::vector<std::shared_ptr<MLUTensor>>& in,
const std::vector<std::shared_ptr<MLUTensor>>& out) { const std::vector<std::shared_ptr<MLUTensor>>& out) {
std::vector<cnmlTensor_t> in_tensor; std::vector<cnmlTensor_t> in_tensor;
...@@ -200,8 +209,7 @@ class Graph { ...@@ -200,8 +209,7 @@ class Graph {
} }
#if PRINT_HW_TIME #if PRINT_HW_TIME
thread_local float hw_time; MEASURE_HWTIME_START(que);
CNRT_CALL(cnrtPlaceNotifier(notifier_start_, que));
#endif #endif
/* Because of using cnmlSetTensorDimMutable, cnmlComputeFusionOpForward_V3 /* Because of using cnmlSetTensorDimMutable, cnmlComputeFusionOpForward_V3
* -> cnmlComputeFusionOpForward_V4 */ * -> cnmlComputeFusionOpForward_V4 */
...@@ -215,15 +223,11 @@ class Graph { ...@@ -215,15 +223,11 @@ class Graph {
que, que,
NULL)); NULL));
#if PRINT_HW_TIME #if PRINT_HW_TIME
CNRT_CALL(cnrtPlaceNotifier(notifier_end_, que)); MEASURE_HWTIME_END(que);
CNRT_CALL(cnrtSyncQueue(que));
CNRT_CALL(cnrtNotifierDuration(notifier_start_, notifier_end_, &hw_time));
hw_time /= 1000.0f;
DLOG(INFO) << "cnml hardware time " << hw_time << "ms" << std::endl;
std::lock_guard<std::mutex> lk(time_mut_);
time_log_.push_back(hw_time);
#endif #endif
} }
#undef MEASURE_HWTIME_START
#undef MEASURE_HWTIME_END
template <typename T> template <typename T>
void* RegisterConstData(size_t len) { void* RegisterConstData(size_t len) {
......
...@@ -56,12 +56,6 @@ void LaunchOp(const std::shared_ptr<lite::OpLite> op, ...@@ -56,12 +56,6 @@ void LaunchOp(const std::shared_ptr<lite::OpLite> op,
CNRT_CALL(cnrtInit(0)); CNRT_CALL(cnrtInit(0));
lite::SetMluDevice(0); lite::SetMluDevice(0);
cnrtQueue_t queue_; cnrtQueue_t queue_;
cnrtInvokeFuncParam_t forward_param;
u32_t affinity = 1;
int data_param = 1;
forward_param.data_parallelism = &data_param;
forward_param.affinity = &affinity;
forward_param.end = CNRT_PARAM_END;
CNRT_CALL(cnrtCreateQueue(&queue_)); CNRT_CALL(cnrtCreateQueue(&queue_));
cnrtDev_t dev_handle; cnrtDev_t dev_handle;
CNRT_CALL(cnrtGetDeviceHandle(&dev_handle, 0)); CNRT_CALL(cnrtGetDeviceHandle(&dev_handle, 0));
...@@ -113,10 +107,7 @@ void LaunchOp(const std::shared_ptr<lite::OpLite> op, ...@@ -113,10 +107,7 @@ void LaunchOp(const std::shared_ptr<lite::OpLite> op,
} }
graph.Compile(CNML_MLU270, 1); graph.Compile(CNML_MLU270, 1);
graph.Compute(forward_param, graph.Compute(queue_, *(graph.MutableInputs()), *(graph.MutableOutputs()));
queue_,
*(graph.MutableInputs()),
*(graph.MutableOutputs()));
CNRT_CALL(cnrtSyncQueue(queue_)); CNRT_CALL(cnrtSyncQueue(queue_));
for (auto& output_name : output_var_names) { for (auto& output_name : output_var_names) {
......
...@@ -330,12 +330,6 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -330,12 +330,6 @@ class SubgraphEngine : public subgraph::Engine {
// prepare input and output memory // prepare input and output memory
auto& mlu_context = this->ctx_->template As<MLUContext>(); auto& mlu_context = this->ctx_->template As<MLUContext>();
auto exec_queue = mlu_context.exec_queue(); auto exec_queue = mlu_context.exec_queue();
u32_t affinity = mlu_context.affinity();
cnrtInvokeFuncParam_t forward_param = mlu_context.forward_param();
int data_param = 1;
forward_param.data_parallelism = &data_param;
forward_param.affinity = &affinity;
forward_param.end = CNRT_PARAM_END;
auto graph = shape_graph_map_[inputs_shape_]; auto graph = shape_graph_map_[inputs_shape_];
auto* graph_input = graph->MutableInputs(); auto* graph_input = graph->MutableInputs();
...@@ -402,7 +396,7 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -402,7 +396,7 @@ class SubgraphEngine : public subgraph::Engine {
} }
shape_tensor_map_out_[all_inputs_shape_] = graph_out; shape_tensor_map_out_[all_inputs_shape_] = graph_out;
} }
graph->Compute(forward_param, exec_queue, graph_in, graph_out); graph->Compute(exec_queue, graph_in, graph_out);
} else { } else {
for (size_t i = 0; i < origin_itensors_.size(); ++i) { for (size_t i = 0; i < origin_itensors_.size(); ++i) {
graph_input->at(i)->set_mlu_ptr( graph_input->at(i)->set_mlu_ptr(
...@@ -413,36 +407,49 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -413,36 +407,49 @@ class SubgraphEngine : public subgraph::Engine {
graph_output->at(i)->set_mlu_ptr( graph_output->at(i)->set_mlu_ptr(
GetOutputDataPtr(origin_otensors_[i], !disable_mlu_cast)); GetOutputDataPtr(origin_otensors_[i], !disable_mlu_cast));
} }
// only cnmlComputeFusionOpForward_V3 need cnrtInvokeFuncParam_t
cnrtInvokeFuncParam_t forward_param = mlu_context.forward_param();
int data_param = 1;
forward_param.data_parallelism = &data_param;
u32_t affinity = mlu_context.affinity();
forward_param.affinity = &affinity;
forward_param.end = CNRT_PARAM_END;
graph->Compute(forward_param, exec_queue); graph->Compute(forward_param, exec_queue);
#ifdef MLU_DUMP_SUBGRAPH_IO
// Graph node store compile-time tensor while batchsize mutable is set.
// Only batchsize mutable is disabled, data exists in graph node at
// runtime
// =========== DUMP ===================
for (auto input_name : input_names_) {
auto input_tensor =
shape_graph_map_[inputs_shape_]->GetNode(input_name);
auto dump_name = input_name;
while (dump_name.find("/") != std::string::npos) {
dump_name = dump_name.replace(dump_name.find("/"), 1, "_");
}
VLOG(6) << "dump_name: " << dump_name;
input_tensor->ToFile(dump_name);
}
for (auto output_name : output_names_) {
if (shape_graph_map_[inputs_shape_]->HasNode(output_name)) {
auto output_tensor =
shape_graph_map_[inputs_shape_]->GetNode(output_name);
auto dump_name = output_name;
while (dump_name.find("/") != std::string::npos) {
dump_name = dump_name.replace(dump_name.find("/"), 1, "_");
}
VLOG(6) << "dump_name: " << dump_name;
output_tensor->ToFile(dump_name);
} else {
VLOG(6) << "graph does not have " << output_name << " as output"
<< std::endl;
}
}
#endif
// =========== DUMP END ================
} }
// // =========== DUMP ===================
// for (auto input_name : input_names_) {
// auto input_tensor =
// shape_graph_map_[inputs_shape_]->GetNode(input_name);
// auto dump_name = input_name;
// while (dump_name.find("/") != std::string::npos) {
// dump_name = dump_name.replace(dump_name.find("/"), 1, "_");
// }
// VLOG(6) << "dump_name: " << dump_name;
// input_tensor->ToFile(dump_name);
// }
// for (auto output_name : output_names_) {
// if (shape_graph_map_[inputs_shape_]->HasNode(output_name)) {
// auto output_tensor =
// shape_graph_map_[inputs_shape_]->GetNode(output_name);
// auto dump_name = output_name;
// while (dump_name.find("/") != std::string::npos) {
// dump_name = dump_name.replace(dump_name.find("/"), 1, "_");
// }
// VLOG(6) << "dump_name: " << dump_name;
// output_tensor->ToFile(dump_name);
// } else {
// VLOG(6) << "graph does not have " << output_name << " as output"
// << std::endl;
// }
// }
// // =========== DUMP END ================
return 0; return 0;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册