提交 cf38c676 编写于 作者: X Xin Pan

fix

上级 7ba55aa2
......@@ -618,10 +618,6 @@ def save_inference_model(dirname,
if main_program is None:
main_program = default_main_program()
if params_filename is not None:
params_filename = os.path.basename(params_filename)
save_persistables(executor, dirname, main_program, params_filename)
# if there is lookup table, the trainer 0 will notify all pserver to save.
if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table:
lookup_table_filename = os.path.join(dirname, "__lookup_table__")
......@@ -665,6 +661,10 @@ def save_inference_model(dirname,
with open(model_basename + ".main_program", "wb") as f:
f.write(main_program.desc.serialize_to_string())
if params_filename is not None:
params_filename = os.path.basename(params_filename)
save_persistables(executor, dirname, main_program, params_filename)
def load_inference_model(dirname,
executor,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册