From dd5d9ca4e528c65c78916424e0da11f7dbf5f3ad Mon Sep 17 00:00:00 2001 From: jiazhenwei Date: Thu, 15 Nov 2018 16:37:18 +0800 Subject: [PATCH] fix python3 compatible when compile model to lib --- mace/python/tools/mace_engine_factory_codegen.py | 3 ++- mace/python/tools/model_saver.py | 10 +++++----- mace/python/tools/operator.jinja2 | 3 +-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mace/python/tools/mace_engine_factory_codegen.py b/mace/python/tools/mace_engine_factory_codegen.py index ce910065..5518deb3 100644 --- a/mace/python/tools/mace_engine_factory_codegen.py +++ b/mace/python/tools/mace_engine_factory_codegen.py @@ -27,9 +27,10 @@ def gen_mace_engine_factory(model_tags, template_dir, loader=FileSystemLoader(template_dir), trim_blocks=True) # generate mace_run BUILD file template_name = 'mace_engine_factory.h.jinja2' + model_tags = list(model_tags) source = j2_env.get_template(template_name).render( model_tags=model_tags, embed_model_data=embed_model_data, ) - with open(output_dir + '/mace_engine_factory.h', "wb") as f: + with open(output_dir + '/mace_engine_factory.h', "w") as f: f.write(source) diff --git a/mace/python/tools/model_saver.py b/mace/python/tools/model_saver.py index 56709ef6..117bb8ef 100644 --- a/mace/python/tools/model_saver.py +++ b/mace/python/tools/model_saver.py @@ -216,7 +216,7 @@ def save_model_to_code(net_def, model_tag, runtime, tensor=tensor, tag=model_tag, ) - with open(output_dir + 'tensor' + str(counter) + '.cc', "wb") as f: + with open(output_dir + 'tensor' + str(counter) + '.cc', "w") as f: f.write(source) counter += 1 @@ -228,7 +228,7 @@ def save_model_to_code(net_def, model_tag, runtime, tag=model_tag, model_data_size=len(model_data), model_data=model_data) - with open(output_dir + 'tensor_data' + '.cc', "wb") as f: + with open(output_dir + 'tensor_data' + '.cc', "w") as f: f.write(source) # generate op source files @@ -243,7 +243,7 @@ def save_model_to_code(net_def, model_tag, runtime, tag=model_tag, runtime=runtime, ) - with open(output_dir + 'op' + str(counter) + '.cc', "wb") as f: + with open(output_dir + 'op' + str(counter) + '.cc', "w") as f: f.write(source) counter += 1 @@ -262,13 +262,13 @@ def save_model_to_code(net_def, model_tag, runtime, winograd_conv=winograd_conv, checksum=checksum, build_time=build_time) - with open(output_dir + 'model.cc', "wb") as f: + with open(output_dir + 'model.cc', "w") as f: f.write(source) # generate model header file template_name = 'model_header.jinja2' source = j2_env.get_template(template_name).render(tag=model_tag, ) - with open(output_dir + model_tag + '.h', "wb") as f: + with open(output_dir + model_tag + '.h', "w") as f: f.write(source) diff --git a/mace/python/tools/operator.jinja2 b/mace/python/tools/operator.jinja2 index fc77b1e6..e3492ddf 100644 --- a/mace/python/tools/operator.jinja2 +++ b/mace/python/tools/operator.jinja2 @@ -80,7 +80,7 @@ void CreateOperator{{i}}(mace::OperatorDef *op) { arg->set_i({{ arg.i }}); {%- endif %} {%- if arg.HasField('s') %} - arg->set_s({{ arg.s|tojson }}); + arg->set_s({{ arg.s.decode('utf-8')|tojson }}); {%- endif %} arg->mutable_floats()->Reserve({{ arg.floats|length }}); @@ -161,4 +161,3 @@ void CreateOperator{{i}}(mace::OperatorDef *op) { } // namespace {{tag}} } // namespace mace - -- GitLab