diff --git a/tools/sh_commands.py b/tools/sh_commands.py index 2d0153a5b4691dd9d2788bb7e31ffb27dfde5e85..f96ad7de560016b81628e17d9c3e432d6578011d 100644 --- a/tools/sh_commands.py +++ b/tools/sh_commands.py @@ -489,7 +489,8 @@ def gen_model_code(model_codegen_dir, fast_conv, obfuscate, model_build_type, - data_type): + data_type, + transformers): bazel_build_common("//mace/python/tools:converter") if os.path.exists(model_codegen_dir): @@ -516,6 +517,7 @@ def gen_model_code(model_codegen_dir, "--output_dir=%s" % model_codegen_dir, "--model_build_type=%s" % model_build_type, "--data_type=%s" % data_type, + "--transformers=%s" % transformers, _fg=True) @@ -523,7 +525,8 @@ def gen_random_input(model_output_dir, input_nodes, input_shapes, input_files, - input_file_name="model_input"): + input_file_name="model_input", + input_ranges=None): for input_name in input_nodes: formatted_name = common.formatted_file_name( input_file_name, input_name) @@ -531,9 +534,14 @@ def gen_random_input(model_output_dir, sh.rm("%s/%s" % (model_output_dir, formatted_name)) input_nodes_str = ",".join(input_nodes) input_shapes_str = ":".join(input_shapes) + if input_ranges: + input_ranges_str = ":".join(input_ranges) + else: + input_ranges_str = None generate_input_data("%s/%s" % (model_output_dir, input_file_name), input_nodes_str, - input_shapes_str) + input_shapes_str, + input_ranges_str) input_file_list = [] if isinstance(input_files, list):