提交 025e3e94 编写于 作者: Y Yu Yang

Add GradientMachine::start/finish to API

上级 ad6cb60d
...@@ -30,8 +30,13 @@ def main(): ...@@ -30,8 +30,13 @@ def main():
updater = api.ParameterUpdater.createLocalUpdater(opt_config) updater = api.ParameterUpdater.createLocalUpdater(opt_config)
assert isinstance(updater, api.ParameterUpdater) assert isinstance(updater, api.ParameterUpdater)
updater.init(m) updater.init(m)
m.start()
for _ in xrange(100):
updater.startPass() updater.startPass()
m.finish()
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -64,6 +64,10 @@ GradientMachine* GradientMachine::createByModelConfig( ...@@ -64,6 +64,10 @@ GradientMachine* GradientMachine::createByModelConfig(
return GradientMachine::createFromPaddleModelPtr(confPtr, mode, types); return GradientMachine::createFromPaddleModelPtr(confPtr, mode, types);
} }
void GradientMachine::start() { m->machine->start(); }
void GradientMachine::finish() { m->machine->finish(); }
void GradientMachine::forward(const Arguments& inArgs, void GradientMachine::forward(const Arguments& inArgs,
Arguments* outArgs, Arguments* outArgs,
PassType passType) { PassType passType) {
......
...@@ -716,6 +716,13 @@ public: ...@@ -716,6 +716,13 @@ public:
GradientMatchineCreateMode mode = CREATE_MODE_NORMAL, GradientMatchineCreateMode mode = CREATE_MODE_NORMAL,
const std::vector<int>& parameterTypes = defaultParamTypes); const std::vector<int>& parameterTypes = defaultParamTypes);
/**
* @brief finish
*/
void finish();
void start();
/** /**
* The forward stage of GradientMachine. * The forward stage of GradientMachine.
* *
...@@ -790,6 +797,8 @@ public: ...@@ -790,6 +797,8 @@ public:
void startPass(); void startPass();
void finishPass();
private: private:
ParameterUpdaterPrivate* m; ParameterUpdaterPrivate* m;
}; };
......
...@@ -33,3 +33,5 @@ void ParameterUpdater::init(const GradientMachine &gm) { ...@@ -33,3 +33,5 @@ void ParameterUpdater::init(const GradientMachine &gm) {
} }
void ParameterUpdater::startPass() { m->updater->startPass(); } void ParameterUpdater::startPass() { m->updater->startPass(); }
void ParameterUpdater::finishPass() {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册