未验证 提交 fba10b6b 编写于 作者: J Jiabin Yang 提交者: GitHub

test=develop, refine api (#17883)

* test=develop, refine api

* test=develop, fix bug when error occured on save_persistable with no optimizer

* test=develop, refine waring

* test=develop, refine example code and comments
上级 fbbdc9cc
...@@ -24,13 +24,10 @@ import warnings ...@@ -24,13 +24,10 @@ import warnings
__all__ = ['save_persistables', 'load_persistables'] __all__ = ['save_persistables', 'load_persistables']
def save_persistables(model_dict, def save_persistables(model_dict, dirname='save_dir', optimizers=None):
optimizer=None,
dirname='save_dir',
filename=None):
""" """
This function filters out all variables in layer.parameters from the This function filters out all variables in layer.parameters from the
give `layer` and then trys to load these variables from the folder give `layer`, and optimizer's learning rate decay and then trys to load these variables from the folder
`dirname` or the file `filename`. `dirname` or the file `filename`.
Use the `dirname` to specify the folder where persistable variables were Use the `dirname` to specify the folder where persistable variables were
...@@ -43,9 +40,7 @@ def save_persistables(model_dict, ...@@ -43,9 +40,7 @@ def save_persistables(model_dict,
be saved. If it is None, nothing be saved. If it is None, nothing
will be deal. will be deal.
dirname(str): The directory path. dirname(str): The directory path.
filename(str|None): The file which saved all variables. If variables were optimizers(fluid.Optimizer|list(fluid.Optimizer)|None): The optimizers to be saved
saved in different files, set it to None.
Default: None
Returns: Returns:
...@@ -57,7 +52,7 @@ def save_persistables(model_dict, ...@@ -57,7 +52,7 @@ def save_persistables(model_dict,
num_layers=num_layers, num_layers=num_layers,
num_steps=num_steps, num_steps=num_steps,
init_scale=init_scale) init_scale=init_scale)
sgd = fluid.optimizer.SGD(learning_rate=0.01)
x_data = np.arange(12).reshape(4, 3).astype('int64') x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64') y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1)) x_data = x_data.reshape((-1, num_steps, 1))
...@@ -72,12 +67,14 @@ def save_persistables(model_dict, ...@@ -72,12 +67,14 @@ def save_persistables(model_dict,
init_cell = to_variable(init_cell_data) init_cell = to_variable(init_cell_data)
dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden, dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden,
init_cell) init_cell)
dy_loss.backward()
sgd.minimize(dy_loss)
ptb_model.clear_gradient()
param_path = "./my_paddle_model" param_path = "./my_paddle_model"
fluid.dygraph.save_persistables(ptb_model.state_dict(), dirname=param_path, fluid.dygraph.save_persistables(ptb_model.state_dict(), dirname=param_path, sgd)
layer=ptb_model)
""" """
if isinstance(model_dict, collections.OrderedDict): if isinstance(model_dict, collections.OrderedDict):
_save_var_to_file(model_dict, optimizer, dirname, filename) _save_var_to_file(model_dict, optimizers, dirname, None)
def load_persistables(dirname='save_dir'): def load_persistables(dirname='save_dir'):
...@@ -92,18 +89,19 @@ def load_persistables(dirname='save_dir'): ...@@ -92,18 +89,19 @@ def load_persistables(dirname='save_dir'):
Args: Args:
dirname(str): The directory path. default is save_dir dirname(str): The directory path. default is save_dir
optimizer(Optimizer): Optimizer to be saved
Returns: Returns:
dict: The parameter-dict resumed from file dict: The parameter-dict resumed from file
optimizer dict: The optimizer
Examples: Examples:
.. code-block:: python .. code-block:: python
my_layer = layer(fluid.Layer) my_layer = layer(fluid.Layer)
param_path = "./my_paddle_model" param_path = "./my_paddle_model"
sgd = SGDOptimizer(learning_rate=1e-3)
param_dict = fluid.dygraph.load_persistables(my_layer.parameters(), param_path) param_dict, optimizer_dict = fluid.dygraph.load_persistables(my_layer.parameters(), param_path)
param_1 = param_dict['PtbModel_0.w_1'] param_1 = param_dict['PtbModel_0.w_1']
sgd.load(optimizer_dict)
""" """
return _load_var_from_file(dirname) return _load_var_from_file(dirname)
...@@ -123,11 +121,14 @@ def _save_var_to_file(stat_dict, optimizers, file_dir, file_name): ...@@ -123,11 +121,14 @@ def _save_var_to_file(stat_dict, optimizers, file_dir, file_name):
'file_path': os.path.join(file_dir, 'file_path': os.path.join(file_dir,
os.path.normpath(each_var.name)) os.path.normpath(each_var.name))
}) })
if optimizers is not None:
if isinstance(optimizers, (list, tuple)): if isinstance(optimizers, (list, tuple)):
optimizers = optimizers optimizers = optimizers
else: else:
optimizers = [optimizers] optimizers = [optimizers]
if os.path.exists(os.path.join(file_dir, os.path.normpath("optimizers"))): if os.path.exists(
os.path.join(file_dir, os.path.normpath("optimizers"))):
pass pass
else: else:
os.mkdir(os.path.join(file_dir, os.path.normpath("optimizers"))) os.mkdir(os.path.join(file_dir, os.path.normpath("optimizers")))
...@@ -137,7 +138,8 @@ def _save_var_to_file(stat_dict, optimizers, file_dir, file_name): ...@@ -137,7 +138,8 @@ def _save_var_to_file(stat_dict, optimizers, file_dir, file_name):
try: try:
f = open( f = open(
os.path.join(file_dir, "optimizers", os.path.join(file_dir, "optimizers",
os.path.normpath(str(optimizer._name))), "wb") os.path.normpath(str(optimizer._name))),
"wb")
pickle.dump(optimizer._learning_rate, f, 2) pickle.dump(optimizer._learning_rate, f, 2)
f.close() f.close()
except (): except ():
...@@ -149,6 +151,8 @@ def _save_var_to_file(stat_dict, optimizers, file_dir, file_name): ...@@ -149,6 +151,8 @@ def _save_var_to_file(stat_dict, optimizers, file_dir, file_name):
warnings.warn( warnings.warn(
"Optimizer not saved, Only optimizer with 'LearningRateDecay' under DyGraph mode need to be saved" "Optimizer not saved, Only optimizer with 'LearningRateDecay' under DyGraph mode need to be saved"
) )
else:
pass
if file_name is not None: if file_name is not None:
save_var_list = [] save_var_list = []
...@@ -213,7 +217,10 @@ def _load_var_from_file(file_dir): ...@@ -213,7 +217,10 @@ def _load_var_from_file(file_dir):
file_dir, "optimizers", file_dir, "optimizers",
os.path.normpath(str(optimizer._name)))) os.path.normpath(str(optimizer._name))))
if len(load_optimizer_map) == 0: if len(load_optimizer_map) == 0:
warnings.warn("No optimizer loaded") print(
"No optimizer loaded. If you didn't save optimizer, please ignore this. The program can still work with new optimizer. "
)
pass
return load_var_map, load_optimizer_map return load_var_map, load_optimizer_map
......
...@@ -144,7 +144,7 @@ class TestDygraphCheckpoint(unittest.TestCase): ...@@ -144,7 +144,7 @@ class TestDygraphCheckpoint(unittest.TestCase):
avg_loss.backward() avg_loss.backward()
sgd.minimize(avg_loss) sgd.minimize(avg_loss)
fluid.dygraph.save_persistables(mnist.state_dict(), [sgd], fluid.dygraph.save_persistables(mnist.state_dict(),
"save_dir") "save_dir")
mnist.clear_gradients() mnist.clear_gradients()
......
...@@ -78,8 +78,8 @@ class TestImperativeOptimizerBase(unittest.TestCase): ...@@ -78,8 +78,8 @@ class TestImperativeOptimizerBase(unittest.TestCase):
optimizer.minimize(avg_loss) optimizer.minimize(avg_loss)
optimizer2.minimize(avg_loss) optimizer2.minimize(avg_loss)
mlp.clear_gradients() mlp.clear_gradients()
fluid.dygraph.save_persistables( fluid.dygraph.save_persistables(mlp.state_dict(), "save_dir_2",
mlp.state_dict(), [optimizer, optimizer2], "save_dir_2") [optimizer, optimizer2])
if batch_id == 2: if batch_id == 2:
break break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册