提交 c8f90e44 编写于 作者: L liuqi

Add caffe_env argument for mace_tool.

上级 9e58f334
# Copyright 2018 Xiaomi, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import logging
import re
################################
# log
################################
def init_logging():
logger = logging.getLogger('MACE')
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s [%(name)s] [%(levelname)s]: %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)
################################
# Argument types
################################
class CaffeEnvType(enum.Enum):
DOCKER = 0,
LOCAL = 1,
################################
# common functions
################################
def formatted_file_name(input_file_name, input_name):
return input_file_name + '_' + \
re.sub('[^0-9a-zA-Z]+', '_', input_name)
......@@ -18,6 +18,7 @@
# --mode=all
import argparse
import enum
import filelock
import hashlib
import os
......@@ -28,6 +29,7 @@ import urllib
import yaml
import re
import common
import sh_commands
from ConfigParser import ConfigParser
......@@ -298,6 +300,27 @@ def md5sum(str):
return md5.hexdigest()
################################
# Parsing arguments
################################
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def str_to_caffe_env_type(v):
if v.lower() == 'docker':
return common.CaffeEnvType.DOCKER
elif v.lower() == 'local':
return common.CaffeEnvType.LOCAL
else:
raise argparse.ArgumentTypeError('[docker | local] expected.')
def parse_model_configs():
with open(FLAGS.config) as f:
configs = yaml.load(f)
......@@ -307,11 +330,11 @@ def parse_model_configs():
def parse_args():
"""Parses command line arguments."""
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--config",
type=str,
default="./tool/config",
required=True,
help="The global config file of models.")
parser.add_argument(
"--output_dir", type=str, default="build", help="The output dir.")
......@@ -328,12 +351,15 @@ def parse_args():
default=1,
help="The model restart round.")
parser.add_argument(
"--tuning", type="bool", default="true", help="Tune opencl params.")
"--tuning",
type=str2bool,
default=True,
help="Tune opencl params.")
parser.add_argument(
"--mode",
type=str,
default="all",
help="[build|run|validate|merge|all|throughput_test].")
help="[build|run|validate|benchmark|merge|all|throughput_test].")
parser.add_argument(
"--target_socs",
type=str,
......@@ -341,19 +367,24 @@ def parse_args():
help="SoCs to build, comma seperated list (getprop ro.board.platform)")
parser.add_argument(
"--out_of_range_check",
type="bool",
default="false",
type=str2bool,
default=False,
help="Enable out of range check for opencl.")
parser.add_argument(
"--collect_report",
type="bool",
default="false",
type=str2bool,
default=False,
help="Collect report.")
parser.add_argument(
"--vlog_level",
type=int,
default=0,
help="VLOG level.")
parser.add_argument(
"--caffe_env",
type=str_to_caffe_env_type,
default='docker',
help="[docker | local] caffe environment.")
return parser.parse_known_args()
......@@ -501,7 +532,8 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
model_config["input_shapes"],
model_config["output_shapes"],
model_output_dir,
phone_data_dir)
phone_data_dir,
FLAGS.caffe_env)
if FLAGS.mode == "build" or FLAGS.mode == "merge" or \
FLAGS.mode == "all":
......@@ -554,6 +586,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
def main(unused_args):
common.init_logging()
configs = parse_model_configs()
if FLAGS.mode == "validate":
......
......@@ -15,6 +15,7 @@
import falcon_cli
import filelock
import glob
import logging
import os
import re
import sh
......@@ -23,6 +24,7 @@ import sys
import time
import urllib
import common
sys.path.insert(0, "mace/python/tools")
try:
......@@ -35,10 +37,12 @@ except Exception as e:
print("Import error:\n%s" % e)
exit(1)
################################
# common
################################
logger = logging.getLogger('MACE')
def strip_invalid_utf8(str):
return sh.iconv(str, "-c", "-t", "UTF-8")
......@@ -67,11 +71,6 @@ def is_device_locked(serialno):
return True
def formatted_file_name(input_name, input_file_name):
return input_file_name + '_' + \
re.sub('[^0-9a-zA-Z]+', '_', input_name)
################################
# clear data
################################
......@@ -491,7 +490,8 @@ def gen_random_input(model_output_dir,
input_files,
input_file_name="model_input"):
for input_name in input_nodes:
formatted_name = formatted_file_name(input_name, input_file_name)
formatted_name = common.formatted_file_name(
input_file_name, input_name)
if os.path.exists("%s/%s" % (model_output_dir, formatted_name)):
sh.rm("%s/%s" % (model_output_dir, formatted_name))
input_nodes_str = ",".join(input_nodes)
......@@ -517,8 +517,8 @@ def gen_random_input(model_output_dir,
for i in range(len(input_file_list)):
if input_file_list[i] is not None:
dst_input_file = model_output_dir + '/' + \
formatted_file_name(input_name_list[i],
input_file_name)
common.formatted_file_name(input_file_name,
input_name_list[i])
if input_file_list[i].startswith("http://") or \
input_file_list[i].startswith("https://"):
urllib.urlretrieve(input_file_list[i], dst_input_file)
......@@ -612,8 +612,8 @@ def tuning_run(abi,
sh.adb("-s", serialno, "shell", "mkdir", "-p", compiled_opencl_dir)
for input_name in input_nodes:
formatted_name = formatted_file_name(input_name,
input_file_name)
formatted_name = common.formatted_file_name(input_file_name,
input_name)
adb_push("%s/%s" % (model_output_dir, formatted_name),
phone_data_dir, serialno)
adb_push("%s/mace_run" % model_output_dir, phone_data_dir,
......@@ -671,20 +671,21 @@ def validate_model(abi,
output_shapes,
model_output_dir,
phone_data_dir,
caffe_env,
input_file_name="model_input",
output_file_name="model_out"):
print("* Validate with %s" % platform)
if platform == "tensorflow":
if abi != "host":
for output_name in output_nodes:
formatted_name = formatted_file_name(
output_name, output_file_name)
formatted_name = common.formatted_file_name(
output_file_name, output_name)
if os.path.exists("%s/%s" % (model_output_dir,
formatted_name)):
sh.rm("%s/%s" % (model_output_dir, formatted_name))
sh.rm("-rf", "%s/%s" % (model_output_dir, formatted_name))
adb_pull("%s/%s" % (phone_data_dir, formatted_name),
model_output_dir, serialno)
if platform == "tensorflow":
validate(platform, model_file_path, "",
"%s/%s" % (model_output_dir, input_file_name),
"%s/%s" % (model_output_dir, output_file_name), runtime,
......@@ -695,12 +696,26 @@ def validate_model(abi,
container_name = "mace_caffe_validator"
res_file = "validation.result"
if caffe_env == common.CaffeEnvType.LOCAL:
import imp
try:
imp.find_module('caffe')
except ImportError:
logger.error('There is no caffe python module.')
validate(platform, model_file_path, weight_file_path,
"%s/%s" % (model_output_dir, input_file_name),
"%s/%s" % (model_output_dir, output_file_name), runtime,
":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes))
elif caffe_env == common.CaffeEnvType.DOCKER:
docker_image_id = sh.docker("images", "-q", image_name)
if not docker_image_id:
print("Build caffe docker")
sh.docker("build", "-t", image_name, "docker/caffe")
sh.docker("build", "-t", image_name,
"mace/third_party/caffe")
container_id = sh.docker("ps", "-qa", "-f", "name=%s" % container_name)
container_id = sh.docker("ps", "-qa", "-f",
"name=%s" % container_name)
if container_id and not sh.docker("ps", "-qa", "--filter",
"status=running", "-f",
"name=%s" % container_name):
......@@ -718,25 +733,16 @@ def validate_model(abi,
"/bin/bash")
for input_name in input_nodes:
formatted_input_name = formatted_file_name(
input_name, input_file_name)
formatted_input_name = common.formatted_file_name(
input_file_name, input_name)
sh.docker(
"cp",
"%s/%s" % (model_output_dir, formatted_input_name),
"%s:/mace" % container_name)
if abi != "host":
for output_name in output_nodes:
formatted_output_name = formatted_file_name(
output_name, output_file_name)
sh.rm("-rf",
"%s/%s" % (model_output_dir, formatted_output_name))
adb_pull("%s/%s" % (phone_data_dir, formatted_output_name),
model_output_dir, serialno)
for output_name in output_nodes:
formatted_output_name = formatted_file_name(
output_name, output_file_name)
formatted_output_name = common.formatted_file_name(
output_file_name, output_name)
sh.docker(
"cp",
"%s/%s" % (model_output_dir, formatted_output_name),
......@@ -941,8 +947,8 @@ def benchmark_model(abi,
sh.adb("-s", serialno, "shell", "mkdir", "-p", phone_data_dir)
for input_name in input_nodes:
formatted_name = formatted_file_name(input_name,
input_file_name)
formatted_name = common.formatted_file_name(input_file_name,
input_name)
adb_push("%s/%s" % (model_output_dir, formatted_name),
phone_data_dir, serialno)
adb_push("%s/benchmark_model" % model_output_dir, phone_data_dir,
......
......@@ -21,6 +21,8 @@ import re
from scipy import spatial
from scipy import stats
import common
# Validation Flow:
# 1. Generate input data
# 2. Use mace_run to run model on phone.
......@@ -42,10 +44,6 @@ def load_data(file):
return np.empty([0])
def format_name(name):
return re.sub('[^0-9a-zA-Z]+', '_', name)
def compare_output(platform, mace_runtime, output_name, mace_out_value,
out_value):
if mace_out_value.size != 0:
......@@ -87,7 +85,7 @@ def validate_tf_model(platform, mace_runtime, model_file, input_file,
input_dict = {}
for i in range(len(input_names)):
input_value = load_data(
input_file + "_" + format_name(input_names[i]))
common.formatted_file_name(input_file, input_names[i]))
input_value = input_value.reshape(input_shapes[i])
input_node = graph.get_tensor_by_name(
input_names[i] + ':0')
......@@ -99,8 +97,8 @@ def validate_tf_model(platform, mace_runtime, model_file, input_file,
[graph.get_tensor_by_name(name + ':0')])
output_values = session.run(output_nodes, feed_dict=input_dict)
for i in range(len(output_names)):
output_file_name = mace_out_file + "_" + \
format_name(output_names[i])
output_file_name = common.formatted_file_name(
mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name)
compare_output(platform, mace_runtime, output_names[i],
mace_out_value, output_values[i])
......@@ -123,7 +121,8 @@ def validate_caffe_model(platform, mace_runtime, model_file, input_file,
net = caffe.Net(model_file, caffe.TEST, weights=weight_file)
for i in range(len(input_names)):
input_value = load_data(input_file + "_" + format_name(input_names[i]))
input_value = load_data(
common.formatted_file_name(input_file, input_names[i]))
input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1,
2))
input_blob_name = input_names[i]
......@@ -142,8 +141,8 @@ def validate_caffe_model(platform, mace_runtime, model_file, input_file,
out_shape[1], out_shape[2], out_shape[3] = out_shape[3], out_shape[
1], out_shape[2]
value = value.reshape(out_shape).transpose((0, 2, 3, 1))
output_file_name = mace_out_file + "_" + format_name(
output_names[i])
output_file_name = common.formatted_file_name(
mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name)
compare_output(platform, mace_runtime, output_names[i], mace_out_value,
value)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册