diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index b061cfb2b8f1faa7979093af19288475b99c57ff..59043ce6c42085fe1988ceba8358b97cb01f74c7 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -35,6 +35,8 @@ def main(): for _ in xrange(100): updater.startPass() + updater.finishPass() + m.finish() diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp index 3b626c05071393740d198d9785879ea35d89647d..4edec78b4a3d42f226a155663a03f84b1840e03b 100644 --- a/paddle/api/ParameterUpdater.cpp +++ b/paddle/api/ParameterUpdater.cpp @@ -34,4 +34,4 @@ void ParameterUpdater::init(const GradientMachine &gm) { void ParameterUpdater::startPass() { m->updater->startPass(); } -void ParameterUpdater::finishPass() {} +void ParameterUpdater::finishPass() { m->updater->finishPass(); }