From 025e3e94d2b216cc278de103cbef27b851274bf5 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 20 Dec 2016 23:00:34 +0800 Subject: [PATCH] Add GradientMachine::start/finish to API --- demo/mnist/api_train.py | 7 ++++++- paddle/api/GradientMachine.cpp | 4 ++++ paddle/api/PaddleAPI.h | 9 +++++++++ paddle/api/ParameterUpdater.cpp | 2 ++ 4 files changed, 21 insertions(+), 1 deletion(-) diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index 5d4ef90f10d..b061cfb2b8f 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -30,7 +30,12 @@ def main(): updater = api.ParameterUpdater.createLocalUpdater(opt_config) assert isinstance(updater, api.ParameterUpdater) updater.init(m) - updater.startPass() + m.start() + + for _ in xrange(100): + updater.startPass() + + m.finish() if __name__ == '__main__': diff --git a/paddle/api/GradientMachine.cpp b/paddle/api/GradientMachine.cpp index 297eaa19bb9..2cece210979 100644 --- a/paddle/api/GradientMachine.cpp +++ b/paddle/api/GradientMachine.cpp @@ -64,6 +64,10 @@ GradientMachine* GradientMachine::createByModelConfig( return GradientMachine::createFromPaddleModelPtr(confPtr, mode, types); } +void GradientMachine::start() { m->machine->start(); } + +void GradientMachine::finish() { m->machine->finish(); } + void GradientMachine::forward(const Arguments& inArgs, Arguments* outArgs, PassType passType) { diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index bd413eb1e9d..c074325091d 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -716,6 +716,13 @@ public: GradientMatchineCreateMode mode = CREATE_MODE_NORMAL, const std::vector& parameterTypes = defaultParamTypes); + /** + * @brief finish + */ + void finish(); + + void start(); + /** * The forward stage of GradientMachine. * @@ -790,6 +797,8 @@ public: void startPass(); + void finishPass(); + private: ParameterUpdaterPrivate* m; }; diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp index af5b746a7cd..3b626c05071 100644 --- a/paddle/api/ParameterUpdater.cpp +++ b/paddle/api/ParameterUpdater.cpp @@ -33,3 +33,5 @@ void ParameterUpdater::init(const GradientMachine &gm) { } void ParameterUpdater::startPass() { m->updater->startPass(); } + +void ParameterUpdater::finishPass() {} -- GitLab