提交 680dd92b 编写于 作者: Y Yu Yang

Add AverageOptimizer, Add save parameter

上级 06dc66b3
...@@ -157,6 +157,7 @@ def main(): ...@@ -157,6 +157,7 @@ def main():
updater.finishBatch(cost) updater.finishBatch(cost)
# testing stage. use test data set to test current network. # testing stage. use test data set to test current network.
updater.apply()
test_evaluator.start() test_evaluator.start()
test_data_generator = input_order_converter(read_from_mnist(test_file)) test_data_generator = input_order_converter(read_from_mnist(test_file))
for data_batch in generator_to_batch(test_data_generator, 128): for data_batch in generator_to_batch(test_data_generator, 128):
...@@ -167,6 +168,18 @@ def main(): ...@@ -167,6 +168,18 @@ def main():
# print error rate for test data set # print error rate for test data set
print 'Pass', pass_id, ' test evaluator: ', test_evaluator print 'Pass', pass_id, ' test evaluator: ', test_evaluator
test_evaluator.finish() test_evaluator.finish()
updater.restore()
updater.catchUpWith()
params = m.getParameters()
for each_param in params:
assert isinstance(each_param, api.Parameter)
value = each_param.getBuf(api.PARAMETER_VALUE)
value = value.toNumpyArrayInplace()
# Here, we could save parameter to every where you want
print each_param.getName(), value
updater.finishPass() updater.finishPass()
m.finish() m.finish()
......
from paddle.trainer_config_helpers import * from paddle.trainer_config_helpers import *
settings(learning_rate=1e-4, learning_method=AdamOptimizer(), batch_size=1000) settings(
learning_rate=1e-4,
learning_method=AdamOptimizer(),
batch_size=1000,
model_average=ModelAverage(average_window=0.5),
regularization=L2Regularization(rate=0.5))
imgs = data_layer(name='pixel', size=784) imgs = data_layer(name='pixel', size=784)
......
...@@ -809,6 +809,12 @@ public: ...@@ -809,6 +809,12 @@ public:
void update(Parameter* param); void update(Parameter* param);
void restore();
void apply();
void catchUpWith();
private: private:
ParameterUpdaterPrivate* m; ParameterUpdaterPrivate* m;
}; };
......
...@@ -48,3 +48,9 @@ void ParameterUpdater::update(Parameter *param) { ...@@ -48,3 +48,9 @@ void ParameterUpdater::update(Parameter *param) {
auto paddleParam = param->m->getPtr(); auto paddleParam = param->m->getPtr();
m->updater->update(paddleParam); m->updater->update(paddleParam);
} }
void ParameterUpdater::restore() { m->updater->restore(); }
void ParameterUpdater::apply() { m->updater->apply(); }
void ParameterUpdater::catchUpWith() { m->updater->catchUpWith(); }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册