提交 4fdcc85e 编写于 作者: L liuqi

Change the mace_engine_creator.cc to mace_engnie_factory.h.

上级 dbf67ad9
...@@ -34,7 +34,7 @@ cc_binary( ...@@ -34,7 +34,7 @@ cc_binary(
":statistics", ":statistics",
"//external:gflags_nothreads", "//external:gflags_nothreads",
"//mace/codegen:generated_models", "//mace/codegen:generated_models",
"//mace/codegen:generated_mace_engine_creator", "//mace/codegen:generated_mace_engine_factory",
], ],
) )
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "mace/public/mace_runtime.h" #include "mace/public/mace_runtime.h"
#include "mace/utils/logging.h" #include "mace/utils/logging.h"
#include "mace/benchmark/statistics.h" #include "mace/benchmark/statistics.h"
#include "mace/codegen/engine/mace_engine_factory.h"
namespace mace { namespace mace {
namespace benchmark { namespace benchmark {
...@@ -174,7 +175,7 @@ bool Run(const std::string &title, ...@@ -174,7 +175,7 @@ bool Run(const std::string &title,
return true; return true;
} }
DEFINE_string(model_tag, "", "model tag"); DEFINE_string(model_name, "", "model name in yaml");
DEFINE_string(device, "CPU", "Device [CPU|GPU|DSP]"); DEFINE_string(device, "CPU", "Device [CPU|GPU|DSP]");
DEFINE_string(input_node, "input_node0,input_node1", DEFINE_string(input_node, "input_node0,input_node1",
"input nodes, separated by comma"); "input nodes, separated by comma");
...@@ -200,7 +201,7 @@ int Main(int argc, char **argv) { ...@@ -200,7 +201,7 @@ int Main(int argc, char **argv) {
gflags::SetUsageMessage("some usage message"); gflags::SetUsageMessage("some usage message");
gflags::ParseCommandLineFlags(&argc, &argv, true); gflags::ParseCommandLineFlags(&argc, &argv, true);
LOG(INFO) << "Model tag: [" << FLAGS_model_tag << "]"; LOG(INFO) << "Model name: [" << FLAGS_model_name << "]";
LOG(INFO) << "Device: [" << FLAGS_device << "]"; LOG(INFO) << "Device: [" << FLAGS_device << "]";
LOG(INFO) << "gpu_perf_hint: [" << FLAGS_gpu_perf_hint << "]"; LOG(INFO) << "gpu_perf_hint: [" << FLAGS_gpu_perf_hint << "]";
LOG(INFO) << "gpu_priority_hint: [" << FLAGS_gpu_priority_hint << "]"; LOG(INFO) << "gpu_priority_hint: [" << FLAGS_gpu_priority_hint << "]";
...@@ -254,22 +255,39 @@ int Main(int argc, char **argv) { ...@@ -254,22 +255,39 @@ int Main(int argc, char **argv) {
} }
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
const char *kernel_path = getenv("MACE_CL_PROGRAM_PATH"); const char *kernel_path = getenv("MACE_INTERNAL_STORAGE_PATH");
const std::string kernel_file_path = const std::string kernel_file_path =
std::string(kernel_path == nullptr ? std::string(kernel_path == nullptr ?
"/data/local/tmp/mace_run/cl_program" : kernel_path); "/data/local/tmp/mace_run/interior" : kernel_path);
std::shared_ptr<KVStorageFactory> storage_factory( std::shared_ptr<KVStorageFactory> storage_factory(
new FileStorageFactory(kernel_file_path)); new FileStorageFactory(kernel_file_path));
SetKVStorageFactory(storage_factory); SetKVStorageFactory(storage_factory);
// Create Engine // Create Engine
std::unique_ptr<mace::MaceEngine> engine_ptr = std::shared_ptr<mace::MaceEngine> engine;
CreateMaceEngine(FLAGS_model_tag, MaceStatus create_engine_status;
input_names, // Create Engine
output_names, if (FLAGS_model_data_file.empty()) {
FLAGS_model_data_file.c_str(), create_engine_status =
device_type); CreateMaceEngine(FLAGS_model_name.c_str(),
nullptr,
input_names,
output_names,
device_type,
&engine);
} else {
create_engine_status =
CreateMaceEngine(FLAGS_model_name.c_str(),
FLAGS_model_data_file.c_str(),
input_names,
output_names,
device_type,
&engine);
}
if (create_engine_status != MaceStatus::MACE_SUCCESS) {
LOG(FATAL) << "Create engine error, please check the arguments";
}
std::map<std::string, mace::MaceTensor> inputs; std::map<std::string, mace::MaceTensor> inputs;
std::map<std::string, mace::MaceTensor> outputs; std::map<std::string, mace::MaceTensor> outputs;
...@@ -309,7 +327,7 @@ int Main(int argc, char **argv) { ...@@ -309,7 +327,7 @@ int Main(int argc, char **argv) {
int64_t num_warmup_runs = 0; int64_t num_warmup_runs = 0;
if (FLAGS_warmup_runs > 0) { if (FLAGS_warmup_runs > 0) {
bool status = bool status =
Run("Warm Up", engine_ptr.get(), inputs, &outputs, Run("Warm Up", engine.get(), inputs, &outputs,
FLAGS_warmup_runs, -1.0, FLAGS_warmup_runs, -1.0,
&warmup_time_us, &num_warmup_runs, nullptr); &warmup_time_us, &num_warmup_runs, nullptr);
if (!status) { if (!status) {
...@@ -320,7 +338,7 @@ int Main(int argc, char **argv) { ...@@ -320,7 +338,7 @@ int Main(int argc, char **argv) {
int64_t no_stat_time_us = 0; int64_t no_stat_time_us = 0;
int64_t no_stat_runs = 0; int64_t no_stat_runs = 0;
bool status = bool status =
Run("Run without statistics", engine_ptr.get(), inputs, &outputs, Run("Run without statistics", engine.get(), inputs, &outputs,
FLAGS_max_num_runs, max_benchmark_time_seconds, FLAGS_max_num_runs, max_benchmark_time_seconds,
&no_stat_time_us, &no_stat_runs, nullptr); &no_stat_time_us, &no_stat_runs, nullptr);
if (!status) { if (!status) {
...@@ -329,7 +347,7 @@ int Main(int argc, char **argv) { ...@@ -329,7 +347,7 @@ int Main(int argc, char **argv) {
int64_t stat_time_us = 0; int64_t stat_time_us = 0;
int64_t stat_runs = 0; int64_t stat_runs = 0;
status = Run("Run with statistics", engine_ptr.get(), inputs, &outputs, status = Run("Run with statistics", engine.get(), inputs, &outputs,
FLAGS_max_num_runs, max_benchmark_time_seconds, FLAGS_max_num_runs, max_benchmark_time_seconds,
&stat_time_us, &stat_runs, statistician.get()); &stat_time_us, &stat_runs, statistician.get());
if (!status) { if (!status) {
......
...@@ -35,12 +35,9 @@ cc_library( ...@@ -35,12 +35,9 @@ cc_library(
) )
cc_library( cc_library(
name = "generated_mace_engine_creator", name = "generated_mace_engine_factory",
srcs = ["engine/mace_engine_creator.cc"], hdrs = ["engine/mace_engine_factory.h"],
linkstatic = 1,
deps = [ deps = [
":generated_models",
"//mace/public", "//mace/public",
"//mace/utils",
], ],
) )
...@@ -9,5 +9,6 @@ cc_binary( ...@@ -9,5 +9,6 @@ cc_binary(
deps = [ deps = [
"//external:gflags_nothreads", "//external:gflags_nothreads",
"//mace/codegen:generated_models", "//mace/codegen:generated_models",
"//mace/codegen:generated_mace_engine_factory",
], ],
) )
...@@ -34,6 +34,8 @@ ...@@ -34,6 +34,8 @@
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "mace/public/mace.h" #include "mace/public/mace.h"
#include "mace/public/mace_runtime.h" #include "mace/public/mace_runtime.h"
// if convert model to code.
#include "mace/codegen/engine/mace_engine_factory.h"
#include "mace/utils/env_time.h" #include "mace/utils/env_time.h"
#include "mace/utils/logging.h" #include "mace/utils/logging.h"
...@@ -94,9 +96,9 @@ DeviceType ParseDeviceType(const std::string &device_str) { ...@@ -94,9 +96,9 @@ DeviceType ParseDeviceType(const std::string &device_str) {
} }
DEFINE_string(model_tag, DEFINE_string(model_name,
"", "",
"model tag in yaml file"); "model name in yaml file");
DEFINE_string(input_node, DEFINE_string(input_node,
"input_node0,input_node1", "input_node0,input_node1",
"input nodes, separated by comma"); "input nodes, separated by comma");
...@@ -149,22 +151,38 @@ bool RunModel(const std::vector<std::string> &input_names, ...@@ -149,22 +151,38 @@ bool RunModel(const std::vector<std::string> &input_names,
// DO NOT USE tmp directory. // DO NOT USE tmp directory.
// Please use APP's own directory and make sure the directory exists. // Please use APP's own directory and make sure the directory exists.
// Just call once // Just call once
const std::string kernel_file_path = const std::string internal_storage_path =
"/data/local/tmp/mace_run/cl"; "/data/local/tmp/mace_run/interior";
// Config internal kv storage factory. // Config internal kv storage factory.
std::shared_ptr<KVStorageFactory> storage_factory( std::shared_ptr<KVStorageFactory> storage_factory(
new FileStorageFactory(kernel_file_path)); new FileStorageFactory(internal_storage_path));
SetKVStorageFactory(storage_factory); SetKVStorageFactory(storage_factory);
// Create Engine // Create Engine
std::unique_ptr<mace::MaceEngine> engine = std::shared_ptr<mace::MaceEngine> engine;
CreateMaceEngine(FLAGS_model_tag, MaceStatus create_engine_status;
input_names, // Create Engine
output_names, if (FLAGS_model_data_file.empty()) {
FLAGS_model_data_file.c_str(), create_engine_status =
device_type); CreateMaceEngine(FLAGS_model_name.c_str(),
nullptr,
input_names,
output_names,
device_type,
&engine);
} else {
create_engine_status =
CreateMaceEngine(FLAGS_model_name.c_str(),
FLAGS_model_data_file.c_str(),
input_names,
output_names,
device_type,
&engine);
}
if (create_engine_status != MaceStatus::MACE_SUCCESS) {
LOG(FATAL) << "Create engine error, please check the arguments";
}
const size_t input_count = input_names.size(); const size_t input_count = input_names.size();
const size_t output_count = output_names.size(); const size_t output_count = output_names.size();
......
...@@ -28,7 +28,7 @@ namespace mace { ...@@ -28,7 +28,7 @@ namespace mace {
const char *MaceVersion(); const char *MaceVersion();
enum DeviceType { CPU = 0, GPU = 2, HEXAGON = 3, AUTO = 4 }; enum DeviceType { CPU = 0, GPU = 2, HEXAGON = 3 };
enum MaceStatus { MACE_SUCCESS = 0, MACE_INVALID_ARGS = 1 }; enum MaceStatus { MACE_SUCCESS = 0, MACE_INVALID_ARGS = 1 };
...@@ -82,13 +82,6 @@ class MaceEngine { ...@@ -82,13 +82,6 @@ class MaceEngine {
MaceEngine &operator=(const MaceEngine &) = delete; MaceEngine &operator=(const MaceEngine &) = delete;
}; };
std::unique_ptr<MaceEngine> CreateMaceEngine(
const std::string &model_tag,
const std::vector<std::string> &input_nodes,
const std::vector<std::string> &output_nodes,
const char *model_data_file = nullptr,
const DeviceType device_type = DeviceType::AUTO);
} // namespace mace } // namespace mace
#endif // MACE_PUBLIC_MACE_H_ #endif // MACE_PUBLIC_MACE_H_
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
#include "mace/public/mace.h" #include "mace/public/mace.h"
#include "mace/public/mace_runtime.h" #include "mace/public/mace_runtime.h"
#include "mace/utils/logging.h"
namespace mace { namespace mace {
{% for tag in model_tags %} {% for tag in model_tags %}
...@@ -42,30 +41,33 @@ extern const std::string ModelBuildOptions(); ...@@ -42,30 +41,33 @@ extern const std::string ModelBuildOptions();
{% endfor %} {% endfor %}
namespace { namespace {
std::map<std::string, int> model_tag_map { std::map<std::string, int> model_name_map {
{% for i in range(model_tags |length) %} {% for i in range(model_tags |length) %}
std::make_pair({{ model_tags[i]|tojson }}, {{ i }}), std::make_pair({{ model_tags[i]|tojson }}, {{ i }}),
{% endfor %} {% endfor %}
}; };
} // namespace } // namespace
std::unique_ptr<MaceEngine> CreateMaceEngine( MaceStatus CreateMaceEngine(
const std::string &model_tag, const char *model_name,
const char *model_data_file,
const std::vector<std::string> &input_nodes, const std::vector<std::string> &input_nodes,
const std::vector<std::string> &output_nodes, const std::vector<std::string> &output_nodes,
const char *model_data_file, const DeviceType device_type,
const DeviceType device_type) { std::shared_ptr<MaceEngine> *engine) {
// load model // load model
std::unique_ptr<MaceEngine> engine; if (engine == nullptr) {
return MaceStatus::MACE_INVALID_ARGS;
}
const unsigned char * model_data = nullptr; const unsigned char * model_data = nullptr;
NetDef net_def; NetDef net_def;
switch (model_tag_map[model_tag]) { switch (model_name_map[model_name]) {
{% for i in range(model_tags |length) %} {% for i in range(model_tags |length) %}
case {{ i }}: case {{ i }}:
model_data = model_data =
mace::{{model_tags[i]}}::LoadModelData(model_data_file); mace::{{model_tags[i]}}::LoadModelData(model_data_file);
net_def = mace::{{model_tags[i]}}::CreateNet(model_data); net_def = mace::{{model_tags[i]}}::CreateNet(model_data);
engine.reset( engine->reset(
new mace::MaceEngine(&net_def, device_type, input_nodes, output_nodes)); new mace::MaceEngine(&net_def, device_type, input_nodes, output_nodes));
if (device_type == DeviceType::GPU || device_type == DeviceType::HEXAGON) { if (device_type == DeviceType::GPU || device_type == DeviceType::HEXAGON) {
mace::{{model_tags[i]}}::UnloadModelData(model_data); mace::{{model_tags[i]}}::UnloadModelData(model_data);
...@@ -73,10 +75,10 @@ std::unique_ptr<MaceEngine> CreateMaceEngine( ...@@ -73,10 +75,10 @@ std::unique_ptr<MaceEngine> CreateMaceEngine(
break; break;
{% endfor %} {% endfor %}
default: default:
LOG(FATAL) << "There is no model named " << model_tag; return MaceStatus::MACE_INVALID_ARGS;
} }
return engine; return MaceStatus::MACE_SUCCESS;
} }
} // namespace mace } // namespace mace
...@@ -20,17 +20,17 @@ from jinja2 import Environment, FileSystemLoader ...@@ -20,17 +20,17 @@ from jinja2 import Environment, FileSystemLoader
FLAGS = None FLAGS = None
def gen_mace_engine_creator(model_tags, template_dir, output_dir): def gen_mace_engine_factory(model_tags, template_dir, output_dir):
# Create the jinja2 environment. # Create the jinja2 environment.
j2_env = Environment( j2_env = Environment(
loader=FileSystemLoader(template_dir), trim_blocks=True) loader=FileSystemLoader(template_dir), trim_blocks=True)
# generate mace_run BUILD file # generate mace_run BUILD file
print model_tags print model_tags
template_name = 'mace_engine_creator.jinja2' template_name = 'mace_engine_factory.h.jinja2'
source = j2_env.get_template(template_name).render( source = j2_env.get_template(template_name).render(
model_tags=model_tags, model_tags=model_tags,
) )
with open(output_dir + '/mace_engine_creator.cc', "wb") as f: with open(output_dir + '/mace_engine_factory.h', "wb") as f:
f.write(source) f.write(source)
......
...@@ -10,7 +10,7 @@ cc_binary( ...@@ -10,7 +10,7 @@ cc_binary(
deps = [ deps = [
"//external:gflags_nothreads", "//external:gflags_nothreads",
"//mace/codegen:generated_models", "//mace/codegen:generated_models",
"//mace/codegen:generated_mace_engine_creator", "//mace/codegen:generated_mace_engine_factory",
"//mace/core:core", "//mace/core:core",
], ],
) )
...@@ -41,6 +41,7 @@ ...@@ -41,6 +41,7 @@
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
#include "mace/codegen/engine/mace_engine_factory.h"
namespace mace { namespace mace {
namespace tools { namespace tools {
...@@ -162,9 +163,9 @@ struct mallinfo LogMallinfoChange(struct mallinfo prev) { ...@@ -162,9 +163,9 @@ struct mallinfo LogMallinfoChange(struct mallinfo prev) {
return curr; return curr;
} }
DEFINE_string(model_tag, DEFINE_string(model_name,
"", "",
"model tag in yaml"); "model name in yaml");
DEFINE_string(input_node, DEFINE_string(input_node,
"input_node0,input_node1", "input_node0,input_node1",
"input nodes, separated by comma"); "input nodes, separated by comma");
...@@ -196,7 +197,7 @@ DEFINE_int32(omp_num_threads, -1, "num of openmp threads"); ...@@ -196,7 +197,7 @@ DEFINE_int32(omp_num_threads, -1, "num of openmp threads");
DEFINE_int32(cpu_affinity_policy, 1, DEFINE_int32(cpu_affinity_policy, 1,
"0:AFFINITY_NONE/1:AFFINITY_BIG_ONLY/2:AFFINITY_LITTLE_ONLY"); "0:AFFINITY_NONE/1:AFFINITY_BIG_ONLY/2:AFFINITY_LITTLE_ONLY");
bool RunModel(const std::string &model_tag, bool RunModel(const std::string &model_name,
const std::vector<std::string> &input_names, const std::vector<std::string> &input_names,
const std::vector<std::vector<int64_t>> &input_shapes, const std::vector<std::vector<int64_t>> &input_shapes,
const std::vector<std::string> &output_names, const std::vector<std::string> &output_names,
...@@ -214,24 +215,42 @@ bool RunModel(const std::string &model_tag, ...@@ -214,24 +215,42 @@ bool RunModel(const std::string &model_tag,
} }
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
const char *kernel_path = getenv("MACE_CL_PROGRAM_PATH"); const char *kernel_path = getenv("MACE_INTERNAL_STORAGE_PATH");
const std::string kernel_file_path = const std::string kernel_file_path =
std::string(kernel_path == nullptr ? std::string(kernel_path == nullptr ?
"/data/local/tmp/mace_run/cl_program" : kernel_path); "/data/local/tmp/mace_run/interior" : kernel_path);
std::shared_ptr<KVStorageFactory> storage_factory( std::shared_ptr<KVStorageFactory> storage_factory(
new FileStorageFactory(kernel_file_path)); new FileStorageFactory(kernel_file_path));
SetKVStorageFactory(storage_factory); SetKVStorageFactory(storage_factory);
std::shared_ptr<mace::MaceEngine> engine;
MaceStatus create_engine_status;
// Create Engine // Create Engine
int64_t t0 = NowMicros(); int64_t t0 = NowMicros();
std::unique_ptr<mace::MaceEngine> engine = if (FLAGS_model_data_file.empty()) {
CreateMaceEngine(model_tag, create_engine_status =
input_names, CreateMaceEngine(model_name.c_str(),
output_names, nullptr,
FLAGS_model_data_file.c_str(), input_names,
device_type); output_names,
device_type,
&engine);
} else {
create_engine_status =
CreateMaceEngine(model_name.c_str(),
FLAGS_model_data_file.c_str(),
input_names,
output_names,
device_type,
&engine);
}
int64_t t1 = NowMicros(); int64_t t1 = NowMicros();
if (create_engine_status != MaceStatus::MACE_SUCCESS) {
LOG(FATAL) << "Create engine error, please check the arguments";
}
double init_millis = (t1 - t0) / 1000.0; double init_millis = (t1 - t0) / 1000.0;
LOG(INFO) << "Total init latency: " << init_millis << " ms"; LOG(INFO) << "Total init latency: " << init_millis << " ms";
...@@ -330,6 +349,7 @@ int Main(int argc, char **argv) { ...@@ -330,6 +349,7 @@ int Main(int argc, char **argv) {
gflags::SetUsageMessage("some usage message"); gflags::SetUsageMessage("some usage message");
gflags::ParseCommandLineFlags(&argc, &argv, true); gflags::ParseCommandLineFlags(&argc, &argv, true);
LOG(INFO) << "model name: " << FLAGS_model_name;
LOG(INFO) << "mace version: " << MaceVersion(); LOG(INFO) << "mace version: " << MaceVersion();
LOG(INFO) << "input node: " << FLAGS_input_node; LOG(INFO) << "input node: " << FLAGS_input_node;
LOG(INFO) << "input shape: " << FLAGS_input_shape; LOG(INFO) << "input shape: " << FLAGS_input_shape;
...@@ -370,7 +390,7 @@ int Main(int argc, char **argv) { ...@@ -370,7 +390,7 @@ int Main(int argc, char **argv) {
for (int i = 0; i < FLAGS_restart_round; ++i) { for (int i = 0; i < FLAGS_restart_round; ++i) {
VLOG(0) << "restart round " << i; VLOG(0) << "restart round " << i;
ret = ret =
RunModel(FLAGS_model_tag, input_names, input_shape_vec, RunModel(FLAGS_model_name, input_names, input_shape_vec,
output_names, output_shape_vec); output_names, output_shape_vec);
} }
if (ret) { if (ret) {
......
...@@ -115,16 +115,14 @@ def model_benchmark_stdout_processor(stdout, ...@@ -115,16 +115,14 @@ def model_benchmark_stdout_processor(stdout,
serialno, serialno,
model_name, model_name,
runtime): runtime):
metrics = [0] * 5 metrics = [0] * 3
for line in stdout.split('\n'): for line in stdout.split('\n'):
line = line.strip() line = line.strip()
parts = line.split() parts = line.split()
if len(parts) == 6 and parts[0].startswith("time"): if len(parts) == 4 and parts[0].startswith("time"):
metrics[0] = str(float(parts[1])) metrics[0] = str(float(parts[1]))
metrics[1] = str(float(parts[2])) metrics[1] = str(float(parts[2]))
metrics[2] = str(float(parts[3])) metrics[2] = str(float(parts[3]))
metrics[3] = str(float(parts[4]))
metrics[4] = str(float(parts[5]))
break break
device_name = "" device_name = ""
...@@ -137,22 +135,20 @@ def model_benchmark_stdout_processor(stdout, ...@@ -137,22 +135,20 @@ def model_benchmark_stdout_processor(stdout,
report_filename = FLAGS.output_dir + "/report.csv" report_filename = FLAGS.output_dir + "/report.csv"
if not os.path.exists(report_filename): if not os.path.exists(report_filename):
with open(report_filename, 'w') as f: with open(report_filename, 'w') as f:
f.write("model_name,device_name,soc,abi,runtime,create_net," f.write("model_name,device_name,soc,abi,runtime,"
"engine_ctor,init,warmup,run_avg\n") "init,warmup,run_avg\n")
data_str = "{model_name},{device_name},{soc},{abi},{runtime}," \ data_str = "{model_name},{device_name},{soc},{abi},{runtime}," \
"{create_net},{engine_ctor},{init},{warmup},{run_avg}\n" \ "{init},{warmup},{run_avg}\n" \
.format( .format(
model_name=model_name, model_name=model_name,
device_name=device_name, device_name=device_name,
soc=target_soc, soc=target_soc,
abi=abi, abi=abi,
runtime=runtime, runtime=runtime,
create_net=metrics[0], init=metrics[0],
engine_ctor=metrics[1], warmup=metrics[1],
init=metrics[2], run_avg=metrics[2]
warmup=metrics[3],
run_avg=metrics[4]
) )
with open(report_filename, 'a') as f: with open(report_filename, 'a') as f:
f.write(data_str) f.write(data_str)
...@@ -300,22 +296,36 @@ def merge_libs_and_tuning_results(target_soc, ...@@ -300,22 +296,36 @@ def merge_libs_and_tuning_results(target_soc,
embed_model_data) embed_model_data)
def get_model_files(model_file_path, def download_model_files(model_file_path,
model_output_dir, model_output_dir,
weight_file_path=""): weight_file_path=""):
model_file = "" model_file = ""
weight_file = "" weight_file = ""
if model_file_path.startswith("http://") or \ if model_file_path.startswith("http://") or \
model_file_path.startswith("https://"): model_file_path.startswith("https://"):
model_file = model_output_dir + "/model.pb" model_file = model_output_dir + "/model.pb"
urllib.urlretrieve(model_file_path, model_file) urllib.urlretrieve(model_file_path, model_file)
if weight_file_path.startswith("http://") or \
weight_file_path.startswith("https://"):
weight_file = model_output_dir + "/model.caffemodel"
urllib.urlretrieve(weight_file_path, weight_file)
def get_model_files_path(model_file_path,
model_output_dir,
weight_file_path=""):
model_file = ""
weight_file = ""
if model_file_path.startswith("http://") or \
model_file_path.startswith("https://"):
model_file = model_output_dir + "/model.pb"
else: else:
model_file = model_file_path model_file = model_file_path
if weight_file_path.startswith("http://") or \ if weight_file_path.startswith("http://") or \
weight_file_path.startswith("https://"): weight_file_path.startswith("https://"):
weight_file = model_output_dir + "/model.caffemodel" weight_file = model_output_dir + "/model.caffemodel"
urllib.urlretrieve(weight_file_path, weight_file)
else: else:
weight_file = weight_file_path weight_file = weight_file_path
...@@ -547,6 +557,8 @@ def process_models(project_name, configs, embed_model_data, vlog_level, ...@@ -547,6 +557,8 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
model_output_dir = "%s/%s_%s/%s" % ( model_output_dir = "%s/%s_%s/%s" % (
model_output_base_dir, device_name.replace(' ', ''), model_output_base_dir, device_name.replace(' ', ''),
target_soc, target_abi) target_soc, target_abi)
sh_commands.clear_phone_data_dir(serialno, phone_data_dir)
model_output_dirs.append(model_output_dir) model_output_dirs.append(model_output_dir)
if FLAGS.mode == "build" or FLAGS.mode == "all": if FLAGS.mode == "build" or FLAGS.mode == "all":
...@@ -554,14 +566,11 @@ def process_models(project_name, configs, embed_model_data, vlog_level, ...@@ -554,14 +566,11 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
sh.rm("-rf", model_output_dir) sh.rm("-rf", model_output_dir)
os.makedirs(model_output_dir) os.makedirs(model_output_dir)
model_file_path, weight_file_path = get_model_files( model_file_path, weight_file_path = get_model_files_path(
model_config["model_file_path"], model_config["model_file_path"],
model_output_base_dir, model_output_base_dir,
model_config["weight_file_path"]) model_config["weight_file_path"])
sh_commands.clear_phone_data_dir(
target_abi, serialno, phone_data_dir)
if FLAGS.mode == "build" or FLAGS.mode == "run" or \ if FLAGS.mode == "build" or FLAGS.mode == "run" or \
FLAGS.mode == "validate" or \ FLAGS.mode == "validate" or \
FLAGS.mode == "benchmark" or FLAGS.mode == "all": FLAGS.mode == "benchmark" or FLAGS.mode == "all":
...@@ -723,7 +732,7 @@ def main(unused_args): ...@@ -723,7 +732,7 @@ def main(unused_args):
# generate source # generate source
sh_commands.gen_mace_version() sh_commands.gen_mace_version()
sh_commands.gen_encrypted_opencl_source() sh_commands.gen_encrypted_opencl_source()
sh_commands.gen_mace_engine_creator_source(configs['models'].keys()) sh_commands.gen_mace_engine_factory_source(configs['models'].keys())
embed_model_data = configs["embed_model_data"] embed_model_data = configs["embed_model_data"]
target_socs = get_target_socs(configs) target_socs = get_target_socs(configs)
...@@ -751,7 +760,12 @@ def main(unused_args): ...@@ -751,7 +760,12 @@ def main(unused_args):
sh.rm("-rf", model_output_base_dir) sh.rm("-rf", model_output_base_dir)
os.makedirs(model_output_base_dir) os.makedirs(model_output_base_dir)
model_file_path, weight_file_path = get_model_files( download_model_files(
model_config["model_file_path"],
model_output_base_dir,
model_config["weight_file_path"])
model_file_path, weight_file_path = get_model_files_path(
model_config["model_file_path"], model_config["model_file_path"],
model_output_base_dir, model_output_base_dir,
model_config["weight_file_path"]) model_config["weight_file_path"])
......
...@@ -33,7 +33,7 @@ try: ...@@ -33,7 +33,7 @@ try:
from binary_codegen import tuning_param_codegen from binary_codegen import tuning_param_codegen
from generate_data import generate_input_data from generate_data import generate_input_data
from validate import validate from validate import validate
from mace_engine_generator import gen_mace_engine_creator from mace_engine_factory_codegen import gen_mace_engine_factory
except Exception as e: except Exception as e:
print("Import error:\n%s" % e) print("Import error:\n%s" % e)
exit(1) exit(1)
...@@ -75,12 +75,11 @@ def is_device_locked(serialno): ...@@ -75,12 +75,11 @@ def is_device_locked(serialno):
################################ ################################
# clear data # clear data
################################ ################################
def clear_phone_data_dir(abi, serialno, phone_data_dir): def clear_phone_data_dir(serialno, phone_data_dir):
if abi != "host": sh.adb("-s",
sh.adb("-s", serialno,
serialno, "shell",
"shell", "rm -rf %s" % phone_data_dir)
"rm -rf %s" % phone_data_dir)
def clear_model_codegen(model_codegen_dir="mace/codegen/models"): def clear_model_codegen(model_codegen_dir="mace/codegen/models"):
...@@ -369,12 +368,12 @@ def gen_encrypted_opencl_source(codegen_path="mace/codegen"): ...@@ -369,12 +368,12 @@ def gen_encrypted_opencl_source(codegen_path="mace/codegen"):
"mace/codegen/opencl/opencl_encrypt_program.cc") "mace/codegen/opencl/opencl_encrypt_program.cc")
def gen_mace_engine_creator_source(model_tags, codegen_path="mace/codegen"): def gen_mace_engine_factory_source(model_tags, codegen_path="mace/codegen"):
print("* Genearte mace engine creator source") print("* Genearte mace engine creator source")
codegen_tools_dir = "%s/engine" % codegen_path codegen_tools_dir = "%s/engine" % codegen_path
sh.rm("-rf", codegen_tools_dir) sh.rm("-rf", codegen_tools_dir)
sh.mkdir("-p", codegen_tools_dir) sh.mkdir("-p", codegen_tools_dir)
gen_mace_engine_creator( gen_mace_engine_factory(
model_tags, model_tags,
"mace/python/tools", "mace/python/tools",
codegen_tools_dir) codegen_tools_dir)
...@@ -384,7 +383,7 @@ def gen_mace_engine_creator_source(model_tags, codegen_path="mace/codegen"): ...@@ -384,7 +383,7 @@ def gen_mace_engine_creator_source(model_tags, codegen_path="mace/codegen"):
def pull_binaries(abi, serialno, model_output_dirs, def pull_binaries(abi, serialno, model_output_dirs,
cl_built_kernel_file_name, cl_built_kernel_file_name,
cl_platform_info_file_name): cl_platform_info_file_name):
compiled_opencl_dir = "/data/local/tmp/mace_run/cl_program/" compiled_opencl_dir = "/data/local/tmp/mace_run/interior/"
mace_run_param_file = "mace_run.config" mace_run_param_file = "mace_run.config"
cl_bin_dirs = [] cl_bin_dirs = []
...@@ -558,9 +557,10 @@ def update_mace_run_lib(model_output_dir, ...@@ -558,9 +557,10 @@ def update_mace_run_lib(model_output_dir,
model_output_dir) model_output_dir)
def create_compiled_opencl_dir(serialno): def create_internal_storage_dir(serialno, phone_data_dir):
compiled_opencl_dir = "/data/local/tmp/mace_run/cl_program/" internal_storage_dir = "%s/interior/" % phone_data_dir
sh.adb("-s", serialno, "shell", "mkdir", "-p", compiled_opencl_dir) sh.adb("-s", serialno, "shell", "mkdir", "-p", internal_storage_dir)
return internal_storage_dir
def tuning_run(abi, def tuning_run(abi,
...@@ -601,7 +601,7 @@ def tuning_run(abi, ...@@ -601,7 +601,7 @@ def tuning_run(abi,
"env", "env",
"MACE_CPP_MIN_VLOG_LEVEL=%s" % vlog_level, "MACE_CPP_MIN_VLOG_LEVEL=%s" % vlog_level,
"%s/mace_run" % model_output_dir, "%s/mace_run" % model_output_dir,
"--model_tag=%s" % model_tag, "--model_name=%s" % model_tag,
"--input_node=%s" % ",".join(input_nodes), "--input_node=%s" % ",".join(input_nodes),
"--output_node=%s" % ",".join(output_nodes), "--output_node=%s" % ",".join(output_nodes),
"--input_shape=%s" % ":".join(input_shapes), "--input_shape=%s" % ":".join(input_shapes),
...@@ -626,7 +626,8 @@ def tuning_run(abi, ...@@ -626,7 +626,8 @@ def tuning_run(abi,
return stdout return stdout
else: else:
sh.adb("-s", serialno, "shell", "mkdir", "-p", phone_data_dir) sh.adb("-s", serialno, "shell", "mkdir", "-p", phone_data_dir)
create_compiled_opencl_dir(serialno) internal_storage_dir = create_internal_storage_dir(
serialno, phone_data_dir)
for input_name in input_nodes: for input_name in input_nodes:
formatted_name = common.formatted_file_name(input_file_name, formatted_name = common.formatted_file_name(input_file_name,
...@@ -649,7 +650,7 @@ def tuning_run(abi, ...@@ -649,7 +650,7 @@ def tuning_run(abi,
"MACE_OUT_OF_RANGE_CHECK=%s" % int(out_of_range_check), "MACE_OUT_OF_RANGE_CHECK=%s" % int(out_of_range_check),
"MACE_CPP_MIN_VLOG_LEVEL=%s" % vlog_level, "MACE_CPP_MIN_VLOG_LEVEL=%s" % vlog_level,
"MACE_RUN_PARAMETER_PATH=%s/mace_run.config" % phone_data_dir, "MACE_RUN_PARAMETER_PATH=%s/mace_run.config" % phone_data_dir,
"MACE_CL_PROGRAM_PATH=%s/cl_program" % phone_data_dir, "MACE_INTERNAL_STORAGE_PATH=%s" % internal_storage_dir,
"MACE_LIMIT_OPENCL_KERNEL_TIME=%s" % limit_opencl_kernel_time, "MACE_LIMIT_OPENCL_KERNEL_TIME=%s" % limit_opencl_kernel_time,
] ]
if valgrind: if valgrind:
...@@ -660,7 +661,7 @@ def tuning_run(abi, ...@@ -660,7 +661,7 @@ def tuning_run(abi,
]) ])
adb_cmd.extend([ adb_cmd.extend([
"%s/mace_run" % phone_data_dir, "%s/mace_run" % phone_data_dir,
"--model_tag=%s" % model_tag, "--model_name=%s" % model_tag,
"--input_node=%s" % ",".join(input_nodes), "--input_node=%s" % ",".join(input_nodes),
"--output_node=%s" % ",".join(output_nodes), "--output_node=%s" % ",".join(output_nodes),
"--input_shape=%s" % ":".join(input_shapes), "--input_shape=%s" % ":".join(input_shapes),
...@@ -840,16 +841,18 @@ def merge_libs(target_soc, ...@@ -840,16 +841,18 @@ def merge_libs(target_soc,
if hexagon_mode: if hexagon_mode:
sh.cp("-f", hexagon_lib_file, model_bin_dir) sh.cp("-f", hexagon_lib_file, model_bin_dir)
sh.cp("-f", glob.glob("mace/codegen/engine/*.h"), model_header_dir)
mri_stream = "" mri_stream = ""
if abi == "host": if abi == "host":
mri_stream += "create %s/libmace_%s.a\n" % \ mri_stream += "create %s/libmace_%s.a\n" % \
(model_bin_dir, project_name) (model_bin_dir, project_name)
mri_stream += ( mri_stream += (
"addlib " "addlib "
"bazel-bin/mace/codegen/libgenerated_opencl.pic.a\n") "bazel-bin/mace/codegen/libgenerated_opencl.pic.a\n")
mri_stream += ( mri_stream += (
"addlib " "addlib "
"bazel-bin/mace/codegen/libgenerated_tuning_params.pic.a\n") "bazel-bin/mace/codegen/libgenerated_tuning_params.pic.a\n")
mri_stream += ( mri_stream += (
"addlib " "addlib "
"bazel-bin/mace/codegen/libgenerated_models.pic.a\n") "bazel-bin/mace/codegen/libgenerated_models.pic.a\n")
...@@ -860,35 +863,35 @@ def merge_libs(target_soc, ...@@ -860,35 +863,35 @@ def merge_libs(target_soc,
mri_stream += "create %s/libmace_%s.%s.a\n" % \ mri_stream += "create %s/libmace_%s.%s.a\n" % \
(model_bin_dir, project_name, target_soc) (model_bin_dir, project_name, target_soc)
mri_stream += ( mri_stream += (
"addlib " "addlib "
"bazel-bin/mace/codegen/libgenerated_opencl.a\n") "bazel-bin/mace/codegen/libgenerated_opencl.a\n")
mri_stream += ( mri_stream += (
"addlib " "addlib "
"bazel-bin/mace/codegen/libgenerated_tuning_params.a\n") "bazel-bin/mace/codegen/libgenerated_tuning_params.a\n")
mri_stream += ( mri_stream += (
"addlib " "addlib "
"bazel-bin/mace/codegen/libgenerated_version.a\n") "bazel-bin/mace/codegen/libgenerated_version.a\n")
mri_stream += ( mri_stream += (
"addlib " "addlib "
"bazel-bin/mace/codegen/libgenerated_models.a\n") "bazel-bin/mace/codegen/libgenerated_models.a\n")
mri_stream += ( mri_stream += (
"addlib " "addlib "
"bazel-bin/mace/codegen/libgenerated_mace_engine_creator.a\n") "bazel-bin/mace/codegen/libgenerated_mace_engine_creator.a\n")
mri_stream += ( mri_stream += (
"addlib " "addlib "
"bazel-bin/mace/core/libcore.a\n") "bazel-bin/mace/core/libcore.a\n")
mri_stream += ( mri_stream += (
"addlib " "addlib "
"bazel-bin/mace/kernels/libkernels.a\n") "bazel-bin/mace/kernels/libkernels.a\n")
mri_stream += ( mri_stream += (
"addlib " "addlib "
"bazel-bin/mace/utils/libutils.a\n") "bazel-bin/mace/utils/libutils.a\n")
mri_stream += ( mri_stream += (
"addlib " "addlib "
"bazel-bin/mace/utils/libutils_prod.a\n") "bazel-bin/mace/utils/libutils_prod.a\n")
mri_stream += ( mri_stream += (
"addlib " "addlib "
"bazel-bin/mace/ops/libops.lo\n") "bazel-bin/mace/ops/libops.lo\n")
for model_output_dir in model_output_dirs: for model_output_dir in model_output_dirs:
if not embed_model_data: if not embed_model_data:
...@@ -984,7 +987,7 @@ def benchmark_model(abi, ...@@ -984,7 +987,7 @@ def benchmark_model(abi,
"env", "env",
"MACE_CPP_MIN_VLOG_LEVEL=%s" % vlog_level, "MACE_CPP_MIN_VLOG_LEVEL=%s" % vlog_level,
"%s/benchmark_model" % model_output_dir, "%s/benchmark_model" % model_output_dir,
"--model_tag=%s" % model_tag, "--model_name=%s" % model_tag,
"--input_node=%s" % ",".join(input_nodes), "--input_node=%s" % ",".join(input_nodes),
"--output_node=%s" % ",".join(output_nodes), "--output_node=%s" % ",".join(output_nodes),
"--input_shape=%s" % ":".join(input_shapes), "--input_shape=%s" % ":".join(input_shapes),
...@@ -1000,7 +1003,8 @@ def benchmark_model(abi, ...@@ -1000,7 +1003,8 @@ def benchmark_model(abi,
p.wait() p.wait()
else: else:
sh.adb("-s", serialno, "shell", "mkdir", "-p", phone_data_dir) sh.adb("-s", serialno, "shell", "mkdir", "-p", phone_data_dir)
create_compiled_opencl_dir(serialno) internal_storage_dir = create_internal_storage_dir(
serialno, phone_data_dir)
for input_name in input_nodes: for input_name in input_nodes:
formatted_name = common.formatted_file_name(input_file_name, formatted_name = common.formatted_file_name(input_file_name,
...@@ -1020,9 +1024,10 @@ def benchmark_model(abi, ...@@ -1020,9 +1024,10 @@ def benchmark_model(abi,
"MACE_CPP_MIN_VLOG_LEVEL=%s" % vlog_level, "MACE_CPP_MIN_VLOG_LEVEL=%s" % vlog_level,
"MACE_RUN_PARAMETER_PATH=%s/mace_run.config" % "MACE_RUN_PARAMETER_PATH=%s/mace_run.config" %
phone_data_dir, phone_data_dir,
"MACE_INTERNAL_STORAGE_PATH=%s" % internal_storage_dir,
"MACE_OPENCL_PROFILING=1", "MACE_OPENCL_PROFILING=1",
"%s/benchmark_model" % phone_data_dir, "%s/benchmark_model" % phone_data_dir,
"--model_tag=%s" % model_tag, "--model_name=%s" % model_tag,
"--input_node=%s" % ",".join(input_nodes), "--input_node=%s" % ",".join(input_nodes),
"--output_node=%s" % ",".join(output_nodes), "--output_node=%s" % ",".join(output_nodes),
"--input_shape=%s" % ":".join(input_shapes), "--input_shape=%s" % ":".join(input_shapes),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册