diff --git a/mace/python/tools/binary_codegen.py b/mace/python/tools/binary_codegen.py index c36941f0ea526631fb916797efd2f2e1ceb95df1..3c78e889b1c5f270fdf8f9853b2512b2e1562e5a 100644 --- a/mace/python/tools/binary_codegen.py +++ b/mace/python/tools/binary_codegen.py @@ -29,10 +29,10 @@ import numpy as np FLAGS = None -def generate_cpp_source(): +def generate_cpp_source(binary_dirs, binary_file_name, variable_name): data_map = {} - for binary_dir in FLAGS.binary_dirs.split(","): - binary_path = os.path.join(binary_dir, FLAGS.binary_file_name) + for binary_dir in binary_dirs.split(","): + binary_path = os.path.join(binary_dir, binary_file_name) if not os.path.exists(binary_path): continue @@ -63,14 +63,18 @@ def generate_cpp_source(): return env.get_template('str2vec_maps.cc.jinja2').render( maps=data_map, data_type='unsigned int', - variable_name=FLAGS.variable_name) - - -def main(unused_args): - cpp_binary_source = generate_cpp_source() - if os.path.isfile(FLAGS.output_path): - os.remove(FLAGS.output_path) - w_file = open(FLAGS.output_path, "w") + variable_name=variable_name) + + +def tuning_param_codegen(binary_dirs, + binary_file_name, + output_path, + variable_name): + cpp_binary_source = generate_cpp_source( + binary_dirs, binary_file_name, variable_name) + if os.path.isfile(output_path): + os.remove(output_path) + w_file = open(output_path, "w") w_file.write(cpp_binary_source) w_file.close() @@ -101,4 +105,7 @@ def parse_args(): if __name__ == '__main__': FLAGS, unparsed = parse_args() - main(unused_args=[sys.argv[0]] + unparsed) + tuning_param_codegen(FLAGS.binary_dirs, + FLAGS.binary_file_name, + FLAGS.output_path, + FLAGS.variable_name) diff --git a/mace/python/tools/encrypt_opencl_codegen.py b/mace/python/tools/encrypt_opencl_codegen.py index 6edf48f377b94480d3ef59482b65a12584be811c..6292b8349e60154824807d070fe0f4802092db07 100644 --- a/mace/python/tools/encrypt_opencl_codegen.py +++ b/mace/python/tools/encrypt_opencl_codegen.py @@ -36,20 +36,20 @@ def encrypt_code(code_str): return encrypted_arr -def main(unused_args): - if not os.path.exists(FLAGS.cl_kernel_dir): - print("Input cl_kernel_dir " + FLAGS.cl_kernel_dir + " doesn't exist!") +def encrypt_opencl_codegen(cl_kernel_dir, output_path): + if not os.path.exists(cl_kernel_dir): + print("Input cl_kernel_dir " + cl_kernel_dir + " doesn't exist!") header_code = "" - for file_name in os.listdir(FLAGS.cl_kernel_dir): - file_path = os.path.join(FLAGS.cl_kernel_dir, file_name) + for file_name in os.listdir(cl_kernel_dir): + file_path = os.path.join(cl_kernel_dir, file_name) if file_path[-2:] == ".h": f = open(file_path, "r") header_code += f.read() encrypted_code_maps = {} - for file_name in os.listdir(FLAGS.cl_kernel_dir): - file_path = os.path.join(FLAGS.cl_kernel_dir, file_name) + for file_name in os.listdir(cl_kernel_dir): + file_path = os.path.join(cl_kernel_dir, file_name) if file_path[-3:] == ".cl": f = open(file_path, "r") code_str = "" @@ -68,9 +68,9 @@ def main(unused_args): data_type='unsigned char', variable_name='kEncryptedProgramMap') - if os.path.isfile(FLAGS.output_path): - os.remove(FLAGS.output_path) - w_file = open(FLAGS.output_path, "w") + if os.path.isfile(output_path): + os.remove(output_path) + w_file = open(output_path, "w") w_file.write(cpp_cl_encrypted_kernel) w_file.close() @@ -95,4 +95,4 @@ def parse_args(): if __name__ == '__main__': FLAGS, unparsed = parse_args() - main(unused_args=[sys.argv[0]] + unparsed) + encrypt_opencl_codegen(FLAGS.cl_kernel_dir, FLAGS.output_path) diff --git a/mace/python/tools/opencl_codegen.py b/mace/python/tools/opencl_codegen.py index 3d9307680dcbba7322bdb9133c69641917438745..cfb12f744b0b25d84099b8b82a4eeab75f9ef2f0 100644 --- a/mace/python/tools/opencl_codegen.py +++ b/mace/python/tools/opencl_codegen.py @@ -27,12 +27,14 @@ import jinja2 FLAGS = None -def generate_cpp_source(): +def generate_cpp_source(cl_binary_dirs, + built_kernel_file_name, + platform_info_file_name): maps = {} platform_info = '' - binary_dirs = FLAGS.cl_binary_dirs.strip().split(",") + binary_dirs = cl_binary_dirs.strip().split(",") for binary_dir in binary_dirs: - binary_path = os.path.join(binary_dir, FLAGS.built_kernel_file_name) + binary_path = os.path.join(binary_dir, built_kernel_file_name) if not os.path.exists(binary_path): continue @@ -59,7 +61,7 @@ def generate_cpp_source(): maps[key].append(hex(ele)) cl_platform_info_path = os.path.join(binary_dir, - FLAGS.platform_info_file_name) + platform_info_file_name) with open(cl_platform_info_path, 'r') as f: curr_platform_info = f.read() if platform_info != "": @@ -75,12 +77,16 @@ def generate_cpp_source(): ) -def main(unused_args): - - cpp_cl_binary_source = generate_cpp_source() - if os.path.isfile(FLAGS.output_path): - os.remove(FLAGS.output_path) - w_file = open(FLAGS.output_path, "w") +def opencl_codegen(output_path, + cl_binary_dirs="", + built_kernel_file_name="", + platform_info_file_name=""): + cpp_cl_binary_source = generate_cpp_source(cl_binary_dirs, + built_kernel_file_name, + platform_info_file_name) + if os.path.isfile(output_path): + os.remove(output_path) + w_file = open(output_path, "w") w_file.write(cpp_cl_binary_source) w_file.close() @@ -113,4 +119,7 @@ def parse_args(): if __name__ == '__main__': FLAGS, unparsed = parse_args() - main(unused_args=[sys.argv[0]] + unparsed) + opencl_codegen(FLAGS.output_path, + FLAGS.cl_binary_dirs, + FLAGS.built_kernel_file_name, + FLAGS.platform_info_file_name) diff --git a/tools/generate_data.py b/tools/generate_data.py index 2069459c65414986473a17050202740e5b296cbd..668838bcde7d965d5a6fd91969c8447568b0731d 100644 --- a/tools/generate_data.py +++ b/tools/generate_data.py @@ -26,22 +26,22 @@ import re # -def generate_data(name, shape): +def generate_data(name, shape, input_file): np.random.seed() data = np.random.random(shape) * 2 - 1 - input_file_name = FLAGS.input_file + "_" + re.sub('[^0-9a-zA-Z]+', '_', - name) + input_file_name = input_file + "_" + re.sub('[^0-9a-zA-Z]+', '_', + name) print 'Generate input file: ', input_file_name data.astype(np.float32).tofile(input_file_name) -def main(unused_args): - input_names = [name for name in FLAGS.input_node.split(',')] - input_shapes = [shape for shape in FLAGS.input_shape.split(':')] +def generate_input_data(input_file, input_node, input_shape): + input_names = [name for name in input_node.split(',')] + input_shapes = [shape for shape in input_shape.split(':')] assert len(input_names) == len(input_shapes) for i in range(len(input_names)): shape = [int(x) for x in input_shapes[i].split(',')] - generate_data(input_names[i], shape) + generate_data(input_names[i], shape, input_file) print "Generate input file done." @@ -61,4 +61,4 @@ def parse_args(): if __name__ == '__main__': FLAGS, unparsed = parse_args() - main(unused_args=[sys.argv[0]] + unparsed) + generate_input_data(FLAGS.input_file, FLAGS.input_node, FLAGS.input_shape) diff --git a/tools/mace_tools.py b/tools/mace_tools.py index 6c18783dd921373a264746735275f2a05aa38498..39b0bba674744d3148c432bc7a8bd382de3d2c49 100644 --- a/tools/mace_tools.py +++ b/tools/mace_tools.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Must run at root dir of libmace project. # python tools/mace_tools.py \ # --config=tools/example.yaml \ # --round=100 \ @@ -89,14 +88,21 @@ def get_hexagon_mode(configs): return False -def generate_code(target_soc, target_abi, model_output_dirs, pull_or_not): +def gen_opencl_and_tuning_code(target_soc, + target_abi, + model_output_dirs, + pull_or_not): if pull_or_not: sh_commands.pull_binaries( target_soc, target_abi, model_output_dirs) - sh_commands.gen_opencl_binary_code( - target_soc, target_abi, model_output_dirs) + + codegen_path = "mace/codegen" + + # generate opencl binary code + sh_commands.gen_opencl_binary_code(target_soc, model_output_dirs) + sh_commands.gen_tuning_param_code( - target_soc, target_abi, model_output_dirs) + target_soc, model_output_dirs) def model_benchmark_stdout_processor(stdout, @@ -170,11 +176,11 @@ def tuning_run(runtime, phone_data_dir, option_args) model_benchmark_stdout_processor(stdout, - target_soc, - target_abi, - runtime, - running_round, - tuning) + target_soc, + target_abi, + runtime, + running_round, + tuning) def build_mace_run_prod(hexagon_mode, runtime, target_soc, target_abi, @@ -182,7 +188,7 @@ def build_mace_run_prod(hexagon_mode, runtime, target_soc, target_abi, input_nodes, output_nodes, input_shapes, output_shapes, model_name, device_type, running_round, restart_round, tuning, limit_opencl_kernel_time, phone_data_dir): - generate_code(target_soc, target_abi, [], False) + gen_opencl_and_tuning_code(target_soc, target_abi, [], False) production_or_not = False mace_run_target = "//mace/tools/validation:mace_run" sh_commands.bazel_build( @@ -207,7 +213,8 @@ def build_mace_run_prod(hexagon_mode, runtime, target_soc, target_abi, phone_data_dir=phone_data_dir, tuning=tuning, limit_opencl_kernel_time=limit_opencl_kernel_time) - generate_code(target_soc, target_abi, [model_output_dir], True) + gen_opencl_and_tuning_code( + target_soc, target_abi, [model_output_dir], True) production_or_not = True sh_commands.bazel_build( mace_run_target, @@ -226,7 +233,8 @@ def merge_libs_and_tuning_results(target_soc, model_output_dirs, hexagon_mode, embed_model_data): - generate_code(target_soc, target_abi, model_output_dirs, False) + gen_opencl_and_tuning_code( + target_soc, target_abi, model_output_dirs, False) sh_commands.build_production_code(target_abi) sh_commands.merge_libs(target_soc, diff --git a/tools/sh_commands.py b/tools/sh_commands.py index 3aa613be1778bca62d1df88beae6191c8fae96a7..f7d67d542ff43898b11af3a16d399dd51c356a9c 100644 --- a/tools/sh_commands.py +++ b/tools/sh_commands.py @@ -19,9 +19,23 @@ import os import re import sh import subprocess +import sys import time +sys.path.insert(0, "mace/python/tools") +try: + from encrypt_opencl_codegen import encrypt_opencl_codegen + from opencl_codegen import opencl_codegen + from binary_codegen import tuning_param_codegen + from generate_data import generate_input_data + from validate import validate +except Exception: + print("Error: import error.") + print("Does the script run at the root dir of mace project?") + exit(1) + + ################################ # common ################################ @@ -283,19 +297,16 @@ def bazel_target_to_bin(target): ################################ # mace commands ################################ -# TODO this should be refactored def gen_encrypted_opencl_source(codegen_path="mace/codegen"): sh.mkdir("-p", "%s/opencl" % codegen_path) - sh.python( - "mace/python/tools/encrypt_opencl_codegen.py", - "--cl_kernel_dir=./mace/kernels/opencl/cl/", - "--output_path=%s/opencl/opencl_encrypt_program.cc" % codegen_path) + encrypt_opencl_codegen("./mace/kernels/opencl/cl/", + "mace/codegen/opencl/opencl_encrypt_program.cc") def pull_binaries(target_soc, abi, model_output_dirs): serialno = adb_devices([target_soc]).pop() compiled_opencl_dir = "/data/local/tmp/mace_run/cl_program/" - mace_run_config_file = "mace_run.config" + mace_run_param_file = "mace_run.config" cl_bin_dirs = [] for d in model_output_dirs: @@ -308,46 +319,33 @@ def pull_binaries(target_soc, abi, model_output_dirs): sh.mkdir("-p", cl_bin_dir) if abi != "host": adb_pull(compiled_opencl_dir, cl_bin_dir, serialno) - adb_pull("/data/local/tmp/mace_run/%s" % mace_run_config_file, + adb_pull("/data/local/tmp/mace_run/%s" % mace_run_param_file, cl_bin_dir, serialno) def gen_opencl_binary_code(target_soc, - abi, model_output_dirs, codegen_path="mace/codegen"): cl_built_kernel_file_name = "mace_cl_compiled_program.bin" cl_platform_info_file_name = "mace_cl_platform_info.txt" + opencl_codegen_file = "%s/opencl/opencl_compiled_program.cc" % codegen_path serialno = adb_devices([target_soc]).pop() - compiled_opencl_dir = "/data/local/tmp/mace_run/cl_program/" cl_bin_dirs = [] for d in model_output_dirs: cl_bin_dirs.append(os.path.join(d, "opencl_bin")) cl_bin_dirs_str = ",".join(cl_bin_dirs) - if not cl_bin_dirs: - sh.python( - "mace/python/tools/opencl_codegen.py", - "--built_kernel_file_name=%s" % cl_built_kernel_file_name, - "--platform_info_file_name=%s" % cl_platform_info_file_name, - "--output_path=%s/opencl/opencl_compiled_program.cc" % - codegen_path) - else: - sh.python( - "mace/python/tools/opencl_codegen.py", - "--built_kernel_file_name=%s" % cl_built_kernel_file_name, - "--platform_info_file_name=%s" % cl_platform_info_file_name, - "--cl_binary_dirs=%s" % cl_bin_dirs_str, - "--output_path=%s/opencl/opencl_compiled_program.cc" % - codegen_path) + opencl_codegen(opencl_codegen_file, + cl_bin_dirs_str, + cl_built_kernel_file_name, + cl_platform_info_file_name) def gen_tuning_param_code(target_soc, - abi, model_output_dirs, codegen_path="mace/codegen"): - mace_run_config_file = "mace_run.config" + mace_run_param_file = "mace_run.config" cl_bin_dirs = [] for d in model_output_dirs: cl_bin_dirs.append(os.path.join(d, "opencl_bin")) @@ -357,11 +355,11 @@ def gen_tuning_param_code(target_soc, if not os.path.exists(tuning_codegen_dir): sh.mkdir("-p", tuning_codegen_dir) - sh.python( - "mace/python/tools/binary_codegen.py", - "--binary_dirs=%s" % cl_bin_dirs_str, - "--binary_file_name=%s" % mace_run_config_file, - "--output_path=%s/tuning_params.cc" % tuning_codegen_dir) + tuning_param_variable_name = "kTuningParamsData" + tuning_param_codegen(cl_bin_dirs_str, + mace_run_param_file, + "%s/tuning_params.cc" % tuning_codegen_dir, + tuning_param_variable_name) def gen_mace_version(codegen_path="mace/codegen"): @@ -371,10 +369,9 @@ def gen_mace_version(codegen_path="mace/codegen"): def gen_compiled_opencl_source(codegen_path="mace/codegen"): + opencl_codegen_file = "%s/opencl/opencl_compiled_program.cc" % codegen_path sh.mkdir("-p", "%s/opencl" % codegen_path) - sh.python( - "mace/python/tools/opencl_codegen.py", - "--output_path=%s/opencl/opencl_compiled_program.cc" % codegen_path) + opencl_codegen(opencl_codegen_file) def gen_model_code(model_codegen_dir, @@ -430,11 +427,9 @@ def gen_random_input(model_output_dir, sh.rm(formatted_name) input_nodes_str = ",".join(input_nodes) input_shapes_str = ":".join(input_shapes) - sh.python("-u", - "tools/generate_data.py", - "--input_node=%s" % input_nodes_str, - "--input_file=%s" % model_output_dir + "/" + input_file_name, - "--input_shape=%s" % input_shapes_str) + generate_input_data("%s/%s" % (model_output_dir, input_file_name), + input_nodes_str, + input_shapes_str) input_file_list = [] if isinstance(input_files, list): @@ -605,8 +600,6 @@ def validate_model(target_soc, output_file_name="model_out"): print("* Validate with %s" % platform) serialno = adb_devices([target_soc]).pop() - stdout_buff = [] - process_output = make_output_processor(stdout_buff) if platform == "tensorflow": if abi != "host": @@ -617,23 +610,11 @@ def validate_model(target_soc, sh.rm(formatted_name) adb_pull("%s/%s" % (phone_data_dir, formatted_name), model_output_dir, serialno) - p = sh.python( - "-u", - "tools/validate.py", - "--platform=%s" % platform, - "--model_file=%s" % model_file_path, - "--input_file=%s" % model_output_dir + "/" + input_file_name, - "--mace_out_file=%s" % model_output_dir + "/" + - output_file_name, - "--mace_runtime=%s" % runtime, - "--input_node=%s" % ",".join(input_nodes), - "--output_node=%s" % ",".join(output_nodes), - "--input_shape=%s" % ":".join(input_shapes), - "--output_shape=%s" % ":".join(output_shapes), - _out=process_output, - _bg=True, - _err_to_out=True) - p.wait() + validate(platform, model_file_path, "", + "%s/%s" % (model_output_dir, input_file_name), + "%s/%s" % (model_output_dir, output_file_name), runtime, + ":".join(input_shapes), ":".join(output_shapes), + ",".join(input_nodes), ",".join(output_nodes)) elif platform == "caffe": image_name = "mace-caffe:latest" container_name = "mace_caffe_validator" @@ -715,7 +696,6 @@ def validate_model(target_soc, p.wait() print("Validation done!\n") - return "".join(stdout_buff) def build_production_code(abi): diff --git a/tools/validate.py b/tools/validate.py index c9595c59690cbf3b7e7ccc8e27c5776246d9b20f..18e54faf6b661746445ec7c47ffb063ece7314e1 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -46,18 +46,19 @@ def format_output_name(name): return re.sub('[^0-9a-zA-Z]+', '_', name) -def compare_output(output_name, mace_out_value, out_value): +def compare_output(platform, mace_runtime, output_name, mace_out_value, + out_value): if mace_out_value.size != 0: out_value = out_value.reshape(-1) mace_out_value = mace_out_value.reshape(-1) assert len(out_value) == len(mace_out_value) similarity = (1 - spatial.distance.cosine(out_value, mace_out_value)) - print output_name, 'MACE VS', FLAGS.platform.upper( + print output_name, 'MACE VS', platform.upper( ), 'similarity: ', similarity - if (FLAGS.mace_runtime == "cpu" and similarity > 0.999) or \ - (FLAGS.mace_runtime == "neon" and similarity > 0.999) or \ - (FLAGS.mace_runtime == "gpu" and similarity > 0.995) or \ - (FLAGS.mace_runtime == "dsp" and similarity > 0.930): + if (mace_runtime == "cpu" and similarity > 0.999) or \ + (mace_runtime == "neon" and similarity > 0.999) or \ + (mace_runtime == "gpu" and similarity > 0.995) or \ + (mace_runtime == "dsp" and similarity > 0.930): print '===================Similarity Test Passed==================' else: print '===================Similarity Test Failed==================' @@ -67,14 +68,15 @@ def compare_output(output_name, mace_out_value, out_value): sys.exit(-1) -def validate_tf_model(input_names, input_shapes, output_names): +def validate_tf_model(platform, mace_runtime, model_file, input_file, + mace_out_file, input_names, input_shapes, output_names): import tensorflow as tf - if not os.path.isfile(FLAGS.model_file): - print("Input graph file '" + FLAGS.model_file + "' does not exist!") + if not os.path.isfile(model_file): + print("Input graph file '" + model_file + "' does not exist!") sys.exit(-1) input_graph_def = tf.GraphDef() - with open(FLAGS.model_file, "rb") as f: + with open(model_file, "rb") as f: data = f.read() input_graph_def.ParseFromString(data) tf.import_graph_def(input_graph_def, name="") @@ -85,7 +87,7 @@ def validate_tf_model(input_names, input_shapes, output_names): input_dict = {} for i in range(len(input_names)): input_value = load_data( - FLAGS.input_file + "_" + input_names[i]) + input_file + "_" + input_names[i]) input_value = input_value.reshape(input_shapes[i]) input_node = graph.get_tensor_by_name( input_names[i] + ':0') @@ -97,30 +99,31 @@ def validate_tf_model(input_names, input_shapes, output_names): [graph.get_tensor_by_name(name + ':0')]) output_values = session.run(output_nodes, feed_dict=input_dict) for i in range(len(output_names)): - output_file_name = FLAGS.mace_out_file + "_" + \ + output_file_name = mace_out_file + "_" + \ format_output_name(output_names[i]) mace_out_value = load_data(output_file_name) - compare_output(output_names[i], mace_out_value, - output_values[i]) + compare_output(platform, mace_runtime, output_names[i], + mace_out_value, output_values[i]) -def validate_caffe_model(input_names, input_shapes, output_names, - output_shapes): +def validate_caffe_model(platform, mace_runtime, model_file, input_file, + mace_out_file, weight_file, input_names, input_shapes, + output_names, output_shapes): os.environ['GLOG_minloglevel'] = '1' # suprress Caffe verbose prints import caffe - if not os.path.isfile(FLAGS.model_file): - print("Input graph file '" + FLAGS.model_file + "' does not exist!") + if not os.path.isfile(model_file): + print("Input graph file '" + model_file + "' does not exist!") sys.exit(-1) - if not os.path.isfile(FLAGS.weight_file): - print("Input weight file '" + FLAGS.weight_file + "' does not exist!") + if not os.path.isfile(weight_file): + print("Input weight file '" + weight_file + "' does not exist!") sys.exit(-1) caffe.set_mode_cpu() - net = caffe.Net(FLAGS.model_file, caffe.TEST, weights=FLAGS.weight_file) + net = caffe.Net(model_file, caffe.TEST, weights=weight_file) for i in range(len(input_names)): - input_value = load_data(FLAGS.input_file + "_" + input_names[i]) + input_value = load_data(input_file + "_" + input_names[i]) input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1, 2)) input_blob_name = input_names[i] @@ -139,28 +142,33 @@ def validate_caffe_model(input_names, input_shapes, output_names, out_shape[1], out_shape[2], out_shape[3] = out_shape[3], out_shape[ 1], out_shape[2] value = value.reshape(out_shape).transpose((0, 2, 3, 1)) - output_file_name = FLAGS.mace_out_file + "_" + format_output_name( + output_file_name = mace_out_file + "_" + format_output_name( output_names[i]) mace_out_value = load_data(output_file_name) - compare_output(output_names[i], mace_out_value, value) + compare_output(platform, mace_runtime, output_names[i], mace_out_value, + value) -def main(unused_args): - input_names = [name for name in FLAGS.input_node.split(',')] - input_shape_strs = [shape for shape in FLAGS.input_shape.split(':')] +def validate(platform, model_file, weight_file, input_file, mace_out_file, + mace_runtime, input_shape, output_shape, input_node, output_node): + input_names = [name for name in input_node.split(',')] + input_shape_strs = [shape for shape in input_shape.split(':')] input_shapes = [[int(x) for x in shape.split(',')] for shape in input_shape_strs] - output_names = [name for name in FLAGS.output_node.split(',')] + output_names = [name for name in output_node.split(',')] assert len(input_names) == len(input_shapes) - if FLAGS.platform == 'tensorflow': - validate_tf_model(input_names, input_shapes, output_names) - elif FLAGS.platform == 'caffe': - output_shape_strs = [shape for shape in FLAGS.output_shape.split(':')] + if platform == 'tensorflow': + validate_tf_model(platform, mace_runtime, model_file, input_file, + mace_out_file, input_names, input_shapes, + output_names) + elif platform == 'caffe': + output_shape_strs = [shape for shape in output_shape.split(':')] output_shapes = [[int(x) for x in shape.split(',')] for shape in output_shape_strs] - validate_caffe_model(input_names, input_shapes, output_names, - output_shapes) + validate_caffe_model(platform, mace_runtime, model_file, input_file, + mace_out_file, weight_file, input_names, + input_shapes, output_names, output_shapes) def parse_args(): @@ -202,4 +210,13 @@ def parse_args(): if __name__ == '__main__': FLAGS, unparsed = parse_args() - main(unused_args=[sys.argv[0]] + unparsed) + validate(FLAGS.platform, + FLAGS.model_file, + FLAGS.weight_file, + FLAGS.input_file, + FLAGS.mace_out_file, + FLAGS.mace_runtime, + FLAGS.input_shape, + FLAGS.output_shape, + FLAGS.input_node, + FLAGS.output_node)