提交 bed79b44 编写于 作者: 李寅

Make config file param name friendly

上级 8f2eeef1
...@@ -108,8 +108,9 @@ def main(unused_args): ...@@ -108,8 +108,9 @@ def main(unused_args):
print("%s does not support dsp runtime yet." % FLAGS.platform) print("%s does not support dsp runtime yet." % FLAGS.platform)
sys.exit(-1) sys.exit(-1)
else: else:
if FLAGS.transformers: if FLAGS.graph_optimize_options:
option = cvt.ConverterOption(FLAGS.transformers.split(',')) option = cvt.ConverterOption(
FLAGS.graph_optimize_options.split(','))
else: else:
option = cvt.ConverterOption() option = cvt.ConverterOption()
option.winograd = FLAGS.winograd option.winograd = FLAGS.winograd
...@@ -287,10 +288,10 @@ def parse_args(): ...@@ -287,10 +288,10 @@ def parse_args():
default="fp16_fp32", default="fp16_fp32",
help="fp16_fp32/fp32_fp32") help="fp16_fp32/fp32_fp32")
parser.add_argument( parser.add_argument(
"--transformers", "--graph_optimize_options",
type=str, type=str,
default="", default="",
help="model transformers") help="graph optimize options")
return parser.parse_known_args() return parser.parse_known_args()
......
...@@ -148,7 +148,7 @@ class YAMLKeyword(object): ...@@ -148,7 +148,7 @@ class YAMLKeyword(object):
obfuscate = 'obfuscate' obfuscate = 'obfuscate'
winograd = 'winograd' winograd = 'winograd'
validation_inputs_data = 'validation_inputs_data' validation_inputs_data = 'validation_inputs_data'
transformers = 'transformers' # keep it private for now graph_optimize_options = 'graph_optimize_options' # internal use for now
class ModuleName(object): class ModuleName(object):
...@@ -657,7 +657,7 @@ def convert_model(configs): ...@@ -657,7 +657,7 @@ def convert_model(configs):
model_config[YAMLKeyword.obfuscate], model_config[YAMLKeyword.obfuscate],
configs[YAMLKeyword.build_type], configs[YAMLKeyword.build_type],
data_type, data_type,
",".join(model_config.get(YAMLKeyword.transformers, []))) ",".join(model_config.get(YAMLKeyword.graph_optimize_options, [])))
if configs[YAMLKeyword.build_type] == BuildType.proto: if configs[YAMLKeyword.build_type] == BuildType.proto:
sh.mv("-f", sh.mv("-f",
......
...@@ -490,7 +490,7 @@ def gen_model_code(model_codegen_dir, ...@@ -490,7 +490,7 @@ def gen_model_code(model_codegen_dir,
obfuscate, obfuscate,
model_build_type, model_build_type,
data_type, data_type,
transformers): graph_optimize_options):
bazel_build_common("//mace/python/tools:converter") bazel_build_common("//mace/python/tools:converter")
if os.path.exists(model_codegen_dir): if os.path.exists(model_codegen_dir):
...@@ -517,7 +517,7 @@ def gen_model_code(model_codegen_dir, ...@@ -517,7 +517,7 @@ def gen_model_code(model_codegen_dir,
"--output_dir=%s" % model_codegen_dir, "--output_dir=%s" % model_codegen_dir,
"--model_build_type=%s" % model_build_type, "--model_build_type=%s" % model_build_type,
"--data_type=%s" % data_type, "--data_type=%s" % data_type,
"--transformers=%s" % transformers, "--graph_optimize_options=%s" % graph_optimize_options,
_fg=True) _fg=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册