From 680dd92bde2e4d6c2173f47d6da3263d827050e8 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 22 Dec 2016 11:31:31 +0800 Subject: [PATCH] Add AverageOptimizer, Add save parameter --- demo/mnist/api_train.py | 13 +++++++++++++ demo/mnist/simple_mnist_network.py | 7 ++++++- paddle/api/PaddleAPI.h | 6 ++++++ paddle/api/ParameterUpdater.cpp | 6 ++++++ 4 files changed, 31 insertions(+), 1 deletion(-) diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index c1439bd526..ce75d79beb 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -157,6 +157,7 @@ def main(): updater.finishBatch(cost) # testing stage. use test data set to test current network. + updater.apply() test_evaluator.start() test_data_generator = input_order_converter(read_from_mnist(test_file)) for data_batch in generator_to_batch(test_data_generator, 128): @@ -167,6 +168,18 @@ def main(): # print error rate for test data set print 'Pass', pass_id, ' test evaluator: ', test_evaluator test_evaluator.finish() + updater.restore() + + updater.catchUpWith() + params = m.getParameters() + for each_param in params: + assert isinstance(each_param, api.Parameter) + value = each_param.getBuf(api.PARAMETER_VALUE) + value = value.toNumpyArrayInplace() + + # Here, we could save parameter to every where you want + print each_param.getName(), value + updater.finishPass() m.finish() diff --git a/demo/mnist/simple_mnist_network.py b/demo/mnist/simple_mnist_network.py index 41f4e51657..f5d1ea169e 100644 --- a/demo/mnist/simple_mnist_network.py +++ b/demo/mnist/simple_mnist_network.py @@ -1,6 +1,11 @@ from paddle.trainer_config_helpers import * -settings(learning_rate=1e-4, learning_method=AdamOptimizer(), batch_size=1000) +settings( + learning_rate=1e-4, + learning_method=AdamOptimizer(), + batch_size=1000, + model_average=ModelAverage(average_window=0.5), + regularization=L2Regularization(rate=0.5)) imgs = data_layer(name='pixel', size=784) diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 413c385146..d94fd1e52e 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -809,6 +809,12 @@ public: void update(Parameter* param); + void restore(); + + void apply(); + + void catchUpWith(); + private: ParameterUpdaterPrivate* m; }; diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp index 91c8392762..7cd8ed7e39 100644 --- a/paddle/api/ParameterUpdater.cpp +++ b/paddle/api/ParameterUpdater.cpp @@ -48,3 +48,9 @@ void ParameterUpdater::update(Parameter *param) { auto paddleParam = param->m->getPtr(); m->updater->update(paddleParam); } + +void ParameterUpdater::restore() { m->updater->restore(); } + +void ParameterUpdater::apply() { m->updater->apply(); } + +void ParameterUpdater::catchUpWith() { m->updater->catchUpWith(); } -- GitLab