From d70c9780c44bda78008376e28622cd09578e951c Mon Sep 17 00:00:00 2001 From: liuqi Date: Thu, 22 Feb 2018 17:50:45 +0800 Subject: [PATCH] Move RunMetadata to public api for benchmarking model. --- mace/core/future.h | 15 +-------------- mace/core/mace.cc | 5 +++-- mace/core/net.cc | 1 - mace/core/public/mace.h | 19 ++++++++++++++++++- mace/core/runtime/opencl/opencl_runtime.cc | 4 +--- 5 files changed, 23 insertions(+), 21 deletions(-) diff --git a/mace/core/future.h b/mace/core/future.h index 41956f07..abae9fee 100644 --- a/mace/core/future.h +++ b/mace/core/future.h @@ -11,20 +11,7 @@ namespace mace { -struct CallStats { - int64_t start_micros; - int64_t end_micros; -}; - -struct OperatorStats { - std::string operator_name; - std::string type; - CallStats stats; -}; - -struct RunMetadata { - std::vector op_stats; -}; +class CallStats; // Wait the call to finish and get the stats if param is not nullptr struct StatsFuture { diff --git a/mace/core/mace.cc b/mace/core/mace.cc index 2af888c0..0d54c895 100644 --- a/mace/core/mace.cc +++ b/mace/core/mace.cc @@ -558,7 +558,8 @@ MaceEngine::~MaceEngine() { }; bool MaceEngine::Run(const float *input, const std::vector &input_shape, - float *output) { + float *output, + RunMetadata *run_metadata) { MACE_CHECK(output != nullptr, "output ptr cannot be NULL"); Tensor *input_tensor = ws_->GetTensor("mace_input_node:0"); Tensor *output_tensor = ws_->GetTensor("mace_output_node:0"); @@ -571,7 +572,7 @@ bool MaceEngine::Run(const float *input, if (device_type_ == HEXAGON) { hexagon_controller_->ExecuteGraph(*input_tensor, output_tensor); } else { - if (!net_->Run()) { + if (!net_->Run(run_metadata)) { LOG(FATAL) << "Net run failed"; } } diff --git a/mace/core/net.cc b/mace/core/net.cc index 9afb008f..aeafcc20 100644 --- a/mace/core/net.cc +++ b/mace/core/net.cc @@ -3,7 +3,6 @@ // #include "mace/core/net.h" -#include "mace/core/workspace.h" #include "mace/utils/utils.h" #include "mace/utils/memory_logging.h" diff --git a/mace/core/public/mace.h b/mace/core/public/mace.h index f68b60fd..e7036ffe 100644 --- a/mace/core/public/mace.h +++ b/mace/core/public/mace.h @@ -334,6 +334,22 @@ class NetDef { uint32_t has_bits_; }; +struct CallStats { + int64_t start_micros; + int64_t end_micros; +}; + +struct OperatorStats { + std::string operator_name; + std::string type; + CallStats stats; +}; + +struct RunMetadata { + std::vector op_stats; +}; + + class Workspace; class NetBase; class OperatorRegistry; @@ -346,7 +362,8 @@ class MaceEngine { ~MaceEngine(); bool Run(const float *input, const std::vector &input_shape, - float *output); + float *output, + RunMetadata *run_metadata = nullptr); MaceEngine(const MaceEngine &) = delete; MaceEngine &operator=(const MaceEngine &) = delete; diff --git a/mace/core/runtime/opencl/opencl_runtime.cc b/mace/core/runtime/opencl/opencl_runtime.cc index 5b7ccdd8..f0041180 100644 --- a/mace/core/runtime/opencl/opencl_runtime.cc +++ b/mace/core/runtime/opencl/opencl_runtime.cc @@ -8,11 +8,9 @@ #include #include "mace/core/runtime/opencl/opencl_runtime.h" -#include "mace/utils/logging.h" +#include "mace/core/public/mace.h" #include "mace/utils/tuner.h" -#include - namespace mace { namespace { -- GitLab