From 4bf2ccaa52d0a4370318a03028282626caf61bab Mon Sep 17 00:00:00 2001 From: songyouwei Date: Wed, 15 Jan 2020 13:04:56 +0800 Subject: [PATCH] fix save_dygraph & save with nonexistent dir (#22266) * fix save_dygraph with nonexistent dir test=develop * minor fix test=develop * fix unittest test=develop * fix static save test=develop --- python/paddle/fluid/dygraph/checkpoint.py | 9 +++++++-- python/paddle/fluid/io.py | 4 ++++ .../fluid/tests/unittests/test_imperative_save_load.py | 6 ++++-- .../fluid/tests/unittests/test_static_save_load.py | 5 +++-- 4 files changed, 18 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/dygraph/checkpoint.py b/python/paddle/fluid/dygraph/checkpoint.py index 5ed4e2d412e..27658ba3d46 100644 --- a/python/paddle/fluid/dygraph/checkpoint.py +++ b/python/paddle/fluid/dygraph/checkpoint.py @@ -82,7 +82,12 @@ def save_dygraph(state_dict, model_path): name_table[k] = v.name model_dict["StructuredToParameterName@@"] = name_table - with open(model_path + suffix, 'wb') as f: + file_name = model_path + suffix + dir_name = os.path.dirname(file_name) + if dir_name and not os.path.exists(dir_name): + os.makedirs(dir_name) + + with open(file_name, 'wb') as f: pickle.dump(model_dict, f) @@ -113,7 +118,7 @@ def load_dygraph(model_path, keep_name_table=False): adam = fluid.optimizer.Adam( learning_rate = fluid.layers.noam_decay( 100, 10000), parameter_list = emb.parameters() ) state_dict = adam.state_dict() - fluid.save_dygraph( state_dict, "padle_dy") + fluid.save_dygraph( state_dict, "paddle_dy") para_state_dict, opti_state_dict = fluid.load_dygraph( "paddle_dy") diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 0436137a80c..73273acf7a2 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -1518,6 +1518,10 @@ def save(program, model_path): assert base_name != "", \ "model_path MUST be format of dirname/filename [dirname\\filename in Window], Now filename is empty str" + dir_name = os.path.dirname(model_path) + if dir_name and not os.path.exists(dir_name): + os.makedirs(dir_name) + def get_tensor(var): t = global_scope().find_var(var.name).get_tensor() return np.array(t) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_save_load.py b/python/paddle/fluid/tests/unittests/test_imperative_save_load.py index 01327ac647f..6a621b8c75c 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_save_load.py @@ -14,6 +14,7 @@ from __future__ import print_function +import os import unittest import paddle.fluid as fluid import paddle.fluid.core as core @@ -879,9 +880,10 @@ class TestDygraphPtbRnn(unittest.TestCase): with fluid.dygraph.guard(): emb = fluid.dygraph.Embedding([10, 10]) state_dict = emb.state_dict() - fluid.save_dygraph(state_dict, "emb_dy") + fluid.save_dygraph(state_dict, os.path.join('saved_dy', 'emb_dy')) - para_state_dict, opti_state_dict = fluid.load_dygraph("emb_dy") + para_state_dict, opti_state_dict = fluid.load_dygraph( + os.path.join('saved_dy', 'emb_dy')) self.assertTrue(opti_state_dict == None) diff --git a/python/paddle/fluid/tests/unittests/test_static_save_load.py b/python/paddle/fluid/tests/unittests/test_static_save_load.py index 0dd767edc4c..24b61f514ce 100644 --- a/python/paddle/fluid/tests/unittests/test_static_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_static_save_load.py @@ -609,7 +609,7 @@ class TestProgramStatePartial(unittest.TestCase): self.assertTrue(np.sum(np.abs(t)) != 0) base_map[var.name] = t - fluid.save(main_program, "./test_1") + fluid.save(main_program, os.path.join('some_dir', 'test_1')) # set var to zero for var in main_program.list_vars(): @@ -623,7 +623,8 @@ class TestProgramStatePartial(unittest.TestCase): self.assertTrue(np.sum(np.abs(new_t)) == 0) #fluid.load(test_program, "./test_1", None ) - program_state = fluid.load_program_state("./test_1") + program_state = fluid.load_program_state( + os.path.join('some_dir', 'test_1')) fluid.set_program_state(test_program, program_state) for var in test_program.list_vars(): -- GitLab