提交 ac32bf6f 编写于 作者: L lujun

update input params type, test=develop

上级 09442fb2
...@@ -17,13 +17,12 @@ from __future__ import print_function ...@@ -17,13 +17,12 @@ from __future__ import print_function
import os import os
import collections import collections
from .. import core from .. import core
from ..framework import Variable, Parameter, default_main_program from ..framework import Variable, default_main_program
from .layers import Layer
__all__ = ['save_persistables', 'load_persistables'] __all__ = ['save_persistables', 'load_persistables']
def save_persistables(obj, dirname, filename=None): def save_persistables(vardict, dirname, filename=None):
""" """
This function filters out all variables in layer.parameters from the This function filters out all variables in layer.parameters from the
give `layer` and then trys to load these variables from the folder give `layer` and then trys to load these variables from the folder
...@@ -35,7 +34,7 @@ def save_persistables(obj, dirname, filename=None): ...@@ -35,7 +34,7 @@ def save_persistables(obj, dirname, filename=None):
the file name. the file name.
Args: Args:
var_list(dict of Parameters|Layer): The parameters will vardict(dict of Parameters): The parameters will
be saved. If it is None, nothing be saved. If it is None, nothing
will be deal. will be deal.
dirname(str): The directory path. dirname(str): The directory path.
...@@ -69,17 +68,14 @@ def save_persistables(obj, dirname, filename=None): ...@@ -69,17 +68,14 @@ def save_persistables(obj, dirname, filename=None):
dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden, dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden,
init_cell) init_cell)
param_path = "./my_paddle_model" param_path = "./my_paddle_model"
fluid.imperative.checkpoint.save_persistables(ptb_model.parameters(), dirname=param_path, fluid.imperative.checkpoint.save_persistables(ptb_model.state_dict(), dirname=param_path,
layer=ptb_model) layer=ptb_model)
""" """
if isinstance(obj, collections.OrderedDict): if isinstance(vardict, collections.OrderedDict):
_save_var_to_file(obj, dirname, filename) _save_var_to_file(vardict, dirname, filename)
elif isinstance(obj, Layer):
_save_var_to_file(
obj.state_dict(include_sublayers=True), dirname, filename)
def load_persistables(obj, dirname, filename=None): def load_persistables(vardict, dirname, filename=None):
""" """
This function trys to load persistable variables from the folder This function trys to load persistable variables from the folder
`dirname` or the file `filename`. `dirname` or the file `filename`.
...@@ -90,7 +86,7 @@ def load_persistables(obj, dirname, filename=None): ...@@ -90,7 +86,7 @@ def load_persistables(obj, dirname, filename=None):
the file name. the file name.
Args: Args:
obj(dict of Parameters|Layer): The parameters will be loaded. vardict(dict of Parameters): The parameters will be loaded.
dirname(str): The directory path. dirname(str): The directory path.
filename(str|None): The file which saved all variables, this file path should be end with '.npz'. If variables were filename(str|None): The file which saved all variables, this file path should be end with '.npz'. If variables were
saved in differnet files, set it to None. saved in differnet files, set it to None.
...@@ -111,16 +107,13 @@ def load_persistables(obj, dirname, filename=None): ...@@ -111,16 +107,13 @@ def load_persistables(obj, dirname, filename=None):
my_layer = layer(fluid.imperative.Layer) my_layer = layer(fluid.imperative.Layer)
param_path = "./my_paddle_model" param_path = "./my_paddle_model"
filename = "model.file" filename = "model.file"
param_dict = fluid.imperative.checkpoint.load_persistables(my_layer, var_list, param_path, param_dict = fluid.imperative.checkpoint.load_persistables(my_layer.state_dict(), param_path,
filename=filename) filename=filename)
param_1 = param_dict['PtbModel_0.w_1'] param_1 = param_dict['PtbModel_0.w_1']
""" """
if isinstance(obj, collections.OrderedDict): if isinstance(vardict, collections.OrderedDict):
return _load_var_from_file(obj, dirname, filename) return _load_var_from_file(vardict, dirname, filename)
elif isinstance(obj, Layer):
return _load_var_from_file(
obj.state_dict(include_sublayers=True), dirname, filename)
return {} return {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册