提交 d70c9780 编写于 作者: L liuqi

Move RunMetadata to public api for benchmarking model.

上级 3f803f84
...@@ -11,20 +11,7 @@ ...@@ -11,20 +11,7 @@
namespace mace { namespace mace {
struct CallStats { class CallStats;
int64_t start_micros;
int64_t end_micros;
};
struct OperatorStats {
std::string operator_name;
std::string type;
CallStats stats;
};
struct RunMetadata {
std::vector<OperatorStats> op_stats;
};
// Wait the call to finish and get the stats if param is not nullptr // Wait the call to finish and get the stats if param is not nullptr
struct StatsFuture { struct StatsFuture {
......
...@@ -558,7 +558,8 @@ MaceEngine::~MaceEngine() { ...@@ -558,7 +558,8 @@ MaceEngine::~MaceEngine() {
}; };
bool MaceEngine::Run(const float *input, bool MaceEngine::Run(const float *input,
const std::vector<index_t> &input_shape, const std::vector<index_t> &input_shape,
float *output) { float *output,
RunMetadata *run_metadata) {
MACE_CHECK(output != nullptr, "output ptr cannot be NULL"); MACE_CHECK(output != nullptr, "output ptr cannot be NULL");
Tensor *input_tensor = ws_->GetTensor("mace_input_node:0"); Tensor *input_tensor = ws_->GetTensor("mace_input_node:0");
Tensor *output_tensor = ws_->GetTensor("mace_output_node:0"); Tensor *output_tensor = ws_->GetTensor("mace_output_node:0");
...@@ -571,7 +572,7 @@ bool MaceEngine::Run(const float *input, ...@@ -571,7 +572,7 @@ bool MaceEngine::Run(const float *input,
if (device_type_ == HEXAGON) { if (device_type_ == HEXAGON) {
hexagon_controller_->ExecuteGraph(*input_tensor, output_tensor); hexagon_controller_->ExecuteGraph(*input_tensor, output_tensor);
} else { } else {
if (!net_->Run()) { if (!net_->Run(run_metadata)) {
LOG(FATAL) << "Net run failed"; LOG(FATAL) << "Net run failed";
} }
} }
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
// //
#include "mace/core/net.h" #include "mace/core/net.h"
#include "mace/core/workspace.h"
#include "mace/utils/utils.h" #include "mace/utils/utils.h"
#include "mace/utils/memory_logging.h" #include "mace/utils/memory_logging.h"
......
...@@ -334,6 +334,22 @@ class NetDef { ...@@ -334,6 +334,22 @@ class NetDef {
uint32_t has_bits_; uint32_t has_bits_;
}; };
struct CallStats {
int64_t start_micros;
int64_t end_micros;
};
struct OperatorStats {
std::string operator_name;
std::string type;
CallStats stats;
};
struct RunMetadata {
std::vector<OperatorStats> op_stats;
};
class Workspace; class Workspace;
class NetBase; class NetBase;
class OperatorRegistry; class OperatorRegistry;
...@@ -346,7 +362,8 @@ class MaceEngine { ...@@ -346,7 +362,8 @@ class MaceEngine {
~MaceEngine(); ~MaceEngine();
bool Run(const float *input, bool Run(const float *input,
const std::vector<int64_t> &input_shape, const std::vector<int64_t> &input_shape,
float *output); float *output,
RunMetadata *run_metadata = nullptr);
MaceEngine(const MaceEngine &) = delete; MaceEngine(const MaceEngine &) = delete;
MaceEngine &operator=(const MaceEngine &) = delete; MaceEngine &operator=(const MaceEngine &) = delete;
......
...@@ -8,11 +8,9 @@ ...@@ -8,11 +8,9 @@
#include <mutex> #include <mutex>
#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/utils/logging.h" #include "mace/core/public/mace.h"
#include "mace/utils/tuner.h" #include "mace/utils/tuner.h"
#include <CL/opencl.h>
namespace mace { namespace mace {
namespace { namespace {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册