提交 e4156f66 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #1018 from jacquesqiao/remote-updater

[in progress]add RemoteUpdater in api for cluster training
......@@ -178,6 +178,7 @@ namespace std {
%newobject ParameterOptimizer::create;
%newobject ParameterOptimizer::needSpecialTraversal;
%newobject ParameterUpdater::createLocalUpdater;
%newobject ParameterUpdater::createRemoteUpdater;
%feature("director") UpdateCallback;
%feature("autodoc", 1); // To generate method stub, for code hint in ide
......
......@@ -803,6 +803,8 @@ private:
public:
static ParameterUpdater* createLocalUpdater(OptimizationConfig* config);
static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config,
int passCount);
~ParameterUpdater();
/**
......
......@@ -15,15 +15,25 @@ limitations under the License. */
#include "PaddleAPI.h"
#include "PaddleAPIPrivate.h"
#include "paddle/trainer/RemoteParameterUpdater.h"
#include "paddle/trainer/ThreadParameterUpdater.h"
ParameterUpdater::ParameterUpdater() : m(new ParameterUpdaterPrivate()) {}
ParameterUpdater *ParameterUpdater::createLocalUpdater(
OptimizationConfig *config) {
auto param = new ParameterUpdater();
param->m->updater.reset(new paddle::SgdThreadUpdater(config->m->getConfig()));
return param;
auto updater = new ParameterUpdater();
updater->m->updater.reset(
new paddle::SgdThreadUpdater(config->m->getConfig()));
return updater;
}
ParameterUpdater *ParameterUpdater::createRemoteUpdater(
OptimizationConfig *config, int passCount) {
auto updater = new ParameterUpdater();
updater->m->updater.reset(new paddle::RemoteParameterUpdater(
config->m->getConfig(), passCount, nullptr));
return updater;
}
ParameterUpdater::~ParameterUpdater() { delete m; }
......
......@@ -56,7 +56,7 @@ class RemoteParameterUpdater : public ParameterUpdater {
public:
RemoteParameterUpdater(
const OptimizationConfig& config,
int expectedPpassCount,
int expectedPassCount,
std::unique_ptr<ParameterUpdater>&& localUpdater = nullptr);
~RemoteParameterUpdater() {
if (controllerThread_) {
......@@ -146,7 +146,7 @@ protected:
BatchStatus batchStatus_;
/// controller thread for sync-sgd
std::unique_ptr<std::thread> controllerThread_;
/// passed alread finished
/// passed already finished
int64_t passCount_;
/// expected passes to finished
int64_t expectedPassCount_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册