From 99dc60642d6b94a5ccc92be21917cfa866d6e7f8 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 7 Jun 2017 23:42:11 +0800 Subject: [PATCH] new parameterupdater use paddle pserver cclient of go --- CMakeLists.txt | 1 + .../cluster_train/remote_parameter_updater.md | 21 ++++ go/cmake/golang.cmake | 8 +- go/pserver/cclient/CMakeLists.txt | 12 +- go/pserver/cclient/test/CMakeLists.txt | 13 ++- go/pserver/cclient/test/main.c | 19 ++-- go/pserver/cclient/test/test_train.py | 60 ++++++++++ paddle/api/CMakeLists.txt | 4 +- paddle/api/Paddle.i | 1 + paddle/api/PaddleAPI.h | 2 + paddle/api/ParameterUpdater.cpp | 9 ++ paddle/trainer/CMakeLists.txt | 11 +- paddle/trainer/NewRemoteParameterUpdater.cpp | 88 +++++++++++++++ paddle/trainer/NewRemoteParameterUpdater.h | 105 ++++++++++++++++++ python/paddle/v2/optimizer.py | 15 ++- python/paddle/v2/trainer.py | 7 +- 16 files changed, 352 insertions(+), 24 deletions(-) create mode 100644 doc/design/cluster_train/remote_parameter_updater.md create mode 100644 go/pserver/cclient/test/test_train.py create mode 100644 paddle/trainer/NewRemoteParameterUpdater.cpp create mode 100644 paddle/trainer/NewRemoteParameterUpdater.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 79210d0436..c2218be5ef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -127,6 +127,7 @@ endif(WITH_GPU) add_subdirectory(proto) add_subdirectory(paddle) add_subdirectory(python) +add_subdirectory(go/pserver/cclient) if(WITH_DOC) add_subdirectory(doc) diff --git a/doc/design/cluster_train/remote_parameter_updater.md b/doc/design/cluster_train/remote_parameter_updater.md new file mode 100644 index 0000000000..6e8e593845 --- /dev/null +++ b/doc/design/cluster_train/remote_parameter_updater.md @@ -0,0 +1,21 @@ +# Design Doc: Remote Parameter Updater for Cluster Train + +For an overview of distribute training, please refer to [distributed training design doc](README.md). In this design doc, we will discuss the parameter updater that will use parameter server cclient [The Client Library of Parameter Server Design Doc](pserver_client.md) to manage and update parameters. + +## Parameter Updater + +Parameter Updater is used by trainer to manage and update parameter, there are mainly two kind of parameter updater: local and remote, since this design is for cluster train, we will only discuss remote parameter updater here. + +### Remote Parameter Updater + +Remote Parameter Updater manage parameters through remote parameter server with the client that communicate with pserver([The Client Library of Parameter Server Design Doc](pserver_client.md)) + +In PaddlePaddle Python V2 API, trainer is implemented in python, and the trainer will hold a instance of parameter updater and call it's functions directly. In this design, we will also expose the api of RemoteParameterUpdater to python with swig. + +#### Sparse Remote Parameter Updater + +Since we will only implement dense parameter management new, the mechanism for sparse parameter will be discussed in next stage. + +### Interface Design + +TBD diff --git a/go/cmake/golang.cmake b/go/cmake/golang.cmake index d38d06de23..7c85fb6298 100644 --- a/go/cmake/golang.cmake +++ b/go/cmake/golang.cmake @@ -17,7 +17,7 @@ function(GO_LIBRARY NAME BUILD_TYPE) endif() file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") - file(RELATIVE_PATH rel ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}) + file(RELATIVE_PATH rel ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}) # find Paddle directory. get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) @@ -32,12 +32,14 @@ function(GO_LIBRARY NAME BUILD_TYPE) # will use the local changes in Paddle rather than checkout Paddle # in github. add_custom_target(copyPaddle - COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH}) + COMMAND rm -rf ${PADDLE_IN_GOPATH}/Paddle + COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH}/Paddle) add_dependencies(goGet copyPaddle) add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE} - -o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}" + -gcflags=-shared -asmflags=-shared -installsuffix=_shared -a + -o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}" ${CMAKE_GO_FLAGS} ${GO_SOURCE} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/go/pserver/cclient/CMakeLists.txt b/go/pserver/cclient/CMakeLists.txt index c017d74656..e00dd6b14a 100644 --- a/go/pserver/cclient/CMakeLists.txt +++ b/go/pserver/cclient/CMakeLists.txt @@ -9,5 +9,15 @@ project(cxx_go C Go) include(golang) include(flags) -go_library(client STATIC) +go_library(paddle_pserver_cclient STATIC) + +if(PROJ_ROOT) + add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/trainer/libpaddle_pserver_cclient.a + COMMAND cp ${CMAKE_BINARY_DIR}/go/pserver/cclient/libpaddle_pserver_cclient.h ${PROJ_ROOT}/paddle/trainer/ + COMMAND cp ${CMAKE_BINARY_DIR}/go/pserver/cclient/libpaddle_pserver_cclient.a ${PROJ_ROOT}/paddle/trainer/ + WORKING_DIRECTORY ${PROJ_ROOT}/paddle + DEPENDS paddle_pserver_cclient) + add_custom_target(paddle_pserver_cclient_lib ALL DEPENDS ${PROJ_ROOT}/paddle/trainer/libpaddle_pserver_cclient.a) +endif(PROJ_ROOT) + add_subdirectory(test) diff --git a/go/pserver/cclient/test/CMakeLists.txt b/go/pserver/cclient/test/CMakeLists.txt index 16f84648c1..762772812f 100644 --- a/go/pserver/cclient/test/CMakeLists.txt +++ b/go/pserver/cclient/test/CMakeLists.txt @@ -1,11 +1,16 @@ cmake_minimum_required(VERSION 3.0) -include_directories(${CMAKE_BINARY_DIR}) - add_executable(main main.c) -add_dependencies(main client) +add_dependencies(main paddle_pserver_cclient) if(APPLE) set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security") endif() -target_link_libraries(main ${CMAKE_BINARY_DIR}/libclient.a) + +if(PROJ_ROOT) + include_directories(${CMAKE_BINARY_DIR}/go/pserver/cclient/) + target_link_libraries(main ${CMAKE_BINARY_DIR}/go/pserver/cclient/libpaddle_pserver_cclient.a pthread) +else(PROJ_ROOT) + include_directories(${CMAKE_BINARY_DIR}) + target_link_libraries(main ${CMAKE_BINARY_DIR}/libpaddle_pserver_cclient.a pthread) +endif(PROJ_ROOT) diff --git a/go/pserver/cclient/test/main.c b/go/pserver/cclient/test/main.c index f75a2110b9..0ad890daa2 100644 --- a/go/pserver/cclient/test/main.c +++ b/go/pserver/cclient/test/main.c @@ -1,6 +1,6 @@ #include -#include "libclient.h" +#include "libpaddle_pserver_cclient.h" void fail() { // TODO(helin): fix: gtest using cmake is not working, using this @@ -14,10 +14,11 @@ int main() { client c = paddle_new_pserver_client(addr, 1); retry: if (paddle_begin_init_params(c)) { + paddle_parameter param; char name_a[] = "param_a"; char name_b[] = "param_b"; - unsigned char content[] = {0x00, 0x11, 0x22}; + unsigned char content[] = {0x00, 0x00, 0x00}; param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32; param.name = name_a; param.content = content; @@ -32,6 +33,7 @@ retry: if (paddle_init_param(c, param, NULL, 0) != 0) { goto retry; } + if (paddle_finish_init_params(c) != 0) { goto retry; } @@ -41,30 +43,31 @@ retry: unsigned char content[] = {0x00, 0x11, 0x22}; paddle_gradient grads[2] = { - {"param_a", PADDLE_ELEMENT_TYPE_INT32, content, 3}, - {"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}}; + {"param_a", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}, + {"param_b", PADDLE_ELEMENT_TYPE_INT32, content, 3}}; - if (!paddle_send_grads(c, grads, 2)) { + if (paddle_send_grads(c, grads, 2) != 0) { fail(); } paddle_parameter* params[2] = {NULL, NULL}; char* names[] = {"param_a", "param_b"}; - if (!paddle_get_params(c, names, params, 2)) { + if (paddle_get_params(c, names, params, 2) != 0) { fail(); } // get parameters again by reusing the allocated parameter buffers. - if (!paddle_get_params(c, names, params, 2)) { + if (paddle_get_params(c, names, params, 2) != 0) { fail(); } paddle_release_param(params[0]); paddle_release_param(params[1]); - if (!paddle_save_model(c, "/tmp/")) { + if (paddle_save_model(c, "/tmp/") != 0) { fail(); } + printf("test success!\n"); return 0; } diff --git a/go/pserver/cclient/test/test_train.py b/go/pserver/cclient/test/test_train.py new file mode 100644 index 0000000000..ddd6371e0c --- /dev/null +++ b/go/pserver/cclient/test/test_train.py @@ -0,0 +1,60 @@ +import paddle.v2 as paddle +import paddle.v2.dataset.uci_housing as uci_housing + + +def main(): + # init + paddle.init(use_gpu=False, trainer_count=1, trainer_id=1) + + # network config + x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13)) + y_predict = paddle.layer.fc(input=x, + param_attr=paddle.attr.Param(name='w'), + size=1, + act=paddle.activation.Linear(), + bias_attr=paddle.attr.Param(name='b')) + y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1)) + cost = paddle.layer.mse_cost(input=y_predict, label=y) + + # create parameters + parameters = paddle.parameters.create(cost) + + # create optimizer + optimizer = paddle.optimizer.Momentum(momentum=0) + + trainer = paddle.trainer.SGD(cost=cost, + parameters=parameters, + update_equation=optimizer, + is_local=False, + pserver_spec="localhost:3000") + + # event_handler to print training and testing info + def event_handler(event): + if isinstance(event, paddle.event.EndIteration): + if event.batch_id % 100 == 0: + print "Pass %d, Batch %d, Cost %f" % ( + event.pass_id, event.batch_id, event.cost) + + if isinstance(event, paddle.event.EndPass): + if (event.pass_id + 1) % 10 == 0: + result = trainer.test( + reader=paddle.batch( + uci_housing.test(), batch_size=2), + feeding={'x': 0, + 'y': 1}) + print "Test %d, %.2f" % (event.pass_id, result.cost) + + # training + trainer.train( + reader=paddle.batch( + paddle.reader.shuffle( + uci_housing.train(), buf_size=500), + batch_size=2), + feeding={'x': 0, + 'y': 1}, + event_handler=event_handler, + num_passes=30) + + +if __name__ == '__main__': + main() diff --git a/paddle/api/CMakeLists.txt b/paddle/api/CMakeLists.txt index e147659566..c258a15240 100644 --- a/paddle/api/CMakeLists.txt +++ b/paddle/api/CMakeLists.txt @@ -16,7 +16,7 @@ set(API_HEADER Internal.h) add_library(paddle_api STATIC ${API_SOURCES}) -add_dependencies(paddle_api gen_proto_cpp) +add_dependencies(paddle_api gen_proto_cpp paddle_pserver_cclient_lib) INCLUDE(${SWIG_USE_FILE}) INCLUDE_DIRECTORIES(${PROJ_ROOT}/paddle) @@ -44,7 +44,7 @@ SET(SWIG_MODULE_swig_paddle_EXTRA_DEPS ) IF(APPLE) - SET(MACOS_LD_FLAGS "-undefined dynamic_lookup -Wl,-all_load") + SET(MACOS_LD_FLAGS "-undefined dynamic_lookup -Wl,-all_load -framework CoreFoundation -framework Security") ELSE(APPLE) SET(START_GROUP "-Xlinker -start-group") SET(END_GROUP "-Xlinker -end-group") diff --git a/paddle/api/Paddle.i b/paddle/api/Paddle.i index 068ba286c0..3237e73745 100644 --- a/paddle/api/Paddle.i +++ b/paddle/api/Paddle.i @@ -179,6 +179,7 @@ namespace std { %newobject ParameterOptimizer::needSpecialTraversal; %newobject ParameterUpdater::createLocalUpdater; %newobject ParameterUpdater::createRemoteUpdater; +%newobject ParameterUpdater::createNewRemoteUpdater; %feature("director") UpdateCallback; %feature("autodoc", 1); // To generate method stub, for code hint in ide diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index da0f157abd..7565ea51fe 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -841,6 +841,8 @@ public: static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config, int passCount, bool useSparseUpdater); + static ParameterUpdater* createNewRemoteUpdater( + OptimizationConfig* config, const std::string pserverSpec); ~ParameterUpdater(); /** diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp index 79921ea6e7..eaf8518ae2 100644 --- a/paddle/api/ParameterUpdater.cpp +++ b/paddle/api/ParameterUpdater.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "PaddleAPI.h" #include "PaddleAPIPrivate.h" +#include "paddle/trainer/NewRemoteParameterUpdater.h" #include "paddle/trainer/RemoteParameterUpdater.h" #include "paddle/trainer/ThreadParameterUpdater.h" @@ -28,6 +29,14 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater( return updater; } +ParameterUpdater *ParameterUpdater::createNewRemoteUpdater( + OptimizationConfig *config, const std::string pserverSpec) { + auto updater = new ParameterUpdater(); + updater->m->updater.reset(new paddle::NewRemoteParameterUpdater( + config->m->getConfig(), pserverSpec)); + return updater; +} + ParameterUpdater *ParameterUpdater::createRemoteUpdater( OptimizationConfig *config, int passCount, bool useSparseUpdater) { auto updater = new ParameterUpdater(); diff --git a/paddle/trainer/CMakeLists.txt b/paddle/trainer/CMakeLists.txt index 06c019f0a9..9d246b6690 100644 --- a/paddle/trainer/CMakeLists.txt +++ b/paddle/trainer/CMakeLists.txt @@ -4,6 +4,7 @@ set(TRAINER_SOURCES ParameterUpdater.cpp ParamUtil.cpp RemoteParameterUpdater.cpp + NewRemoteParameterUpdater.cpp Tester.cpp Trainer.cpp TrainerInternal.cpp @@ -16,6 +17,7 @@ set(TRAINER_HEADERS ParameterUpdater.h ParamUtil.h RemoteParameterUpdater.h + NewRemoteParameterUpdater.h Tester.h TesterConfig.h Trainer.h @@ -32,7 +34,7 @@ add_style_check_target(paddle_trainer_lib add_style_check_target(paddle_trainer_lib ${TRAINER_HEADERS}) add_dependencies(paddle_trainer_lib - gen_proto_cpp) + gen_proto_cpp paddle_pserver_cclient_lib) macro(add_paddle_exe TARGET_NAME) add_executable(${TARGET_NAME} ${ARGN}) @@ -56,3 +58,10 @@ install(TARGETS paddle_trainer paddle_merge_model set_target_properties(paddle_trainer PROPERTIES INSTALL_RPATH_USE_LINK_PATH TRUE) set_target_properties(paddle_merge_model PROPERTIES INSTALL_RPATH_USE_LINK_PATH TRUE) + +if(APPLE) + set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security") +endif() + +target_link_libraries(paddle_trainer ${CMAKE_CURRENT_SOURCE_DIR}/libpaddle_pserver_cclient.a) +target_link_libraries(paddle_trainer_lib ${CMAKE_CURRENT_SOURCE_DIR}/libpaddle_pserver_cclient.a) diff --git a/paddle/trainer/NewRemoteParameterUpdater.cpp b/paddle/trainer/NewRemoteParameterUpdater.cpp new file mode 100644 index 0000000000..9060052e11 --- /dev/null +++ b/paddle/trainer/NewRemoteParameterUpdater.cpp @@ -0,0 +1,88 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "NewRemoteParameterUpdater.h" +#include "Trainer.h" +#include "paddle/utils/Stat.h" + +DECLARE_int32(trainer_id); +DECLARE_string(save_dir); + +namespace paddle { +NewRemoteParameterUpdater::NewRemoteParameterUpdater( + const OptimizationConfig &config, const std::string pserverSpec) + : pserverSpec_(pserverSpec) {} + +void NewRemoteParameterUpdater::init( + const std::vector ¶meters) { + ParameterUpdater::init(parameters); + LOG(INFO) << "NewRemoteParameterUpdater init in"; + + for (auto ¶ : parameters_) { + para->getBuf(PARAMETER_VALUE)->zeroMem(); + para->getBuf(PARAMETER_GRADIENT)->zeroMem(); + } + + // create parameter server client. + parameterClient_ = + paddle_new_pserver_client((char *)pserverSpec_.c_str(), FLAGS_trainer_id); + + // init names_ for get parameter through paddle_cclient + names_ = (char **)malloc(parameterSize() * sizeof(char *)); + for (int i = 0; i < parameterSize(); ++i) { + names_[i] = (char *)parameters_[i]->getName().c_str(); + } + + // init new parameter and gradient. + initNewParameter(newParameters_, PARAMETER_VALUE); + initNewParameter(newGradients_, PARAMETER_GRADIENT); + + // init parameter, one trainer will get the opportunity to int parameter and + // send them to parameter server. Others will get the initialized parameter + // from parameter server + if (paddle_begin_init_params(parameterClient_)) { + LOG(INFO) << "paddle_begin_init_params start"; + for (int i = 0; i < parameterSize(); ++i) { + paddle_init_param(parameterClient_, *newParameters_[i], NULL, 0); + } + paddle_finish_init_params(parameterClient_); + LOG(INFO) << "paddle_begin_init_params done"; + } else { + paddle_get_params( + parameterClient_, names_, newParameters_, (int)parameters_.size()); + } + + LOG(INFO) << "NewRemoteParameterUpdater initialized"; +} + +void NewRemoteParameterUpdater::updateImpl(Parameter *para) {} + +void NewRemoteParameterUpdater::finishBatch(real cost) { + LOG(INFO) << "finishBatch in, cost: " << cost; + + // send gradient to parameter server. + paddle_send_grads(parameterClient_, *newGradients_, parameterSize()); + // get the updated parameter from parameterClient. + paddle_get_params(parameterClient_, names_, newParameters_, parameterSize()); + + // clear gradient after update parameter. + for (auto ¶ : parameters_) { + para->getBuf(PARAMETER_GRADIENT)->zeroMem(); + } +} + +void NewRemoteParameterUpdater::startPass() {} + +bool NewRemoteParameterUpdater::finishPass() { return true; } +} diff --git a/paddle/trainer/NewRemoteParameterUpdater.h b/paddle/trainer/NewRemoteParameterUpdater.h new file mode 100644 index 0000000000..33640bc8a3 --- /dev/null +++ b/paddle/trainer/NewRemoteParameterUpdater.h @@ -0,0 +1,105 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "ParameterUpdater.h" +#include "libpaddle_pserver_cclient.h" +#include "paddle/pserver/ParameterClient2.h" +#include "paddle/utils/Queue.h" +#include "paddle/utils/Util.h" + +namespace paddle { + +/** + * New remote parameter updater for dense parameters that use cclient of go. + */ +class NewRemoteParameterUpdater : public ParameterUpdater { +public: + NewRemoteParameterUpdater(const OptimizationConfig& config, + const std::string pserverSpec); + ~NewRemoteParameterUpdater() { + if (newGradients_) { + paddle_pserver_client_release(parameterClient_); + } + } + + /** + * initialize the internal parameter client and itself. + */ + virtual void init(const std::vector& parameters); + /** + * @brief start batch + * + * @note one batch training exhibits stateful feature to help + * to do performance tuning, sgd optimization if necessary. + */ + virtual PassType startBatch(int64_t batchSize) { return PASS_TRAIN; } + + /** + * send parameters to pservers and get returned parameters + * from all pservers if necessary. + */ + virtual void finishBatch(real cost); + virtual void startPass(); + virtual bool finishPass(); + + int parameterSize() { return (int)parameters_.size(); } + + /** + * init parameter of paddle pserver cclient. + * @param new_paras + * @param type + */ + void initNewParameter(paddle_parameter**& new_paras, ParameterType type) { + new_paras = + (paddle_parameter**)malloc(sizeof(paddle_parameter*) * parameterSize()); + for (int i = 0; i < parameterSize(); ++i) { + new_paras[i] = (paddle_parameter*)malloc(sizeof(paddle_parameter)); + memset(new_paras[i], 0, sizeof(paddle_parameter)); + } + + for (int i = 0; i < parameterSize(); ++i) { + ParameterPtr para = parameters_[i]; + new_paras[i]->content_len = 10; + new_paras[i]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32; + new_paras[i]->name = (char*)para->getName().c_str(); + new_paras[i]->content = + (unsigned char*)(para->getBuf(type).get()->getData()); + new_paras[i]->content_len = (int)para->getBuf(type).get()->getSize(); + } + } + +protected: + /** + * work need to do after finishBatch + */ + virtual void updateImpl(Parameter* para); + +protected: + /// internal parameter client object for exchanging data with pserver + client parameterClient_ = -1; + /// the parameters for new pserver client + paddle_parameter** newParameters_; + /// the gradinets for new pserver client + paddle_parameter** newGradients_; + /// the names for new parameters. + char** names_; + /// the specification of parameter server "host1:port,host1:port" + std::string pserverSpec_; +}; + +} // namespace paddle diff --git a/python/paddle/v2/optimizer.py b/python/paddle/v2/optimizer.py index 5e99d4a241..1ef2dceca9 100644 --- a/python/paddle/v2/optimizer.py +++ b/python/paddle/v2/optimizer.py @@ -45,7 +45,12 @@ class Optimizer(object): return swig_api.ParameterUpdater.createRemoteUpdater( self.__opt_conf__, pass_num, use_sparse_updater) - def create_updater(self, is_local, num_passes, use_sparse_updater): + def __create_new_remote_updater__(self, pserver_spec): + return swig_api.ParameterUpdater.createNewRemoteUpdater( + self.__opt_conf__, pserver_spec) + + def create_updater(self, is_local, num_passes, use_sparse_updater, + pserver_spec): """ create proper parameter_updater by configuration. :param is_local: create local or remote parameter updater @@ -64,8 +69,12 @@ class Optimizer(object): if is_local: parameter_updater = self.__create_local_updater__() else: - parameter_updater = self.__create_remote_updater__( - num_passes, use_sparse_updater) + if pserver_spec is None: + parameter_updater = self.__create_remote_updater__( + num_passes, use_sparse_updater) + else: + parameter_updater = self.__create_new_remote_updater__( + pserver_spec) return parameter_updater diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 8fdb67cc26..f9658a8c5d 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -49,7 +49,8 @@ class SGD(object): parameters, update_equation, extra_layers=None, - is_local=True): + is_local=True, + pserver_spec=None): if not isinstance(parameters, v2_parameters.Parameters): raise TypeError('parameters should be parameters') @@ -63,6 +64,7 @@ class SGD(object): self.__parameters__ = parameters self.__topology_in_proto__ = topology.proto() self.__is_local__ = is_local + self.__pserver_spec__ = pserver_spec self.__use_sparse_updater__ = self.__topology__.use_sparse_updater() # # In local mode, disable sparse_remote_update. @@ -126,7 +128,8 @@ class SGD(object): __check_train_args__(**locals()) self.__parameter_updater__ = self.__optimizer__.create_updater( - self.__is_local__, num_passes, self.__use_sparse_updater__) + self.__is_local__, num_passes, self.__use_sparse_updater__, + self.__pserver_spec__) self.__parameter_updater__.init(self.__gradient_machine__) self.__gradient_machine__.start() -- GitLab