diff --git a/python/paddle/fluid/dygraph/checkpoint.py b/python/paddle/fluid/dygraph/checkpoint.py index 3d797d5049ea35e4af7da68ab47d6131d8e6e3f5..f96b53e8c0b1e6ee93a14ecc811cd32a01bc7702 100644 --- a/python/paddle/fluid/dygraph/checkpoint.py +++ b/python/paddle/fluid/dygraph/checkpoint.py @@ -134,14 +134,19 @@ def _save_var_to_file(stat_dict, file_dir, file_name): def _load_var_from_file(file_dir): def walk_filename(file_dir): + base_path = os.path.join(file_dir) var_name_list = [] - if os.path.exists(file_dir) and os.path.exists(os.path.join(file_dir)): - base_path = os.path.join(file_dir) - for dirpath, dirnames, filenames in os.walk(os.path.join(file_dir)): - pt = dirpath.replace(base_path, "", 1)[1:] + if os.path.exists(base_path): + for dirpath, dirnames, filenames in os.walk(base_path): + pt = dirpath.replace(base_path, "", 1) + if pt.startswith("/") or pt.startswith("\\"): + pt = pt[1:] for fth_name in filenames: if fth_name[0] != '.': - var_name_list.append(os.path.join(pt, fth_name)) + name_path = os.path.join(pt, fth_name) + if "\\" in name_path: + name_path = name_path.replace("\\", "/") + var_name_list.append(name_path) return var_name_list