提交 b250fcea 编写于 作者: Y Yu Yang

Add save/load parameters.

上级 cdecd53b
......@@ -5,3 +5,4 @@ plot.png
train.log
*pyc
.ipynb_checkpoints
params.pkl
import paddle.v2 as paddle
import cPickle
def main():
......@@ -16,6 +17,10 @@ def main():
act=paddle.activation.Softmax())
cost = paddle.layer.classification_cost(input=inference, label=label)
try:
with open('params.pkl', 'r') as f:
parameters = cPickle.load(f)
except IOError:
parameters = paddle.parameters.create(cost)
adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01)
......@@ -34,6 +39,10 @@ def main():
event.pass_id, event.batch_id, event.cost, event.metrics,
result.metrics)
with open('params.pkl', 'w') as f:
cPickle.dump(
parameters, f, protocol=cPickle.HIGHEST_PROTOCOL)
else:
pass
......
......@@ -222,6 +222,33 @@ class Parameters(object):
self.__gradient_machines__.append(gradient_machine)
def __getstate__(self):
params = {}
for name in self.names():
params[name] = self.get(name)
param_conf = {}
for name in self.__param_conf__:
conf = self.__param_conf__[name]
assert isinstance(conf, ParameterConfig)
param_conf[name] = conf.SerializeToString()
return {'conf': param_conf, 'params': params}
def __setstate__(self, obj):
Parameters.__init__(self)
def __impl__(conf, params):
for name in conf:
p = ParameterConfig()
p.ParseFromString(conf[name])
self.__append_config__(p)
for name in params:
shape = self.get_shape(name)
self.set(name, params[name].reshape(shape))
__impl__(**obj)
def __get_parameter_in_gradient_machine__(gradient_machine, name):
"""
......
......@@ -66,9 +66,9 @@ class SGD(ITrainer):
self.__topology_in_proto__, api.CREATE_MODE_NORMAL,
self.__optimizer__.enable_types())
assert isinstance(gm, api.GradientMachine)
parameters.append_gradient_machine(gm)
self.__gradient_machine__ = gm
self.__gradient_machine__.randParameters()
parameters.append_gradient_machine(gm)
def train(self, reader, num_passes=1, event_handler=None, reader_dict=None):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册