提交 c8f90e44 编写于 作者: L liuqi

Add caffe_env argument for mace_tool.

上级 9e58f334
# Copyright 2018 Xiaomi, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import logging
import re
################################
# log
################################
def init_logging():
logger = logging.getLogger('MACE')
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s [%(name)s] [%(levelname)s]: %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)
################################
# Argument types
################################
class CaffeEnvType(enum.Enum):
DOCKER = 0,
LOCAL = 1,
################################
# common functions
################################
def formatted_file_name(input_file_name, input_name):
return input_file_name + '_' + \
re.sub('[^0-9a-zA-Z]+', '_', input_name)
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
# --mode=all # --mode=all
import argparse import argparse
import enum
import filelock import filelock
import hashlib import hashlib
import os import os
...@@ -28,6 +29,7 @@ import urllib ...@@ -28,6 +29,7 @@ import urllib
import yaml import yaml
import re import re
import common
import sh_commands import sh_commands
from ConfigParser import ConfigParser from ConfigParser import ConfigParser
...@@ -298,6 +300,27 @@ def md5sum(str): ...@@ -298,6 +300,27 @@ def md5sum(str):
return md5.hexdigest() return md5.hexdigest()
################################
# Parsing arguments
################################
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def str_to_caffe_env_type(v):
if v.lower() == 'docker':
return common.CaffeEnvType.DOCKER
elif v.lower() == 'local':
return common.CaffeEnvType.LOCAL
else:
raise argparse.ArgumentTypeError('[docker | local] expected.')
def parse_model_configs(): def parse_model_configs():
with open(FLAGS.config) as f: with open(FLAGS.config) as f:
configs = yaml.load(f) configs = yaml.load(f)
...@@ -307,11 +330,11 @@ def parse_model_configs(): ...@@ -307,11 +330,11 @@ def parse_model_configs():
def parse_args(): def parse_args():
"""Parses command line arguments.""" """Parses command line arguments."""
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument( parser.add_argument(
"--config", "--config",
type=str, type=str,
default="./tool/config", default="./tool/config",
required=True,
help="The global config file of models.") help="The global config file of models.")
parser.add_argument( parser.add_argument(
"--output_dir", type=str, default="build", help="The output dir.") "--output_dir", type=str, default="build", help="The output dir.")
...@@ -328,12 +351,15 @@ def parse_args(): ...@@ -328,12 +351,15 @@ def parse_args():
default=1, default=1,
help="The model restart round.") help="The model restart round.")
parser.add_argument( parser.add_argument(
"--tuning", type="bool", default="true", help="Tune opencl params.") "--tuning",
type=str2bool,
default=True,
help="Tune opencl params.")
parser.add_argument( parser.add_argument(
"--mode", "--mode",
type=str, type=str,
default="all", default="all",
help="[build|run|validate|merge|all|throughput_test].") help="[build|run|validate|benchmark|merge|all|throughput_test].")
parser.add_argument( parser.add_argument(
"--target_socs", "--target_socs",
type=str, type=str,
...@@ -341,19 +367,24 @@ def parse_args(): ...@@ -341,19 +367,24 @@ def parse_args():
help="SoCs to build, comma seperated list (getprop ro.board.platform)") help="SoCs to build, comma seperated list (getprop ro.board.platform)")
parser.add_argument( parser.add_argument(
"--out_of_range_check", "--out_of_range_check",
type="bool", type=str2bool,
default="false", default=False,
help="Enable out of range check for opencl.") help="Enable out of range check for opencl.")
parser.add_argument( parser.add_argument(
"--collect_report", "--collect_report",
type="bool", type=str2bool,
default="false", default=False,
help="Collect report.") help="Collect report.")
parser.add_argument( parser.add_argument(
"--vlog_level", "--vlog_level",
type=int, type=int,
default=0, default=0,
help="VLOG level.") help="VLOG level.")
parser.add_argument(
"--caffe_env",
type=str_to_caffe_env_type,
default='docker',
help="[docker | local] caffe environment.")
return parser.parse_known_args() return parser.parse_known_args()
...@@ -501,7 +532,8 @@ def process_models(project_name, configs, embed_model_data, vlog_level, ...@@ -501,7 +532,8 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
model_config["input_shapes"], model_config["input_shapes"],
model_config["output_shapes"], model_config["output_shapes"],
model_output_dir, model_output_dir,
phone_data_dir) phone_data_dir,
FLAGS.caffe_env)
if FLAGS.mode == "build" or FLAGS.mode == "merge" or \ if FLAGS.mode == "build" or FLAGS.mode == "merge" or \
FLAGS.mode == "all": FLAGS.mode == "all":
...@@ -554,6 +586,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level, ...@@ -554,6 +586,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
def main(unused_args): def main(unused_args):
common.init_logging()
configs = parse_model_configs() configs = parse_model_configs()
if FLAGS.mode == "validate": if FLAGS.mode == "validate":
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import falcon_cli import falcon_cli
import filelock import filelock
import glob import glob
import logging
import os import os
import re import re
import sh import sh
...@@ -23,6 +24,7 @@ import sys ...@@ -23,6 +24,7 @@ import sys
import time import time
import urllib import urllib
import common
sys.path.insert(0, "mace/python/tools") sys.path.insert(0, "mace/python/tools")
try: try:
...@@ -35,10 +37,12 @@ except Exception as e: ...@@ -35,10 +37,12 @@ except Exception as e:
print("Import error:\n%s" % e) print("Import error:\n%s" % e)
exit(1) exit(1)
################################ ################################
# common # common
################################ ################################
logger = logging.getLogger('MACE')
def strip_invalid_utf8(str): def strip_invalid_utf8(str):
return sh.iconv(str, "-c", "-t", "UTF-8") return sh.iconv(str, "-c", "-t", "UTF-8")
...@@ -67,11 +71,6 @@ def is_device_locked(serialno): ...@@ -67,11 +71,6 @@ def is_device_locked(serialno):
return True return True
def formatted_file_name(input_name, input_file_name):
return input_file_name + '_' + \
re.sub('[^0-9a-zA-Z]+', '_', input_name)
################################ ################################
# clear data # clear data
################################ ################################
...@@ -491,7 +490,8 @@ def gen_random_input(model_output_dir, ...@@ -491,7 +490,8 @@ def gen_random_input(model_output_dir,
input_files, input_files,
input_file_name="model_input"): input_file_name="model_input"):
for input_name in input_nodes: for input_name in input_nodes:
formatted_name = formatted_file_name(input_name, input_file_name) formatted_name = common.formatted_file_name(
input_file_name, input_name)
if os.path.exists("%s/%s" % (model_output_dir, formatted_name)): if os.path.exists("%s/%s" % (model_output_dir, formatted_name)):
sh.rm("%s/%s" % (model_output_dir, formatted_name)) sh.rm("%s/%s" % (model_output_dir, formatted_name))
input_nodes_str = ",".join(input_nodes) input_nodes_str = ",".join(input_nodes)
...@@ -517,8 +517,8 @@ def gen_random_input(model_output_dir, ...@@ -517,8 +517,8 @@ def gen_random_input(model_output_dir,
for i in range(len(input_file_list)): for i in range(len(input_file_list)):
if input_file_list[i] is not None: if input_file_list[i] is not None:
dst_input_file = model_output_dir + '/' + \ dst_input_file = model_output_dir + '/' + \
formatted_file_name(input_name_list[i], common.formatted_file_name(input_file_name,
input_file_name) input_name_list[i])
if input_file_list[i].startswith("http://") or \ if input_file_list[i].startswith("http://") or \
input_file_list[i].startswith("https://"): input_file_list[i].startswith("https://"):
urllib.urlretrieve(input_file_list[i], dst_input_file) urllib.urlretrieve(input_file_list[i], dst_input_file)
...@@ -612,8 +612,8 @@ def tuning_run(abi, ...@@ -612,8 +612,8 @@ def tuning_run(abi,
sh.adb("-s", serialno, "shell", "mkdir", "-p", compiled_opencl_dir) sh.adb("-s", serialno, "shell", "mkdir", "-p", compiled_opencl_dir)
for input_name in input_nodes: for input_name in input_nodes:
formatted_name = formatted_file_name(input_name, formatted_name = common.formatted_file_name(input_file_name,
input_file_name) input_name)
adb_push("%s/%s" % (model_output_dir, formatted_name), adb_push("%s/%s" % (model_output_dir, formatted_name),
phone_data_dir, serialno) phone_data_dir, serialno)
adb_push("%s/mace_run" % model_output_dir, phone_data_dir, adb_push("%s/mace_run" % model_output_dir, phone_data_dir,
...@@ -671,20 +671,21 @@ def validate_model(abi, ...@@ -671,20 +671,21 @@ def validate_model(abi,
output_shapes, output_shapes,
model_output_dir, model_output_dir,
phone_data_dir, phone_data_dir,
caffe_env,
input_file_name="model_input", input_file_name="model_input",
output_file_name="model_out"): output_file_name="model_out"):
print("* Validate with %s" % platform) print("* Validate with %s" % platform)
if platform == "tensorflow":
if abi != "host": if abi != "host":
for output_name in output_nodes: for output_name in output_nodes:
formatted_name = formatted_file_name( formatted_name = common.formatted_file_name(
output_name, output_file_name) output_file_name, output_name)
if os.path.exists("%s/%s" % (model_output_dir, if os.path.exists("%s/%s" % (model_output_dir,
formatted_name)): formatted_name)):
sh.rm("%s/%s" % (model_output_dir, formatted_name)) sh.rm("-rf", "%s/%s" % (model_output_dir, formatted_name))
adb_pull("%s/%s" % (phone_data_dir, formatted_name), adb_pull("%s/%s" % (phone_data_dir, formatted_name),
model_output_dir, serialno) model_output_dir, serialno)
if platform == "tensorflow":
validate(platform, model_file_path, "", validate(platform, model_file_path, "",
"%s/%s" % (model_output_dir, input_file_name), "%s/%s" % (model_output_dir, input_file_name),
"%s/%s" % (model_output_dir, output_file_name), runtime, "%s/%s" % (model_output_dir, output_file_name), runtime,
...@@ -695,12 +696,26 @@ def validate_model(abi, ...@@ -695,12 +696,26 @@ def validate_model(abi,
container_name = "mace_caffe_validator" container_name = "mace_caffe_validator"
res_file = "validation.result" res_file = "validation.result"
if caffe_env == common.CaffeEnvType.LOCAL:
import imp
try:
imp.find_module('caffe')
except ImportError:
logger.error('There is no caffe python module.')
validate(platform, model_file_path, weight_file_path,
"%s/%s" % (model_output_dir, input_file_name),
"%s/%s" % (model_output_dir, output_file_name), runtime,
":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes))
elif caffe_env == common.CaffeEnvType.DOCKER:
docker_image_id = sh.docker("images", "-q", image_name) docker_image_id = sh.docker("images", "-q", image_name)
if not docker_image_id: if not docker_image_id:
print("Build caffe docker") print("Build caffe docker")
sh.docker("build", "-t", image_name, "docker/caffe") sh.docker("build", "-t", image_name,
"mace/third_party/caffe")
container_id = sh.docker("ps", "-qa", "-f", "name=%s" % container_name) container_id = sh.docker("ps", "-qa", "-f",
"name=%s" % container_name)
if container_id and not sh.docker("ps", "-qa", "--filter", if container_id and not sh.docker("ps", "-qa", "--filter",
"status=running", "-f", "status=running", "-f",
"name=%s" % container_name): "name=%s" % container_name):
...@@ -718,25 +733,16 @@ def validate_model(abi, ...@@ -718,25 +733,16 @@ def validate_model(abi,
"/bin/bash") "/bin/bash")
for input_name in input_nodes: for input_name in input_nodes:
formatted_input_name = formatted_file_name( formatted_input_name = common.formatted_file_name(
input_name, input_file_name) input_file_name, input_name)
sh.docker( sh.docker(
"cp", "cp",
"%s/%s" % (model_output_dir, formatted_input_name), "%s/%s" % (model_output_dir, formatted_input_name),
"%s:/mace" % container_name) "%s:/mace" % container_name)
if abi != "host":
for output_name in output_nodes:
formatted_output_name = formatted_file_name(
output_name, output_file_name)
sh.rm("-rf",
"%s/%s" % (model_output_dir, formatted_output_name))
adb_pull("%s/%s" % (phone_data_dir, formatted_output_name),
model_output_dir, serialno)
for output_name in output_nodes: for output_name in output_nodes:
formatted_output_name = formatted_file_name( formatted_output_name = common.formatted_file_name(
output_name, output_file_name) output_file_name, output_name)
sh.docker( sh.docker(
"cp", "cp",
"%s/%s" % (model_output_dir, formatted_output_name), "%s/%s" % (model_output_dir, formatted_output_name),
...@@ -941,8 +947,8 @@ def benchmark_model(abi, ...@@ -941,8 +947,8 @@ def benchmark_model(abi,
sh.adb("-s", serialno, "shell", "mkdir", "-p", phone_data_dir) sh.adb("-s", serialno, "shell", "mkdir", "-p", phone_data_dir)
for input_name in input_nodes: for input_name in input_nodes:
formatted_name = formatted_file_name(input_name, formatted_name = common.formatted_file_name(input_file_name,
input_file_name) input_name)
adb_push("%s/%s" % (model_output_dir, formatted_name), adb_push("%s/%s" % (model_output_dir, formatted_name),
phone_data_dir, serialno) phone_data_dir, serialno)
adb_push("%s/benchmark_model" % model_output_dir, phone_data_dir, adb_push("%s/benchmark_model" % model_output_dir, phone_data_dir,
......
...@@ -21,6 +21,8 @@ import re ...@@ -21,6 +21,8 @@ import re
from scipy import spatial from scipy import spatial
from scipy import stats from scipy import stats
import common
# Validation Flow: # Validation Flow:
# 1. Generate input data # 1. Generate input data
# 2. Use mace_run to run model on phone. # 2. Use mace_run to run model on phone.
...@@ -42,10 +44,6 @@ def load_data(file): ...@@ -42,10 +44,6 @@ def load_data(file):
return np.empty([0]) return np.empty([0])
def format_name(name):
return re.sub('[^0-9a-zA-Z]+', '_', name)
def compare_output(platform, mace_runtime, output_name, mace_out_value, def compare_output(platform, mace_runtime, output_name, mace_out_value,
out_value): out_value):
if mace_out_value.size != 0: if mace_out_value.size != 0:
...@@ -87,7 +85,7 @@ def validate_tf_model(platform, mace_runtime, model_file, input_file, ...@@ -87,7 +85,7 @@ def validate_tf_model(platform, mace_runtime, model_file, input_file,
input_dict = {} input_dict = {}
for i in range(len(input_names)): for i in range(len(input_names)):
input_value = load_data( input_value = load_data(
input_file + "_" + format_name(input_names[i])) common.formatted_file_name(input_file, input_names[i]))
input_value = input_value.reshape(input_shapes[i]) input_value = input_value.reshape(input_shapes[i])
input_node = graph.get_tensor_by_name( input_node = graph.get_tensor_by_name(
input_names[i] + ':0') input_names[i] + ':0')
...@@ -99,8 +97,8 @@ def validate_tf_model(platform, mace_runtime, model_file, input_file, ...@@ -99,8 +97,8 @@ def validate_tf_model(platform, mace_runtime, model_file, input_file,
[graph.get_tensor_by_name(name + ':0')]) [graph.get_tensor_by_name(name + ':0')])
output_values = session.run(output_nodes, feed_dict=input_dict) output_values = session.run(output_nodes, feed_dict=input_dict)
for i in range(len(output_names)): for i in range(len(output_names)):
output_file_name = mace_out_file + "_" + \ output_file_name = common.formatted_file_name(
format_name(output_names[i]) mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name) mace_out_value = load_data(output_file_name)
compare_output(platform, mace_runtime, output_names[i], compare_output(platform, mace_runtime, output_names[i],
mace_out_value, output_values[i]) mace_out_value, output_values[i])
...@@ -123,7 +121,8 @@ def validate_caffe_model(platform, mace_runtime, model_file, input_file, ...@@ -123,7 +121,8 @@ def validate_caffe_model(platform, mace_runtime, model_file, input_file,
net = caffe.Net(model_file, caffe.TEST, weights=weight_file) net = caffe.Net(model_file, caffe.TEST, weights=weight_file)
for i in range(len(input_names)): for i in range(len(input_names)):
input_value = load_data(input_file + "_" + format_name(input_names[i])) input_value = load_data(
common.formatted_file_name(input_file, input_names[i]))
input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1, input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1,
2)) 2))
input_blob_name = input_names[i] input_blob_name = input_names[i]
...@@ -142,8 +141,8 @@ def validate_caffe_model(platform, mace_runtime, model_file, input_file, ...@@ -142,8 +141,8 @@ def validate_caffe_model(platform, mace_runtime, model_file, input_file,
out_shape[1], out_shape[2], out_shape[3] = out_shape[3], out_shape[ out_shape[1], out_shape[2], out_shape[3] = out_shape[3], out_shape[
1], out_shape[2] 1], out_shape[2]
value = value.reshape(out_shape).transpose((0, 2, 3, 1)) value = value.reshape(out_shape).transpose((0, 2, 3, 1))
output_file_name = mace_out_file + "_" + format_name( output_file_name = common.formatted_file_name(
output_names[i]) mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name) mace_out_value = load_data(output_file_name)
compare_output(platform, mace_runtime, output_names[i], mace_out_value, compare_output(platform, mace_runtime, output_names[i], mace_out_value,
value) value)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册