diff --git a/mace_run.cc b/mace_run.cc index 87b7b5b4ca9ffdcc25b53fa4a053538bcc45f509..5576b1a64891e4d2585a9ae364173a99646d618a 100644 --- a/mace_run.cc +++ b/mace_run.cc @@ -11,6 +11,7 @@ * --output_shape=1,224,224,2 \ * --input_file=input_data \ * --output_file=mace.out \ + * --model_data_file=model_data.data \ * --device=OPENCL */ #include @@ -31,7 +32,11 @@ using namespace mace; namespace mace { namespace MACE_MODEL_TAG { -extern NetDef CreateNet(); +extern unsigned char *LoadModelData(const char *model_data_file); + +extern void UnloadModelData(unsigned char *model_data); + +extern NetDef CreateNet(const unsigned char *model_data); extern const std::string ModelChecksum(); @@ -133,6 +138,7 @@ DEFINE_string(input_shape, "1,224,224,3", "input shape, separated by comma"); DEFINE_string(output_shape, "1,224,224,2", "output shape, separated by comma"); DEFINE_string(input_file, "", "input file name"); DEFINE_string(output_file, "", "output file name"); +DEFINE_string(model_data_file, "", "model data file name, used when EMBED_MODEL_DATA set to 0"); DEFINE_string(device, "OPENCL", "CPU/NEON/OPENCL/HEXAGON"); DEFINE_int32(round, 1, "round"); DEFINE_int32(malloc_check_cycle, -1, "malloc debug check cycle, -1 to disable"); @@ -148,6 +154,7 @@ int main(int argc, char **argv) { << "output_shape: " << FLAGS_output_shape << std::endl << "input_file: " << FLAGS_input_file << std::endl << "output_file: " << FLAGS_output_file << std::endl + << "model_data_file: " << FLAGS_model_data_file << std::endl << "device: " << FLAGS_device << std::endl << "round: " << FLAGS_round << std::endl; @@ -158,7 +165,8 @@ int main(int argc, char **argv) { // load model int64_t t0 = NowMicros(); - NetDef net_def = mace::MACE_MODEL_TAG::CreateNet(); + unsigned char *model_data = mace::MACE_MODEL_TAG::LoadModelData(FLAGS_model_data_file.c_str()); + NetDef net_def = mace::MACE_MODEL_TAG::CreateNet(model_data); int64_t t1 = NowMicros(); std::cout << "CreateNetDef duration: " << t1 - t0 << " us" << std::endl; int64_t init_micros = t1 - t0; @@ -187,6 +195,9 @@ int main(int argc, char **argv) { std::cout << "Run init" << std::endl; t0 = NowMicros(); mace::MaceEngine engine(&net_def, device_type); + if (device_type == DeviceType::OPENCL || device_type == DeviceType::HEXAGON) { + mace::MACE_MODEL_TAG::UnloadModelData(model_data); + } t1 = NowMicros(); init_micros += t1 - t0; std::cout << "Net init duration: " << t1 - t0 << " us" << std::endl;