提交 e9bd2dd4 编写于 作者: 李寅

Add transformer config

上级 db4e94e3
...@@ -489,7 +489,8 @@ def gen_model_code(model_codegen_dir, ...@@ -489,7 +489,8 @@ def gen_model_code(model_codegen_dir,
fast_conv, fast_conv,
obfuscate, obfuscate,
model_build_type, model_build_type,
data_type): data_type,
transformers):
bazel_build_common("//mace/python/tools:converter") bazel_build_common("//mace/python/tools:converter")
if os.path.exists(model_codegen_dir): if os.path.exists(model_codegen_dir):
...@@ -516,6 +517,7 @@ def gen_model_code(model_codegen_dir, ...@@ -516,6 +517,7 @@ def gen_model_code(model_codegen_dir,
"--output_dir=%s" % model_codegen_dir, "--output_dir=%s" % model_codegen_dir,
"--model_build_type=%s" % model_build_type, "--model_build_type=%s" % model_build_type,
"--data_type=%s" % data_type, "--data_type=%s" % data_type,
"--transformers=%s" % transformers,
_fg=True) _fg=True)
...@@ -523,7 +525,8 @@ def gen_random_input(model_output_dir, ...@@ -523,7 +525,8 @@ def gen_random_input(model_output_dir,
input_nodes, input_nodes,
input_shapes, input_shapes,
input_files, input_files,
input_file_name="model_input"): input_file_name="model_input",
input_ranges=None):
for input_name in input_nodes: for input_name in input_nodes:
formatted_name = common.formatted_file_name( formatted_name = common.formatted_file_name(
input_file_name, input_name) input_file_name, input_name)
...@@ -531,9 +534,14 @@ def gen_random_input(model_output_dir, ...@@ -531,9 +534,14 @@ def gen_random_input(model_output_dir,
sh.rm("%s/%s" % (model_output_dir, formatted_name)) sh.rm("%s/%s" % (model_output_dir, formatted_name))
input_nodes_str = ",".join(input_nodes) input_nodes_str = ",".join(input_nodes)
input_shapes_str = ":".join(input_shapes) input_shapes_str = ":".join(input_shapes)
if input_ranges:
input_ranges_str = ":".join(input_ranges)
else:
input_ranges_str = None
generate_input_data("%s/%s" % (model_output_dir, input_file_name), generate_input_data("%s/%s" % (model_output_dir, input_file_name),
input_nodes_str, input_nodes_str,
input_shapes_str) input_shapes_str,
input_ranges_str)
input_file_list = [] input_file_list = []
if isinstance(input_files, list): if isinstance(input_files, list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册