提交 c887cfa8 编写于 作者: Y yejianwu

update python run from sh.python(...) to python function

上级 d84c1a5a
...@@ -29,10 +29,10 @@ import numpy as np ...@@ -29,10 +29,10 @@ import numpy as np
FLAGS = None FLAGS = None
def generate_cpp_source(): def generate_cpp_source(binary_dirs, binary_file_name, variable_name):
data_map = {} data_map = {}
for binary_dir in FLAGS.binary_dirs.split(","): for binary_dir in binary_dirs.split(","):
binary_path = os.path.join(binary_dir, FLAGS.binary_file_name) binary_path = os.path.join(binary_dir, binary_file_name)
if not os.path.exists(binary_path): if not os.path.exists(binary_path):
continue continue
...@@ -63,14 +63,18 @@ def generate_cpp_source(): ...@@ -63,14 +63,18 @@ def generate_cpp_source():
return env.get_template('str2vec_maps.cc.jinja2').render( return env.get_template('str2vec_maps.cc.jinja2').render(
maps=data_map, maps=data_map,
data_type='unsigned int', data_type='unsigned int',
variable_name=FLAGS.variable_name) variable_name=variable_name)
def main(unused_args): def tuning_param_codegen(binary_dirs,
cpp_binary_source = generate_cpp_source() binary_file_name,
if os.path.isfile(FLAGS.output_path): output_path,
os.remove(FLAGS.output_path) variable_name):
w_file = open(FLAGS.output_path, "w") cpp_binary_source = generate_cpp_source(
binary_dirs, binary_file_name, variable_name)
if os.path.isfile(output_path):
os.remove(output_path)
w_file = open(output_path, "w")
w_file.write(cpp_binary_source) w_file.write(cpp_binary_source)
w_file.close() w_file.close()
...@@ -101,4 +105,7 @@ def parse_args(): ...@@ -101,4 +105,7 @@ def parse_args():
if __name__ == '__main__': if __name__ == '__main__':
FLAGS, unparsed = parse_args() FLAGS, unparsed = parse_args()
main(unused_args=[sys.argv[0]] + unparsed) tuning_param_codegen(FLAGS.binary_dirs,
FLAGS.binary_file_name,
FLAGS.output_path,
FLAGS.variable_name)
...@@ -36,20 +36,20 @@ def encrypt_code(code_str): ...@@ -36,20 +36,20 @@ def encrypt_code(code_str):
return encrypted_arr return encrypted_arr
def main(unused_args): def encrypt_opencl_codegen(cl_kernel_dir, output_path):
if not os.path.exists(FLAGS.cl_kernel_dir): if not os.path.exists(cl_kernel_dir):
print("Input cl_kernel_dir " + FLAGS.cl_kernel_dir + " doesn't exist!") print("Input cl_kernel_dir " + cl_kernel_dir + " doesn't exist!")
header_code = "" header_code = ""
for file_name in os.listdir(FLAGS.cl_kernel_dir): for file_name in os.listdir(cl_kernel_dir):
file_path = os.path.join(FLAGS.cl_kernel_dir, file_name) file_path = os.path.join(cl_kernel_dir, file_name)
if file_path[-2:] == ".h": if file_path[-2:] == ".h":
f = open(file_path, "r") f = open(file_path, "r")
header_code += f.read() header_code += f.read()
encrypted_code_maps = {} encrypted_code_maps = {}
for file_name in os.listdir(FLAGS.cl_kernel_dir): for file_name in os.listdir(cl_kernel_dir):
file_path = os.path.join(FLAGS.cl_kernel_dir, file_name) file_path = os.path.join(cl_kernel_dir, file_name)
if file_path[-3:] == ".cl": if file_path[-3:] == ".cl":
f = open(file_path, "r") f = open(file_path, "r")
code_str = "" code_str = ""
...@@ -68,9 +68,9 @@ def main(unused_args): ...@@ -68,9 +68,9 @@ def main(unused_args):
data_type='unsigned char', data_type='unsigned char',
variable_name='kEncryptedProgramMap') variable_name='kEncryptedProgramMap')
if os.path.isfile(FLAGS.output_path): if os.path.isfile(output_path):
os.remove(FLAGS.output_path) os.remove(output_path)
w_file = open(FLAGS.output_path, "w") w_file = open(output_path, "w")
w_file.write(cpp_cl_encrypted_kernel) w_file.write(cpp_cl_encrypted_kernel)
w_file.close() w_file.close()
...@@ -95,4 +95,4 @@ def parse_args(): ...@@ -95,4 +95,4 @@ def parse_args():
if __name__ == '__main__': if __name__ == '__main__':
FLAGS, unparsed = parse_args() FLAGS, unparsed = parse_args()
main(unused_args=[sys.argv[0]] + unparsed) encrypt_opencl_codegen(FLAGS.cl_kernel_dir, FLAGS.output_path)
...@@ -27,12 +27,14 @@ import jinja2 ...@@ -27,12 +27,14 @@ import jinja2
FLAGS = None FLAGS = None
def generate_cpp_source(): def generate_cpp_source(cl_binary_dirs,
built_kernel_file_name,
platform_info_file_name):
maps = {} maps = {}
platform_info = '' platform_info = ''
binary_dirs = FLAGS.cl_binary_dirs.strip().split(",") binary_dirs = cl_binary_dirs.strip().split(",")
for binary_dir in binary_dirs: for binary_dir in binary_dirs:
binary_path = os.path.join(binary_dir, FLAGS.built_kernel_file_name) binary_path = os.path.join(binary_dir, built_kernel_file_name)
if not os.path.exists(binary_path): if not os.path.exists(binary_path):
continue continue
...@@ -59,7 +61,7 @@ def generate_cpp_source(): ...@@ -59,7 +61,7 @@ def generate_cpp_source():
maps[key].append(hex(ele)) maps[key].append(hex(ele))
cl_platform_info_path = os.path.join(binary_dir, cl_platform_info_path = os.path.join(binary_dir,
FLAGS.platform_info_file_name) platform_info_file_name)
with open(cl_platform_info_path, 'r') as f: with open(cl_platform_info_path, 'r') as f:
curr_platform_info = f.read() curr_platform_info = f.read()
if platform_info != "": if platform_info != "":
...@@ -75,12 +77,16 @@ def generate_cpp_source(): ...@@ -75,12 +77,16 @@ def generate_cpp_source():
) )
def main(unused_args): def opencl_codegen(output_path,
cl_binary_dirs="",
cpp_cl_binary_source = generate_cpp_source() built_kernel_file_name="",
if os.path.isfile(FLAGS.output_path): platform_info_file_name=""):
os.remove(FLAGS.output_path) cpp_cl_binary_source = generate_cpp_source(cl_binary_dirs,
w_file = open(FLAGS.output_path, "w") built_kernel_file_name,
platform_info_file_name)
if os.path.isfile(output_path):
os.remove(output_path)
w_file = open(output_path, "w")
w_file.write(cpp_cl_binary_source) w_file.write(cpp_cl_binary_source)
w_file.close() w_file.close()
...@@ -113,4 +119,7 @@ def parse_args(): ...@@ -113,4 +119,7 @@ def parse_args():
if __name__ == '__main__': if __name__ == '__main__':
FLAGS, unparsed = parse_args() FLAGS, unparsed = parse_args()
main(unused_args=[sys.argv[0]] + unparsed) opencl_codegen(FLAGS.output_path,
FLAGS.cl_binary_dirs,
FLAGS.built_kernel_file_name,
FLAGS.platform_info_file_name)
...@@ -26,22 +26,22 @@ import re ...@@ -26,22 +26,22 @@ import re
# #
def generate_data(name, shape): def generate_data(name, shape, input_file):
np.random.seed() np.random.seed()
data = np.random.random(shape) * 2 - 1 data = np.random.random(shape) * 2 - 1
input_file_name = FLAGS.input_file + "_" + re.sub('[^0-9a-zA-Z]+', '_', input_file_name = input_file + "_" + re.sub('[^0-9a-zA-Z]+', '_',
name) name)
print 'Generate input file: ', input_file_name print 'Generate input file: ', input_file_name
data.astype(np.float32).tofile(input_file_name) data.astype(np.float32).tofile(input_file_name)
def main(unused_args): def generate_input_data(input_file, input_node, input_shape):
input_names = [name for name in FLAGS.input_node.split(',')] input_names = [name for name in input_node.split(',')]
input_shapes = [shape for shape in FLAGS.input_shape.split(':')] input_shapes = [shape for shape in input_shape.split(':')]
assert len(input_names) == len(input_shapes) assert len(input_names) == len(input_shapes)
for i in range(len(input_names)): for i in range(len(input_names)):
shape = [int(x) for x in input_shapes[i].split(',')] shape = [int(x) for x in input_shapes[i].split(',')]
generate_data(input_names[i], shape) generate_data(input_names[i], shape, input_file)
print "Generate input file done." print "Generate input file done."
...@@ -61,4 +61,4 @@ def parse_args(): ...@@ -61,4 +61,4 @@ def parse_args():
if __name__ == '__main__': if __name__ == '__main__':
FLAGS, unparsed = parse_args() FLAGS, unparsed = parse_args()
main(unused_args=[sys.argv[0]] + unparsed) generate_input_data(FLAGS.input_file, FLAGS.input_node, FLAGS.input_shape)
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Must run at root dir of libmace project.
# python tools/mace_tools.py \ # python tools/mace_tools.py \
# --config=tools/example.yaml \ # --config=tools/example.yaml \
# --round=100 \ # --round=100 \
...@@ -89,14 +88,21 @@ def get_hexagon_mode(configs): ...@@ -89,14 +88,21 @@ def get_hexagon_mode(configs):
return False return False
def generate_code(target_soc, target_abi, model_output_dirs, pull_or_not): def gen_opencl_and_tuning_code(target_soc,
target_abi,
model_output_dirs,
pull_or_not):
if pull_or_not: if pull_or_not:
sh_commands.pull_binaries( sh_commands.pull_binaries(
target_soc, target_abi, model_output_dirs) target_soc, target_abi, model_output_dirs)
sh_commands.gen_opencl_binary_code(
target_soc, target_abi, model_output_dirs) codegen_path = "mace/codegen"
# generate opencl binary code
sh_commands.gen_opencl_binary_code(target_soc, model_output_dirs)
sh_commands.gen_tuning_param_code( sh_commands.gen_tuning_param_code(
target_soc, target_abi, model_output_dirs) target_soc, model_output_dirs)
def model_benchmark_stdout_processor(stdout, def model_benchmark_stdout_processor(stdout,
...@@ -170,11 +176,11 @@ def tuning_run(runtime, ...@@ -170,11 +176,11 @@ def tuning_run(runtime,
phone_data_dir, phone_data_dir,
option_args) option_args)
model_benchmark_stdout_processor(stdout, model_benchmark_stdout_processor(stdout,
target_soc, target_soc,
target_abi, target_abi,
runtime, runtime,
running_round, running_round,
tuning) tuning)
def build_mace_run_prod(hexagon_mode, runtime, target_soc, target_abi, def build_mace_run_prod(hexagon_mode, runtime, target_soc, target_abi,
...@@ -182,7 +188,7 @@ def build_mace_run_prod(hexagon_mode, runtime, target_soc, target_abi, ...@@ -182,7 +188,7 @@ def build_mace_run_prod(hexagon_mode, runtime, target_soc, target_abi,
input_nodes, output_nodes, input_shapes, output_shapes, input_nodes, output_nodes, input_shapes, output_shapes,
model_name, device_type, running_round, restart_round, model_name, device_type, running_round, restart_round,
tuning, limit_opencl_kernel_time, phone_data_dir): tuning, limit_opencl_kernel_time, phone_data_dir):
generate_code(target_soc, target_abi, [], False) gen_opencl_and_tuning_code(target_soc, target_abi, [], False)
production_or_not = False production_or_not = False
mace_run_target = "//mace/tools/validation:mace_run" mace_run_target = "//mace/tools/validation:mace_run"
sh_commands.bazel_build( sh_commands.bazel_build(
...@@ -207,7 +213,8 @@ def build_mace_run_prod(hexagon_mode, runtime, target_soc, target_abi, ...@@ -207,7 +213,8 @@ def build_mace_run_prod(hexagon_mode, runtime, target_soc, target_abi,
phone_data_dir=phone_data_dir, tuning=tuning, phone_data_dir=phone_data_dir, tuning=tuning,
limit_opencl_kernel_time=limit_opencl_kernel_time) limit_opencl_kernel_time=limit_opencl_kernel_time)
generate_code(target_soc, target_abi, [model_output_dir], True) gen_opencl_and_tuning_code(
target_soc, target_abi, [model_output_dir], True)
production_or_not = True production_or_not = True
sh_commands.bazel_build( sh_commands.bazel_build(
mace_run_target, mace_run_target,
...@@ -226,7 +233,8 @@ def merge_libs_and_tuning_results(target_soc, ...@@ -226,7 +233,8 @@ def merge_libs_and_tuning_results(target_soc,
model_output_dirs, model_output_dirs,
hexagon_mode, hexagon_mode,
embed_model_data): embed_model_data):
generate_code(target_soc, target_abi, model_output_dirs, False) gen_opencl_and_tuning_code(
target_soc, target_abi, model_output_dirs, False)
sh_commands.build_production_code(target_abi) sh_commands.build_production_code(target_abi)
sh_commands.merge_libs(target_soc, sh_commands.merge_libs(target_soc,
......
...@@ -19,9 +19,23 @@ import os ...@@ -19,9 +19,23 @@ import os
import re import re
import sh import sh
import subprocess import subprocess
import sys
import time import time
sys.path.insert(0, "mace/python/tools")
try:
from encrypt_opencl_codegen import encrypt_opencl_codegen
from opencl_codegen import opencl_codegen
from binary_codegen import tuning_param_codegen
from generate_data import generate_input_data
from validate import validate
except Exception:
print("Error: import error.")
print("Does the script run at the root dir of mace project?")
exit(1)
################################ ################################
# common # common
################################ ################################
...@@ -283,19 +297,16 @@ def bazel_target_to_bin(target): ...@@ -283,19 +297,16 @@ def bazel_target_to_bin(target):
################################ ################################
# mace commands # mace commands
################################ ################################
# TODO this should be refactored
def gen_encrypted_opencl_source(codegen_path="mace/codegen"): def gen_encrypted_opencl_source(codegen_path="mace/codegen"):
sh.mkdir("-p", "%s/opencl" % codegen_path) sh.mkdir("-p", "%s/opencl" % codegen_path)
sh.python( encrypt_opencl_codegen("./mace/kernels/opencl/cl/",
"mace/python/tools/encrypt_opencl_codegen.py", "mace/codegen/opencl/opencl_encrypt_program.cc")
"--cl_kernel_dir=./mace/kernels/opencl/cl/",
"--output_path=%s/opencl/opencl_encrypt_program.cc" % codegen_path)
def pull_binaries(target_soc, abi, model_output_dirs): def pull_binaries(target_soc, abi, model_output_dirs):
serialno = adb_devices([target_soc]).pop() serialno = adb_devices([target_soc]).pop()
compiled_opencl_dir = "/data/local/tmp/mace_run/cl_program/" compiled_opencl_dir = "/data/local/tmp/mace_run/cl_program/"
mace_run_config_file = "mace_run.config" mace_run_param_file = "mace_run.config"
cl_bin_dirs = [] cl_bin_dirs = []
for d in model_output_dirs: for d in model_output_dirs:
...@@ -308,46 +319,33 @@ def pull_binaries(target_soc, abi, model_output_dirs): ...@@ -308,46 +319,33 @@ def pull_binaries(target_soc, abi, model_output_dirs):
sh.mkdir("-p", cl_bin_dir) sh.mkdir("-p", cl_bin_dir)
if abi != "host": if abi != "host":
adb_pull(compiled_opencl_dir, cl_bin_dir, serialno) adb_pull(compiled_opencl_dir, cl_bin_dir, serialno)
adb_pull("/data/local/tmp/mace_run/%s" % mace_run_config_file, adb_pull("/data/local/tmp/mace_run/%s" % mace_run_param_file,
cl_bin_dir, serialno) cl_bin_dir, serialno)
def gen_opencl_binary_code(target_soc, def gen_opencl_binary_code(target_soc,
abi,
model_output_dirs, model_output_dirs,
codegen_path="mace/codegen"): codegen_path="mace/codegen"):
cl_built_kernel_file_name = "mace_cl_compiled_program.bin" cl_built_kernel_file_name = "mace_cl_compiled_program.bin"
cl_platform_info_file_name = "mace_cl_platform_info.txt" cl_platform_info_file_name = "mace_cl_platform_info.txt"
opencl_codegen_file = "%s/opencl/opencl_compiled_program.cc" % codegen_path
serialno = adb_devices([target_soc]).pop() serialno = adb_devices([target_soc]).pop()
compiled_opencl_dir = "/data/local/tmp/mace_run/cl_program/"
cl_bin_dirs = [] cl_bin_dirs = []
for d in model_output_dirs: for d in model_output_dirs:
cl_bin_dirs.append(os.path.join(d, "opencl_bin")) cl_bin_dirs.append(os.path.join(d, "opencl_bin"))
cl_bin_dirs_str = ",".join(cl_bin_dirs) cl_bin_dirs_str = ",".join(cl_bin_dirs)
if not cl_bin_dirs: opencl_codegen(opencl_codegen_file,
sh.python( cl_bin_dirs_str,
"mace/python/tools/opencl_codegen.py", cl_built_kernel_file_name,
"--built_kernel_file_name=%s" % cl_built_kernel_file_name, cl_platform_info_file_name)
"--platform_info_file_name=%s" % cl_platform_info_file_name,
"--output_path=%s/opencl/opencl_compiled_program.cc" %
codegen_path)
else:
sh.python(
"mace/python/tools/opencl_codegen.py",
"--built_kernel_file_name=%s" % cl_built_kernel_file_name,
"--platform_info_file_name=%s" % cl_platform_info_file_name,
"--cl_binary_dirs=%s" % cl_bin_dirs_str,
"--output_path=%s/opencl/opencl_compiled_program.cc" %
codegen_path)
def gen_tuning_param_code(target_soc, def gen_tuning_param_code(target_soc,
abi,
model_output_dirs, model_output_dirs,
codegen_path="mace/codegen"): codegen_path="mace/codegen"):
mace_run_config_file = "mace_run.config" mace_run_param_file = "mace_run.config"
cl_bin_dirs = [] cl_bin_dirs = []
for d in model_output_dirs: for d in model_output_dirs:
cl_bin_dirs.append(os.path.join(d, "opencl_bin")) cl_bin_dirs.append(os.path.join(d, "opencl_bin"))
...@@ -357,11 +355,11 @@ def gen_tuning_param_code(target_soc, ...@@ -357,11 +355,11 @@ def gen_tuning_param_code(target_soc,
if not os.path.exists(tuning_codegen_dir): if not os.path.exists(tuning_codegen_dir):
sh.mkdir("-p", tuning_codegen_dir) sh.mkdir("-p", tuning_codegen_dir)
sh.python( tuning_param_variable_name = "kTuningParamsData"
"mace/python/tools/binary_codegen.py", tuning_param_codegen(cl_bin_dirs_str,
"--binary_dirs=%s" % cl_bin_dirs_str, mace_run_param_file,
"--binary_file_name=%s" % mace_run_config_file, "%s/tuning_params.cc" % tuning_codegen_dir,
"--output_path=%s/tuning_params.cc" % tuning_codegen_dir) tuning_param_variable_name)
def gen_mace_version(codegen_path="mace/codegen"): def gen_mace_version(codegen_path="mace/codegen"):
...@@ -371,10 +369,9 @@ def gen_mace_version(codegen_path="mace/codegen"): ...@@ -371,10 +369,9 @@ def gen_mace_version(codegen_path="mace/codegen"):
def gen_compiled_opencl_source(codegen_path="mace/codegen"): def gen_compiled_opencl_source(codegen_path="mace/codegen"):
opencl_codegen_file = "%s/opencl/opencl_compiled_program.cc" % codegen_path
sh.mkdir("-p", "%s/opencl" % codegen_path) sh.mkdir("-p", "%s/opencl" % codegen_path)
sh.python( opencl_codegen(opencl_codegen_file)
"mace/python/tools/opencl_codegen.py",
"--output_path=%s/opencl/opencl_compiled_program.cc" % codegen_path)
def gen_model_code(model_codegen_dir, def gen_model_code(model_codegen_dir,
...@@ -430,11 +427,9 @@ def gen_random_input(model_output_dir, ...@@ -430,11 +427,9 @@ def gen_random_input(model_output_dir,
sh.rm(formatted_name) sh.rm(formatted_name)
input_nodes_str = ",".join(input_nodes) input_nodes_str = ",".join(input_nodes)
input_shapes_str = ":".join(input_shapes) input_shapes_str = ":".join(input_shapes)
sh.python("-u", generate_input_data("%s/%s" % (model_output_dir, input_file_name),
"tools/generate_data.py", input_nodes_str,
"--input_node=%s" % input_nodes_str, input_shapes_str)
"--input_file=%s" % model_output_dir + "/" + input_file_name,
"--input_shape=%s" % input_shapes_str)
input_file_list = [] input_file_list = []
if isinstance(input_files, list): if isinstance(input_files, list):
...@@ -605,8 +600,6 @@ def validate_model(target_soc, ...@@ -605,8 +600,6 @@ def validate_model(target_soc,
output_file_name="model_out"): output_file_name="model_out"):
print("* Validate with %s" % platform) print("* Validate with %s" % platform)
serialno = adb_devices([target_soc]).pop() serialno = adb_devices([target_soc]).pop()
stdout_buff = []
process_output = make_output_processor(stdout_buff)
if platform == "tensorflow": if platform == "tensorflow":
if abi != "host": if abi != "host":
...@@ -617,23 +610,11 @@ def validate_model(target_soc, ...@@ -617,23 +610,11 @@ def validate_model(target_soc,
sh.rm(formatted_name) sh.rm(formatted_name)
adb_pull("%s/%s" % (phone_data_dir, formatted_name), adb_pull("%s/%s" % (phone_data_dir, formatted_name),
model_output_dir, serialno) model_output_dir, serialno)
p = sh.python( validate(platform, model_file_path, "",
"-u", "%s/%s" % (model_output_dir, input_file_name),
"tools/validate.py", "%s/%s" % (model_output_dir, output_file_name), runtime,
"--platform=%s" % platform, ":".join(input_shapes), ":".join(output_shapes),
"--model_file=%s" % model_file_path, ",".join(input_nodes), ",".join(output_nodes))
"--input_file=%s" % model_output_dir + "/" + input_file_name,
"--mace_out_file=%s" % model_output_dir + "/" +
output_file_name,
"--mace_runtime=%s" % runtime,
"--input_node=%s" % ",".join(input_nodes),
"--output_node=%s" % ",".join(output_nodes),
"--input_shape=%s" % ":".join(input_shapes),
"--output_shape=%s" % ":".join(output_shapes),
_out=process_output,
_bg=True,
_err_to_out=True)
p.wait()
elif platform == "caffe": elif platform == "caffe":
image_name = "mace-caffe:latest" image_name = "mace-caffe:latest"
container_name = "mace_caffe_validator" container_name = "mace_caffe_validator"
...@@ -715,7 +696,6 @@ def validate_model(target_soc, ...@@ -715,7 +696,6 @@ def validate_model(target_soc,
p.wait() p.wait()
print("Validation done!\n") print("Validation done!\n")
return "".join(stdout_buff)
def build_production_code(abi): def build_production_code(abi):
......
...@@ -46,18 +46,19 @@ def format_output_name(name): ...@@ -46,18 +46,19 @@ def format_output_name(name):
return re.sub('[^0-9a-zA-Z]+', '_', name) return re.sub('[^0-9a-zA-Z]+', '_', name)
def compare_output(output_name, mace_out_value, out_value): def compare_output(platform, mace_runtime, output_name, mace_out_value,
out_value):
if mace_out_value.size != 0: if mace_out_value.size != 0:
out_value = out_value.reshape(-1) out_value = out_value.reshape(-1)
mace_out_value = mace_out_value.reshape(-1) mace_out_value = mace_out_value.reshape(-1)
assert len(out_value) == len(mace_out_value) assert len(out_value) == len(mace_out_value)
similarity = (1 - spatial.distance.cosine(out_value, mace_out_value)) similarity = (1 - spatial.distance.cosine(out_value, mace_out_value))
print output_name, 'MACE VS', FLAGS.platform.upper( print output_name, 'MACE VS', platform.upper(
), 'similarity: ', similarity ), 'similarity: ', similarity
if (FLAGS.mace_runtime == "cpu" and similarity > 0.999) or \ if (mace_runtime == "cpu" and similarity > 0.999) or \
(FLAGS.mace_runtime == "neon" and similarity > 0.999) or \ (mace_runtime == "neon" and similarity > 0.999) or \
(FLAGS.mace_runtime == "gpu" and similarity > 0.995) or \ (mace_runtime == "gpu" and similarity > 0.995) or \
(FLAGS.mace_runtime == "dsp" and similarity > 0.930): (mace_runtime == "dsp" and similarity > 0.930):
print '===================Similarity Test Passed==================' print '===================Similarity Test Passed=================='
else: else:
print '===================Similarity Test Failed==================' print '===================Similarity Test Failed=================='
...@@ -67,14 +68,15 @@ def compare_output(output_name, mace_out_value, out_value): ...@@ -67,14 +68,15 @@ def compare_output(output_name, mace_out_value, out_value):
sys.exit(-1) sys.exit(-1)
def validate_tf_model(input_names, input_shapes, output_names): def validate_tf_model(platform, mace_runtime, model_file, input_file,
mace_out_file, input_names, input_shapes, output_names):
import tensorflow as tf import tensorflow as tf
if not os.path.isfile(FLAGS.model_file): if not os.path.isfile(model_file):
print("Input graph file '" + FLAGS.model_file + "' does not exist!") print("Input graph file '" + model_file + "' does not exist!")
sys.exit(-1) sys.exit(-1)
input_graph_def = tf.GraphDef() input_graph_def = tf.GraphDef()
with open(FLAGS.model_file, "rb") as f: with open(model_file, "rb") as f:
data = f.read() data = f.read()
input_graph_def.ParseFromString(data) input_graph_def.ParseFromString(data)
tf.import_graph_def(input_graph_def, name="") tf.import_graph_def(input_graph_def, name="")
...@@ -85,7 +87,7 @@ def validate_tf_model(input_names, input_shapes, output_names): ...@@ -85,7 +87,7 @@ def validate_tf_model(input_names, input_shapes, output_names):
input_dict = {} input_dict = {}
for i in range(len(input_names)): for i in range(len(input_names)):
input_value = load_data( input_value = load_data(
FLAGS.input_file + "_" + input_names[i]) input_file + "_" + input_names[i])
input_value = input_value.reshape(input_shapes[i]) input_value = input_value.reshape(input_shapes[i])
input_node = graph.get_tensor_by_name( input_node = graph.get_tensor_by_name(
input_names[i] + ':0') input_names[i] + ':0')
...@@ -97,30 +99,31 @@ def validate_tf_model(input_names, input_shapes, output_names): ...@@ -97,30 +99,31 @@ def validate_tf_model(input_names, input_shapes, output_names):
[graph.get_tensor_by_name(name + ':0')]) [graph.get_tensor_by_name(name + ':0')])
output_values = session.run(output_nodes, feed_dict=input_dict) output_values = session.run(output_nodes, feed_dict=input_dict)
for i in range(len(output_names)): for i in range(len(output_names)):
output_file_name = FLAGS.mace_out_file + "_" + \ output_file_name = mace_out_file + "_" + \
format_output_name(output_names[i]) format_output_name(output_names[i])
mace_out_value = load_data(output_file_name) mace_out_value = load_data(output_file_name)
compare_output(output_names[i], mace_out_value, compare_output(platform, mace_runtime, output_names[i],
output_values[i]) mace_out_value, output_values[i])
def validate_caffe_model(input_names, input_shapes, output_names, def validate_caffe_model(platform, mace_runtime, model_file, input_file,
output_shapes): mace_out_file, weight_file, input_names, input_shapes,
output_names, output_shapes):
os.environ['GLOG_minloglevel'] = '1' # suprress Caffe verbose prints os.environ['GLOG_minloglevel'] = '1' # suprress Caffe verbose prints
import caffe import caffe
if not os.path.isfile(FLAGS.model_file): if not os.path.isfile(model_file):
print("Input graph file '" + FLAGS.model_file + "' does not exist!") print("Input graph file '" + model_file + "' does not exist!")
sys.exit(-1) sys.exit(-1)
if not os.path.isfile(FLAGS.weight_file): if not os.path.isfile(weight_file):
print("Input weight file '" + FLAGS.weight_file + "' does not exist!") print("Input weight file '" + weight_file + "' does not exist!")
sys.exit(-1) sys.exit(-1)
caffe.set_mode_cpu() caffe.set_mode_cpu()
net = caffe.Net(FLAGS.model_file, caffe.TEST, weights=FLAGS.weight_file) net = caffe.Net(model_file, caffe.TEST, weights=weight_file)
for i in range(len(input_names)): for i in range(len(input_names)):
input_value = load_data(FLAGS.input_file + "_" + input_names[i]) input_value = load_data(input_file + "_" + input_names[i])
input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1, input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1,
2)) 2))
input_blob_name = input_names[i] input_blob_name = input_names[i]
...@@ -139,28 +142,33 @@ def validate_caffe_model(input_names, input_shapes, output_names, ...@@ -139,28 +142,33 @@ def validate_caffe_model(input_names, input_shapes, output_names,
out_shape[1], out_shape[2], out_shape[3] = out_shape[3], out_shape[ out_shape[1], out_shape[2], out_shape[3] = out_shape[3], out_shape[
1], out_shape[2] 1], out_shape[2]
value = value.reshape(out_shape).transpose((0, 2, 3, 1)) value = value.reshape(out_shape).transpose((0, 2, 3, 1))
output_file_name = FLAGS.mace_out_file + "_" + format_output_name( output_file_name = mace_out_file + "_" + format_output_name(
output_names[i]) output_names[i])
mace_out_value = load_data(output_file_name) mace_out_value = load_data(output_file_name)
compare_output(output_names[i], mace_out_value, value) compare_output(platform, mace_runtime, output_names[i], mace_out_value,
value)
def main(unused_args): def validate(platform, model_file, weight_file, input_file, mace_out_file,
input_names = [name for name in FLAGS.input_node.split(',')] mace_runtime, input_shape, output_shape, input_node, output_node):
input_shape_strs = [shape for shape in FLAGS.input_shape.split(':')] input_names = [name for name in input_node.split(',')]
input_shape_strs = [shape for shape in input_shape.split(':')]
input_shapes = [[int(x) for x in shape.split(',')] input_shapes = [[int(x) for x in shape.split(',')]
for shape in input_shape_strs] for shape in input_shape_strs]
output_names = [name for name in FLAGS.output_node.split(',')] output_names = [name for name in output_node.split(',')]
assert len(input_names) == len(input_shapes) assert len(input_names) == len(input_shapes)
if FLAGS.platform == 'tensorflow': if platform == 'tensorflow':
validate_tf_model(input_names, input_shapes, output_names) validate_tf_model(platform, mace_runtime, model_file, input_file,
elif FLAGS.platform == 'caffe': mace_out_file, input_names, input_shapes,
output_shape_strs = [shape for shape in FLAGS.output_shape.split(':')] output_names)
elif platform == 'caffe':
output_shape_strs = [shape for shape in output_shape.split(':')]
output_shapes = [[int(x) for x in shape.split(',')] output_shapes = [[int(x) for x in shape.split(',')]
for shape in output_shape_strs] for shape in output_shape_strs]
validate_caffe_model(input_names, input_shapes, output_names, validate_caffe_model(platform, mace_runtime, model_file, input_file,
output_shapes) mace_out_file, weight_file, input_names,
input_shapes, output_names, output_shapes)
def parse_args(): def parse_args():
...@@ -202,4 +210,13 @@ def parse_args(): ...@@ -202,4 +210,13 @@ def parse_args():
if __name__ == '__main__': if __name__ == '__main__':
FLAGS, unparsed = parse_args() FLAGS, unparsed = parse_args()
main(unused_args=[sys.argv[0]] + unparsed) validate(FLAGS.platform,
FLAGS.model_file,
FLAGS.weight_file,
FLAGS.input_file,
FLAGS.mace_out_file,
FLAGS.mace_runtime,
FLAGS.input_shape,
FLAGS.output_shape,
FLAGS.input_node,
FLAGS.output_node)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册