提交 4bf2ccaa 编写于 作者: S songyouwei 提交者: hong

fix save_dygraph & save with nonexistent dir (#22266)

* fix save_dygraph with nonexistent dir
test=develop

* minor fix
test=develop

* fix unittest
test=develop

* fix static save
test=develop
上级 9a7245de
......@@ -82,7 +82,12 @@ def save_dygraph(state_dict, model_path):
name_table[k] = v.name
model_dict["StructuredToParameterName@@"] = name_table
with open(model_path + suffix, 'wb') as f:
file_name = model_path + suffix
dir_name = os.path.dirname(file_name)
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name)
with open(file_name, 'wb') as f:
pickle.dump(model_dict, f)
......@@ -113,7 +118,7 @@ def load_dygraph(model_path, keep_name_table=False):
adam = fluid.optimizer.Adam( learning_rate = fluid.layers.noam_decay( 100, 10000),
parameter_list = emb.parameters() )
state_dict = adam.state_dict()
fluid.save_dygraph( state_dict, "padle_dy")
fluid.save_dygraph( state_dict, "paddle_dy")
para_state_dict, opti_state_dict = fluid.load_dygraph( "paddle_dy")
......
......@@ -1518,6 +1518,10 @@ def save(program, model_path):
assert base_name != "", \
"model_path MUST be format of dirname/filename [dirname\\filename in Window], Now filename is empty str"
dir_name = os.path.dirname(model_path)
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name)
def get_tensor(var):
t = global_scope().find_var(var.name).get_tensor()
return np.array(t)
......
......@@ -14,6 +14,7 @@
from __future__ import print_function
import os
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
......@@ -879,9 +880,10 @@ class TestDygraphPtbRnn(unittest.TestCase):
with fluid.dygraph.guard():
emb = fluid.dygraph.Embedding([10, 10])
state_dict = emb.state_dict()
fluid.save_dygraph(state_dict, "emb_dy")
fluid.save_dygraph(state_dict, os.path.join('saved_dy', 'emb_dy'))
para_state_dict, opti_state_dict = fluid.load_dygraph("emb_dy")
para_state_dict, opti_state_dict = fluid.load_dygraph(
os.path.join('saved_dy', 'emb_dy'))
self.assertTrue(opti_state_dict == None)
......
......@@ -609,7 +609,7 @@ class TestProgramStatePartial(unittest.TestCase):
self.assertTrue(np.sum(np.abs(t)) != 0)
base_map[var.name] = t
fluid.save(main_program, "./test_1")
fluid.save(main_program, os.path.join('some_dir', 'test_1'))
# set var to zero
for var in main_program.list_vars():
......@@ -623,7 +623,8 @@ class TestProgramStatePartial(unittest.TestCase):
self.assertTrue(np.sum(np.abs(new_t)) == 0)
#fluid.load(test_program, "./test_1", None )
program_state = fluid.load_program_state("./test_1")
program_state = fluid.load_program_state(
os.path.join('some_dir', 'test_1'))
fluid.set_program_state(test_program, program_state)
for var in test_program.list_vars():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册