From ac32bf6f77d45275418e45b0d6579f96877dfaf1 Mon Sep 17 00:00:00 2001 From: lujun Date: Thu, 21 Mar 2019 19:32:48 +0800 Subject: [PATCH] update input params type, test=develop --- python/paddle/fluid/imperative/checkpoint.py | 29 ++++++++------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/python/paddle/fluid/imperative/checkpoint.py b/python/paddle/fluid/imperative/checkpoint.py index 97bad771793..37c43f29d2a 100644 --- a/python/paddle/fluid/imperative/checkpoint.py +++ b/python/paddle/fluid/imperative/checkpoint.py @@ -17,13 +17,12 @@ from __future__ import print_function import os import collections from .. import core -from ..framework import Variable, Parameter, default_main_program -from .layers import Layer +from ..framework import Variable, default_main_program __all__ = ['save_persistables', 'load_persistables'] -def save_persistables(obj, dirname, filename=None): +def save_persistables(vardict, dirname, filename=None): """ This function filters out all variables in layer.parameters from the give `layer` and then trys to load these variables from the folder @@ -35,7 +34,7 @@ def save_persistables(obj, dirname, filename=None): the file name. Args: - var_list(dict of Parameters|Layer): The parameters will + vardict(dict of Parameters): The parameters will be saved. If it is None, nothing will be deal. dirname(str): The directory path. @@ -69,17 +68,14 @@ def save_persistables(obj, dirname, filename=None): dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden, init_cell) param_path = "./my_paddle_model" - fluid.imperative.checkpoint.save_persistables(ptb_model.parameters(), dirname=param_path, + fluid.imperative.checkpoint.save_persistables(ptb_model.state_dict(), dirname=param_path, layer=ptb_model) """ - if isinstance(obj, collections.OrderedDict): - _save_var_to_file(obj, dirname, filename) - elif isinstance(obj, Layer): - _save_var_to_file( - obj.state_dict(include_sublayers=True), dirname, filename) + if isinstance(vardict, collections.OrderedDict): + _save_var_to_file(vardict, dirname, filename) -def load_persistables(obj, dirname, filename=None): +def load_persistables(vardict, dirname, filename=None): """ This function trys to load persistable variables from the folder `dirname` or the file `filename`. @@ -90,7 +86,7 @@ def load_persistables(obj, dirname, filename=None): the file name. Args: - obj(dict of Parameters|Layer): The parameters will be loaded. + vardict(dict of Parameters): The parameters will be loaded. dirname(str): The directory path. filename(str|None): The file which saved all variables, this file path should be end with '.npz'. If variables were saved in differnet files, set it to None. @@ -111,16 +107,13 @@ def load_persistables(obj, dirname, filename=None): my_layer = layer(fluid.imperative.Layer) param_path = "./my_paddle_model" filename = "model.file" - param_dict = fluid.imperative.checkpoint.load_persistables(my_layer, var_list, param_path, + param_dict = fluid.imperative.checkpoint.load_persistables(my_layer.state_dict(), param_path, filename=filename) param_1 = param_dict['PtbModel_0.w_1'] """ - if isinstance(obj, collections.OrderedDict): - return _load_var_from_file(obj, dirname, filename) - elif isinstance(obj, Layer): - return _load_var_from_file( - obj.state_dict(include_sublayers=True), dirname, filename) + if isinstance(vardict, collections.OrderedDict): + return _load_var_from_file(vardict, dirname, filename) return {} -- GitLab