diff --git a/docs/user_guide/basic_usage.rst b/docs/user_guide/basic_usage.rst index c9bde6fccabe0ed91f971452b22b46d4f0862366..051dc7094496a11174559aa4f57916d8d6039cd1 100644 --- a/docs/user_guide/basic_usage.rst +++ b/docs/user_guide/basic_usage.rst @@ -344,8 +344,10 @@ Please refer to \ ``mace/examples/example.cc``\ for full usage. The following li // Create Engine from model file create_engine_status = - CreateMaceEngineFromProto(model_pb_data, - model_data_file.c_str(), + CreateMaceEngineFromProto(model_graph_proto, + model_graph_proto_size, + model_weights_data, + model_weights_data_size, input_names, output_names, device_type, diff --git a/mace/benchmark/benchmark_model.cc b/mace/benchmark/benchmark_model.cc index 82849d6761d6583b2dbc8e685a7c55cdae2011a9..9897015ad41e77871f7851da88c6369ef0baaffd 100644 --- a/mace/benchmark/benchmark_model.cc +++ b/mace/benchmark/benchmark_model.cc @@ -277,15 +277,16 @@ int Main(int argc, char **argv) { std::shared_ptr engine; MaceStatus create_engine_status; // Create Engine - const char *model_data_file_ptr = - FLAGS_model_data_file.empty() ? nullptr : FLAGS_model_data_file.c_str(); + std::vector model_graph_data; + if (!mace::ReadBinaryFile(&model_graph_data, FLAGS_model_file)) { + LOG(FATAL) << "Failed to read file: " << FLAGS_model_file; + } - std::vector model_pb_data; - if (FLAGS_model_file != "") { - if (!mace::ReadBinaryFile(&model_pb_data, FLAGS_model_file)) { - LOG(FATAL) << "Failed to read file: " << FLAGS_model_file; - } + std::vector model_weights_data; + if (!mace::ReadBinaryFile(&model_weights_data, FLAGS_model_data_file)) { + LOG(FATAL) << "Failed to read file: " << FLAGS_model_data_file; } + #ifdef MODEL_GRAPH_FORMAT_CODE create_engine_status = CreateMaceEngineFromCode(FLAGS_model_name, @@ -296,8 +297,10 @@ int Main(int argc, char **argv) { &engine); #else create_engine_status = - CreateMaceEngineFromProto(model_pb_data, - model_data_file_ptr, + CreateMaceEngineFromProto(model_graph_data.data(), + model_graph_data.size(), + model_weights_data.data(), + model_weights_data.size(), input_names, output_names, config, diff --git a/mace/examples/cli/example.cc b/mace/examples/cli/example.cc index 181c7c8d15342ff67140beec1817fc29d43778a9..88f822acf81160b638f6f9d14a0ee2da78b4d35c 100644 --- a/mace/examples/cli/example.cc +++ b/mace/examples/cli/example.cc @@ -225,13 +225,19 @@ bool RunModel(const std::vector &input_names, config, &engine); #else - std::vector model_pb_data; - if (!ReadBinaryFile(&model_pb_data, FLAGS_model_file)) { + std::vector model_graph_data; + if (!ReadBinaryFile(&model_graph_data, FLAGS_model_file)) { std::cerr << "Failed to read file: " << FLAGS_model_file << std::endl; } + std::vector model_weights_data; + if (!ReadBinaryFile(&model_weights_data, FLAGS_model_data_file)) { + std::cerr << "Failed to read file: " << FLAGS_model_data_file << std::endl; + } create_engine_status = - CreateMaceEngineFromProto(model_pb_data, - FLAGS_model_data_file, + CreateMaceEngineFromProto(model_graph_data.data(), + model_graph_data.size(), + model_weights_data.data(), + model_weights_data.size(), input_names, output_names, config, diff --git a/mace/libmace/mace.cc b/mace/libmace/mace.cc index 73c51fff58b5d1e2fa68fda5afba1c2267af6521..e9256b9fdd88813fbcadd8ce211c328e6fceb1c2 100644 --- a/mace/libmace/mace.cc +++ b/mace/libmace/mace.cc @@ -562,7 +562,7 @@ MaceStatus MaceEngine::Impl::Init( MaceEngine::Impl::~Impl() { LOG(INFO) << "Destroying MaceEngine"; - if (device_type_ == DeviceType::CPU && model_data_ != nullptr) { + if (model_data_ != nullptr) { UnloadModelData(model_data_, model_data_size_); } #ifdef MACE_ENABLE_HEXAGON @@ -771,6 +771,34 @@ MaceStatus MaceEngine::Run(const std::map &inputs, return impl_->Run(inputs, outputs, nullptr); } +MaceStatus CreateMaceEngineFromProto( + const unsigned char *model_graph_proto, + const size_t model_graph_proto_size, + const unsigned char *model_weights_data, + const size_t model_weights_data_size, + const std::vector &input_nodes, + const std::vector &output_nodes, + const MaceEngineConfig &config, + std::shared_ptr *engine) { + // TODO(heliangliang) Add buffer range checking + MACE_UNUSED(model_weights_data_size); + LOG(INFO) << "Create MaceEngine from model graph proto and weights data"; + + if (engine == nullptr) { + return MaceStatus::MACE_INVALID_ARGS; + } + + auto net_def = std::make_shared(); + net_def->ParseFromArray(model_graph_proto, model_graph_proto_size); + + engine->reset(new mace::MaceEngine(config)); + MaceStatus status = (*engine)->Init( + net_def.get(), input_nodes, output_nodes, model_weights_data); + + return status; +} + +// Deprecated, will be removed in future version. MaceStatus CreateMaceEngineFromProto( const std::vector &model_pb, const std::string &model_data_file, @@ -779,6 +807,7 @@ MaceStatus CreateMaceEngineFromProto( const MaceEngineConfig &config, std::shared_ptr *engine) { LOG(INFO) << "Create MaceEngine from model pb"; + LOG(WARNING) << "Function deprecated, please change to the new API"; // load model if (engine == nullptr) { return MaceStatus::MACE_INVALID_ARGS; diff --git a/mace/public/mace.h b/mace/public/mace.h index 854df1377abc77777cf72d72768bf821c779c93e..912867f74b60b613439a7f545e0b2d2fab335454 100644 --- a/mace/public/mace.h +++ b/mace/public/mace.h @@ -341,7 +341,36 @@ class MACE_API MaceEngine { MaceEngine &operator=(const MaceEngine &) = delete; }; +/// \brief Create MaceEngine from model graph proto and weights data +/// +/// Create MaceEngine object +/// +/// \param model_graph_proto[in]: the content of model graph proto +/// \param model_graph_proto_size[in]: the size of model graph proto +/// \param model_weights_data[in]: the content of model weights data, the +/// returned engine will refer to this buffer +/// if CPU runtime is used. In this case, the +/// buffer should keep alive. +/// \param model_weights_data_size[in]: the size of model weights data +/// \param input_nodes[in]: the array of input nodes' name +/// \param output_nodes[in]: the array of output nodes' name +/// \param config[in]: configurations for MaceEngine. +/// \param engine[out]: output MaceEngine object +/// \return MaceStatus::MACE_SUCCESS for success, +/// MaceStatus::MACE_INVALID_ARGS for wrong arguments, +/// MaceStatus::MACE_OUT_OF_RESOURCES for resources is out of range. +MACE_API MaceStatus CreateMaceEngineFromProto( + const unsigned char *model_graph_proto, + const size_t model_graph_proto_size, + const unsigned char *model_weights_data, + const size_t model_weights_data_size, + const std::vector &input_nodes, + const std::vector &output_nodes, + const MaceEngineConfig &config, + std::shared_ptr *engine); + /// \brief Create MaceEngine from files (model file + data file) +/// Deprecated, will be removed in future version /// /// Create MaceEngine object /// @@ -360,7 +389,7 @@ MACE_API MaceStatus CreateMaceEngineFromProto( const std::vector &input_nodes, const std::vector &output_nodes, const MaceEngineConfig &config, - std::shared_ptr *engine); + std::shared_ptr *engine) __attribute__((deprecated)); } // namespace mace diff --git a/mace/tools/validation/mace_run.cc b/mace/tools/validation/mace_run.cc index cfd3c4a46b5021193ca3c284cafeed1312cb5a96..53e98f118f9e6274c6d485a48eaefd334062c73e 100644 --- a/mace/tools/validation/mace_run.cc +++ b/mace/tools/validation/mace_run.cc @@ -240,13 +240,20 @@ bool RunModel(const std::string &model_name, } #endif // MACE_ENABLE_OPENCL - std::vector model_pb_data; + std::vector model_graph_data; if (FLAGS_model_file != "") { - if (!mace::ReadBinaryFile(&model_pb_data, FLAGS_model_file)) { + if (!mace::ReadBinaryFile(&model_graph_data, FLAGS_model_file)) { LOG(FATAL) << "Failed to read file: " << FLAGS_model_file; } } + std::vector model_weights_data; + if (FLAGS_model_data_file != "") { + if (!mace::ReadBinaryFile(&model_weights_data, FLAGS_model_data_file)) { + LOG(FATAL) << "Failed to read file: " << FLAGS_model_data_file; + } + } + std::shared_ptr engine; MaceStatus create_engine_status; @@ -265,8 +272,10 @@ bool RunModel(const std::string &model_name, #else (void)(model_name); create_engine_status = - CreateMaceEngineFromProto(model_pb_data, - FLAGS_model_data_file, + CreateMaceEngineFromProto(model_graph_data.data(), + model_graph_data.size(), + model_weights_data.data(), + model_weights_data.size(), input_names, output_names, config, @@ -338,8 +347,10 @@ bool RunModel(const std::string &model_name, &engine); #else create_engine_status = - CreateMaceEngineFromProto(model_pb_data, - FLAGS_model_data_file, + CreateMaceEngineFromProto(model_graph_data.data(), + model_graph_data.size(), + model_weights_data.data(), + model_weights_data.size(), input_names, output_names, config, @@ -378,8 +389,10 @@ bool RunModel(const std::string &model_name, &engine); #else create_engine_status = - CreateMaceEngineFromProto(model_pb_data, - FLAGS_model_data_file, + CreateMaceEngineFromProto(model_graph_data.data(), + model_graph_data.size(), + model_weights_data.data(), + model_weights_data.size(), input_names, output_names, config,