提交 53a79f58 编写于 作者: L Liangliang He

Reduce model tuning time

上级 cb7d1396
......@@ -168,7 +168,6 @@ target_abis: [armeabi-v7a, arm64-v8a]
# 具体机型的soc编号,可以使用`adb shell getprop | grep ro.board.platform | cut -d [ -f3 | cut -d ] -f1`获取
target_socs: [msm8998]
embed_model_data: 1
vlog_level: 0
models: # 一个配置文件可以包含多个模型的配置信息,最终生成的库中包含多个模型
first_net: # 模型的标签,在调度模型的时候,会用这个变量
platform: tensorflow
......@@ -218,7 +217,6 @@ models: # 一个配置文件可以包含多个模型的配置信息,最终生
| ---------- |:--------------:|
| target_abis | 运行的ABI,可选包括安卓设备的armeabi-v7a,arm64-v8a等,以及开发人员的电脑终端(电脑终端使用‘host’表示)。可以同时指定多个ABI |
| embed_model_data | 是否将模型里的数据嵌入到代码中,默认为1 |
| vlog_level | 设置log打印的级别 |
| platform | 模型对应的框架名称 [tensorflow | caffe] |
| model_file_path | 模型的路径,可以是一个http或https的下载链接 |
| weight_file_path | 权重文件的路径,可以是一个http或https的下载链接(caffe model)|
......
......@@ -141,13 +141,18 @@ class Tuner {
double *time_us,
std::vector<param_type> *tuning_result) {
RetType res;
int iter = 0;
int64_t total_time_us = 0;
for (int i = 0; i < num_runs; ++i) {
for (iter = 0; iter < num_runs; ++iter) {
res = func(params, timer, tuning_result);
total_time_us += timer->AccumulatedMicros();
if (iter >= 1 && total_time_us > 100000 || total_time_us > 200000) {
++iter;
break;
}
}
*time_us = total_time_us * 1.0 / num_runs;
*time_us = total_time_us * 1.0 / iter;
return res;
}
......@@ -167,7 +172,7 @@ class Tuner {
for (auto param : params) {
double tmp_time = 0.0;
// warm up
Run<RetType>(func, param, timer, 2, &tmp_time, &tuning_result);
Run<RetType>(func, param, timer, 1, &tmp_time, &tuning_result);
// run
RetType tmp_res =
......
......@@ -5,7 +5,6 @@
target_abis: [armeabi-v7a, arm64-v8a]
target_socs: [MSM8953]
embed_model_data: 1
vlog_level: 0
models:
preview_net:
platform: tensorflow
......
......@@ -349,6 +349,11 @@ def parse_args():
type="bool",
default="false",
help="Collect report.")
parser.add_argument(
"--vlog_level",
type=int,
default=0,
help="VLOG level.")
return parser.parse_known_args()
......@@ -567,7 +572,7 @@ def main(unused_args):
target_socs = get_target_socs(configs)
embed_model_data = configs.get("embed_model_data", 1)
vlog_level = configs.get("vlog_level", 0)
vlog_level = FLAGS.vlog_level
phone_data_dir = "/data/local/tmp/mace_run/"
for target_abi in configs["target_abis"]:
for target_soc in target_socs:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册