提交 025e3e94 编写于 作者: Y Yu Yang

Add GradientMachine::start/finish to API

上级 ad6cb60d
......@@ -30,8 +30,13 @@ def main():
updater = api.ParameterUpdater.createLocalUpdater(opt_config)
assert isinstance(updater, api.ParameterUpdater)
updater.init(m)
m.start()
for _ in xrange(100):
updater.startPass()
m.finish()
if __name__ == '__main__':
main()
......@@ -64,6 +64,10 @@ GradientMachine* GradientMachine::createByModelConfig(
return GradientMachine::createFromPaddleModelPtr(confPtr, mode, types);
}
void GradientMachine::start() { m->machine->start(); }
void GradientMachine::finish() { m->machine->finish(); }
void GradientMachine::forward(const Arguments& inArgs,
Arguments* outArgs,
PassType passType) {
......
......@@ -716,6 +716,13 @@ public:
GradientMatchineCreateMode mode = CREATE_MODE_NORMAL,
const std::vector<int>& parameterTypes = defaultParamTypes);
/**
* @brief finish
*/
void finish();
void start();
/**
* The forward stage of GradientMachine.
*
......@@ -790,6 +797,8 @@ public:
void startPass();
void finishPass();
private:
ParameterUpdaterPrivate* m;
};
......
......@@ -33,3 +33,5 @@ void ParameterUpdater::init(const GradientMachine &gm) {
}
void ParameterUpdater::startPass() { m->updater->startPass(); }
void ParameterUpdater::finishPass() {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册