diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index c1439bd526d8e6f2d4af97463e65a1240ef73e66..ce75d79bebe3a86563021f98833316e0ebee77cd 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -157,6 +157,7 @@ def main(): updater.finishBatch(cost) # testing stage. use test data set to test current network. + updater.apply() test_evaluator.start() test_data_generator = input_order_converter(read_from_mnist(test_file)) for data_batch in generator_to_batch(test_data_generator, 128): @@ -167,6 +168,18 @@ def main(): # print error rate for test data set print 'Pass', pass_id, ' test evaluator: ', test_evaluator test_evaluator.finish() + updater.restore() + + updater.catchUpWith() + params = m.getParameters() + for each_param in params: + assert isinstance(each_param, api.Parameter) + value = each_param.getBuf(api.PARAMETER_VALUE) + value = value.toNumpyArrayInplace() + + # Here, we could save parameter to every where you want + print each_param.getName(), value + updater.finishPass() m.finish() diff --git a/demo/mnist/simple_mnist_network.py b/demo/mnist/simple_mnist_network.py index 41f4e51657d35bf72401f4076c53b7b3bf7d5b52..f5d1ea169e784e832d95c0d9d42bd37649bc164c 100644 --- a/demo/mnist/simple_mnist_network.py +++ b/demo/mnist/simple_mnist_network.py @@ -1,6 +1,11 @@ from paddle.trainer_config_helpers import * -settings(learning_rate=1e-4, learning_method=AdamOptimizer(), batch_size=1000) +settings( + learning_rate=1e-4, + learning_method=AdamOptimizer(), + batch_size=1000, + model_average=ModelAverage(average_window=0.5), + regularization=L2Regularization(rate=0.5)) imgs = data_layer(name='pixel', size=784) diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 413c38514646211befc18a83a2d7ce70644b5183..d94fd1e52ed0367d9ee5276b1a2480260d93bce1 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -809,6 +809,12 @@ public: void update(Parameter* param); + void restore(); + + void apply(); + + void catchUpWith(); + private: ParameterUpdaterPrivate* m; }; diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp index 91c839276280804bc9decc87c245728e0893de51..7cd8ed7e3907489a60f37090df6f51492def2612 100644 --- a/paddle/api/ParameterUpdater.cpp +++ b/paddle/api/ParameterUpdater.cpp @@ -48,3 +48,9 @@ void ParameterUpdater::update(Parameter *param) { auto paddleParam = param->m->getPtr(); m->updater->update(paddleParam); } + +void ParameterUpdater::restore() { m->updater->restore(); } + +void ParameterUpdater::apply() { m->updater->apply(); } + +void ParameterUpdater::catchUpWith() { m->updater->catchUpWith(); }