提交 a3fe4a89 编写于 作者: 刘琦

Merge branch 'fix_mace_run_for_code_model' into 'master'

fix mace_run & benchmark in code model format

See merge request !960
......@@ -278,19 +278,24 @@ int Main(int argc, char **argv) {
MaceStatus create_engine_status;
// Create Engine
std::vector<unsigned char> model_graph_data;
if (!mace::ReadBinaryFile(&model_graph_data, FLAGS_model_file)) {
LOG(FATAL) << "Failed to read file: " << FLAGS_model_file;
if (FLAGS_model_file != "") {
if (!mace::ReadBinaryFile(&model_graph_data, FLAGS_model_file)) {
LOG(FATAL) << "Failed to read file: " << FLAGS_model_file;
}
}
std::vector<unsigned char> model_weights_data;
if (!mace::ReadBinaryFile(&model_weights_data, FLAGS_model_data_file)) {
LOG(FATAL) << "Failed to read file: " << FLAGS_model_data_file;
if (FLAGS_model_data_file != "") {
if (!mace::ReadBinaryFile(&model_weights_data, FLAGS_model_data_file)) {
LOG(FATAL) << "Failed to read file: " << FLAGS_model_data_file;
}
}
#ifdef MODEL_GRAPH_FORMAT_CODE
create_engine_status =
CreateMaceEngineFromCode(FLAGS_model_name,
model_data_file_ptr,
model_weights_data.data(),
model_weights_data.size(),
input_names,
output_names,
config,
......
......@@ -59,6 +59,29 @@ std::vector<std::string> Split(const std::string &str, char delims) {
} // namespace str_util
namespace {
bool ReadBinaryFile(std::vector<unsigned char> *data,
const std::string &filename) {
std::ifstream ifs(filename, std::ios::in | std::ios::binary);
if (!ifs.is_open()) {
return false;
}
ifs.seekg(0, ifs.end);
size_t length = ifs.tellg();
ifs.seekg(0, ifs.beg);
data->reserve(length);
data->insert(data->begin(), std::istreambuf_iterator<char>(ifs),
std::istreambuf_iterator<char>());
if (ifs.fail()) {
return false;
}
ifs.close();
return true;
}
} // namespace
void ParseShape(const std::string &str, std::vector<int64_t> *shape) {
std::string tmp = str;
while (!tmp.empty()) {
......@@ -142,30 +165,6 @@ DEFINE_int32(gpu_priority_hint, 1, "0:DEFAULT/1:LOW/2:NORMAL/3:HIGH");
DEFINE_int32(omp_num_threads, -1, "num of openmp threads");
DEFINE_int32(cpu_affinity_policy, 1,
"0:AFFINITY_NONE/1:AFFINITY_BIG_ONLY/2:AFFINITY_LITTLE_ONLY");
#ifndef MODEL_GRAPH_FORMAT_CODE
namespace {
bool ReadBinaryFile(std::vector<unsigned char> *data,
const std::string &filename) {
std::ifstream ifs(filename, std::ios::in | std::ios::binary);
if (!ifs.is_open()) {
return false;
}
ifs.seekg(0, ifs.end);
size_t length = ifs.tellg();
ifs.seekg(0, ifs.beg);
data->reserve(length);
data->insert(data->begin(), std::istreambuf_iterator<char>(ifs),
std::istreambuf_iterator<char>());
if (ifs.fail()) {
return false;
}
ifs.close();
return true;
}
} // namespace
#endif
bool RunModel(const std::vector<std::string> &input_names,
const std::vector<std::vector<int64_t>> &input_shapes,
......@@ -212,6 +211,16 @@ bool RunModel(const std::vector<std::string> &input_names,
// Create Engine
std::shared_ptr<mace::MaceEngine> engine;
MaceStatus create_engine_status;
std::vector<unsigned char> model_graph_data;
if (!ReadBinaryFile(&model_graph_data, FLAGS_model_file)) {
std::cerr << "Failed to read file: " << FLAGS_model_file << std::endl;
}
std::vector<unsigned char> model_weights_data;
if (!ReadBinaryFile(&model_weights_data, FLAGS_model_data_file)) {
std::cerr << "Failed to read file: " << FLAGS_model_data_file << std::endl;
}
// Only choose one of the two type based on the `model_graph_format`
// in model deployment file(.yml).
#ifdef MODEL_GRAPH_FORMAT_CODE
......@@ -219,20 +228,13 @@ bool RunModel(const std::vector<std::string> &input_names,
// to model_data_file parameter.
create_engine_status =
CreateMaceEngineFromCode(FLAGS_model_name,
FLAGS_model_data_file,
model_weights_data.data(),
model_weights_data.size(),
input_names,
output_names,
config,
&engine);
#else
std::vector<unsigned char> model_graph_data;
if (!ReadBinaryFile(&model_graph_data, FLAGS_model_file)) {
std::cerr << "Failed to read file: " << FLAGS_model_file << std::endl;
}
std::vector<unsigned char> model_weights_data;
if (!ReadBinaryFile(&model_weights_data, FLAGS_model_data_file)) {
std::cerr << "Failed to read file: " << FLAGS_model_data_file << std::endl;
}
create_engine_status =
CreateMaceEngineFromProto(model_graph_data.data(),
model_graph_data.size(),
......
......@@ -62,7 +62,7 @@ std::map<std::string, int> model_name_map {
/// \param engine[out]: output MaceEngine object
/// \return MaceStatus::MACE_SUCCESS for success, MACE_INVALID_ARGS for wrong arguments,
/// MACE_OUT_OF_RESOURCES for resources is out of range.
MaceStatus CreateMaceEngineFromCode(
__attribute__((deprecated)) MaceStatus CreateMaceEngineFromCode(
const std::string &model_name,
const std::string &model_data_file,
const std::vector<std::string> &input_nodes,
......@@ -101,5 +101,48 @@ MaceStatus CreateMaceEngineFromCode(
return status;
}
MaceStatus CreateMaceEngineFromCode(
const std::string &model_name,
const unsigned char *model_weights_data,
const size_t model_weights_data_size,
const std::vector<std::string> &input_nodes,
const std::vector<std::string> &output_nodes,
const MaceEngineConfig &config,
std::shared_ptr<MaceEngine> *engine) {
// load model
if (engine == nullptr) {
return MaceStatus::MACE_INVALID_ARGS;
}
std::shared_ptr<NetDef> net_def;
{% if embed_model_data %}
const unsigned char * model_data;
(void)model_weights_data;
{% endif %}
// TODO(yejianwu) Add buffer range checking
(void)model_weights_data_size;
MaceStatus status = MaceStatus::MACE_SUCCESS;
switch (model_name_map[model_name]) {
{% for i in range(model_tags |length) %}
case {{ i }}:
net_def = mace::{{model_tags[i]}}::CreateNet();
engine->reset(new mace::MaceEngine(config));
{% if embed_model_data %}
model_data = mace::{{model_tags[i]}}::LoadModelData();
status = (*engine)->Init(net_def.get(), input_nodes, output_nodes,
model_data);
{% else %}
status = (*engine)->Init(net_def.get(), input_nodes, output_nodes,
model_weights_data);
{% endif %}
break;
{% endfor %}
default:
status = MaceStatus::MACE_INVALID_ARGS;
}
return status;
}
} // namespace mace
#endif // MACE_CODEGEN_ENGINE_MACE_ENGINE_FACTORY_H_
......@@ -264,7 +264,8 @@ bool RunModel(const std::string &model_name,
#ifdef MODEL_GRAPH_FORMAT_CODE
create_engine_status =
CreateMaceEngineFromCode(model_name,
FLAGS_model_data_file,
model_weights_data.data(),
model_weights_data.size(),
input_names,
output_names,
config,
......@@ -340,7 +341,8 @@ bool RunModel(const std::string &model_name,
#ifdef MODEL_GRAPH_FORMAT_CODE
create_engine_status =
CreateMaceEngineFromCode(model_name,
FLAGS_model_data_file,
model_weights_data.data(),
model_weights_data.size(),
input_names,
output_names,
config,
......@@ -382,7 +384,8 @@ bool RunModel(const std::string &model_name,
#ifdef MODEL_GRAPH_FORMAT_CODE
create_engine_status =
CreateMaceEngineFromCode(model_name,
FLAGS_model_data_file,
model_weights_data.data(),
model_weights_data.size(),
input_names,
output_names,
config,
......
......@@ -192,6 +192,14 @@ class DeviceWrapper:
if model_graph_format == ModelFormat.file:
mace_model_path = layers_validate_file if layers_validate_file \
else "%s/%s.pb" % (mace_model_dir, model_tag)
model_data_file = ""
if not embed_model_data:
if self.system == SystemType.host:
model_data_file = "%s/%s.data" % (mace_model_dir, model_tag)
else:
model_data_file = "%s/%s.data" % (self.data_dir, model_tag)
if self.system == SystemType.host:
libmace_dynamic_lib_path = \
os.path.dirname(libmace_dynamic_library_path)
......@@ -214,8 +222,7 @@ class DeviceWrapper:
output_file_name),
"--input_dir=%s" % input_dir,
"--output_dir=%s" % output_dir,
"--model_data_file=%s/%s.data" % (mace_model_dir,
model_tag),
"--model_data_file=%s" % model_data_file,
"--device=%s" % device_type,
"--round=%s" % running_round,
"--restart_round=%s" % restart_round,
......@@ -229,7 +236,7 @@ class DeviceWrapper:
stdout=subprocess.PIPE)
out, err = p.communicate()
self.stdout = err + out
six.print_(self.stdout)
six.print_(self.stdout.decode('UTF-8'))
six.print_("Running finished!\n")
elif self.system in [SystemType.android, SystemType.arm_linux]:
self.rm(self.data_dir)
......@@ -304,7 +311,7 @@ class DeviceWrapper:
"--output_file=%s/%s" % (self.data_dir, output_file_name),
"--input_dir=%s" % input_dir,
"--output_dir=%s" % output_dir,
"--model_data_file=%s/%s.data" % (self.data_dir, model_tag),
"--model_data_file=%s" % model_data_file,
"--device=%s" % device_type,
"--round=%s" % running_round,
"--restart_round=%s" % restart_round,
......@@ -753,6 +760,14 @@ class DeviceWrapper:
mace_model_path = ''
if model_graph_format == ModelFormat.file:
mace_model_path = '%s/%s.pb' % (mace_model_dir, model_tag)
model_data_file = ""
if not embed_model_data:
if self.system == SystemType.host:
model_data_file = "%s/%s.data" % (mace_model_dir, model_tag)
else:
model_data_file = "%s/%s.data" % (self.data_dir, model_tag)
if abi == ABIType.host:
libmace_dynamic_lib_dir_path = \
os.path.dirname(libmace_dynamic_library_path)
......@@ -768,8 +783,7 @@ class DeviceWrapper:
'--input_shape=%s' % ':'.join(input_shapes),
'--output_shape=%s' % ':'.join(output_shapes),
'--input_file=%s/%s' % (model_output_dir, input_file_name),
'--model_data_file=%s/%s.data' % (mace_model_dir,
model_tag),
"--model_data_file=%s" % model_data_file,
'--device=%s' % device_type,
'--omp_num_threads=%s' % omp_num_threads,
'--cpu_affinity_policy=%s' % cpu_affinity_policy,
......@@ -822,7 +836,7 @@ class DeviceWrapper:
'--input_shape=%s' % ':'.join(input_shapes),
'--output_shape=%s' % ':'.join(output_shapes),
'--input_file=%s/%s' % (self.data_dir, input_file_name),
'--model_data_file=%s/%s.data' % (self.data_dir, model_tag),
"--model_data_file=%s" % model_data_file,
'--device=%s' % device_type,
'--omp_num_threads=%s' % omp_num_threads,
'--cpu_affinity_policy=%s' % cpu_affinity_policy,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册