提交 d7713c13 编写于 作者: L Liangliang He

Update CreateMaceEngineFromProto API

上级 8a468a9f
......@@ -344,8 +344,10 @@ Please refer to \ ``mace/examples/example.cc``\ for full usage. The following li
// Create Engine from model file
create_engine_status =
CreateMaceEngineFromProto(model_pb_data,
model_data_file.c_str(),
CreateMaceEngineFromProto(model_graph_proto,
model_graph_proto_size,
model_weights_data,
model_weights_data_size,
input_names,
output_names,
device_type,
......
......@@ -277,15 +277,16 @@ int Main(int argc, char **argv) {
std::shared_ptr<mace::MaceEngine> engine;
MaceStatus create_engine_status;
// Create Engine
const char *model_data_file_ptr =
FLAGS_model_data_file.empty() ? nullptr : FLAGS_model_data_file.c_str();
std::vector<unsigned char> model_graph_data;
if (!mace::ReadBinaryFile(&model_graph_data, FLAGS_model_file)) {
LOG(FATAL) << "Failed to read file: " << FLAGS_model_file;
}
std::vector<unsigned char> model_pb_data;
if (FLAGS_model_file != "") {
if (!mace::ReadBinaryFile(&model_pb_data, FLAGS_model_file)) {
LOG(FATAL) << "Failed to read file: " << FLAGS_model_file;
}
std::vector<unsigned char> model_weights_data;
if (!mace::ReadBinaryFile(&model_weights_data, FLAGS_model_data_file)) {
LOG(FATAL) << "Failed to read file: " << FLAGS_model_data_file;
}
#ifdef MODEL_GRAPH_FORMAT_CODE
create_engine_status =
CreateMaceEngineFromCode(FLAGS_model_name,
......@@ -296,8 +297,10 @@ int Main(int argc, char **argv) {
&engine);
#else
create_engine_status =
CreateMaceEngineFromProto(model_pb_data,
model_data_file_ptr,
CreateMaceEngineFromProto(model_graph_data.data(),
model_graph_data.size(),
model_weights_data.data(),
model_weights_data.size(),
input_names,
output_names,
config,
......
......@@ -213,13 +213,19 @@ bool RunModel(const std::vector<std::string> &input_names,
config,
&engine);
#else
std::vector<unsigned char> model_pb_data;
if (!ReadBinaryFile(&model_pb_data, FLAGS_model_file)) {
std::vector<unsigned char> model_graph_data;
if (!ReadBinaryFile(&model_graph_data, FLAGS_model_file)) {
std::cerr << "Failed to read file: " << FLAGS_model_file << std::endl;
}
std::vector<unsigned char> model_weights_data;
if (!ReadBinaryFile(&model_weights_data, FLAGS_model_data_file)) {
std::cerr << "Failed to read file: " << FLAGS_model_data_file << std::endl;
}
create_engine_status =
CreateMaceEngineFromProto(model_pb_data,
FLAGS_model_data_file,
CreateMaceEngineFromProto(model_graph_data.data(),
model_graph_data.size(),
model_weights_data.data(),
model_weights_data.size(),
input_names,
output_names,
config,
......
......@@ -520,7 +520,7 @@ MaceStatus MaceEngine::Impl::Init(
MaceEngine::Impl::~Impl() {
LOG(INFO) << "Destroying MaceEngine";
if (device_type_ == DeviceType::CPU && model_data_ != nullptr) {
if (model_data_ != nullptr) {
UnloadModelData(model_data_, model_data_size_);
}
#ifdef MACE_ENABLE_HEXAGON
......@@ -729,6 +729,34 @@ MaceStatus MaceEngine::Run(const std::map<std::string, MaceTensor> &inputs,
return impl_->Run(inputs, outputs, nullptr);
}
MaceStatus CreateMaceEngineFromProto(
const unsigned char *model_graph_proto,
const size_t model_graph_proto_size,
const unsigned char *model_weights_data,
const size_t model_weights_data_size,
const std::vector<std::string> &input_nodes,
const std::vector<std::string> &output_nodes,
const MaceEngineConfig &config,
std::shared_ptr<MaceEngine> *engine) {
// TODO(heliangliang) Add buffer range checking
MACE_UNUSED(model_weights_data_size);
LOG(INFO) << "Create MaceEngine from model graph proto and weights data";
if (engine == nullptr) {
return MaceStatus::MACE_INVALID_ARGS;
}
auto net_def = std::make_shared<NetDef>();
net_def->ParseFromArray(model_graph_proto, model_graph_proto_size);
engine->reset(new mace::MaceEngine(config));
MaceStatus status = (*engine)->Init(
net_def.get(), input_nodes, output_nodes, model_weights_data);
return status;
}
// Deprecated, will be removed in future version.
MaceStatus CreateMaceEngineFromProto(
const std::vector<unsigned char> &model_pb,
const std::string &model_data_file,
......@@ -737,6 +765,7 @@ MaceStatus CreateMaceEngineFromProto(
const MaceEngineConfig &config,
std::shared_ptr<MaceEngine> *engine) {
LOG(INFO) << "Create MaceEngine from model pb";
LOG(WARNING) << "Function deprecated, please change to the new API";
// load model
if (engine == nullptr) {
return MaceStatus::MACE_INVALID_ARGS;
......
......@@ -318,7 +318,36 @@ class MACE_API MaceEngine {
MaceEngine &operator=(const MaceEngine &) = delete;
};
/// \brief Create MaceEngine from model graph proto and weights data
///
/// Create MaceEngine object
///
/// \param model_graph_proto[in]: the content of model graph proto
/// \param model_graph_proto_size[in]: the size of model graph proto
/// \param model_weights_data[in]: the content of model weights data, the
/// returned engine will refer to this buffer
/// if CPU runtime is used. In this case, the
/// buffer should keep alive.
/// \param model_weights_data_size[in]: the size of model weights data
/// \param input_nodes[in]: the array of input nodes' name
/// \param output_nodes[in]: the array of output nodes' name
/// \param config[in]: configurations for MaceEngine.
/// \param engine[out]: output MaceEngine object
/// \return MaceStatus::MACE_SUCCESS for success,
/// MaceStatus::MACE_INVALID_ARGS for wrong arguments,
/// MaceStatus::MACE_OUT_OF_RESOURCES for resources is out of range.
MACE_API MaceStatus CreateMaceEngineFromProto(
const unsigned char *model_graph_proto,
const size_t model_graph_proto_size,
const unsigned char *model_weights_data,
const size_t model_weights_data_size,
const std::vector<std::string> &input_nodes,
const std::vector<std::string> &output_nodes,
const MaceEngineConfig &config,
std::shared_ptr<MaceEngine> *engine);
/// \brief Create MaceEngine from files (model file + data file)
/// Deprecated, will be removed in future version
///
/// Create MaceEngine object
///
......@@ -337,7 +366,7 @@ MACE_API MaceStatus CreateMaceEngineFromProto(
const std::vector<std::string> &input_nodes,
const std::vector<std::string> &output_nodes,
const MaceEngineConfig &config,
std::shared_ptr<MaceEngine> *engine);
std::shared_ptr<MaceEngine> *engine) __attribute__((deprecated));
} // namespace mace
......
......@@ -240,13 +240,20 @@ bool RunModel(const std::string &model_name,
}
#endif // MACE_ENABLE_OPENCL
std::vector<unsigned char> model_pb_data;
std::vector<unsigned char> model_graph_data;
if (FLAGS_model_file != "") {
if (!mace::ReadBinaryFile(&model_pb_data, FLAGS_model_file)) {
if (!mace::ReadBinaryFile(&model_graph_data, FLAGS_model_file)) {
LOG(FATAL) << "Failed to read file: " << FLAGS_model_file;
}
}
std::vector<unsigned char> model_weights_data;
if (FLAGS_model_data_file != "") {
if (!mace::ReadBinaryFile(&model_weights_data, FLAGS_model_data_file)) {
LOG(FATAL) << "Failed to read file: " << FLAGS_model_data_file;
}
}
std::shared_ptr<mace::MaceEngine> engine;
MaceStatus create_engine_status;
......@@ -265,8 +272,10 @@ bool RunModel(const std::string &model_name,
#else
(void)(model_name);
create_engine_status =
CreateMaceEngineFromProto(model_pb_data,
FLAGS_model_data_file,
CreateMaceEngineFromProto(model_graph_data.data(),
model_graph_data.size(),
model_weights_data.data(),
model_weights_data.size(),
input_names,
output_names,
config,
......@@ -338,8 +347,10 @@ bool RunModel(const std::string &model_name,
&engine);
#else
create_engine_status =
CreateMaceEngineFromProto(model_pb_data,
FLAGS_model_data_file,
CreateMaceEngineFromProto(model_graph_data.data(),
model_graph_data.size(),
model_weights_data.data(),
model_weights_data.size(),
input_names,
output_names,
config,
......@@ -378,8 +389,10 @@ bool RunModel(const std::string &model_name,
&engine);
#else
create_engine_status =
CreateMaceEngineFromProto(model_pb_data,
FLAGS_model_data_file,
CreateMaceEngineFromProto(model_graph_data.data(),
model_graph_data.size(),
model_weights_data.data(),
model_weights_data.size(),
input_names,
output_names,
config,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册