From f6f79b832a56e17ad269d3e45e40333a2a41c268 Mon Sep 17 00:00:00 2001 From: wuchenghui Date: Mon, 5 Mar 2018 17:18:37 +0800 Subject: [PATCH] fix benchmark & support global_avg_pool convert --- benchmark_model.cc | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/benchmark_model.cc b/benchmark_model.cc index 316de827..45214803 100644 --- a/benchmark_model.cc +++ b/benchmark_model.cc @@ -16,7 +16,11 @@ namespace mace { namespace MACE_MODEL_TAG { -extern NetDef CreateNet(); +extern const unsigned char *LoadModelData(const char *model_data_file); + +extern void UnloadModelData(const unsigned char *model_data); + +extern NetDef CreateNet(const unsigned char *model_data); extern const std::string ModelChecksum(); @@ -164,6 +168,9 @@ DEFINE_bool(show_type, true, "whether to list stats by op type"); DEFINE_bool(show_summary, true, "whether to show a summary of the stats"); DEFINE_bool(show_flops, true, "whether to estimate the model's FLOPs"); DEFINE_int32(warmup_runs, 1, "how many runs to initialize model"); +DEFINE_string(model_data_file, + "", + "model data file name, used when EMBED_MODEL_DATA set to 0"); int Main(int argc, char **argv) { gflags::SetUsageMessage("some usage message"); @@ -212,7 +219,9 @@ int Main(int argc, char **argv) { device_type = OPENCL; } - NetDef net_def = mace::MACE_MODEL_TAG::CreateNet(); + const unsigned char *model_data = + mace::MACE_MODEL_TAG::LoadModelData(FLAGS_model_data_file.c_str()); + NetDef net_def = mace::MACE_MODEL_TAG::CreateNet(model_data); int64_t input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); @@ -235,6 +244,9 @@ int Main(int argc, char **argv) { // Init model std::cout << "Run init" << std::endl; mace::MaceEngine engine(&net_def, device_type); + if (device_type == DeviceType::OPENCL || device_type == DeviceType::HEXAGON) { + mace::MACE_MODEL_TAG::UnloadModelData(model_data); + } std::cout << "Warm up" << std::endl; -- GitLab