diff --git a/python/tools/binary_codegen.py b/python/tools/binary_codegen.py index a7cd756b74d3a2b67bae1974bddba85cd3deca99..aea06a0a1da060051cdf4b97ac93058e8241f3a5 100644 --- a/python/tools/binary_codegen.py +++ b/python/tools/binary_codegen.py @@ -8,41 +8,40 @@ import jinja2 import numpy as np # python mace/python/tools/binary_codegen.py \ -# --binary_file=${BIN_FILE} --output_path=${CODE_GEN_PATH} --variable_name=kTuningParamsData +# --binary_dirs=${BIN_FILE} \ +# --binary_file_name=mace_run.config \ +# --output_path=${CODE_GEN_PATH} --variable_name=kTuningParamsData FLAGS = None def generate_cpp_source(): data_map = {} - if not os.path.exists(FLAGS.binary_file): - env = jinja2.Environment(loader=jinja2.FileSystemLoader(sys.path[0])) - return env.get_template('str2vec_maps.cc.tmpl').render( - maps=data_map, - data_type='unsigned int', - variable_name=FLAGS.variable_name - ) + for binary_dir in FLAGS.binary_dirs.split(","): + binary_path = os.path.join(binary_dir, FLAGS.binary_file_name) + if not os.path.exists(binary_path): + continue - with open(FLAGS.binary_file, "rb") as binary_file: - binary_array = np.fromfile(binary_file, dtype=np.uint8) + with open(binary_path, "rb") as f: + binary_array = np.fromfile(f, dtype=np.uint8) - idx = 0 - size, = struct.unpack("Q", binary_array[idx:idx+8]) - print size - idx += 8 - for _ in xrange(size): - key_size, = struct.unpack("i", binary_array[idx:idx+4]) - idx += 4 - key, = struct.unpack(str(key_size) + "s", binary_array[idx:idx+key_size]) - idx += key_size - params_size, = struct.unpack("i", binary_array[idx:idx+4]) - idx += 4 - data_map[key] = [] - count = params_size / 4 - params = struct.unpack(str(count) + "i", binary_array[idx:idx+params_size]) - for i in params: - data_map[key].append(i) - idx += params_size + idx = 0 + size, = struct.unpack("Q", binary_array[idx:idx+8]) + print size + idx += 8 + for _ in xrange(size): + key_size, = struct.unpack("i", binary_array[idx:idx+4]) + idx += 4 + key, = struct.unpack(str(key_size) + "s", binary_array[idx:idx+key_size]) + idx += key_size + params_size, = struct.unpack("i", binary_array[idx:idx+4]) + idx += 4 + data_map[key] = [] + count = params_size / 4 + params = struct.unpack(str(count) + "i", binary_array[idx:idx+params_size]) + for i in params: + data_map[key].append(i) + idx += params_size env = jinja2.Environment(loader=jinja2.FileSystemLoader(sys.path[0])) return env.get_template('str2vec_maps.cc.tmpl').render( @@ -63,10 +62,15 @@ def parse_args(): """Parses command line arguments.""" parser = argparse.ArgumentParser() parser.add_argument( - "--binary_file", + "--binary_dirs", type=str, - default="", + default="cl_bin0/,cl_bin1/", help="The binaries file path.") + parser.add_argument( + "--binary_file_name", + type=str, + default="mace_run.config", + help="The binary file name.") parser.add_argument( "--output_path", type=str, diff --git a/python/tools/opencl_codegen.py b/python/tools/opencl_codegen.py index a9d73c1223ef2ab7b5de1a9f641d7766f99f2819..d510932633e5413976973b8cec697634bf4a7a05 100644 --- a/python/tools/opencl_codegen.py +++ b/python/tools/opencl_codegen.py @@ -7,24 +7,28 @@ import numpy as np import jinja2 # python mace/python/tools/opencl_codegen.py \ -# --cl_binary_dir=${CL_BIN_DIR} --output_path=${CL_HEADER_PATH} +# --cl_binary_dirs=${CL_BIN_DIR} --output_path=${CL_HEADER_PATH} FLAGS = None def generate_cpp_source(): maps = {} - for file_name in os.listdir(FLAGS.cl_binary_dir): - file_path = os.path.join(FLAGS.cl_binary_dir, file_name) - if file_path[-4:] == ".bin": - # read binary - f = open(file_path, "rb") - binary_array = np.fromfile(f, dtype=np.uint8) - f.close() + cl_binary_dir_arr = FLAGS.cl_binary_dirs.split(",") + for cl_binary_dir in cl_binary_dir_arr: + if not os.path.exists(cl_binary_dir): + print("Input cl_binary_dir " + cl_binary_dir + " doesn't exist!") + for file_name in os.listdir(cl_binary_dir): + file_path = os.path.join(cl_binary_dir, file_name) + if file_path[-4:] == ".bin": + # read binary + f = open(file_path, "rb") + binary_array = np.fromfile(f, dtype=np.uint8) + f.close() - maps[file_name[:-4]] = [] - for ele in binary_array: - maps[file_name[:-4]].append(hex(ele)) + maps[file_name[:-4]] = [] + for ele in binary_array: + maps[file_name[:-4]].append(hex(ele)) env = jinja2.Environment(loader=jinja2.FileSystemLoader(sys.path[0])) return env.get_template('str2vec_maps.cc.tmpl').render( @@ -35,8 +39,6 @@ def generate_cpp_source(): def main(unused_args): - if not os.path.exists(FLAGS.cl_binary_dir): - print("Input cl_binary_dir " + FLAGS.cl_binary_dir + " doesn't exist!") cpp_cl_binary_source = generate_cpp_source() if os.path.isfile(FLAGS.output_path): @@ -50,10 +52,10 @@ def parse_args(): """Parses command line arguments.""" parser = argparse.ArgumentParser() parser.add_argument( - "--cl_binary_dir", + "--cl_binary_dirs", type=str, - default="./cl_bin/", - help="The cl binaries directory.") + default="cl_bin0/,cl_bin1/,cl_bin2/", + help="The cl binaries directories.") parser.add_argument( "--output_path", type=str,