diff --git a/mace/benchmark/benchmark_model.cc b/mace/benchmark/benchmark_model.cc index 9897015ad41e77871f7851da88c6369ef0baaffd..445b7dc382298e455d5e3eb77b9575072cb612dd 100644 --- a/mace/benchmark/benchmark_model.cc +++ b/mace/benchmark/benchmark_model.cc @@ -278,19 +278,24 @@ int Main(int argc, char **argv) { MaceStatus create_engine_status; // Create Engine std::vector model_graph_data; - if (!mace::ReadBinaryFile(&model_graph_data, FLAGS_model_file)) { - LOG(FATAL) << "Failed to read file: " << FLAGS_model_file; + if (FLAGS_model_file != "") { + if (!mace::ReadBinaryFile(&model_graph_data, FLAGS_model_file)) { + LOG(FATAL) << "Failed to read file: " << FLAGS_model_file; + } } std::vector model_weights_data; - if (!mace::ReadBinaryFile(&model_weights_data, FLAGS_model_data_file)) { - LOG(FATAL) << "Failed to read file: " << FLAGS_model_data_file; + if (FLAGS_model_data_file != "") { + if (!mace::ReadBinaryFile(&model_weights_data, FLAGS_model_data_file)) { + LOG(FATAL) << "Failed to read file: " << FLAGS_model_data_file; + } } #ifdef MODEL_GRAPH_FORMAT_CODE create_engine_status = CreateMaceEngineFromCode(FLAGS_model_name, - model_data_file_ptr, + model_weights_data.data(), + model_weights_data.size(), input_names, output_names, config, diff --git a/mace/examples/cli/example.cc b/mace/examples/cli/example.cc index 88f822acf81160b638f6f9d14a0ee2da78b4d35c..3a984f96aaf79b2b88f059b0315621b39a1f7002 100644 --- a/mace/examples/cli/example.cc +++ b/mace/examples/cli/example.cc @@ -59,6 +59,29 @@ std::vector Split(const std::string &str, char delims) { } // namespace str_util +namespace { +bool ReadBinaryFile(std::vector *data, + const std::string &filename) { + std::ifstream ifs(filename, std::ios::in | std::ios::binary); + if (!ifs.is_open()) { + return false; + } + ifs.seekg(0, ifs.end); + size_t length = ifs.tellg(); + ifs.seekg(0, ifs.beg); + + data->reserve(length); + data->insert(data->begin(), std::istreambuf_iterator(ifs), + std::istreambuf_iterator()); + if (ifs.fail()) { + return false; + } + ifs.close(); + + return true; +} +} // namespace + void ParseShape(const std::string &str, std::vector *shape) { std::string tmp = str; while (!tmp.empty()) { @@ -142,30 +165,6 @@ DEFINE_int32(gpu_priority_hint, 1, "0:DEFAULT/1:LOW/2:NORMAL/3:HIGH"); DEFINE_int32(omp_num_threads, -1, "num of openmp threads"); DEFINE_int32(cpu_affinity_policy, 1, "0:AFFINITY_NONE/1:AFFINITY_BIG_ONLY/2:AFFINITY_LITTLE_ONLY"); -#ifndef MODEL_GRAPH_FORMAT_CODE -namespace { -bool ReadBinaryFile(std::vector *data, - const std::string &filename) { - std::ifstream ifs(filename, std::ios::in | std::ios::binary); - if (!ifs.is_open()) { - return false; - } - ifs.seekg(0, ifs.end); - size_t length = ifs.tellg(); - ifs.seekg(0, ifs.beg); - - data->reserve(length); - data->insert(data->begin(), std::istreambuf_iterator(ifs), - std::istreambuf_iterator()); - if (ifs.fail()) { - return false; - } - ifs.close(); - - return true; -} -} // namespace -#endif bool RunModel(const std::vector &input_names, const std::vector> &input_shapes, @@ -212,6 +211,16 @@ bool RunModel(const std::vector &input_names, // Create Engine std::shared_ptr engine; MaceStatus create_engine_status; + + std::vector model_graph_data; + if (!ReadBinaryFile(&model_graph_data, FLAGS_model_file)) { + std::cerr << "Failed to read file: " << FLAGS_model_file << std::endl; + } + std::vector model_weights_data; + if (!ReadBinaryFile(&model_weights_data, FLAGS_model_data_file)) { + std::cerr << "Failed to read file: " << FLAGS_model_data_file << std::endl; + } + // Only choose one of the two type based on the `model_graph_format` // in model deployment file(.yml). #ifdef MODEL_GRAPH_FORMAT_CODE @@ -219,20 +228,13 @@ bool RunModel(const std::vector &input_names, // to model_data_file parameter. create_engine_status = CreateMaceEngineFromCode(FLAGS_model_name, - FLAGS_model_data_file, + model_weights_data.data(), + model_weights_data.size(), input_names, output_names, config, &engine); #else - std::vector model_graph_data; - if (!ReadBinaryFile(&model_graph_data, FLAGS_model_file)) { - std::cerr << "Failed to read file: " << FLAGS_model_file << std::endl; - } - std::vector model_weights_data; - if (!ReadBinaryFile(&model_weights_data, FLAGS_model_data_file)) { - std::cerr << "Failed to read file: " << FLAGS_model_data_file << std::endl; - } create_engine_status = CreateMaceEngineFromProto(model_graph_data.data(), model_graph_data.size(), diff --git a/mace/python/tools/mace_engine_factory.h.jinja2 b/mace/python/tools/mace_engine_factory.h.jinja2 index d540c49d375f4a913aaa93f8e8e7048e9b9cb7c2..e1502389118955c0d67c23d160de67e080c2a52d 100644 --- a/mace/python/tools/mace_engine_factory.h.jinja2 +++ b/mace/python/tools/mace_engine_factory.h.jinja2 @@ -62,7 +62,7 @@ std::map model_name_map { /// \param engine[out]: output MaceEngine object /// \return MaceStatus::MACE_SUCCESS for success, MACE_INVALID_ARGS for wrong arguments, /// MACE_OUT_OF_RESOURCES for resources is out of range. -MaceStatus CreateMaceEngineFromCode( +__attribute__((deprecated)) MaceStatus CreateMaceEngineFromCode( const std::string &model_name, const std::string &model_data_file, const std::vector &input_nodes, @@ -101,5 +101,48 @@ MaceStatus CreateMaceEngineFromCode( return status; } +MaceStatus CreateMaceEngineFromCode( + const std::string &model_name, + const unsigned char *model_weights_data, + const size_t model_weights_data_size, + const std::vector &input_nodes, + const std::vector &output_nodes, + const MaceEngineConfig &config, + std::shared_ptr *engine) { + // load model + if (engine == nullptr) { + return MaceStatus::MACE_INVALID_ARGS; + } + std::shared_ptr net_def; +{% if embed_model_data %} + const unsigned char * model_data; + (void)model_weights_data; +{% endif %} + // TODO(yejianwu) Add buffer range checking + (void)model_weights_data_size; + + MaceStatus status = MaceStatus::MACE_SUCCESS; + switch (model_name_map[model_name]) { +{% for i in range(model_tags |length) %} + case {{ i }}: + net_def = mace::{{model_tags[i]}}::CreateNet(); + engine->reset(new mace::MaceEngine(config)); +{% if embed_model_data %} + model_data = mace::{{model_tags[i]}}::LoadModelData(); + status = (*engine)->Init(net_def.get(), input_nodes, output_nodes, + model_data); +{% else %} + status = (*engine)->Init(net_def.get(), input_nodes, output_nodes, + model_weights_data); +{% endif %} + break; +{% endfor %} + default: + status = MaceStatus::MACE_INVALID_ARGS; + } + + return status; +} + } // namespace mace #endif // MACE_CODEGEN_ENGINE_MACE_ENGINE_FACTORY_H_ diff --git a/mace/tools/validation/mace_run.cc b/mace/tools/validation/mace_run.cc index 53e98f118f9e6274c6d485a48eaefd334062c73e..3e1c88741ab27ef6c2f3ed5144765448dee3ea00 100644 --- a/mace/tools/validation/mace_run.cc +++ b/mace/tools/validation/mace_run.cc @@ -264,7 +264,8 @@ bool RunModel(const std::string &model_name, #ifdef MODEL_GRAPH_FORMAT_CODE create_engine_status = CreateMaceEngineFromCode(model_name, - FLAGS_model_data_file, + model_weights_data.data(), + model_weights_data.size(), input_names, output_names, config, @@ -340,7 +341,8 @@ bool RunModel(const std::string &model_name, #ifdef MODEL_GRAPH_FORMAT_CODE create_engine_status = CreateMaceEngineFromCode(model_name, - FLAGS_model_data_file, + model_weights_data.data(), + model_weights_data.size(), input_names, output_names, config, @@ -382,7 +384,8 @@ bool RunModel(const std::string &model_name, #ifdef MODEL_GRAPH_FORMAT_CODE create_engine_status = CreateMaceEngineFromCode(model_name, - FLAGS_model_data_file, + model_weights_data.data(), + model_weights_data.size(), input_names, output_names, config, diff --git a/tools/device.py b/tools/device.py index 117dc509179ffba3b6c07a6dd3296dbe7430a502..d90e1907fe4c85c7642bc70bd69e118a5e1eaa4f 100644 --- a/tools/device.py +++ b/tools/device.py @@ -192,6 +192,14 @@ class DeviceWrapper: if model_graph_format == ModelFormat.file: mace_model_path = layers_validate_file if layers_validate_file \ else "%s/%s.pb" % (mace_model_dir, model_tag) + + model_data_file = "" + if not embed_model_data: + if self.system == SystemType.host: + model_data_file = "%s/%s.data" % (mace_model_dir, model_tag) + else: + model_data_file = "%s/%s.data" % (self.data_dir, model_tag) + if self.system == SystemType.host: libmace_dynamic_lib_path = \ os.path.dirname(libmace_dynamic_library_path) @@ -214,8 +222,7 @@ class DeviceWrapper: output_file_name), "--input_dir=%s" % input_dir, "--output_dir=%s" % output_dir, - "--model_data_file=%s/%s.data" % (mace_model_dir, - model_tag), + "--model_data_file=%s" % model_data_file, "--device=%s" % device_type, "--round=%s" % running_round, "--restart_round=%s" % restart_round, @@ -229,7 +236,7 @@ class DeviceWrapper: stdout=subprocess.PIPE) out, err = p.communicate() self.stdout = err + out - six.print_(self.stdout) + six.print_(self.stdout.decode('UTF-8')) six.print_("Running finished!\n") elif self.system in [SystemType.android, SystemType.arm_linux]: self.rm(self.data_dir) @@ -304,7 +311,7 @@ class DeviceWrapper: "--output_file=%s/%s" % (self.data_dir, output_file_name), "--input_dir=%s" % input_dir, "--output_dir=%s" % output_dir, - "--model_data_file=%s/%s.data" % (self.data_dir, model_tag), + "--model_data_file=%s" % model_data_file, "--device=%s" % device_type, "--round=%s" % running_round, "--restart_round=%s" % restart_round, @@ -753,6 +760,14 @@ class DeviceWrapper: mace_model_path = '' if model_graph_format == ModelFormat.file: mace_model_path = '%s/%s.pb' % (mace_model_dir, model_tag) + + model_data_file = "" + if not embed_model_data: + if self.system == SystemType.host: + model_data_file = "%s/%s.data" % (mace_model_dir, model_tag) + else: + model_data_file = "%s/%s.data" % (self.data_dir, model_tag) + if abi == ABIType.host: libmace_dynamic_lib_dir_path = \ os.path.dirname(libmace_dynamic_library_path) @@ -768,8 +783,7 @@ class DeviceWrapper: '--input_shape=%s' % ':'.join(input_shapes), '--output_shape=%s' % ':'.join(output_shapes), '--input_file=%s/%s' % (model_output_dir, input_file_name), - '--model_data_file=%s/%s.data' % (mace_model_dir, - model_tag), + "--model_data_file=%s" % model_data_file, '--device=%s' % device_type, '--omp_num_threads=%s' % omp_num_threads, '--cpu_affinity_policy=%s' % cpu_affinity_policy, @@ -822,7 +836,7 @@ class DeviceWrapper: '--input_shape=%s' % ':'.join(input_shapes), '--output_shape=%s' % ':'.join(output_shapes), '--input_file=%s/%s' % (self.data_dir, input_file_name), - '--model_data_file=%s/%s.data' % (self.data_dir, model_tag), + "--model_data_file=%s" % model_data_file, '--device=%s' % device_type, '--omp_num_threads=%s' % omp_num_threads, '--cpu_affinity_policy=%s' % cpu_affinity_policy,