From e9bd2dd49b5435cf2218bcdbfe437f77f39238c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Mon, 25 Jun 2018 10:17:43 +0800 Subject: [PATCH] Add transformer config --- tools/sh_commands.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tools/sh_commands.py b/tools/sh_commands.py index 2d0153a5..f96ad7de 100644 --- a/tools/sh_commands.py +++ b/tools/sh_commands.py @@ -489,7 +489,8 @@ def gen_model_code(model_codegen_dir, fast_conv, obfuscate, model_build_type, - data_type): + data_type, + transformers): bazel_build_common("//mace/python/tools:converter") if os.path.exists(model_codegen_dir): @@ -516,6 +517,7 @@ def gen_model_code(model_codegen_dir, "--output_dir=%s" % model_codegen_dir, "--model_build_type=%s" % model_build_type, "--data_type=%s" % data_type, + "--transformers=%s" % transformers, _fg=True) @@ -523,7 +525,8 @@ def gen_random_input(model_output_dir, input_nodes, input_shapes, input_files, - input_file_name="model_input"): + input_file_name="model_input", + input_ranges=None): for input_name in input_nodes: formatted_name = common.formatted_file_name( input_file_name, input_name) @@ -531,9 +534,14 @@ def gen_random_input(model_output_dir, sh.rm("%s/%s" % (model_output_dir, formatted_name)) input_nodes_str = ",".join(input_nodes) input_shapes_str = ":".join(input_shapes) + if input_ranges: + input_ranges_str = ":".join(input_ranges) + else: + input_ranges_str = None generate_input_data("%s/%s" % (model_output_dir, input_file_name), input_nodes_str, - input_shapes_str) + input_shapes_str, + input_ranges_str) input_file_list = [] if isinstance(input_files, list): -- GitLab