提交 221b0c0d 编写于 作者: Y yejianwu

update api and format code

上级 cefa11d4
...@@ -288,12 +288,12 @@ int Main(int argc, char **argv) { ...@@ -288,12 +288,12 @@ int Main(int argc, char **argv) {
model_pb_data); model_pb_data);
} else { } else {
create_engine_status = create_engine_status =
CreateMaceEngine(FLAGS_model_name, CreateMaceEngineFromCode(FLAGS_model_name,
model_data_file_ptr, model_data_file_ptr,
input_names, input_names,
output_names, output_names,
device_type, device_type,
&engine); &engine);
} }
if (create_engine_status != MaceStatus::MACE_SUCCESS) { if (create_engine_status != MaceStatus::MACE_SUCCESS) {
LOG(FATAL) << "Create engine error, please check the arguments"; LOG(FATAL) << "Create engine error, please check the arguments";
......
...@@ -42,11 +42,11 @@ namespace mace { ...@@ -42,11 +42,11 @@ namespace mace {
#ifdef MACE_CPU_MODEL_TAG #ifdef MACE_CPU_MODEL_TAG
namespace MACE_CPU_MODEL_TAG { namespace MACE_CPU_MODEL_TAG {
extern const unsigned char *LoadModelData(const char *model_data_file); extern const unsigned char *LoadModelData(const std::string &model_data_file);
extern void UnloadModelData(const unsigned char *model_data); extern void UnloadModelData(const unsigned char *model_data);
extern NetDef CreateNet(const std::vector<unsigned char> &model_pb = {}); extern NetDef CreateNet();
extern const std::string ModelChecksum(); extern const std::string ModelChecksum();
...@@ -60,7 +60,7 @@ extern const unsigned char *LoadModelData(const char *model_data_file); ...@@ -60,7 +60,7 @@ extern const unsigned char *LoadModelData(const char *model_data_file);
extern void UnloadModelData(const unsigned char *model_data); extern void UnloadModelData(const unsigned char *model_data);
extern NetDef CreateNet(const std::vector<unsigned char> &model_pb = {}); extern NetDef CreateNet();
extern const std::string ModelChecksum(); extern const std::string ModelChecksum();
...@@ -74,7 +74,7 @@ extern const unsigned char *LoadModelData(const char *model_data_file); ...@@ -74,7 +74,7 @@ extern const unsigned char *LoadModelData(const char *model_data_file);
extern void UnloadModelData(const unsigned char *model_data); extern void UnloadModelData(const unsigned char *model_data);
extern NetDef CreateNet(const std::vector<unsigned char> &model_pb = {}); extern NetDef CreateNet();
extern const std::string ModelChecksum(); extern const std::string ModelChecksum();
......
...@@ -14,11 +14,12 @@ ...@@ -14,11 +14,12 @@
#include <errno.h> #include <errno.h>
#include <fcntl.h> #include <fcntl.h>
#include <memory>
#include <string.h> #include <string.h>
#include <sys/mman.h> #include <sys/mman.h>
#include <unistd.h> #include <unistd.h>
#include <memory>
#include "mace/core/net.h" #include "mace/core/net.h"
#include "mace/core/types.h" #include "mace/core/types.h"
#include "mace/public/mace.h" #include "mace/public/mace.h"
...@@ -276,15 +277,14 @@ MaceStatus MaceEngine::Run(const std::map<std::string, MaceTensor> &inputs, ...@@ -276,15 +277,14 @@ MaceStatus MaceEngine::Run(const std::map<std::string, MaceTensor> &inputs,
return impl_->Run(inputs, outputs, nullptr); return impl_->Run(inputs, outputs, nullptr);
} }
namespace { const unsigned char *LoadModelData(const std::string &model_data_file,
const unsigned char *LoadModelData(const char *model_data_file) { const size_t &data_size) {
int fd = open(model_data_file, O_RDONLY); int fd = open(model_data_file.c_str(), O_RDONLY);
MACE_CHECK(fd >= 0, "Failed to open model data file ", MACE_CHECK(fd >= 0, "Failed to open model data file ",
model_data_file, ", error code: ", errno); model_data_file, ", error code: ", errno);
const unsigned char *model_data = const unsigned char *model_data = static_cast<const unsigned char *>(
static_cast<const unsigned char *>(mmap(nullptr, 2453764, mmap(nullptr, data_size, PROT_READ, MAP_PRIVATE, fd, 0));
PROT_READ, MAP_PRIVATE, fd, 0));
MACE_CHECK(model_data != MAP_FAILED, "Failed to map model data file ", MACE_CHECK(model_data != MAP_FAILED, "Failed to map model data file ",
model_data_file, ", error code: ", errno); model_data_file, ", error code: ", errno);
...@@ -295,37 +295,45 @@ const unsigned char *LoadModelData(const char *model_data_file) { ...@@ -295,37 +295,45 @@ const unsigned char *LoadModelData(const char *model_data_file) {
return model_data; return model_data;
} }
void UnloadModelData(const unsigned char *model_data) { void UnloadModelData(const unsigned char *model_data,
const size_t &data_size) {
int ret = munmap(const_cast<unsigned char *>(model_data), int ret = munmap(const_cast<unsigned char *>(model_data),
2453764); data_size);
MACE_CHECK(ret == 0, "Failed to unmap model data file, error code: ", errno); MACE_CHECK(ret == 0, "Failed to unmap model data file, error code: ", errno);
} }
} // namespace
MaceStatus CreateMaceEngineFromPB(const char *model_data_file, MaceStatus CreateMaceEngineFromPB(const std::string &model_data_file,
const std::vector<std::string> &input_nodes, const std::vector<std::string> &input_nodes,
const std::vector<std::string> &output_nodes, const std::vector<std::string> &output_nodes,
const DeviceType device_type, const DeviceType device_type,
std::shared_ptr<MaceEngine> *engine, std::shared_ptr<MaceEngine> *engine,
const std::vector<unsigned char> model_pb) { const std::vector<unsigned char> &model_pb) {
LOG(INFO) << "Create MaceEngine from model pb"; LOG(INFO) << "Create MaceEngine from model pb";
// load model // load model
if (engine == nullptr) { if (engine == nullptr) {
return MaceStatus::MACE_INVALID_ARGS; return MaceStatus::MACE_INVALID_ARGS;
} }
const unsigned char * model_data = nullptr;
model_data = LoadModelData(model_data_file);
NetDef net_def; NetDef net_def;
net_def.ParseFromArray(&model_pb[0], model_pb.size()); net_def.ParseFromArray(&model_pb[0], model_pb.size());
index_t model_data_size = 0;
for (auto &const_tensor : net_def.tensors()) {
model_data_size = std::max(
model_data_size,
static_cast<index_t>(const_tensor.offset() +
const_tensor.data_size() *
GetEnumTypeSize(const_tensor.data_type())));
}
const unsigned char *model_data = nullptr;
model_data = LoadModelData(model_data_file, model_data_size);
engine->reset( engine->reset(
new mace::MaceEngine(&net_def, device_type, input_nodes, output_nodes, new mace::MaceEngine(&net_def, device_type, input_nodes, output_nodes,
model_data)); model_data));
if (device_type == DeviceType::GPU || device_type == DeviceType::HEXAGON) { if (device_type == DeviceType::GPU || device_type == DeviceType::HEXAGON) {
UnloadModelData(model_data); UnloadModelData(model_data, model_data_size);
} }
return MACE_SUCCESS; return MACE_SUCCESS;
} }
......
...@@ -169,15 +169,13 @@ bool RunModel(const std::vector<std::string> &input_names, ...@@ -169,15 +169,13 @@ bool RunModel(const std::vector<std::string> &input_names,
MaceStatus create_engine_status; MaceStatus create_engine_status;
// Create Engine // Create Engine
int64_t t0 = NowMicros(); int64_t t0 = NowMicros();
const char *model_data_file_ptr =
FLAGS_model_data_file.empty() ? nullptr : FLAGS_model_data_file.c_str();
if (FLAGS_model_file != "") { if (FLAGS_model_file != "") {
std::vector<unsigned char> model_pb_data; std::vector<unsigned char> model_pb_data;
if (!mace::ReadBinaryFile(&model_pb_data, FLAGS_model_file)) { if (!mace::ReadBinaryFile(&model_pb_data, FLAGS_model_file)) {
LOG(FATAL) << "Failed to read file: " << FLAGS_model_file; LOG(FATAL) << "Failed to read file: " << FLAGS_model_file;
} }
create_engine_status = create_engine_status =
CreateMaceEngineFromPB(model_data_file_ptr, CreateMaceEngineFromPB(FLAGS_model_data_file,
input_names, input_names,
output_names, output_names,
device_type, device_type,
...@@ -185,12 +183,12 @@ bool RunModel(const std::vector<std::string> &input_names, ...@@ -185,12 +183,12 @@ bool RunModel(const std::vector<std::string> &input_names,
model_pb_data); model_pb_data);
} else { } else {
create_engine_status = create_engine_status =
CreateMaceEngine(model_name, CreateMaceEngineFromCode(model_name,
model_data_file_ptr, FLAGS_model_data_file,
input_names, input_names,
output_names, output_names,
device_type, device_type,
&engine); &engine);
} }
if (create_engine_status != MaceStatus::MACE_SUCCESS) { if (create_engine_status != MaceStatus::MACE_SUCCESS) {
......
...@@ -106,12 +106,18 @@ class MaceEngine { ...@@ -106,12 +106,18 @@ class MaceEngine {
MaceEngine &operator=(const MaceEngine &) = delete; MaceEngine &operator=(const MaceEngine &) = delete;
}; };
MaceStatus CreateMaceEngineFromPB(const char *model_data_file, const unsigned char *LoadModelData(const std::string &model_data_file,
const size_t &data_size);
void UnloadModelData(const unsigned char *model_data,
const size_t &data_size);
MaceStatus CreateMaceEngineFromPB(const std::string &model_data_file,
const std::vector<std::string> &input_nodes, const std::vector<std::string> &input_nodes,
const std::vector<std::string> &output_nodes, const std::vector<std::string> &output_nodes,
const DeviceType device_type, const DeviceType device_type,
std::shared_ptr<MaceEngine> *engine, std::shared_ptr<MaceEngine> *engine,
const std::vector<unsigned char> model_pb); const std::vector<unsigned char> &model_pb);
} // namespace mace } // namespace mace
......
...@@ -29,11 +29,11 @@ namespace mace { ...@@ -29,11 +29,11 @@ namespace mace {
{% for tag in model_tags %} {% for tag in model_tags %}
namespace {{tag}} { namespace {{tag}} {
extern const unsigned char *LoadModelData(const char *model_data_file); extern const unsigned char *LoadModelData(const std::string &model_data_file);
extern void UnloadModelData(const unsigned char *model_data); extern void UnloadModelData(const unsigned char *model_data);
extern NetDef CreateNet(const std::vector<unsigned char> &model_pb = {}); extern NetDef CreateNet();
extern const std::string ModelName(); extern const std::string ModelName();
extern const std::string ModelChecksum(); extern const std::string ModelChecksum();
...@@ -51,9 +51,9 @@ std::map<std::string, int> model_name_map { ...@@ -51,9 +51,9 @@ std::map<std::string, int> model_name_map {
}; };
} // namespace } // namespace
MaceStatus CreateMaceEngine( MaceStatus CreateMaceEngineFromCode(
const std::string &model_name, const std::string &model_name,
const char *model_data_file, const std::string &model_data_file,
const std::vector<std::string> &input_nodes, const std::vector<std::string> &input_nodes,
const std::vector<std::string> &output_nodes, const std::vector<std::string> &output_nodes,
const DeviceType device_type, const DeviceType device_type,
...@@ -86,9 +86,9 @@ MaceStatus CreateMaceEngine( ...@@ -86,9 +86,9 @@ MaceStatus CreateMaceEngine(
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
{% else %} {% else %}
MaceStatus CreateMaceEngine( MaceStatus CreateMaceEngineFromCode(
const std::string &model_name, const std::string &model_name,
const char *model_data_file, const std::string &model_data_file,
const std::vector<std::string> &input_nodes, const std::vector<std::string> &input_nodes,
const std::vector<std::string> &output_nodes, const std::vector<std::string> &output_nodes,
const DeviceType device_type, const DeviceType device_type,
......
...@@ -127,13 +127,11 @@ void CreateMemoryArena(mace::MemoryArena *mem_arena) { ...@@ -127,13 +127,11 @@ void CreateMemoryArena(mace::MemoryArena *mem_arena) {
namespace {{tag}} { namespace {{tag}} {
NetDef CreateNet(const std::vector<unsigned char> &model_pb = {}) { NetDef CreateNet() {
MACE_LATENCY_LOGGER(1, "Create net {{ net.name }}"); MACE_LATENCY_LOGGER(1, "Create net {{ net.name }}");
NetDef net_def; NetDef net_def;
{% if model_type == 'source' %}
MACE_UNUSED(model_pb);
net_def.set_name("{{ net.name}}"); net_def.set_name("{{ net.name}}");
net_def.set_version("{{ net.version }}"); net_def.set_version("{{ net.version }}");
...@@ -150,11 +148,6 @@ NetDef CreateNet(const std::vector<unsigned char> &model_pb = {}) { ...@@ -150,11 +148,6 @@ NetDef CreateNet(const std::vector<unsigned char> &model_pb = {}) {
CreateOutputInfo(net_def); CreateOutputInfo(net_def);
{% endif %} {% endif %}
{% else %}
net_def.ParseFromArray(&model_pb[0], model_pb.size());
{% endif %}
return net_def; return net_def;
} }
......
...@@ -28,7 +28,7 @@ const unsigned char *LoadModelData(const char *model_data_file); ...@@ -28,7 +28,7 @@ const unsigned char *LoadModelData(const char *model_data_file);
void UnloadModelData(const unsigned char *model_data); void UnloadModelData(const unsigned char *model_data);
NetDef CreateNet(const unsigned char *model_data); NetDef CreateNet();
const std::string ModelName(); const std::string ModelName();
......
...@@ -22,16 +22,6 @@ ...@@ -22,16 +22,6 @@
#include "mace/utils/env_time.h" #include "mace/utils/env_time.h"
#include "mace/utils/logging.h" #include "mace/utils/logging.h"
{% if not embed_model_data %}
#include <errno.h>
#include <fcntl.h>
#include <string.h>
#include <sys/mman.h>
#include <unistd.h>
{% endif %}
namespace mace { namespace mace {
namespace {{tag}} { namespace {{tag}} {
...@@ -41,34 +31,18 @@ alignas(4) const unsigned char model_data[{{ model_data_size }}] = { ...@@ -41,34 +31,18 @@ alignas(4) const unsigned char model_data[{{ model_data_size }}] = {
}; };
{% endif %} {% endif %}
const unsigned char *LoadModelData(const char *model_data_file) { const unsigned char *LoadModelData(const std::string &model_data_file) {
{% if embed_model_data %} {% if embed_model_data %}
MACE_UNUSED(model_data_file); MACE_UNUSED(model_data_file);
return model_data; return model_data;
{% else %} {% else %}
int fd = open(model_data_file, O_RDONLY); return mace::LoadModelData(model_data_file, {{ model_data_size }});
MACE_CHECK(fd >= 0, "Failed to open model data file ",
model_data_file, ", error code: ", errno);
const unsigned char *model_data =
static_cast<const unsigned char *>(mmap(nullptr, {{ model_data_size }},
PROT_READ, MAP_PRIVATE, fd, 0));
MACE_CHECK(model_data != MAP_FAILED, "Failed to map model data file ",
model_data_file, ", error code: ", errno);
int ret = close(fd);
MACE_CHECK(ret == 0, "Failed to close model data file ",
model_data_file, ", error code: ", errno);
return model_data;
{% endif %} {% endif %}
} }
void UnloadModelData(const unsigned char *model_data) { void UnloadModelData(const unsigned char *model_data) {
{% if not embed_model_data %} {% if not embed_model_data %}
int ret = munmap(const_cast<unsigned char *>(model_data), mace::UnloadModelData(model_data, {{ model_data_size }});
{{ model_data_size }});
MACE_CHECK(ret == 0, "Failed to unmap model data file, error code: ", errno);
{% else %} {% else %}
MACE_UNUSED(model_data); MACE_UNUSED(model_data);
{% endif %} {% endif %}
......
...@@ -231,15 +231,13 @@ bool RunModel(const std::string &model_name, ...@@ -231,15 +231,13 @@ bool RunModel(const std::string &model_name,
MaceStatus create_engine_status; MaceStatus create_engine_status;
// Create Engine // Create Engine
int64_t t0 = NowMicros(); int64_t t0 = NowMicros();
const char *model_data_file_ptr =
FLAGS_model_data_file.empty() ? nullptr : FLAGS_model_data_file.c_str();
if (FLAGS_model_file != "") { if (FLAGS_model_file != "") {
std::vector<unsigned char> model_pb_data; std::vector<unsigned char> model_pb_data;
if (!mace::ReadBinaryFile(&model_pb_data, FLAGS_model_file)) { if (!mace::ReadBinaryFile(&model_pb_data, FLAGS_model_file)) {
LOG(FATAL) << "Failed to read file: " << FLAGS_model_file; LOG(FATAL) << "Failed to read file: " << FLAGS_model_file;
} }
create_engine_status = create_engine_status =
CreateMaceEngineFromPB(model_data_file_ptr, CreateMaceEngineFromPB(FLAGS_model_data_file,
input_names, input_names,
output_names, output_names,
device_type, device_type,
...@@ -247,12 +245,12 @@ bool RunModel(const std::string &model_name, ...@@ -247,12 +245,12 @@ bool RunModel(const std::string &model_name,
model_pb_data); model_pb_data);
} else { } else {
create_engine_status = create_engine_status =
CreateMaceEngine(model_name, CreateMaceEngineFromCode(model_name,
model_data_file_ptr, FLAGS_model_data_file,
input_names, input_names,
output_names, output_names,
device_type, device_type,
&engine); &engine);
} }
int64_t t1 = NowMicros(); int64_t t1 = NowMicros();
......
...@@ -942,7 +942,8 @@ def merge_libs(target_soc, ...@@ -942,7 +942,8 @@ def merge_libs(target_soc,
sh.cp("-f", glob.glob("%s/*.data" % model_output_dir), sh.cp("-f", glob.glob("%s/*.data" % model_output_dir),
model_data_dir) model_data_dir)
if model_load_type == "source": if model_load_type == "source":
sh.cp("-f", glob.glob("%s/*.h" % model_output_dir), model_header_dir) sh.cp("-f", glob.glob("%s/*.h" % model_output_dir),
model_header_dir)
for model_name in mace_model_dirs_kv: for model_name in mace_model_dirs_kv:
sh.cp("-f", "%s/%s.pb" % (mace_model_dirs_kv[model_name], model_name), sh.cp("-f", "%s/%s.pb" % (mace_model_dirs_kv[model_name], model_name),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册