提交 2ac3ca52 编写于 作者: Y yejianwu

refactor validate_model.sh with python

上级 2fb96625
......@@ -8,41 +8,40 @@ import jinja2
import numpy as np
# python mace/python/tools/binary_codegen.py \
# --binary_file=${BIN_FILE} --output_path=${CODE_GEN_PATH} --variable_name=kTuningParamsData
# --binary_dirs=${BIN_FILE} \
# --binary_file_name=mace_run.config \
# --output_path=${CODE_GEN_PATH} --variable_name=kTuningParamsData
FLAGS = None
def generate_cpp_source():
data_map = {}
if not os.path.exists(FLAGS.binary_file):
env = jinja2.Environment(loader=jinja2.FileSystemLoader(sys.path[0]))
return env.get_template('str2vec_maps.cc.tmpl').render(
maps=data_map,
data_type='unsigned int',
variable_name=FLAGS.variable_name
)
for binary_dir in FLAGS.binary_dirs.split(","):
binary_path = os.path.join(binary_dir, FLAGS.binary_file_name)
if not os.path.exists(binary_path):
continue
with open(FLAGS.binary_file, "rb") as binary_file:
binary_array = np.fromfile(binary_file, dtype=np.uint8)
with open(binary_path, "rb") as f:
binary_array = np.fromfile(f, dtype=np.uint8)
idx = 0
size, = struct.unpack("Q", binary_array[idx:idx+8])
print size
idx += 8
for _ in xrange(size):
key_size, = struct.unpack("i", binary_array[idx:idx+4])
idx += 4
key, = struct.unpack(str(key_size) + "s", binary_array[idx:idx+key_size])
idx += key_size
params_size, = struct.unpack("i", binary_array[idx:idx+4])
idx += 4
data_map[key] = []
count = params_size / 4
params = struct.unpack(str(count) + "i", binary_array[idx:idx+params_size])
for i in params:
data_map[key].append(i)
idx += params_size
idx = 0
size, = struct.unpack("Q", binary_array[idx:idx+8])
print size
idx += 8
for _ in xrange(size):
key_size, = struct.unpack("i", binary_array[idx:idx+4])
idx += 4
key, = struct.unpack(str(key_size) + "s", binary_array[idx:idx+key_size])
idx += key_size
params_size, = struct.unpack("i", binary_array[idx:idx+4])
idx += 4
data_map[key] = []
count = params_size / 4
params = struct.unpack(str(count) + "i", binary_array[idx:idx+params_size])
for i in params:
data_map[key].append(i)
idx += params_size
env = jinja2.Environment(loader=jinja2.FileSystemLoader(sys.path[0]))
return env.get_template('str2vec_maps.cc.tmpl').render(
......@@ -63,10 +62,15 @@ def parse_args():
"""Parses command line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--binary_file",
"--binary_dirs",
type=str,
default="",
default="cl_bin0/,cl_bin1/",
help="The binaries file path.")
parser.add_argument(
"--binary_file_name",
type=str,
default="mace_run.config",
help="The binary file name.")
parser.add_argument(
"--output_path",
type=str,
......
......@@ -7,24 +7,28 @@ import numpy as np
import jinja2
# python mace/python/tools/opencl_codegen.py \
# --cl_binary_dir=${CL_BIN_DIR} --output_path=${CL_HEADER_PATH}
# --cl_binary_dirs=${CL_BIN_DIR} --output_path=${CL_HEADER_PATH}
FLAGS = None
def generate_cpp_source():
maps = {}
for file_name in os.listdir(FLAGS.cl_binary_dir):
file_path = os.path.join(FLAGS.cl_binary_dir, file_name)
if file_path[-4:] == ".bin":
# read binary
f = open(file_path, "rb")
binary_array = np.fromfile(f, dtype=np.uint8)
f.close()
cl_binary_dir_arr = FLAGS.cl_binary_dirs.split(",")
for cl_binary_dir in cl_binary_dir_arr:
if not os.path.exists(cl_binary_dir):
print("Input cl_binary_dir " + cl_binary_dir + " doesn't exist!")
for file_name in os.listdir(cl_binary_dir):
file_path = os.path.join(cl_binary_dir, file_name)
if file_path[-4:] == ".bin":
# read binary
f = open(file_path, "rb")
binary_array = np.fromfile(f, dtype=np.uint8)
f.close()
maps[file_name[:-4]] = []
for ele in binary_array:
maps[file_name[:-4]].append(hex(ele))
maps[file_name[:-4]] = []
for ele in binary_array:
maps[file_name[:-4]].append(hex(ele))
env = jinja2.Environment(loader=jinja2.FileSystemLoader(sys.path[0]))
return env.get_template('str2vec_maps.cc.tmpl').render(
......@@ -35,8 +39,6 @@ def generate_cpp_source():
def main(unused_args):
if not os.path.exists(FLAGS.cl_binary_dir):
print("Input cl_binary_dir " + FLAGS.cl_binary_dir + " doesn't exist!")
cpp_cl_binary_source = generate_cpp_source()
if os.path.isfile(FLAGS.output_path):
......@@ -50,10 +52,10 @@ def parse_args():
"""Parses command line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--cl_binary_dir",
"--cl_binary_dirs",
type=str,
default="./cl_bin/",
help="The cl binaries directory.")
default="cl_bin0/,cl_bin1/,cl_bin2/",
help="The cl binaries directories.")
parser.add_argument(
"--output_path",
type=str,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册