提交 b250fcea 编写于 作者: Y Yu Yang

Add save/load parameters.

上级 cdecd53b
...@@ -5,3 +5,4 @@ plot.png ...@@ -5,3 +5,4 @@ plot.png
train.log train.log
*pyc *pyc
.ipynb_checkpoints .ipynb_checkpoints
params.pkl
import paddle.v2 as paddle import paddle.v2 as paddle
import cPickle
def main(): def main():
...@@ -16,7 +17,11 @@ def main(): ...@@ -16,7 +17,11 @@ def main():
act=paddle.activation.Softmax()) act=paddle.activation.Softmax())
cost = paddle.layer.classification_cost(input=inference, label=label) cost = paddle.layer.classification_cost(input=inference, label=label)
parameters = paddle.parameters.create(cost) try:
with open('params.pkl', 'r') as f:
parameters = cPickle.load(f)
except IOError:
parameters = paddle.parameters.create(cost)
adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01) adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01)
...@@ -34,6 +39,10 @@ def main(): ...@@ -34,6 +39,10 @@ def main():
event.pass_id, event.batch_id, event.cost, event.metrics, event.pass_id, event.batch_id, event.cost, event.metrics,
result.metrics) result.metrics)
with open('params.pkl', 'w') as f:
cPickle.dump(
parameters, f, protocol=cPickle.HIGHEST_PROTOCOL)
else: else:
pass pass
......
...@@ -222,6 +222,33 @@ class Parameters(object): ...@@ -222,6 +222,33 @@ class Parameters(object):
self.__gradient_machines__.append(gradient_machine) self.__gradient_machines__.append(gradient_machine)
def __getstate__(self):
params = {}
for name in self.names():
params[name] = self.get(name)
param_conf = {}
for name in self.__param_conf__:
conf = self.__param_conf__[name]
assert isinstance(conf, ParameterConfig)
param_conf[name] = conf.SerializeToString()
return {'conf': param_conf, 'params': params}
def __setstate__(self, obj):
Parameters.__init__(self)
def __impl__(conf, params):
for name in conf:
p = ParameterConfig()
p.ParseFromString(conf[name])
self.__append_config__(p)
for name in params:
shape = self.get_shape(name)
self.set(name, params[name].reshape(shape))
__impl__(**obj)
def __get_parameter_in_gradient_machine__(gradient_machine, name): def __get_parameter_in_gradient_machine__(gradient_machine, name):
""" """
......
...@@ -66,9 +66,9 @@ class SGD(ITrainer): ...@@ -66,9 +66,9 @@ class SGD(ITrainer):
self.__topology_in_proto__, api.CREATE_MODE_NORMAL, self.__topology_in_proto__, api.CREATE_MODE_NORMAL,
self.__optimizer__.enable_types()) self.__optimizer__.enable_types())
assert isinstance(gm, api.GradientMachine) assert isinstance(gm, api.GradientMachine)
parameters.append_gradient_machine(gm)
self.__gradient_machine__ = gm self.__gradient_machine__ = gm
self.__gradient_machine__.randParameters() self.__gradient_machine__.randParameters()
parameters.append_gradient_machine(gm)
def train(self, reader, num_passes=1, event_handler=None, reader_dict=None): def train(self, reader, num_passes=1, event_handler=None, reader_dict=None):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册