From d7713c1328c196de4eb0d61ef77f1f072aa57905 Mon Sep 17 00:00:00 2001 From: Liangliang He Date: Fri, 28 Dec 2018 16:54:18 +0800 Subject: [PATCH] Update CreateMaceEngineFromProto API --- docs/user_guide/basic_usage.rst | 6 ++++-- mace/benchmark/benchmark_model.cc | 21 ++++++++++++--------- mace/examples/cli/example.cc | 14 ++++++++++---- mace/libmace/mace.cc | 31 ++++++++++++++++++++++++++++++- mace/public/mace.h | 31 ++++++++++++++++++++++++++++++- mace/tools/validation/mace_run.cc | 29 +++++++++++++++++++++-------- 6 files changed, 107 insertions(+), 25 deletions(-) diff --git a/docs/user_guide/basic_usage.rst b/docs/user_guide/basic_usage.rst index c9bde6fc..051dc709 100644 --- a/docs/user_guide/basic_usage.rst +++ b/docs/user_guide/basic_usage.rst @@ -344,8 +344,10 @@ Please refer to \ ``mace/examples/example.cc``\ for full usage. The following li // Create Engine from model file create_engine_status = - CreateMaceEngineFromProto(model_pb_data, - model_data_file.c_str(), + CreateMaceEngineFromProto(model_graph_proto, + model_graph_proto_size, + model_weights_data, + model_weights_data_size, input_names, output_names, device_type, diff --git a/mace/benchmark/benchmark_model.cc b/mace/benchmark/benchmark_model.cc index bcb9ae75..961f2be3 100644 --- a/mace/benchmark/benchmark_model.cc +++ b/mace/benchmark/benchmark_model.cc @@ -277,15 +277,16 @@ int Main(int argc, char **argv) { std::shared_ptr engine; MaceStatus create_engine_status; // Create Engine - const char *model_data_file_ptr = - FLAGS_model_data_file.empty() ? nullptr : FLAGS_model_data_file.c_str(); + std::vector model_graph_data; + if (!mace::ReadBinaryFile(&model_graph_data, FLAGS_model_file)) { + LOG(FATAL) << "Failed to read file: " << FLAGS_model_file; + } - std::vector model_pb_data; - if (FLAGS_model_file != "") { - if (!mace::ReadBinaryFile(&model_pb_data, FLAGS_model_file)) { - LOG(FATAL) << "Failed to read file: " << FLAGS_model_file; - } + std::vector model_weights_data; + if (!mace::ReadBinaryFile(&model_weights_data, FLAGS_model_data_file)) { + LOG(FATAL) << "Failed to read file: " << FLAGS_model_data_file; } + #ifdef MODEL_GRAPH_FORMAT_CODE create_engine_status = CreateMaceEngineFromCode(FLAGS_model_name, @@ -296,8 +297,10 @@ int Main(int argc, char **argv) { &engine); #else create_engine_status = - CreateMaceEngineFromProto(model_pb_data, - model_data_file_ptr, + CreateMaceEngineFromProto(model_graph_data.data(), + model_graph_data.size(), + model_weights_data.data(), + model_weights_data.size(), input_names, output_names, config, diff --git a/mace/examples/cli/example.cc b/mace/examples/cli/example.cc index 73adbb75..92cfeb7c 100644 --- a/mace/examples/cli/example.cc +++ b/mace/examples/cli/example.cc @@ -213,13 +213,19 @@ bool RunModel(const std::vector &input_names, config, &engine); #else - std::vector model_pb_data; - if (!ReadBinaryFile(&model_pb_data, FLAGS_model_file)) { + std::vector model_graph_data; + if (!ReadBinaryFile(&model_graph_data, FLAGS_model_file)) { std::cerr << "Failed to read file: " << FLAGS_model_file << std::endl; } + std::vector model_weights_data; + if (!ReadBinaryFile(&model_weights_data, FLAGS_model_data_file)) { + std::cerr << "Failed to read file: " << FLAGS_model_data_file << std::endl; + } create_engine_status = - CreateMaceEngineFromProto(model_pb_data, - FLAGS_model_data_file, + CreateMaceEngineFromProto(model_graph_data.data(), + model_graph_data.size(), + model_weights_data.data(), + model_weights_data.size(), input_names, output_names, config, diff --git a/mace/libmace/mace.cc b/mace/libmace/mace.cc index 047cdf8e..7ca217eb 100644 --- a/mace/libmace/mace.cc +++ b/mace/libmace/mace.cc @@ -520,7 +520,7 @@ MaceStatus MaceEngine::Impl::Init( MaceEngine::Impl::~Impl() { LOG(INFO) << "Destroying MaceEngine"; - if (device_type_ == DeviceType::CPU && model_data_ != nullptr) { + if (model_data_ != nullptr) { UnloadModelData(model_data_, model_data_size_); } #ifdef MACE_ENABLE_HEXAGON @@ -729,6 +729,34 @@ MaceStatus MaceEngine::Run(const std::map &inputs, return impl_->Run(inputs, outputs, nullptr); } +MaceStatus CreateMaceEngineFromProto( + const unsigned char *model_graph_proto, + const size_t model_graph_proto_size, + const unsigned char *model_weights_data, + const size_t model_weights_data_size, + const std::vector &input_nodes, + const std::vector &output_nodes, + const MaceEngineConfig &config, + std::shared_ptr *engine) { + // TODO(heliangliang) Add buffer range checking + MACE_UNUSED(model_weights_data_size); + LOG(INFO) << "Create MaceEngine from model graph proto and weights data"; + + if (engine == nullptr) { + return MaceStatus::MACE_INVALID_ARGS; + } + + auto net_def = std::make_shared(); + net_def->ParseFromArray(model_graph_proto, model_graph_proto_size); + + engine->reset(new mace::MaceEngine(config)); + MaceStatus status = (*engine)->Init( + net_def.get(), input_nodes, output_nodes, model_weights_data); + + return status; +} + +// Deprecated, will be removed in future version. MaceStatus CreateMaceEngineFromProto( const std::vector &model_pb, const std::string &model_data_file, @@ -737,6 +765,7 @@ MaceStatus CreateMaceEngineFromProto( const MaceEngineConfig &config, std::shared_ptr *engine) { LOG(INFO) << "Create MaceEngine from model pb"; + LOG(WARNING) << "Function deprecated, please change to the new API"; // load model if (engine == nullptr) { return MaceStatus::MACE_INVALID_ARGS; diff --git a/mace/public/mace.h b/mace/public/mace.h index a7b2a13e..3d81677e 100644 --- a/mace/public/mace.h +++ b/mace/public/mace.h @@ -318,7 +318,36 @@ class MACE_API MaceEngine { MaceEngine &operator=(const MaceEngine &) = delete; }; +/// \brief Create MaceEngine from model graph proto and weights data +/// +/// Create MaceEngine object +/// +/// \param model_graph_proto[in]: the content of model graph proto +/// \param model_graph_proto_size[in]: the size of model graph proto +/// \param model_weights_data[in]: the content of model weights data, the +/// returned engine will refer to this buffer +/// if CPU runtime is used. In this case, the +/// buffer should keep alive. +/// \param model_weights_data_size[in]: the size of model weights data +/// \param input_nodes[in]: the array of input nodes' name +/// \param output_nodes[in]: the array of output nodes' name +/// \param config[in]: configurations for MaceEngine. +/// \param engine[out]: output MaceEngine object +/// \return MaceStatus::MACE_SUCCESS for success, +/// MaceStatus::MACE_INVALID_ARGS for wrong arguments, +/// MaceStatus::MACE_OUT_OF_RESOURCES for resources is out of range. +MACE_API MaceStatus CreateMaceEngineFromProto( + const unsigned char *model_graph_proto, + const size_t model_graph_proto_size, + const unsigned char *model_weights_data, + const size_t model_weights_data_size, + const std::vector &input_nodes, + const std::vector &output_nodes, + const MaceEngineConfig &config, + std::shared_ptr *engine); + /// \brief Create MaceEngine from files (model file + data file) +/// Deprecated, will be removed in future version /// /// Create MaceEngine object /// @@ -337,7 +366,7 @@ MACE_API MaceStatus CreateMaceEngineFromProto( const std::vector &input_nodes, const std::vector &output_nodes, const MaceEngineConfig &config, - std::shared_ptr *engine); + std::shared_ptr *engine) __attribute__((deprecated)); } // namespace mace diff --git a/mace/tools/validation/mace_run.cc b/mace/tools/validation/mace_run.cc index 08ffdebe..0403a53c 100644 --- a/mace/tools/validation/mace_run.cc +++ b/mace/tools/validation/mace_run.cc @@ -240,13 +240,20 @@ bool RunModel(const std::string &model_name, } #endif // MACE_ENABLE_OPENCL - std::vector model_pb_data; + std::vector model_graph_data; if (FLAGS_model_file != "") { - if (!mace::ReadBinaryFile(&model_pb_data, FLAGS_model_file)) { + if (!mace::ReadBinaryFile(&model_graph_data, FLAGS_model_file)) { LOG(FATAL) << "Failed to read file: " << FLAGS_model_file; } } + std::vector model_weights_data; + if (FLAGS_model_data_file != "") { + if (!mace::ReadBinaryFile(&model_weights_data, FLAGS_model_data_file)) { + LOG(FATAL) << "Failed to read file: " << FLAGS_model_data_file; + } + } + std::shared_ptr engine; MaceStatus create_engine_status; @@ -265,8 +272,10 @@ bool RunModel(const std::string &model_name, #else (void)(model_name); create_engine_status = - CreateMaceEngineFromProto(model_pb_data, - FLAGS_model_data_file, + CreateMaceEngineFromProto(model_graph_data.data(), + model_graph_data.size(), + model_weights_data.data(), + model_weights_data.size(), input_names, output_names, config, @@ -338,8 +347,10 @@ bool RunModel(const std::string &model_name, &engine); #else create_engine_status = - CreateMaceEngineFromProto(model_pb_data, - FLAGS_model_data_file, + CreateMaceEngineFromProto(model_graph_data.data(), + model_graph_data.size(), + model_weights_data.data(), + model_weights_data.size(), input_names, output_names, config, @@ -378,8 +389,10 @@ bool RunModel(const std::string &model_name, &engine); #else create_engine_status = - CreateMaceEngineFromProto(model_pb_data, - FLAGS_model_data_file, + CreateMaceEngineFromProto(model_graph_data.data(), + model_graph_data.size(), + model_weights_data.data(), + model_weights_data.size(), input_names, output_names, config, -- GitLab