diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index 5d4ef90f10d3d4faeb43f61a8c20862c2f8dbbd1..b061cfb2b8f1faa7979093af19288475b99c57ff 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -30,7 +30,12 @@ def main(): updater = api.ParameterUpdater.createLocalUpdater(opt_config) assert isinstance(updater, api.ParameterUpdater) updater.init(m) - updater.startPass() + m.start() + + for _ in xrange(100): + updater.startPass() + + m.finish() if __name__ == '__main__': diff --git a/paddle/api/GradientMachine.cpp b/paddle/api/GradientMachine.cpp index 297eaa19bb9981c7f07c90763d76494b7910af93..2cece2109795a986966d2decfdde27b2759e51cc 100644 --- a/paddle/api/GradientMachine.cpp +++ b/paddle/api/GradientMachine.cpp @@ -64,6 +64,10 @@ GradientMachine* GradientMachine::createByModelConfig( return GradientMachine::createFromPaddleModelPtr(confPtr, mode, types); } +void GradientMachine::start() { m->machine->start(); } + +void GradientMachine::finish() { m->machine->finish(); } + void GradientMachine::forward(const Arguments& inArgs, Arguments* outArgs, PassType passType) { diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index bd413eb1e9d9a945965bdf6767da82b4d631bbb5..c074325091dee95ab4a07e2f06e2abeb7a5c76bc 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -716,6 +716,13 @@ public: GradientMatchineCreateMode mode = CREATE_MODE_NORMAL, const std::vector& parameterTypes = defaultParamTypes); + /** + * @brief finish + */ + void finish(); + + void start(); + /** * The forward stage of GradientMachine. * @@ -790,6 +797,8 @@ public: void startPass(); + void finishPass(); + private: ParameterUpdaterPrivate* m; }; diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp index af5b746a7cd0825dcb6839b64e464228713efbd5..3b626c05071393740d198d9785879ea35d89647d 100644 --- a/paddle/api/ParameterUpdater.cpp +++ b/paddle/api/ParameterUpdater.cpp @@ -33,3 +33,5 @@ void ParameterUpdater::init(const GradientMachine &gm) { } void ParameterUpdater::startPass() { m->updater->startPass(); } + +void ParameterUpdater::finishPass() {}