未验证 提交 a88a1faa 编写于 作者: L lujun 提交者: GitHub

Format file path (#17280)

The parameter dirpath will be passed directly to c++ operater. The file address format will be different under win and UNIX.
上级 5d6a1fcf
...@@ -17,8 +17,6 @@ from __future__ import print_function ...@@ -17,8 +17,6 @@ from __future__ import print_function
import os import os
import errno import errno
import warnings import warnings
import time
import shutil
import six import six
import logging import logging
from functools import reduce from functools import reduce
...@@ -168,6 +166,7 @@ def save_vars(executor, ...@@ -168,6 +166,7 @@ def save_vars(executor,
# var_a, var_b and var_c will be saved. And they are going to be # var_a, var_b and var_c will be saved. And they are going to be
# saved in the same file named 'var_file' in the path "./my_paddle_model". # saved in the same file named 'var_file' in the path "./my_paddle_model".
""" """
save_dirname = os.path.normpath(dirname)
if vars is None: if vars is None:
if main_program is None: if main_program is None:
main_program = default_main_program() main_program = default_main_program()
...@@ -177,7 +176,7 @@ def save_vars(executor, ...@@ -177,7 +176,7 @@ def save_vars(executor,
save_vars( save_vars(
executor, executor,
main_program=main_program, main_program=main_program,
dirname=dirname, dirname=save_dirname,
vars=list(filter(predicate, main_program.list_vars())), vars=list(filter(predicate, main_program.list_vars())),
filename=filename) filename=filename)
else: else:
...@@ -200,7 +199,9 @@ def save_vars(executor, ...@@ -200,7 +199,9 @@ def save_vars(executor,
type='save', type='save',
inputs={'X': [new_var]}, inputs={'X': [new_var]},
outputs={}, outputs={},
attrs={'file_path': os.path.join(dirname, new_var.name)}) attrs={
'file_path': os.path.join(save_dirname, new_var.name)
})
else: else:
save_var_map[new_var.name] = new_var save_var_map[new_var.name] = new_var
...@@ -213,7 +214,7 @@ def save_vars(executor, ...@@ -213,7 +214,7 @@ def save_vars(executor,
type='save_combine', type='save_combine',
inputs={'X': save_var_list}, inputs={'X': save_var_list},
outputs={}, outputs={},
attrs={'file_path': os.path.join(dirname, filename)}) attrs={'file_path': os.path.join(save_dirname, filename)})
executor.run(save_program) executor.run(save_program)
...@@ -567,6 +568,7 @@ def load_vars(executor, ...@@ -567,6 +568,7 @@ def load_vars(executor,
# var_a, var_b and var_c will be loaded. And they are supposed to haven # var_a, var_b and var_c will be loaded. And they are supposed to haven
# been saved in the same file named 'var_file' in the path "./my_paddle_model". # been saved in the same file named 'var_file' in the path "./my_paddle_model".
""" """
load_dirname = os.path.normpath(dirname)
if vars is None: if vars is None:
if main_program is None: if main_program is None:
main_program = default_main_program() main_program = default_main_program()
...@@ -575,7 +577,7 @@ def load_vars(executor, ...@@ -575,7 +577,7 @@ def load_vars(executor,
load_vars( load_vars(
executor, executor,
dirname=dirname, dirname=load_dirname,
main_program=main_program, main_program=main_program,
vars=list(filter(predicate, main_program.list_vars())), vars=list(filter(predicate, main_program.list_vars())),
filename=filename) filename=filename)
...@@ -599,7 +601,9 @@ def load_vars(executor, ...@@ -599,7 +601,9 @@ def load_vars(executor,
type='load', type='load',
inputs={}, inputs={},
outputs={'Out': [new_var]}, outputs={'Out': [new_var]},
attrs={'file_path': os.path.join(dirname, new_var.name)}) attrs={
'file_path': os.path.join(load_dirname, new_var.name)
})
else: else:
load_var_map[new_var.name] = new_var load_var_map[new_var.name] = new_var
...@@ -612,7 +616,7 @@ def load_vars(executor, ...@@ -612,7 +616,7 @@ def load_vars(executor,
type='load_combine', type='load_combine',
inputs={}, inputs={},
outputs={"Out": load_var_list}, outputs={"Out": load_var_list},
attrs={'file_path': os.path.join(dirname, filename)}) attrs={'file_path': os.path.join(load_dirname, filename)})
executor.run(load_prog) executor.run(load_prog)
...@@ -985,8 +989,10 @@ def save_inference_model(dirname, ...@@ -985,8 +989,10 @@ def save_inference_model(dirname,
target_var_name_list = [var.name for var in target_vars] target_var_name_list = [var.name for var in target_vars]
# when a pserver and a trainer running on the same machine, mkdir may conflict # when a pserver and a trainer running on the same machine, mkdir may conflict
save_dirname = dirname
try: try:
os.makedirs(dirname) save_dirname = os.path.normpath(dirname)
os.makedirs(save_dirname)
except OSError as e: except OSError as e:
if e.errno != errno.EEXIST: if e.errno != errno.EEXIST:
raise raise
...@@ -995,7 +1001,7 @@ def save_inference_model(dirname, ...@@ -995,7 +1001,7 @@ def save_inference_model(dirname,
model_basename = os.path.basename(model_filename) model_basename = os.path.basename(model_filename)
else: else:
model_basename = "__model__" model_basename = "__model__"
model_basename = os.path.join(dirname, model_basename) model_basename = os.path.join(save_dirname, model_basename)
# When export_for_deployment is true, we modify the program online so that # When export_for_deployment is true, we modify the program online so that
# it can only be loaded for inference directly. If it's false, the whole # it can only be loaded for inference directly. If it's false, the whole
...@@ -1038,7 +1044,7 @@ def save_inference_model(dirname, ...@@ -1038,7 +1044,7 @@ def save_inference_model(dirname,
if params_filename is not None: if params_filename is not None:
params_filename = os.path.basename(params_filename) params_filename = os.path.basename(params_filename)
save_persistables(executor, dirname, main_program, params_filename) save_persistables(executor, save_dirname, main_program, params_filename)
return target_var_name_list return target_var_name_list
...@@ -1102,14 +1108,15 @@ def load_inference_model(dirname, ...@@ -1102,14 +1108,15 @@ def load_inference_model(dirname,
# program to get the inference result. # program to get the inference result.
""" """
if not os.path.isdir(dirname): load_dirname = os.path.normpath(dirname)
if not os.path.isdir(load_dirname):
raise ValueError("There is no directory named '%s'", dirname) raise ValueError("There is no directory named '%s'", dirname)
if model_filename is not None: if model_filename is not None:
model_filename = os.path.basename(model_filename) model_filename = os.path.basename(model_filename)
else: else:
model_filename = "__model__" model_filename = "__model__"
model_filename = os.path.join(dirname, model_filename) model_filename = os.path.join(load_dirname, model_filename)
if params_filename is not None: if params_filename is not None:
params_filename = os.path.basename(params_filename) params_filename = os.path.basename(params_filename)
...@@ -1122,7 +1129,7 @@ def load_inference_model(dirname, ...@@ -1122,7 +1129,7 @@ def load_inference_model(dirname,
raise ValueError("Unsupported program version: %d\n" % raise ValueError("Unsupported program version: %d\n" %
program._version()) program._version())
# Binary data also need versioning. # Binary data also need versioning.
load_persistables(executor, dirname, program, params_filename) load_persistables(executor, load_dirname, program, params_filename)
if pserver_endpoints: if pserver_endpoints:
program = _endpoints_replacement(program, pserver_endpoints) program = _endpoints_replacement(program, pserver_endpoints)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册