From a88a1faa48a42a8c3737deb0f05da968d200a7d3 Mon Sep 17 00:00:00 2001 From: lujun Date: Thu, 9 May 2019 10:03:02 +0800 Subject: [PATCH] Format file path (#17280) The parameter dirpath will be passed directly to c++ operater. The file address format will be different under win and UNIX. --- python/paddle/fluid/io.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index e15497b633..b573093c30 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -17,8 +17,6 @@ from __future__ import print_function import os import errno import warnings -import time -import shutil import six import logging from functools import reduce @@ -168,6 +166,7 @@ def save_vars(executor, # var_a, var_b and var_c will be saved. And they are going to be # saved in the same file named 'var_file' in the path "./my_paddle_model". """ + save_dirname = os.path.normpath(dirname) if vars is None: if main_program is None: main_program = default_main_program() @@ -177,7 +176,7 @@ def save_vars(executor, save_vars( executor, main_program=main_program, - dirname=dirname, + dirname=save_dirname, vars=list(filter(predicate, main_program.list_vars())), filename=filename) else: @@ -200,7 +199,9 @@ def save_vars(executor, type='save', inputs={'X': [new_var]}, outputs={}, - attrs={'file_path': os.path.join(dirname, new_var.name)}) + attrs={ + 'file_path': os.path.join(save_dirname, new_var.name) + }) else: save_var_map[new_var.name] = new_var @@ -213,7 +214,7 @@ def save_vars(executor, type='save_combine', inputs={'X': save_var_list}, outputs={}, - attrs={'file_path': os.path.join(dirname, filename)}) + attrs={'file_path': os.path.join(save_dirname, filename)}) executor.run(save_program) @@ -567,6 +568,7 @@ def load_vars(executor, # var_a, var_b and var_c will be loaded. And they are supposed to haven # been saved in the same file named 'var_file' in the path "./my_paddle_model". """ + load_dirname = os.path.normpath(dirname) if vars is None: if main_program is None: main_program = default_main_program() @@ -575,7 +577,7 @@ def load_vars(executor, load_vars( executor, - dirname=dirname, + dirname=load_dirname, main_program=main_program, vars=list(filter(predicate, main_program.list_vars())), filename=filename) @@ -599,7 +601,9 @@ def load_vars(executor, type='load', inputs={}, outputs={'Out': [new_var]}, - attrs={'file_path': os.path.join(dirname, new_var.name)}) + attrs={ + 'file_path': os.path.join(load_dirname, new_var.name) + }) else: load_var_map[new_var.name] = new_var @@ -612,7 +616,7 @@ def load_vars(executor, type='load_combine', inputs={}, outputs={"Out": load_var_list}, - attrs={'file_path': os.path.join(dirname, filename)}) + attrs={'file_path': os.path.join(load_dirname, filename)}) executor.run(load_prog) @@ -985,8 +989,10 @@ def save_inference_model(dirname, target_var_name_list = [var.name for var in target_vars] # when a pserver and a trainer running on the same machine, mkdir may conflict + save_dirname = dirname try: - os.makedirs(dirname) + save_dirname = os.path.normpath(dirname) + os.makedirs(save_dirname) except OSError as e: if e.errno != errno.EEXIST: raise @@ -995,7 +1001,7 @@ def save_inference_model(dirname, model_basename = os.path.basename(model_filename) else: model_basename = "__model__" - model_basename = os.path.join(dirname, model_basename) + model_basename = os.path.join(save_dirname, model_basename) # When export_for_deployment is true, we modify the program online so that # it can only be loaded for inference directly. If it's false, the whole @@ -1038,7 +1044,7 @@ def save_inference_model(dirname, if params_filename is not None: params_filename = os.path.basename(params_filename) - save_persistables(executor, dirname, main_program, params_filename) + save_persistables(executor, save_dirname, main_program, params_filename) return target_var_name_list @@ -1102,14 +1108,15 @@ def load_inference_model(dirname, # program to get the inference result. """ - if not os.path.isdir(dirname): + load_dirname = os.path.normpath(dirname) + if not os.path.isdir(load_dirname): raise ValueError("There is no directory named '%s'", dirname) if model_filename is not None: model_filename = os.path.basename(model_filename) else: model_filename = "__model__" - model_filename = os.path.join(dirname, model_filename) + model_filename = os.path.join(load_dirname, model_filename) if params_filename is not None: params_filename = os.path.basename(params_filename) @@ -1122,7 +1129,7 @@ def load_inference_model(dirname, raise ValueError("Unsupported program version: %d\n" % program._version()) # Binary data also need versioning. - load_persistables(executor, dirname, program, params_filename) + load_persistables(executor, load_dirname, program, params_filename) if pserver_endpoints: program = _endpoints_replacement(program, pserver_endpoints) -- GitLab