提交 221b0c0d 编写于 作者: Y yejianwu

update api and format code

上级 cefa11d4
......@@ -288,12 +288,12 @@ int Main(int argc, char **argv) {
model_pb_data);
} else {
create_engine_status =
CreateMaceEngine(FLAGS_model_name,
model_data_file_ptr,
input_names,
output_names,
device_type,
&engine);
CreateMaceEngineFromCode(FLAGS_model_name,
model_data_file_ptr,
input_names,
output_names,
device_type,
&engine);
}
if (create_engine_status != MaceStatus::MACE_SUCCESS) {
LOG(FATAL) << "Create engine error, please check the arguments";
......
......@@ -42,11 +42,11 @@ namespace mace {
#ifdef MACE_CPU_MODEL_TAG
namespace MACE_CPU_MODEL_TAG {
extern const unsigned char *LoadModelData(const char *model_data_file);
extern const unsigned char *LoadModelData(const std::string &model_data_file);
extern void UnloadModelData(const unsigned char *model_data);
extern NetDef CreateNet(const std::vector<unsigned char> &model_pb = {});
extern NetDef CreateNet();
extern const std::string ModelChecksum();
......@@ -60,7 +60,7 @@ extern const unsigned char *LoadModelData(const char *model_data_file);
extern void UnloadModelData(const unsigned char *model_data);
extern NetDef CreateNet(const std::vector<unsigned char> &model_pb = {});
extern NetDef CreateNet();
extern const std::string ModelChecksum();
......@@ -74,7 +74,7 @@ extern const unsigned char *LoadModelData(const char *model_data_file);
extern void UnloadModelData(const unsigned char *model_data);
extern NetDef CreateNet(const std::vector<unsigned char> &model_pb = {});
extern NetDef CreateNet();
extern const std::string ModelChecksum();
......
......@@ -14,11 +14,12 @@
#include <errno.h>
#include <fcntl.h>
#include <memory>
#include <string.h>
#include <sys/mman.h>
#include <unistd.h>
#include <memory>
#include "mace/core/net.h"
#include "mace/core/types.h"
#include "mace/public/mace.h"
......@@ -276,15 +277,14 @@ MaceStatus MaceEngine::Run(const std::map<std::string, MaceTensor> &inputs,
return impl_->Run(inputs, outputs, nullptr);
}
namespace {
const unsigned char *LoadModelData(const char *model_data_file) {
int fd = open(model_data_file, O_RDONLY);
const unsigned char *LoadModelData(const std::string &model_data_file,
const size_t &data_size) {
int fd = open(model_data_file.c_str(), O_RDONLY);
MACE_CHECK(fd >= 0, "Failed to open model data file ",
model_data_file, ", error code: ", errno);
const unsigned char *model_data =
static_cast<const unsigned char *>(mmap(nullptr, 2453764,
PROT_READ, MAP_PRIVATE, fd, 0));
const unsigned char *model_data = static_cast<const unsigned char *>(
mmap(nullptr, data_size, PROT_READ, MAP_PRIVATE, fd, 0));
MACE_CHECK(model_data != MAP_FAILED, "Failed to map model data file ",
model_data_file, ", error code: ", errno);
......@@ -295,37 +295,45 @@ const unsigned char *LoadModelData(const char *model_data_file) {
return model_data;
}
void UnloadModelData(const unsigned char *model_data) {
void UnloadModelData(const unsigned char *model_data,
const size_t &data_size) {
int ret = munmap(const_cast<unsigned char *>(model_data),
2453764);
data_size);
MACE_CHECK(ret == 0, "Failed to unmap model data file, error code: ", errno);
}
} // namespace
MaceStatus CreateMaceEngineFromPB(const char *model_data_file,
MaceStatus CreateMaceEngineFromPB(const std::string &model_data_file,
const std::vector<std::string> &input_nodes,
const std::vector<std::string> &output_nodes,
const DeviceType device_type,
std::shared_ptr<MaceEngine> *engine,
const std::vector<unsigned char> model_pb) {
const std::vector<unsigned char> &model_pb) {
LOG(INFO) << "Create MaceEngine from model pb";
// load model
if (engine == nullptr) {
return MaceStatus::MACE_INVALID_ARGS;
}
const unsigned char * model_data = nullptr;
model_data = LoadModelData(model_data_file);
NetDef net_def;
net_def.ParseFromArray(&model_pb[0], model_pb.size());
index_t model_data_size = 0;
for (auto &const_tensor : net_def.tensors()) {
model_data_size = std::max(
model_data_size,
static_cast<index_t>(const_tensor.offset() +
const_tensor.data_size() *
GetEnumTypeSize(const_tensor.data_type())));
}
const unsigned char *model_data = nullptr;
model_data = LoadModelData(model_data_file, model_data_size);
engine->reset(
new mace::MaceEngine(&net_def, device_type, input_nodes, output_nodes,
model_data));
if (device_type == DeviceType::GPU || device_type == DeviceType::HEXAGON) {
UnloadModelData(model_data);
UnloadModelData(model_data, model_data_size);
}
return MACE_SUCCESS;
}
......
......@@ -169,15 +169,13 @@ bool RunModel(const std::vector<std::string> &input_names,
MaceStatus create_engine_status;
// Create Engine
int64_t t0 = NowMicros();
const char *model_data_file_ptr =
FLAGS_model_data_file.empty() ? nullptr : FLAGS_model_data_file.c_str();
if (FLAGS_model_file != "") {
std::vector<unsigned char> model_pb_data;
if (!mace::ReadBinaryFile(&model_pb_data, FLAGS_model_file)) {
LOG(FATAL) << "Failed to read file: " << FLAGS_model_file;
}
create_engine_status =
CreateMaceEngineFromPB(model_data_file_ptr,
CreateMaceEngineFromPB(FLAGS_model_data_file,
input_names,
output_names,
device_type,
......@@ -185,12 +183,12 @@ bool RunModel(const std::vector<std::string> &input_names,
model_pb_data);
} else {
create_engine_status =
CreateMaceEngine(model_name,
model_data_file_ptr,
input_names,
output_names,
device_type,
&engine);
CreateMaceEngineFromCode(model_name,
FLAGS_model_data_file,
input_names,
output_names,
device_type,
&engine);
}
if (create_engine_status != MaceStatus::MACE_SUCCESS) {
......
......@@ -106,12 +106,18 @@ class MaceEngine {
MaceEngine &operator=(const MaceEngine &) = delete;
};
MaceStatus CreateMaceEngineFromPB(const char *model_data_file,
const unsigned char *LoadModelData(const std::string &model_data_file,
const size_t &data_size);
void UnloadModelData(const unsigned char *model_data,
const size_t &data_size);
MaceStatus CreateMaceEngineFromPB(const std::string &model_data_file,
const std::vector<std::string> &input_nodes,
const std::vector<std::string> &output_nodes,
const DeviceType device_type,
std::shared_ptr<MaceEngine> *engine,
const std::vector<unsigned char> model_pb);
const std::vector<unsigned char> &model_pb);
} // namespace mace
......
......@@ -29,11 +29,11 @@ namespace mace {
{% for tag in model_tags %}
namespace {{tag}} {
extern const unsigned char *LoadModelData(const char *model_data_file);
extern const unsigned char *LoadModelData(const std::string &model_data_file);
extern void UnloadModelData(const unsigned char *model_data);
extern NetDef CreateNet(const std::vector<unsigned char> &model_pb = {});
extern NetDef CreateNet();
extern const std::string ModelName();
extern const std::string ModelChecksum();
......@@ -51,9 +51,9 @@ std::map<std::string, int> model_name_map {
};
} // namespace
MaceStatus CreateMaceEngine(
MaceStatus CreateMaceEngineFromCode(
const std::string &model_name,
const char *model_data_file,
const std::string &model_data_file,
const std::vector<std::string> &input_nodes,
const std::vector<std::string> &output_nodes,
const DeviceType device_type,
......@@ -86,9 +86,9 @@ MaceStatus CreateMaceEngine(
return MaceStatus::MACE_SUCCESS;
}
{% else %}
MaceStatus CreateMaceEngine(
MaceStatus CreateMaceEngineFromCode(
const std::string &model_name,
const char *model_data_file,
const std::string &model_data_file,
const std::vector<std::string> &input_nodes,
const std::vector<std::string> &output_nodes,
const DeviceType device_type,
......
......@@ -127,13 +127,11 @@ void CreateMemoryArena(mace::MemoryArena *mem_arena) {
namespace {{tag}} {
NetDef CreateNet(const std::vector<unsigned char> &model_pb = {}) {
NetDef CreateNet() {
MACE_LATENCY_LOGGER(1, "Create net {{ net.name }}");
NetDef net_def;
{% if model_type == 'source' %}
MACE_UNUSED(model_pb);
net_def.set_name("{{ net.name}}");
net_def.set_version("{{ net.version }}");
......@@ -150,11 +148,6 @@ NetDef CreateNet(const std::vector<unsigned char> &model_pb = {}) {
CreateOutputInfo(net_def);
{% endif %}
{% else %}
net_def.ParseFromArray(&model_pb[0], model_pb.size());
{% endif %}
return net_def;
}
......
......@@ -28,7 +28,7 @@ const unsigned char *LoadModelData(const char *model_data_file);
void UnloadModelData(const unsigned char *model_data);
NetDef CreateNet(const unsigned char *model_data);
NetDef CreateNet();
const std::string ModelName();
......
......@@ -22,16 +22,6 @@
#include "mace/utils/env_time.h"
#include "mace/utils/logging.h"
{% if not embed_model_data %}
#include <errno.h>
#include <fcntl.h>
#include <string.h>
#include <sys/mman.h>
#include <unistd.h>
{% endif %}
namespace mace {
namespace {{tag}} {
......@@ -41,34 +31,18 @@ alignas(4) const unsigned char model_data[{{ model_data_size }}] = {
};
{% endif %}
const unsigned char *LoadModelData(const char *model_data_file) {
const unsigned char *LoadModelData(const std::string &model_data_file) {
{% if embed_model_data %}
MACE_UNUSED(model_data_file);
return model_data;
{% else %}
int fd = open(model_data_file, O_RDONLY);
MACE_CHECK(fd >= 0, "Failed to open model data file ",
model_data_file, ", error code: ", errno);
const unsigned char *model_data =
static_cast<const unsigned char *>(mmap(nullptr, {{ model_data_size }},
PROT_READ, MAP_PRIVATE, fd, 0));
MACE_CHECK(model_data != MAP_FAILED, "Failed to map model data file ",
model_data_file, ", error code: ", errno);
int ret = close(fd);
MACE_CHECK(ret == 0, "Failed to close model data file ",
model_data_file, ", error code: ", errno);
return model_data;
return mace::LoadModelData(model_data_file, {{ model_data_size }});
{% endif %}
}
void UnloadModelData(const unsigned char *model_data) {
{% if not embed_model_data %}
int ret = munmap(const_cast<unsigned char *>(model_data),
{{ model_data_size }});
MACE_CHECK(ret == 0, "Failed to unmap model data file, error code: ", errno);
mace::UnloadModelData(model_data, {{ model_data_size }});
{% else %}
MACE_UNUSED(model_data);
{% endif %}
......
......@@ -231,15 +231,13 @@ bool RunModel(const std::string &model_name,
MaceStatus create_engine_status;
// Create Engine
int64_t t0 = NowMicros();
const char *model_data_file_ptr =
FLAGS_model_data_file.empty() ? nullptr : FLAGS_model_data_file.c_str();
if (FLAGS_model_file != "") {
std::vector<unsigned char> model_pb_data;
if (!mace::ReadBinaryFile(&model_pb_data, FLAGS_model_file)) {
LOG(FATAL) << "Failed to read file: " << FLAGS_model_file;
}
create_engine_status =
CreateMaceEngineFromPB(model_data_file_ptr,
CreateMaceEngineFromPB(FLAGS_model_data_file,
input_names,
output_names,
device_type,
......@@ -247,12 +245,12 @@ bool RunModel(const std::string &model_name,
model_pb_data);
} else {
create_engine_status =
CreateMaceEngine(model_name,
model_data_file_ptr,
input_names,
output_names,
device_type,
&engine);
CreateMaceEngineFromCode(model_name,
FLAGS_model_data_file,
input_names,
output_names,
device_type,
&engine);
}
int64_t t1 = NowMicros();
......
......@@ -942,7 +942,8 @@ def merge_libs(target_soc,
sh.cp("-f", glob.glob("%s/*.data" % model_output_dir),
model_data_dir)
if model_load_type == "source":
sh.cp("-f", glob.glob("%s/*.h" % model_output_dir), model_header_dir)
sh.cp("-f", glob.glob("%s/*.h" % model_output_dir),
model_header_dir)
for model_name in mace_model_dirs_kv:
sh.cp("-f", "%s/%s.pb" % (mace_model_dirs_kv[model_name], model_name),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册