提交 f6f79b83 编写于 作者: W wuchenghui

fix benchmark & support global_avg_pool convert

上级 a9a76d57
...@@ -16,7 +16,11 @@ ...@@ -16,7 +16,11 @@
namespace mace { namespace mace {
namespace MACE_MODEL_TAG { namespace MACE_MODEL_TAG {
extern NetDef CreateNet(); extern const unsigned char *LoadModelData(const char *model_data_file);
extern void UnloadModelData(const unsigned char *model_data);
extern NetDef CreateNet(const unsigned char *model_data);
extern const std::string ModelChecksum(); extern const std::string ModelChecksum();
...@@ -164,6 +168,9 @@ DEFINE_bool(show_type, true, "whether to list stats by op type"); ...@@ -164,6 +168,9 @@ DEFINE_bool(show_type, true, "whether to list stats by op type");
DEFINE_bool(show_summary, true, "whether to show a summary of the stats"); DEFINE_bool(show_summary, true, "whether to show a summary of the stats");
DEFINE_bool(show_flops, true, "whether to estimate the model's FLOPs"); DEFINE_bool(show_flops, true, "whether to estimate the model's FLOPs");
DEFINE_int32(warmup_runs, 1, "how many runs to initialize model"); DEFINE_int32(warmup_runs, 1, "how many runs to initialize model");
DEFINE_string(model_data_file,
"",
"model data file name, used when EMBED_MODEL_DATA set to 0");
int Main(int argc, char **argv) { int Main(int argc, char **argv) {
gflags::SetUsageMessage("some usage message"); gflags::SetUsageMessage("some usage message");
...@@ -212,7 +219,9 @@ int Main(int argc, char **argv) { ...@@ -212,7 +219,9 @@ int Main(int argc, char **argv) {
device_type = OPENCL; device_type = OPENCL;
} }
NetDef net_def = mace::MACE_MODEL_TAG::CreateNet(); const unsigned char *model_data =
mace::MACE_MODEL_TAG::LoadModelData(FLAGS_model_data_file.c_str());
NetDef net_def = mace::MACE_MODEL_TAG::CreateNet(model_data);
int64_t input_size = std::accumulate(input_shape.begin(), int64_t input_size = std::accumulate(input_shape.begin(),
input_shape.end(), 1, std::multiplies<int64_t>()); input_shape.end(), 1, std::multiplies<int64_t>());
...@@ -235,6 +244,9 @@ int Main(int argc, char **argv) { ...@@ -235,6 +244,9 @@ int Main(int argc, char **argv) {
// Init model // Init model
std::cout << "Run init" << std::endl; std::cout << "Run init" << std::endl;
mace::MaceEngine engine(&net_def, device_type); mace::MaceEngine engine(&net_def, device_type);
if (device_type == DeviceType::OPENCL || device_type == DeviceType::HEXAGON) {
mace::MACE_MODEL_TAG::UnloadModelData(model_data);
}
std::cout << "Warm up" << std::endl; std::cout << "Warm up" << std::endl;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册