提交 07b62fac 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #432 from pengli09/fix-model-saving-bug

fix a bug in saving model (invoke the wrong function).
...@@ -214,7 +214,7 @@ def event_handler_plot(event): ...@@ -214,7 +214,7 @@ def event_handler_plot(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
if event.pass_id % 10 == 0: if event.pass_id % 10 == 0:
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
``` ```
### 开始训练 ### 开始训练
......
...@@ -220,7 +220,7 @@ def event_handler_plot(event): ...@@ -220,7 +220,7 @@ def event_handler_plot(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
if event.pass_id % 10 == 0: if event.pass_id % 10 == 0:
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
``` ```
### Start Training ### Start Training
......
...@@ -256,7 +256,7 @@ def event_handler_plot(event): ...@@ -256,7 +256,7 @@ def event_handler_plot(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
if event.pass_id % 10 == 0: if event.pass_id % 10 == 0:
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
``` ```
### 开始训练 ### 开始训练
......
...@@ -262,7 +262,7 @@ def event_handler_plot(event): ...@@ -262,7 +262,7 @@ def event_handler_plot(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
if event.pass_id % 10 == 0: if event.pass_id % 10 == 0:
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
``` ```
### Start Training ### Start Training
......
...@@ -41,7 +41,7 @@ def main(): ...@@ -41,7 +41,7 @@ def main():
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
if event.pass_id % 10 == 0: if event.pass_id % 10 == 0:
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test( result = trainer.test(
reader=paddle.batch(uci_housing.test(), batch_size=2), reader=paddle.batch(uci_housing.test(), batch_size=2),
feeding=feeding) feeding=feeding)
......
...@@ -256,7 +256,7 @@ def event_handler_plot(event): ...@@ -256,7 +256,7 @@ def event_handler_plot(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=paddle.batch( result = trainer.test(reader=paddle.batch(
paddle.dataset.mnist.test(), batch_size=128)) paddle.dataset.mnist.test(), batch_size=128))
...@@ -275,7 +275,7 @@ def event_handler(event): ...@@ -275,7 +275,7 @@ def event_handler(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=paddle.batch( result = trainer.test(reader=paddle.batch(
paddle.dataset.mnist.test(), batch_size=128)) paddle.dataset.mnist.test(), batch_size=128))
......
...@@ -249,7 +249,7 @@ def event_handler_plot(event): ...@@ -249,7 +249,7 @@ def event_handler_plot(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=paddle.batch( result = trainer.test(reader=paddle.batch(
paddle.dataset.mnist.test(), batch_size=128)) paddle.dataset.mnist.test(), batch_size=128))
...@@ -270,7 +270,7 @@ def event_handler(event): ...@@ -270,7 +270,7 @@ def event_handler(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=paddle.batch( result = trainer.test(reader=paddle.batch(
paddle.dataset.mnist.test(), batch_size=128)) paddle.dataset.mnist.test(), batch_size=128))
......
...@@ -298,7 +298,7 @@ def event_handler_plot(event): ...@@ -298,7 +298,7 @@ def event_handler_plot(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=paddle.batch( result = trainer.test(reader=paddle.batch(
paddle.dataset.mnist.test(), batch_size=128)) paddle.dataset.mnist.test(), batch_size=128))
...@@ -317,7 +317,7 @@ def event_handler(event): ...@@ -317,7 +317,7 @@ def event_handler(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=paddle.batch( result = trainer.test(reader=paddle.batch(
paddle.dataset.mnist.test(), batch_size=128)) paddle.dataset.mnist.test(), batch_size=128))
......
...@@ -291,7 +291,7 @@ def event_handler_plot(event): ...@@ -291,7 +291,7 @@ def event_handler_plot(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=paddle.batch( result = trainer.test(reader=paddle.batch(
paddle.dataset.mnist.test(), batch_size=128)) paddle.dataset.mnist.test(), batch_size=128))
...@@ -312,7 +312,7 @@ def event_handler(event): ...@@ -312,7 +312,7 @@ def event_handler(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=paddle.batch( result = trainer.test(reader=paddle.batch(
paddle.dataset.mnist.test(), batch_size=128)) paddle.dataset.mnist.test(), batch_size=128))
......
...@@ -87,7 +87,7 @@ def main(): ...@@ -87,7 +87,7 @@ def main():
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=paddle.batch( result = trainer.test(reader=paddle.batch(
paddle.dataset.mnist.test(), batch_size=128)) paddle.dataset.mnist.test(), batch_size=128))
......
...@@ -432,7 +432,7 @@ def event_handler(event): ...@@ -432,7 +432,7 @@ def event_handler(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test( result = trainer.test(
reader=paddle.batch( reader=paddle.batch(
......
...@@ -438,7 +438,7 @@ def event_handler(event): ...@@ -438,7 +438,7 @@ def event_handler(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test( result = trainer.test(
reader=paddle.batch( reader=paddle.batch(
......
...@@ -474,7 +474,7 @@ def event_handler(event): ...@@ -474,7 +474,7 @@ def event_handler(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test( result = trainer.test(
reader=paddle.batch( reader=paddle.batch(
......
...@@ -480,7 +480,7 @@ def event_handler(event): ...@@ -480,7 +480,7 @@ def event_handler(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test( result = trainer.test(
reader=paddle.batch( reader=paddle.batch(
......
...@@ -57,6 +57,10 @@ def main(): ...@@ -57,6 +57,10 @@ def main():
learning_rate_decay_b=50000 * 100, learning_rate_decay_b=50000 * 100,
learning_rate_schedule='discexp') learning_rate_schedule='discexp')
# Create trainer
trainer = paddle.trainer.SGD(
cost=cost, parameters=parameters, update_equation=momentum_optimizer)
# End batch and end pass event handler # End batch and end pass event handler
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
...@@ -69,7 +73,7 @@ def main(): ...@@ -69,7 +73,7 @@ def main():
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test( result = trainer.test(
reader=paddle.batch( reader=paddle.batch(
...@@ -78,10 +82,6 @@ def main(): ...@@ -78,10 +82,6 @@ def main():
'label': 1}) 'label': 1})
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
# Create trainer
trainer = paddle.trainer.SGD(
cost=cost, parameters=parameters, update_equation=momentum_optimizer)
# Save the inference topology to protobuf. # Save the inference topology to protobuf.
inference_topology = paddle.topology.Topology(layers=out) inference_topology = paddle.topology.Topology(layers=out)
with open("inference_topology.pkl", 'wb') as f: with open("inference_topology.pkl", 'wb') as f:
......
...@@ -336,7 +336,7 @@ def event_handler(event): ...@@ -336,7 +336,7 @@ def event_handler(event):
paddle.dataset.imikolov.test(word_dict, N), 32)) paddle.dataset.imikolov.test(word_dict, N), 32))
print "Pass %d, Testing metrics %s" % (event.pass_id, result.metrics) print "Pass %d, Testing metrics %s" % (event.pass_id, result.metrics)
with open("model_%d.tar"%event.pass_id, 'w') as f: with open("model_%d.tar"%event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
trainer.train( trainer.train(
paddle.batch(paddle.dataset.imikolov.train(word_dict, N), 32), paddle.batch(paddle.dataset.imikolov.train(word_dict, N), 32),
......
...@@ -348,7 +348,7 @@ def event_handler(event): ...@@ -348,7 +348,7 @@ def event_handler(event):
paddle.dataset.imikolov.test(word_dict, N), 32)) paddle.dataset.imikolov.test(word_dict, N), 32))
print "Pass %d, Testing metrics %s" % (event.pass_id, result.metrics) print "Pass %d, Testing metrics %s" % (event.pass_id, result.metrics)
with open("model_%d.tar"%event.pass_id, 'w') as f: with open("model_%d.tar"%event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
trainer.train( trainer.train(
paddle.batch(paddle.dataset.imikolov.train(word_dict, N), 32), paddle.batch(paddle.dataset.imikolov.train(word_dict, N), 32),
......
...@@ -378,7 +378,7 @@ def event_handler(event): ...@@ -378,7 +378,7 @@ def event_handler(event):
paddle.dataset.imikolov.test(word_dict, N), 32)) paddle.dataset.imikolov.test(word_dict, N), 32))
print "Pass %d, Testing metrics %s" % (event.pass_id, result.metrics) print "Pass %d, Testing metrics %s" % (event.pass_id, result.metrics)
with open("model_%d.tar"%event.pass_id, 'w') as f: with open("model_%d.tar"%event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
trainer.train( trainer.train(
paddle.batch(paddle.dataset.imikolov.train(word_dict, N), 32), paddle.batch(paddle.dataset.imikolov.train(word_dict, N), 32),
......
...@@ -390,7 +390,7 @@ def event_handler(event): ...@@ -390,7 +390,7 @@ def event_handler(event):
paddle.dataset.imikolov.test(word_dict, N), 32)) paddle.dataset.imikolov.test(word_dict, N), 32))
print "Pass %d, Testing metrics %s" % (event.pass_id, result.metrics) print "Pass %d, Testing metrics %s" % (event.pass_id, result.metrics)
with open("model_%d.tar"%event.pass_id, 'w') as f: with open("model_%d.tar"%event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
trainer.train( trainer.train(
paddle.batch(paddle.dataset.imikolov.train(word_dict, N), 32), paddle.batch(paddle.dataset.imikolov.train(word_dict, N), 32),
......
...@@ -76,6 +76,13 @@ def main(): ...@@ -76,6 +76,13 @@ def main():
bias_attr=paddle.attr.Param(learning_rate=2), bias_attr=paddle.attr.Param(learning_rate=2),
act=paddle.activation.Softmax()) act=paddle.activation.Softmax())
cost = paddle.layer.classification_cost(input=predictword, label=nextword)
parameters = paddle.parameters.create(cost)
adagrad = paddle.optimizer.AdaGrad(
learning_rate=3e-3,
regularization=paddle.optimizer.L2Regularization(8e-4))
trainer = paddle.trainer.SGD(cost, parameters, adagrad)
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0: if event.batch_id % 100 == 0:
...@@ -88,14 +95,8 @@ def main(): ...@@ -88,14 +95,8 @@ def main():
print "Pass %d, Testing metrics %s" % (event.pass_id, print "Pass %d, Testing metrics %s" % (event.pass_id,
result.metrics) result.metrics)
with open("model_%d.tar" % event.pass_id, 'w') as f: with open("model_%d.tar" % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
cost = paddle.layer.classification_cost(input=predictword, label=nextword)
parameters = paddle.parameters.create(cost)
adagrad = paddle.optimizer.AdaGrad(
learning_rate=3e-3,
regularization=paddle.optimizer.L2Regularization(8e-4))
trainer = paddle.trainer.SGD(cost, parameters, adagrad)
trainer.train( trainer.train(
paddle.batch(paddle.dataset.imikolov.train(word_dict, N), 32), paddle.batch(paddle.dataset.imikolov.train(word_dict, N), 32),
num_passes=100, num_passes=100,
......
...@@ -291,7 +291,7 @@ Paddle中提供了一系列优化算法的API,这里使用Adam优化算法。 ...@@ -291,7 +291,7 @@ Paddle中提供了一系列优化算法的API,这里使用Adam优化算法。
sys.stdout.flush() sys.stdout.flush()
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
with open('./params_pass_%d.tar' % event.pass_id, 'w') as f: with open('./params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=test_reader, feeding=feeding) result = trainer.test(reader=test_reader, feeding=feeding)
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
......
...@@ -310,7 +310,7 @@ def event_handler(event): ...@@ -310,7 +310,7 @@ def event_handler(event):
sys.stdout.flush() sys.stdout.flush()
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
with open('./params_pass_%d.tar' % event.pass_id, 'w') as f: with open('./params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=test_reader, feeding=feeding) result = trainer.test(reader=test_reader, feeding=feeding)
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
......
...@@ -333,7 +333,7 @@ Paddle中提供了一系列优化算法的API,这里使用Adam优化算法。 ...@@ -333,7 +333,7 @@ Paddle中提供了一系列优化算法的API,这里使用Adam优化算法。
sys.stdout.flush() sys.stdout.flush()
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
with open('./params_pass_%d.tar' % event.pass_id, 'w') as f: with open('./params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=test_reader, feeding=feeding) result = trainer.test(reader=test_reader, feeding=feeding)
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
......
...@@ -352,7 +352,7 @@ def event_handler(event): ...@@ -352,7 +352,7 @@ def event_handler(event):
sys.stdout.flush() sys.stdout.flush()
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
with open('./params_pass_%d.tar' % event.pass_id, 'w') as f: with open('./params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=test_reader, feeding=feeding) result = trainer.test(reader=test_reader, feeding=feeding)
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
......
...@@ -129,6 +129,10 @@ if __name__ == '__main__': ...@@ -129,6 +129,10 @@ if __name__ == '__main__':
regularization=paddle.optimizer.L2Regularization(rate=8e-4), regularization=paddle.optimizer.L2Regularization(rate=8e-4),
model_average=paddle.optimizer.ModelAverage(average_window=0.5)) model_average=paddle.optimizer.ModelAverage(average_window=0.5))
# create trainer
trainer = paddle.trainer.SGD(
cost=cost, parameters=parameters, update_equation=adam_optimizer)
# End batch and end pass event handler # End batch and end pass event handler
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
...@@ -140,14 +144,11 @@ if __name__ == '__main__': ...@@ -140,14 +144,11 @@ if __name__ == '__main__':
sys.stdout.flush() sys.stdout.flush()
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
with open('./params_pass_%d.tar' % event.pass_id, 'w') as f: with open('./params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=test_reader, feeding=feeding) result = trainer.test(reader=test_reader, feeding=feeding)
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
# create trainer
trainer = paddle.trainer.SGD(
cost=cost, parameters=parameters, update_equation=adam_optimizer)
# Save the inference topology to protobuf. # Save the inference topology to protobuf.
inference_topology = paddle.topology.Topology(layers=output) inference_topology = paddle.topology.Topology(layers=output)
with open("./inference_topology.pkl", 'wb') as f: with open("./inference_topology.pkl", 'wb') as f:
......
...@@ -448,7 +448,7 @@ def event_handler(event): ...@@ -448,7 +448,7 @@ def event_handler(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=reader, feeding=feeding) result = trainer.test(reader=reader, feeding=feeding)
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
......
...@@ -466,7 +466,7 @@ def event_handler(event): ...@@ -466,7 +466,7 @@ def event_handler(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=reader, feeding=feeding) result = trainer.test(reader=reader, feeding=feeding)
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
......
...@@ -490,7 +490,7 @@ def event_handler(event): ...@@ -490,7 +490,7 @@ def event_handler(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=reader, feeding=feeding) result = trainer.test(reader=reader, feeding=feeding)
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
......
...@@ -508,7 +508,7 @@ def event_handler(event): ...@@ -508,7 +508,7 @@ def event_handler(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=reader, feeding=feeding) result = trainer.test(reader=reader, feeding=feeding)
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
......
...@@ -188,7 +188,7 @@ def main(): ...@@ -188,7 +188,7 @@ def main():
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=test_reader, feeding=feeding) result = trainer.test(reader=test_reader, feeding=feeding)
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
......
...@@ -5,9 +5,9 @@ import paddle.v2 as paddle ...@@ -5,9 +5,9 @@ import paddle.v2 as paddle
with_gpu = os.getenv('WITH_GPU', '0') != '0' with_gpu = os.getenv('WITH_GPU', '0') != '0'
def save_model(parameters, save_path): def save_model(trainer, parameters, save_path):
with open(save_path, 'w') as f: with open(save_path, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
def seq_to_seq_net(source_dict_dim, def seq_to_seq_net(source_dict_dim,
...@@ -175,12 +175,12 @@ def main(): ...@@ -175,12 +175,12 @@ def main():
if not event.batch_id % 10: if not event.batch_id % 10:
save_path = 'params_pass_%05d_batch_%05d.tar' % ( save_path = 'params_pass_%05d_batch_%05d.tar' % (
event.pass_id, event.batch_id) event.pass_id, event.batch_id)
save_model(parameters, save_path) save_model(trainer, parameters, save_path)
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
save_path = 'params_pass_%05d.tar' % (event.pass_id) save_path = 'params_pass_%05d.tar' % (event.pass_id)
save_model(parameters, save_path) save_model(trainer, parameters, save_path)
# start to train # start to train
trainer.train( trainer.train(
......
...@@ -61,12 +61,12 @@ PaddlePaddle stores the topology and parameter separately. ...@@ -61,12 +61,12 @@ PaddlePaddle stores the topology and parameter separately.
inference_topology.serialize_for_inference(f) inference_topology.serialize_for_inference(f)
``` ```
2. To save a parameter, we need to invoke `to_tar` method in Parameter 2. To save a parameter, we need to invoke `save_parameter_to_tar` method of
class. `trainer`.
```python ```python
with open('param.tar', 'w') as f: with open('param.tar', 'w') as f:
params.to_tar(f) trainer.save_parameter_to_tar(f)
``` ```
After serializing the parameter and topology into two files, we could After serializing the parameter and topology into two files, we could
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册