diff --git a/mace/core/future.h b/mace/core/future.h index 41956f07985f4f6335a4645262ce7b1a737d7865..abae9feed57ad5866ebd2c3228776ad0f52e007b 100644 --- a/mace/core/future.h +++ b/mace/core/future.h @@ -11,20 +11,7 @@ namespace mace { -struct CallStats { - int64_t start_micros; - int64_t end_micros; -}; - -struct OperatorStats { - std::string operator_name; - std::string type; - CallStats stats; -}; - -struct RunMetadata { - std::vector op_stats; -}; +class CallStats; // Wait the call to finish and get the stats if param is not nullptr struct StatsFuture { diff --git a/mace/core/mace.cc b/mace/core/mace.cc index 2af888c02655b2953d4804ed2462103207130cad..0d54c89563210e02dee44538f3164ff23e89943c 100644 --- a/mace/core/mace.cc +++ b/mace/core/mace.cc @@ -558,7 +558,8 @@ MaceEngine::~MaceEngine() { }; bool MaceEngine::Run(const float *input, const std::vector &input_shape, - float *output) { + float *output, + RunMetadata *run_metadata) { MACE_CHECK(output != nullptr, "output ptr cannot be NULL"); Tensor *input_tensor = ws_->GetTensor("mace_input_node:0"); Tensor *output_tensor = ws_->GetTensor("mace_output_node:0"); @@ -571,7 +572,7 @@ bool MaceEngine::Run(const float *input, if (device_type_ == HEXAGON) { hexagon_controller_->ExecuteGraph(*input_tensor, output_tensor); } else { - if (!net_->Run()) { + if (!net_->Run(run_metadata)) { LOG(FATAL) << "Net run failed"; } } diff --git a/mace/core/net.cc b/mace/core/net.cc index 9afb008ff03b4e4980f198d1f8bd32ebb0ccb42f..aeafcc203667b28e0d00e85eddae84c6e7db3f0d 100644 --- a/mace/core/net.cc +++ b/mace/core/net.cc @@ -3,7 +3,6 @@ // #include "mace/core/net.h" -#include "mace/core/workspace.h" #include "mace/utils/utils.h" #include "mace/utils/memory_logging.h" diff --git a/mace/core/public/mace.h b/mace/core/public/mace.h index f68b60fde36d6c52d647ed8198e05151d02757c6..e7036ffe00cbff3ff0eb20a7617a036a2e36d33f 100644 --- a/mace/core/public/mace.h +++ b/mace/core/public/mace.h @@ -334,6 +334,22 @@ class NetDef { uint32_t has_bits_; }; +struct CallStats { + int64_t start_micros; + int64_t end_micros; +}; + +struct OperatorStats { + std::string operator_name; + std::string type; + CallStats stats; +}; + +struct RunMetadata { + std::vector op_stats; +}; + + class Workspace; class NetBase; class OperatorRegistry; @@ -346,7 +362,8 @@ class MaceEngine { ~MaceEngine(); bool Run(const float *input, const std::vector &input_shape, - float *output); + float *output, + RunMetadata *run_metadata = nullptr); MaceEngine(const MaceEngine &) = delete; MaceEngine &operator=(const MaceEngine &) = delete; diff --git a/mace/core/runtime/opencl/opencl_runtime.cc b/mace/core/runtime/opencl/opencl_runtime.cc index 5b7ccdd8a7f24aec1247fe0a60b22c1b915f37eb..f0041180ed1570a0b452e63f8bb6fbc124161f92 100644 --- a/mace/core/runtime/opencl/opencl_runtime.cc +++ b/mace/core/runtime/opencl/opencl_runtime.cc @@ -8,11 +8,9 @@ #include #include "mace/core/runtime/opencl/opencl_runtime.h" -#include "mace/utils/logging.h" +#include "mace/core/public/mace.h" #include "mace/utils/tuner.h" -#include - namespace mace { namespace {