From 9d2f7d762c4f91d415c9e2d5a3668dc7b80471a0 Mon Sep 17 00:00:00 2001 From: lujun Date: Fri, 19 Apr 2019 19:55:08 +0800 Subject: [PATCH] fix dy-load bug, test=develop --- python/paddle/fluid/dygraph/checkpoint.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/dygraph/checkpoint.py b/python/paddle/fluid/dygraph/checkpoint.py index 3d797d5049e..f96b53e8c0b 100644 --- a/python/paddle/fluid/dygraph/checkpoint.py +++ b/python/paddle/fluid/dygraph/checkpoint.py @@ -134,14 +134,19 @@ def _save_var_to_file(stat_dict, file_dir, file_name): def _load_var_from_file(file_dir): def walk_filename(file_dir): + base_path = os.path.join(file_dir) var_name_list = [] - if os.path.exists(file_dir) and os.path.exists(os.path.join(file_dir)): - base_path = os.path.join(file_dir) - for dirpath, dirnames, filenames in os.walk(os.path.join(file_dir)): - pt = dirpath.replace(base_path, "", 1)[1:] + if os.path.exists(base_path): + for dirpath, dirnames, filenames in os.walk(base_path): + pt = dirpath.replace(base_path, "", 1) + if pt.startswith("/") or pt.startswith("\\"): + pt = pt[1:] for fth_name in filenames: if fth_name[0] != '.': - var_name_list.append(os.path.join(pt, fth_name)) + name_path = os.path.join(pt, fth_name) + if "\\" in name_path: + name_path = name_path.replace("\\", "/") + var_name_list.append(name_path) return var_name_list -- GitLab