From b250fceab5e8f9f0c763d1faa054c078fc4db669 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 2 Mar 2017 14:28:15 +0800 Subject: [PATCH] Add save/load parameters. --- demo/mnist/.gitignore | 1 + demo/mnist/api_train_v2.py | 11 ++++++++++- python/paddle/v2/parameters.py | 27 +++++++++++++++++++++++++++ python/paddle/v2/trainer.py | 2 +- 4 files changed, 39 insertions(+), 2 deletions(-) diff --git a/demo/mnist/.gitignore b/demo/mnist/.gitignore index 8bd9837523c..9c552159be6 100644 --- a/demo/mnist/.gitignore +++ b/demo/mnist/.gitignore @@ -5,3 +5,4 @@ plot.png train.log *pyc .ipynb_checkpoints +params.pkl diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index a59b30ccdb2..73fcb9d79d3 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -1,4 +1,5 @@ import paddle.v2 as paddle +import cPickle def main(): @@ -16,7 +17,11 @@ def main(): act=paddle.activation.Softmax()) cost = paddle.layer.classification_cost(input=inference, label=label) - parameters = paddle.parameters.create(cost) + try: + with open('params.pkl', 'r') as f: + parameters = cPickle.load(f) + except IOError: + parameters = paddle.parameters.create(cost) adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01) @@ -34,6 +39,10 @@ def main(): event.pass_id, event.batch_id, event.cost, event.metrics, result.metrics) + with open('params.pkl', 'w') as f: + cPickle.dump( + parameters, f, protocol=cPickle.HIGHEST_PROTOCOL) + else: pass diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index 2a6026bcab1..d8c3a73b0ea 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -222,6 +222,33 @@ class Parameters(object): self.__gradient_machines__.append(gradient_machine) + def __getstate__(self): + params = {} + for name in self.names(): + params[name] = self.get(name) + + param_conf = {} + for name in self.__param_conf__: + conf = self.__param_conf__[name] + assert isinstance(conf, ParameterConfig) + param_conf[name] = conf.SerializeToString() + + return {'conf': param_conf, 'params': params} + + def __setstate__(self, obj): + Parameters.__init__(self) + + def __impl__(conf, params): + for name in conf: + p = ParameterConfig() + p.ParseFromString(conf[name]) + self.__append_config__(p) + for name in params: + shape = self.get_shape(name) + self.set(name, params[name].reshape(shape)) + + __impl__(**obj) + def __get_parameter_in_gradient_machine__(gradient_machine, name): """ diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 5003f55f3e0..709566ca447 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -66,9 +66,9 @@ class SGD(ITrainer): self.__topology_in_proto__, api.CREATE_MODE_NORMAL, self.__optimizer__.enable_types()) assert isinstance(gm, api.GradientMachine) - parameters.append_gradient_machine(gm) self.__gradient_machine__ = gm self.__gradient_machine__.randParameters() + parameters.append_gradient_machine(gm) def train(self, reader, num_passes=1, event_handler=None, reader_dict=None): """ -- GitLab