提交 99dc6064 编写于 作者: Q qiaolongfei

new parameterupdater use paddle pserver cclient of go

上级 29a0bc83
...@@ -127,6 +127,7 @@ endif(WITH_GPU) ...@@ -127,6 +127,7 @@ endif(WITH_GPU)
add_subdirectory(proto) add_subdirectory(proto)
add_subdirectory(paddle) add_subdirectory(paddle)
add_subdirectory(python) add_subdirectory(python)
add_subdirectory(go/pserver/cclient)
if(WITH_DOC) if(WITH_DOC)
add_subdirectory(doc) add_subdirectory(doc)
......
# Design Doc: Remote Parameter Updater for Cluster Train
For an overview of distribute training, please refer to [distributed training design doc](README.md). In this design doc, we will discuss the parameter updater that will use parameter server cclient [The Client Library of Parameter Server Design Doc](pserver_client.md) to manage and update parameters.
## Parameter Updater
Parameter Updater is used by trainer to manage and update parameter, there are mainly two kind of parameter updater: local and remote, since this design is for cluster train, we will only discuss remote parameter updater here.
### Remote Parameter Updater
Remote Parameter Updater manage parameters through remote parameter server with the client that communicate with pserver([The Client Library of Parameter Server Design Doc](pserver_client.md))
In PaddlePaddle Python V2 API, trainer is implemented in python, and the trainer will hold a instance of parameter updater and call it's functions directly. In this design, we will also expose the api of RemoteParameterUpdater to python with swig.
#### Sparse Remote Parameter Updater
Since we will only implement dense parameter management new, the mechanism for sparse parameter will be discussed in next stage.
### Interface Design
TBD
...@@ -17,7 +17,7 @@ function(GO_LIBRARY NAME BUILD_TYPE) ...@@ -17,7 +17,7 @@ function(GO_LIBRARY NAME BUILD_TYPE)
endif() endif()
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
file(RELATIVE_PATH rel ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}) file(RELATIVE_PATH rel ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR})
# find Paddle directory. # find Paddle directory.
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
...@@ -32,12 +32,14 @@ function(GO_LIBRARY NAME BUILD_TYPE) ...@@ -32,12 +32,14 @@ function(GO_LIBRARY NAME BUILD_TYPE)
# will use the local changes in Paddle rather than checkout Paddle # will use the local changes in Paddle rather than checkout Paddle
# in github. # in github.
add_custom_target(copyPaddle add_custom_target(copyPaddle
COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH}) COMMAND rm -rf ${PADDLE_IN_GOPATH}/Paddle
COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH}/Paddle)
add_dependencies(goGet copyPaddle) add_dependencies(goGet copyPaddle)
add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE} COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}" -gcflags=-shared -asmflags=-shared -installsuffix=_shared -a
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
${CMAKE_GO_FLAGS} ${GO_SOURCE} ${CMAKE_GO_FLAGS} ${GO_SOURCE}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
......
...@@ -9,5 +9,15 @@ project(cxx_go C Go) ...@@ -9,5 +9,15 @@ project(cxx_go C Go)
include(golang) include(golang)
include(flags) include(flags)
go_library(client STATIC) go_library(paddle_pserver_cclient STATIC)
if(PROJ_ROOT)
add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/trainer/libpaddle_pserver_cclient.a
COMMAND cp ${CMAKE_BINARY_DIR}/go/pserver/cclient/libpaddle_pserver_cclient.h ${PROJ_ROOT}/paddle/trainer/
COMMAND cp ${CMAKE_BINARY_DIR}/go/pserver/cclient/libpaddle_pserver_cclient.a ${PROJ_ROOT}/paddle/trainer/
WORKING_DIRECTORY ${PROJ_ROOT}/paddle
DEPENDS paddle_pserver_cclient)
add_custom_target(paddle_pserver_cclient_lib ALL DEPENDS ${PROJ_ROOT}/paddle/trainer/libpaddle_pserver_cclient.a)
endif(PROJ_ROOT)
add_subdirectory(test) add_subdirectory(test)
cmake_minimum_required(VERSION 3.0) cmake_minimum_required(VERSION 3.0)
include_directories(${CMAKE_BINARY_DIR})
add_executable(main main.c) add_executable(main main.c)
add_dependencies(main client) add_dependencies(main paddle_pserver_cclient)
if(APPLE) if(APPLE)
set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security") set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security")
endif() endif()
target_link_libraries(main ${CMAKE_BINARY_DIR}/libclient.a)
if(PROJ_ROOT)
include_directories(${CMAKE_BINARY_DIR}/go/pserver/cclient/)
target_link_libraries(main ${CMAKE_BINARY_DIR}/go/pserver/cclient/libpaddle_pserver_cclient.a pthread)
else(PROJ_ROOT)
include_directories(${CMAKE_BINARY_DIR})
target_link_libraries(main ${CMAKE_BINARY_DIR}/libpaddle_pserver_cclient.a pthread)
endif(PROJ_ROOT)
#include <stdio.h> #include <stdio.h>
#include "libclient.h" #include "libpaddle_pserver_cclient.h"
void fail() { void fail() {
// TODO(helin): fix: gtest using cmake is not working, using this // TODO(helin): fix: gtest using cmake is not working, using this
...@@ -14,10 +14,11 @@ int main() { ...@@ -14,10 +14,11 @@ int main() {
client c = paddle_new_pserver_client(addr, 1); client c = paddle_new_pserver_client(addr, 1);
retry: retry:
if (paddle_begin_init_params(c)) { if (paddle_begin_init_params(c)) {
paddle_parameter param; paddle_parameter param;
char name_a[] = "param_a"; char name_a[] = "param_a";
char name_b[] = "param_b"; char name_b[] = "param_b";
unsigned char content[] = {0x00, 0x11, 0x22}; unsigned char content[] = {0x00, 0x00, 0x00};
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32; param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param.name = name_a; param.name = name_a;
param.content = content; param.content = content;
...@@ -32,6 +33,7 @@ retry: ...@@ -32,6 +33,7 @@ retry:
if (paddle_init_param(c, param, NULL, 0) != 0) { if (paddle_init_param(c, param, NULL, 0) != 0) {
goto retry; goto retry;
} }
if (paddle_finish_init_params(c) != 0) { if (paddle_finish_init_params(c) != 0) {
goto retry; goto retry;
} }
...@@ -41,30 +43,31 @@ retry: ...@@ -41,30 +43,31 @@ retry:
unsigned char content[] = {0x00, 0x11, 0x22}; unsigned char content[] = {0x00, 0x11, 0x22};
paddle_gradient grads[2] = { paddle_gradient grads[2] = {
{"param_a", PADDLE_ELEMENT_TYPE_INT32, content, 3}, {"param_a", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3},
{"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}}; {"param_b", PADDLE_ELEMENT_TYPE_INT32, content, 3}};
if (!paddle_send_grads(c, grads, 2)) { if (paddle_send_grads(c, grads, 2) != 0) {
fail(); fail();
} }
paddle_parameter* params[2] = {NULL, NULL}; paddle_parameter* params[2] = {NULL, NULL};
char* names[] = {"param_a", "param_b"}; char* names[] = {"param_a", "param_b"};
if (!paddle_get_params(c, names, params, 2)) { if (paddle_get_params(c, names, params, 2) != 0) {
fail(); fail();
} }
// get parameters again by reusing the allocated parameter buffers. // get parameters again by reusing the allocated parameter buffers.
if (!paddle_get_params(c, names, params, 2)) { if (paddle_get_params(c, names, params, 2) != 0) {
fail(); fail();
} }
paddle_release_param(params[0]); paddle_release_param(params[0]);
paddle_release_param(params[1]); paddle_release_param(params[1]);
if (!paddle_save_model(c, "/tmp/")) { if (paddle_save_model(c, "/tmp/") != 0) {
fail(); fail();
} }
printf("test success!\n");
return 0; return 0;
} }
import paddle.v2 as paddle
import paddle.v2.dataset.uci_housing as uci_housing
def main():
# init
paddle.init(use_gpu=False, trainer_count=1, trainer_id=1)
# network config
x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13))
y_predict = paddle.layer.fc(input=x,
param_attr=paddle.attr.Param(name='w'),
size=1,
act=paddle.activation.Linear(),
bias_attr=paddle.attr.Param(name='b'))
y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1))
cost = paddle.layer.mse_cost(input=y_predict, label=y)
# create parameters
parameters = paddle.parameters.create(cost)
# create optimizer
optimizer = paddle.optimizer.Momentum(momentum=0)
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer,
is_local=False,
pserver_spec="localhost:3000")
# event_handler to print training and testing info
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f" % (
event.pass_id, event.batch_id, event.cost)
if isinstance(event, paddle.event.EndPass):
if (event.pass_id + 1) % 10 == 0:
result = trainer.test(
reader=paddle.batch(
uci_housing.test(), batch_size=2),
feeding={'x': 0,
'y': 1})
print "Test %d, %.2f" % (event.pass_id, result.cost)
# training
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(
uci_housing.train(), buf_size=500),
batch_size=2),
feeding={'x': 0,
'y': 1},
event_handler=event_handler,
num_passes=30)
if __name__ == '__main__':
main()
...@@ -16,7 +16,7 @@ set(API_HEADER ...@@ -16,7 +16,7 @@ set(API_HEADER
Internal.h) Internal.h)
add_library(paddle_api STATIC ${API_SOURCES}) add_library(paddle_api STATIC ${API_SOURCES})
add_dependencies(paddle_api gen_proto_cpp) add_dependencies(paddle_api gen_proto_cpp paddle_pserver_cclient_lib)
INCLUDE(${SWIG_USE_FILE}) INCLUDE(${SWIG_USE_FILE})
INCLUDE_DIRECTORIES(${PROJ_ROOT}/paddle) INCLUDE_DIRECTORIES(${PROJ_ROOT}/paddle)
...@@ -44,7 +44,7 @@ SET(SWIG_MODULE_swig_paddle_EXTRA_DEPS ...@@ -44,7 +44,7 @@ SET(SWIG_MODULE_swig_paddle_EXTRA_DEPS
) )
IF(APPLE) IF(APPLE)
SET(MACOS_LD_FLAGS "-undefined dynamic_lookup -Wl,-all_load") SET(MACOS_LD_FLAGS "-undefined dynamic_lookup -Wl,-all_load -framework CoreFoundation -framework Security")
ELSE(APPLE) ELSE(APPLE)
SET(START_GROUP "-Xlinker -start-group") SET(START_GROUP "-Xlinker -start-group")
SET(END_GROUP "-Xlinker -end-group") SET(END_GROUP "-Xlinker -end-group")
......
...@@ -179,6 +179,7 @@ namespace std { ...@@ -179,6 +179,7 @@ namespace std {
%newobject ParameterOptimizer::needSpecialTraversal; %newobject ParameterOptimizer::needSpecialTraversal;
%newobject ParameterUpdater::createLocalUpdater; %newobject ParameterUpdater::createLocalUpdater;
%newobject ParameterUpdater::createRemoteUpdater; %newobject ParameterUpdater::createRemoteUpdater;
%newobject ParameterUpdater::createNewRemoteUpdater;
%feature("director") UpdateCallback; %feature("director") UpdateCallback;
%feature("autodoc", 1); // To generate method stub, for code hint in ide %feature("autodoc", 1); // To generate method stub, for code hint in ide
......
...@@ -841,6 +841,8 @@ public: ...@@ -841,6 +841,8 @@ public:
static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config, static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config,
int passCount, int passCount,
bool useSparseUpdater); bool useSparseUpdater);
static ParameterUpdater* createNewRemoteUpdater(
OptimizationConfig* config, const std::string pserverSpec);
~ParameterUpdater(); ~ParameterUpdater();
/** /**
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "PaddleAPI.h" #include "PaddleAPI.h"
#include "PaddleAPIPrivate.h" #include "PaddleAPIPrivate.h"
#include "paddle/trainer/NewRemoteParameterUpdater.h"
#include "paddle/trainer/RemoteParameterUpdater.h" #include "paddle/trainer/RemoteParameterUpdater.h"
#include "paddle/trainer/ThreadParameterUpdater.h" #include "paddle/trainer/ThreadParameterUpdater.h"
...@@ -28,6 +29,14 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater( ...@@ -28,6 +29,14 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater(
return updater; return updater;
} }
ParameterUpdater *ParameterUpdater::createNewRemoteUpdater(
OptimizationConfig *config, const std::string pserverSpec) {
auto updater = new ParameterUpdater();
updater->m->updater.reset(new paddle::NewRemoteParameterUpdater(
config->m->getConfig(), pserverSpec));
return updater;
}
ParameterUpdater *ParameterUpdater::createRemoteUpdater( ParameterUpdater *ParameterUpdater::createRemoteUpdater(
OptimizationConfig *config, int passCount, bool useSparseUpdater) { OptimizationConfig *config, int passCount, bool useSparseUpdater) {
auto updater = new ParameterUpdater(); auto updater = new ParameterUpdater();
......
...@@ -4,6 +4,7 @@ set(TRAINER_SOURCES ...@@ -4,6 +4,7 @@ set(TRAINER_SOURCES
ParameterUpdater.cpp ParameterUpdater.cpp
ParamUtil.cpp ParamUtil.cpp
RemoteParameterUpdater.cpp RemoteParameterUpdater.cpp
NewRemoteParameterUpdater.cpp
Tester.cpp Tester.cpp
Trainer.cpp Trainer.cpp
TrainerInternal.cpp TrainerInternal.cpp
...@@ -16,6 +17,7 @@ set(TRAINER_HEADERS ...@@ -16,6 +17,7 @@ set(TRAINER_HEADERS
ParameterUpdater.h ParameterUpdater.h
ParamUtil.h ParamUtil.h
RemoteParameterUpdater.h RemoteParameterUpdater.h
NewRemoteParameterUpdater.h
Tester.h Tester.h
TesterConfig.h TesterConfig.h
Trainer.h Trainer.h
...@@ -32,7 +34,7 @@ add_style_check_target(paddle_trainer_lib ...@@ -32,7 +34,7 @@ add_style_check_target(paddle_trainer_lib
add_style_check_target(paddle_trainer_lib add_style_check_target(paddle_trainer_lib
${TRAINER_HEADERS}) ${TRAINER_HEADERS})
add_dependencies(paddle_trainer_lib add_dependencies(paddle_trainer_lib
gen_proto_cpp) gen_proto_cpp paddle_pserver_cclient_lib)
macro(add_paddle_exe TARGET_NAME) macro(add_paddle_exe TARGET_NAME)
add_executable(${TARGET_NAME} ${ARGN}) add_executable(${TARGET_NAME} ${ARGN})
...@@ -56,3 +58,10 @@ install(TARGETS paddle_trainer paddle_merge_model ...@@ -56,3 +58,10 @@ install(TARGETS paddle_trainer paddle_merge_model
set_target_properties(paddle_trainer PROPERTIES INSTALL_RPATH_USE_LINK_PATH TRUE) set_target_properties(paddle_trainer PROPERTIES INSTALL_RPATH_USE_LINK_PATH TRUE)
set_target_properties(paddle_merge_model PROPERTIES INSTALL_RPATH_USE_LINK_PATH TRUE) set_target_properties(paddle_merge_model PROPERTIES INSTALL_RPATH_USE_LINK_PATH TRUE)
if(APPLE)
set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security")
endif()
target_link_libraries(paddle_trainer ${CMAKE_CURRENT_SOURCE_DIR}/libpaddle_pserver_cclient.a)
target_link_libraries(paddle_trainer_lib ${CMAKE_CURRENT_SOURCE_DIR}/libpaddle_pserver_cclient.a)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "NewRemoteParameterUpdater.h"
#include "Trainer.h"
#include "paddle/utils/Stat.h"
DECLARE_int32(trainer_id);
DECLARE_string(save_dir);
namespace paddle {
NewRemoteParameterUpdater::NewRemoteParameterUpdater(
const OptimizationConfig &config, const std::string pserverSpec)
: pserverSpec_(pserverSpec) {}
void NewRemoteParameterUpdater::init(
const std::vector<ParameterPtr> &parameters) {
ParameterUpdater::init(parameters);
LOG(INFO) << "NewRemoteParameterUpdater init in";
for (auto &para : parameters_) {
para->getBuf(PARAMETER_VALUE)->zeroMem();
para->getBuf(PARAMETER_GRADIENT)->zeroMem();
}
// create parameter server client.
parameterClient_ =
paddle_new_pserver_client((char *)pserverSpec_.c_str(), FLAGS_trainer_id);
// init names_ for get parameter through paddle_cclient
names_ = (char **)malloc(parameterSize() * sizeof(char *));
for (int i = 0; i < parameterSize(); ++i) {
names_[i] = (char *)parameters_[i]->getName().c_str();
}
// init new parameter and gradient.
initNewParameter(newParameters_, PARAMETER_VALUE);
initNewParameter(newGradients_, PARAMETER_GRADIENT);
// init parameter, one trainer will get the opportunity to int parameter and
// send them to parameter server. Others will get the initialized parameter
// from parameter server
if (paddle_begin_init_params(parameterClient_)) {
LOG(INFO) << "paddle_begin_init_params start";
for (int i = 0; i < parameterSize(); ++i) {
paddle_init_param(parameterClient_, *newParameters_[i], NULL, 0);
}
paddle_finish_init_params(parameterClient_);
LOG(INFO) << "paddle_begin_init_params done";
} else {
paddle_get_params(
parameterClient_, names_, newParameters_, (int)parameters_.size());
}
LOG(INFO) << "NewRemoteParameterUpdater initialized";
}
void NewRemoteParameterUpdater::updateImpl(Parameter *para) {}
void NewRemoteParameterUpdater::finishBatch(real cost) {
LOG(INFO) << "finishBatch in, cost: " << cost;
// send gradient to parameter server.
paddle_send_grads(parameterClient_, *newGradients_, parameterSize());
// get the updated parameter from parameterClient.
paddle_get_params(parameterClient_, names_, newParameters_, parameterSize());
// clear gradient after update parameter.
for (auto &para : parameters_) {
para->getBuf(PARAMETER_GRADIENT)->zeroMem();
}
}
void NewRemoteParameterUpdater::startPass() {}
bool NewRemoteParameterUpdater::finishPass() { return true; }
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include <thread>
#include "ParameterUpdater.h"
#include "libpaddle_pserver_cclient.h"
#include "paddle/pserver/ParameterClient2.h"
#include "paddle/utils/Queue.h"
#include "paddle/utils/Util.h"
namespace paddle {
/**
* New remote parameter updater for dense parameters that use cclient of go.
*/
class NewRemoteParameterUpdater : public ParameterUpdater {
public:
NewRemoteParameterUpdater(const OptimizationConfig& config,
const std::string pserverSpec);
~NewRemoteParameterUpdater() {
if (newGradients_) {
paddle_pserver_client_release(parameterClient_);
}
}
/**
* initialize the internal parameter client and itself.
*/
virtual void init(const std::vector<ParameterPtr>& parameters);
/**
* @brief start batch
*
* @note one batch training exhibits stateful feature to help
* to do performance tuning, sgd optimization if necessary.
*/
virtual PassType startBatch(int64_t batchSize) { return PASS_TRAIN; }
/**
* send parameters to pservers and get returned parameters
* from all pservers if necessary.
*/
virtual void finishBatch(real cost);
virtual void startPass();
virtual bool finishPass();
int parameterSize() { return (int)parameters_.size(); }
/**
* init parameter of paddle pserver cclient.
* @param new_paras
* @param type
*/
void initNewParameter(paddle_parameter**& new_paras, ParameterType type) {
new_paras =
(paddle_parameter**)malloc(sizeof(paddle_parameter*) * parameterSize());
for (int i = 0; i < parameterSize(); ++i) {
new_paras[i] = (paddle_parameter*)malloc(sizeof(paddle_parameter));
memset(new_paras[i], 0, sizeof(paddle_parameter));
}
for (int i = 0; i < parameterSize(); ++i) {
ParameterPtr para = parameters_[i];
new_paras[i]->content_len = 10;
new_paras[i]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
new_paras[i]->name = (char*)para->getName().c_str();
new_paras[i]->content =
(unsigned char*)(para->getBuf(type).get()->getData());
new_paras[i]->content_len = (int)para->getBuf(type).get()->getSize();
}
}
protected:
/**
* work need to do after finishBatch
*/
virtual void updateImpl(Parameter* para);
protected:
/// internal parameter client object for exchanging data with pserver
client parameterClient_ = -1;
/// the parameters for new pserver client
paddle_parameter** newParameters_;
/// the gradinets for new pserver client
paddle_parameter** newGradients_;
/// the names for new parameters.
char** names_;
/// the specification of parameter server "host1:port,host1:port"
std::string pserverSpec_;
};
} // namespace paddle
...@@ -45,7 +45,12 @@ class Optimizer(object): ...@@ -45,7 +45,12 @@ class Optimizer(object):
return swig_api.ParameterUpdater.createRemoteUpdater( return swig_api.ParameterUpdater.createRemoteUpdater(
self.__opt_conf__, pass_num, use_sparse_updater) self.__opt_conf__, pass_num, use_sparse_updater)
def create_updater(self, is_local, num_passes, use_sparse_updater): def __create_new_remote_updater__(self, pserver_spec):
return swig_api.ParameterUpdater.createNewRemoteUpdater(
self.__opt_conf__, pserver_spec)
def create_updater(self, is_local, num_passes, use_sparse_updater,
pserver_spec):
""" """
create proper parameter_updater by configuration. create proper parameter_updater by configuration.
:param is_local: create local or remote parameter updater :param is_local: create local or remote parameter updater
...@@ -64,8 +69,12 @@ class Optimizer(object): ...@@ -64,8 +69,12 @@ class Optimizer(object):
if is_local: if is_local:
parameter_updater = self.__create_local_updater__() parameter_updater = self.__create_local_updater__()
else: else:
parameter_updater = self.__create_remote_updater__( if pserver_spec is None:
num_passes, use_sparse_updater) parameter_updater = self.__create_remote_updater__(
num_passes, use_sparse_updater)
else:
parameter_updater = self.__create_new_remote_updater__(
pserver_spec)
return parameter_updater return parameter_updater
......
...@@ -49,7 +49,8 @@ class SGD(object): ...@@ -49,7 +49,8 @@ class SGD(object):
parameters, parameters,
update_equation, update_equation,
extra_layers=None, extra_layers=None,
is_local=True): is_local=True,
pserver_spec=None):
if not isinstance(parameters, v2_parameters.Parameters): if not isinstance(parameters, v2_parameters.Parameters):
raise TypeError('parameters should be parameters') raise TypeError('parameters should be parameters')
...@@ -63,6 +64,7 @@ class SGD(object): ...@@ -63,6 +64,7 @@ class SGD(object):
self.__parameters__ = parameters self.__parameters__ = parameters
self.__topology_in_proto__ = topology.proto() self.__topology_in_proto__ = topology.proto()
self.__is_local__ = is_local self.__is_local__ = is_local
self.__pserver_spec__ = pserver_spec
self.__use_sparse_updater__ = self.__topology__.use_sparse_updater() self.__use_sparse_updater__ = self.__topology__.use_sparse_updater()
# # In local mode, disable sparse_remote_update. # # In local mode, disable sparse_remote_update.
...@@ -126,7 +128,8 @@ class SGD(object): ...@@ -126,7 +128,8 @@ class SGD(object):
__check_train_args__(**locals()) __check_train_args__(**locals())
self.__parameter_updater__ = self.__optimizer__.create_updater( self.__parameter_updater__ = self.__optimizer__.create_updater(
self.__is_local__, num_passes, self.__use_sparse_updater__) self.__is_local__, num_passes, self.__use_sparse_updater__,
self.__pserver_spec__)
self.__parameter_updater__.init(self.__gradient_machine__) self.__parameter_updater__.init(self.__gradient_machine__)
self.__gradient_machine__.start() self.__gradient_machine__.start()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册