提交 e4156f66 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #1018 from jacquesqiao/remote-updater

[in progress]add RemoteUpdater in api for cluster training
...@@ -178,6 +178,7 @@ namespace std { ...@@ -178,6 +178,7 @@ namespace std {
%newobject ParameterOptimizer::create; %newobject ParameterOptimizer::create;
%newobject ParameterOptimizer::needSpecialTraversal; %newobject ParameterOptimizer::needSpecialTraversal;
%newobject ParameterUpdater::createLocalUpdater; %newobject ParameterUpdater::createLocalUpdater;
%newobject ParameterUpdater::createRemoteUpdater;
%feature("director") UpdateCallback; %feature("director") UpdateCallback;
%feature("autodoc", 1); // To generate method stub, for code hint in ide %feature("autodoc", 1); // To generate method stub, for code hint in ide
......
...@@ -803,6 +803,8 @@ private: ...@@ -803,6 +803,8 @@ private:
public: public:
static ParameterUpdater* createLocalUpdater(OptimizationConfig* config); static ParameterUpdater* createLocalUpdater(OptimizationConfig* config);
static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config,
int passCount);
~ParameterUpdater(); ~ParameterUpdater();
/** /**
......
...@@ -15,15 +15,25 @@ limitations under the License. */ ...@@ -15,15 +15,25 @@ limitations under the License. */
#include "PaddleAPI.h" #include "PaddleAPI.h"
#include "PaddleAPIPrivate.h" #include "PaddleAPIPrivate.h"
#include "paddle/trainer/RemoteParameterUpdater.h"
#include "paddle/trainer/ThreadParameterUpdater.h" #include "paddle/trainer/ThreadParameterUpdater.h"
ParameterUpdater::ParameterUpdater() : m(new ParameterUpdaterPrivate()) {} ParameterUpdater::ParameterUpdater() : m(new ParameterUpdaterPrivate()) {}
ParameterUpdater *ParameterUpdater::createLocalUpdater( ParameterUpdater *ParameterUpdater::createLocalUpdater(
OptimizationConfig *config) { OptimizationConfig *config) {
auto param = new ParameterUpdater(); auto updater = new ParameterUpdater();
param->m->updater.reset(new paddle::SgdThreadUpdater(config->m->getConfig())); updater->m->updater.reset(
return param; new paddle::SgdThreadUpdater(config->m->getConfig()));
return updater;
}
ParameterUpdater *ParameterUpdater::createRemoteUpdater(
OptimizationConfig *config, int passCount) {
auto updater = new ParameterUpdater();
updater->m->updater.reset(new paddle::RemoteParameterUpdater(
config->m->getConfig(), passCount, nullptr));
return updater;
} }
ParameterUpdater::~ParameterUpdater() { delete m; } ParameterUpdater::~ParameterUpdater() { delete m; }
......
...@@ -56,7 +56,7 @@ class RemoteParameterUpdater : public ParameterUpdater { ...@@ -56,7 +56,7 @@ class RemoteParameterUpdater : public ParameterUpdater {
public: public:
RemoteParameterUpdater( RemoteParameterUpdater(
const OptimizationConfig& config, const OptimizationConfig& config,
int expectedPpassCount, int expectedPassCount,
std::unique_ptr<ParameterUpdater>&& localUpdater = nullptr); std::unique_ptr<ParameterUpdater>&& localUpdater = nullptr);
~RemoteParameterUpdater() { ~RemoteParameterUpdater() {
if (controllerThread_) { if (controllerThread_) {
...@@ -146,7 +146,7 @@ protected: ...@@ -146,7 +146,7 @@ protected:
BatchStatus batchStatus_; BatchStatus batchStatus_;
/// controller thread for sync-sgd /// controller thread for sync-sgd
std::unique_ptr<std::thread> controllerThread_; std::unique_ptr<std::thread> controllerThread_;
/// passed alread finished /// passed already finished
int64_t passCount_; int64_t passCount_;
/// expected passes to finished /// expected passes to finished
int64_t expectedPassCount_; int64_t expectedPassCount_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册