未验证 提交 a88a1faa 编写于 作者: L lujun 提交者: GitHub

Format file path (#17280)

The parameter dirpath will be passed directly to c++ operater. The file address format will be different under win and UNIX.
上级 5d6a1fcf
......@@ -17,8 +17,6 @@ from __future__ import print_function
import os
import errno
import warnings
import time
import shutil
import six
import logging
from functools import reduce
......@@ -168,6 +166,7 @@ def save_vars(executor,
# var_a, var_b and var_c will be saved. And they are going to be
# saved in the same file named 'var_file' in the path "./my_paddle_model".
"""
save_dirname = os.path.normpath(dirname)
if vars is None:
if main_program is None:
main_program = default_main_program()
......@@ -177,7 +176,7 @@ def save_vars(executor,
save_vars(
executor,
main_program=main_program,
dirname=dirname,
dirname=save_dirname,
vars=list(filter(predicate, main_program.list_vars())),
filename=filename)
else:
......@@ -200,7 +199,9 @@ def save_vars(executor,
type='save',
inputs={'X': [new_var]},
outputs={},
attrs={'file_path': os.path.join(dirname, new_var.name)})
attrs={
'file_path': os.path.join(save_dirname, new_var.name)
})
else:
save_var_map[new_var.name] = new_var
......@@ -213,7 +214,7 @@ def save_vars(executor,
type='save_combine',
inputs={'X': save_var_list},
outputs={},
attrs={'file_path': os.path.join(dirname, filename)})
attrs={'file_path': os.path.join(save_dirname, filename)})
executor.run(save_program)
......@@ -567,6 +568,7 @@ def load_vars(executor,
# var_a, var_b and var_c will be loaded. And they are supposed to haven
# been saved in the same file named 'var_file' in the path "./my_paddle_model".
"""
load_dirname = os.path.normpath(dirname)
if vars is None:
if main_program is None:
main_program = default_main_program()
......@@ -575,7 +577,7 @@ def load_vars(executor,
load_vars(
executor,
dirname=dirname,
dirname=load_dirname,
main_program=main_program,
vars=list(filter(predicate, main_program.list_vars())),
filename=filename)
......@@ -599,7 +601,9 @@ def load_vars(executor,
type='load',
inputs={},
outputs={'Out': [new_var]},
attrs={'file_path': os.path.join(dirname, new_var.name)})
attrs={
'file_path': os.path.join(load_dirname, new_var.name)
})
else:
load_var_map[new_var.name] = new_var
......@@ -612,7 +616,7 @@ def load_vars(executor,
type='load_combine',
inputs={},
outputs={"Out": load_var_list},
attrs={'file_path': os.path.join(dirname, filename)})
attrs={'file_path': os.path.join(load_dirname, filename)})
executor.run(load_prog)
......@@ -985,8 +989,10 @@ def save_inference_model(dirname,
target_var_name_list = [var.name for var in target_vars]
# when a pserver and a trainer running on the same machine, mkdir may conflict
save_dirname = dirname
try:
os.makedirs(dirname)
save_dirname = os.path.normpath(dirname)
os.makedirs(save_dirname)
except OSError as e:
if e.errno != errno.EEXIST:
raise
......@@ -995,7 +1001,7 @@ def save_inference_model(dirname,
model_basename = os.path.basename(model_filename)
else:
model_basename = "__model__"
model_basename = os.path.join(dirname, model_basename)
model_basename = os.path.join(save_dirname, model_basename)
# When export_for_deployment is true, we modify the program online so that
# it can only be loaded for inference directly. If it's false, the whole
......@@ -1038,7 +1044,7 @@ def save_inference_model(dirname,
if params_filename is not None:
params_filename = os.path.basename(params_filename)
save_persistables(executor, dirname, main_program, params_filename)
save_persistables(executor, save_dirname, main_program, params_filename)
return target_var_name_list
......@@ -1102,14 +1108,15 @@ def load_inference_model(dirname,
# program to get the inference result.
"""
if not os.path.isdir(dirname):
load_dirname = os.path.normpath(dirname)
if not os.path.isdir(load_dirname):
raise ValueError("There is no directory named '%s'", dirname)
if model_filename is not None:
model_filename = os.path.basename(model_filename)
else:
model_filename = "__model__"
model_filename = os.path.join(dirname, model_filename)
model_filename = os.path.join(load_dirname, model_filename)
if params_filename is not None:
params_filename = os.path.basename(params_filename)
......@@ -1122,7 +1129,7 @@ def load_inference_model(dirname,
raise ValueError("Unsupported program version: %d\n" %
program._version())
# Binary data also need versioning.
load_persistables(executor, dirname, program, params_filename)
load_persistables(executor, load_dirname, program, params_filename)
if pserver_endpoints:
program = _endpoints_replacement(program, pserver_endpoints)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册