提交 dd5d9ca4 编写于 作者: J jiazhenwei

fix python3 compatible when compile model to lib

上级 162399b6
......@@ -27,9 +27,10 @@ def gen_mace_engine_factory(model_tags, template_dir,
loader=FileSystemLoader(template_dir), trim_blocks=True)
# generate mace_run BUILD file
template_name = 'mace_engine_factory.h.jinja2'
model_tags = list(model_tags)
source = j2_env.get_template(template_name).render(
model_tags=model_tags,
embed_model_data=embed_model_data,
)
with open(output_dir + '/mace_engine_factory.h', "wb") as f:
with open(output_dir + '/mace_engine_factory.h', "w") as f:
f.write(source)
......@@ -216,7 +216,7 @@ def save_model_to_code(net_def, model_tag, runtime,
tensor=tensor,
tag=model_tag,
)
with open(output_dir + 'tensor' + str(counter) + '.cc', "wb") as f:
with open(output_dir + 'tensor' + str(counter) + '.cc', "w") as f:
f.write(source)
counter += 1
......@@ -228,7 +228,7 @@ def save_model_to_code(net_def, model_tag, runtime,
tag=model_tag,
model_data_size=len(model_data),
model_data=model_data)
with open(output_dir + 'tensor_data' + '.cc', "wb") as f:
with open(output_dir + 'tensor_data' + '.cc', "w") as f:
f.write(source)
# generate op source files
......@@ -243,7 +243,7 @@ def save_model_to_code(net_def, model_tag, runtime,
tag=model_tag,
runtime=runtime,
)
with open(output_dir + 'op' + str(counter) + '.cc', "wb") as f:
with open(output_dir + 'op' + str(counter) + '.cc', "w") as f:
f.write(source)
counter += 1
......@@ -262,13 +262,13 @@ def save_model_to_code(net_def, model_tag, runtime,
winograd_conv=winograd_conv,
checksum=checksum,
build_time=build_time)
with open(output_dir + 'model.cc', "wb") as f:
with open(output_dir + 'model.cc', "w") as f:
f.write(source)
# generate model header file
template_name = 'model_header.jinja2'
source = j2_env.get_template(template_name).render(tag=model_tag, )
with open(output_dir + model_tag + '.h', "wb") as f:
with open(output_dir + model_tag + '.h', "w") as f:
f.write(source)
......
......@@ -80,7 +80,7 @@ void CreateOperator{{i}}(mace::OperatorDef *op) {
arg->set_i({{ arg.i }});
{%- endif %}
{%- if arg.HasField('s') %}
arg->set_s({{ arg.s|tojson }});
arg->set_s({{ arg.s.decode('utf-8')|tojson }});
{%- endif %}
arg->mutable_floats()->Reserve({{ arg.floats|length }});
......@@ -161,4 +161,3 @@ void CreateOperator{{i}}(mace::OperatorDef *op) {
} // namespace {{tag}}
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册