diff --git a/demo/mnist/.gitignore b/demo/mnist/.gitignore index 8bd9837523ccf98e6e72d5b82934b7b104816217..9c552159be6f73e06233956ba0ed1e077ac1d02a 100644 --- a/demo/mnist/.gitignore +++ b/demo/mnist/.gitignore @@ -5,3 +5,4 @@ plot.png train.log *pyc .ipynb_checkpoints +params.pkl diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index a59b30ccdb2eddea6680d6ad5c790c857b9c5141..73fcb9d79d32cca5c679234f3e4a254f01861a72 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -1,4 +1,5 @@ import paddle.v2 as paddle +import cPickle def main(): @@ -16,7 +17,11 @@ def main(): act=paddle.activation.Softmax()) cost = paddle.layer.classification_cost(input=inference, label=label) - parameters = paddle.parameters.create(cost) + try: + with open('params.pkl', 'r') as f: + parameters = cPickle.load(f) + except IOError: + parameters = paddle.parameters.create(cost) adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01) @@ -34,6 +39,10 @@ def main(): event.pass_id, event.batch_id, event.cost, event.metrics, result.metrics) + with open('params.pkl', 'w') as f: + cPickle.dump( + parameters, f, protocol=cPickle.HIGHEST_PROTOCOL) + else: pass diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index 2a6026bcab1c8a373d8dd5eac480dec62a8eb3b9..d8c3a73b0ea9e5df11dd0933a1c87cf13e0b5ded 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -222,6 +222,33 @@ class Parameters(object): self.__gradient_machines__.append(gradient_machine) + def __getstate__(self): + params = {} + for name in self.names(): + params[name] = self.get(name) + + param_conf = {} + for name in self.__param_conf__: + conf = self.__param_conf__[name] + assert isinstance(conf, ParameterConfig) + param_conf[name] = conf.SerializeToString() + + return {'conf': param_conf, 'params': params} + + def __setstate__(self, obj): + Parameters.__init__(self) + + def __impl__(conf, params): + for name in conf: + p = ParameterConfig() + p.ParseFromString(conf[name]) + self.__append_config__(p) + for name in params: + shape = self.get_shape(name) + self.set(name, params[name].reshape(shape)) + + __impl__(**obj) + def __get_parameter_in_gradient_machine__(gradient_machine, name): """ diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 5003f55f3e0d15149d28d1478e0487d6873d6e0a..709566ca4475f27dba2b64ee74041b9e6d7d9611 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -66,9 +66,9 @@ class SGD(ITrainer): self.__topology_in_proto__, api.CREATE_MODE_NORMAL, self.__optimizer__.enable_types()) assert isinstance(gm, api.GradientMachine) - parameters.append_gradient_machine(gm) self.__gradient_machine__ = gm self.__gradient_machine__.randParameters() + parameters.append_gradient_machine(gm) def train(self, reader, num_passes=1, event_handler=None, reader_dict=None): """