/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/pserver/ProtoServer.h" #include "paddle/math/Matrix.h" #include "paddle/utils/Queue.h" #include "paddle/utils/TypeDefs.h" #include "ParameterService.pb.h" namespace paddle { /** * it manages all connections to pservers. * it exists two modes to manage connections to all pservers. Firstly, one * connection owns two threads that separately manage to send and receive * data. Secondly, each thread uses one connection for all activation in it. * the first solution arms with sendThreads_/recvThreads_ and sendJobQueue_/ * recvJobQueue_. the second solution use some shared thread pool to manage * connections. * In addition to pserver, metric learning also uses network to exchange * features within multi-machines, so this class just abstracts some basic * threads and queue buffer creation for them */ class BaseClient { protected: typedef std::unique_ptr ThreadPtr; typedef std::vector> InputIovs; typedef std::vector SendRequest; typedef std::vector SendDataRequestVec; // TODO(yanfei): // refine data structure to unify parameter and features communication struct SendJob { /// store parameters related blocks data InputIovs parallelInputIovs; /// store protobuf request SendRequest parallelRequests; /// store data, such as features for metric learning SendDataRequestVec parallelDataRequests; }; public: explicit BaseClient(bool separate = false, int numPorts = FLAGS_ports_num); virtual ~BaseClient(); typedef std::shared_ptr SendJobPtr; typedef Queue SendQueue; /// send data to server, support only synchronize template void putData(int clientId, SendDataType type, DataType* datas, size_t size, DataUpdateMode mode) { synchronize(SYNC_DATA); sendData(clientId, type, mode, datas, size); recvData(); synchronize(SYNC_DATA); } template void putOwnData(int clientId, SendDataType type, DataType* datas, size_t size) { putData(clientId, type, datas, size, DATA_UPDATE_MODE_SET_OWN); } template void getAllData(int clientId, SendDataType type, DataType* datas, size_t size) { sendData(clientId, type, DATA_UPDATE_MODE_GET_ALL, reinterpret_cast(NULL), 0); recvData(); size_t dataOffset = 0; for (auto& recvMem : recvDataMems_) { CHECK_LE(dataOffset, size); size_t memSize = std::min(recvMem.get()->getSize(), sizeof(DataType) * (size - dataOffset)); CHECK_EQ(memSize % sizeof(DataType), size_t(0)); memcpy(datas + dataOffset, recvMem.get()->getBuf(), memSize); dataOffset += memSize / sizeof(DataType); } CHECK_EQ(dataOffset, size); } /** * Reduces values on all clients. * This reduce just support SUM. * The results are saved in recvBuf of rootId client */ template void reduce(DataType* sendBuf, DataType* recvBuf, size_t size, int clientId, int rootId) { putOwnData(clientId, DATA_REDUCE_SUM, sendBuf, size); if (rootId == clientId) { getAllData(clientId, DATA_REDUCE_SUM, recvBuf, size); } } /** * return trans data type according to the input type */ virtual TransDataType getTransDtype(const std::type_info& info) { TransDataType dataType; if (typeid(int*) == info) { // NOLINT dataType = TRANS_INT32; } else if (typeid(uint32_t*) == info) { // NOLINT dataType = TRANS_UINT32_T; } else if (typeid(int64_t*) == info) { // NOLINT dataType = TRANS_INT64_T; } else if (typeid(uint64_t*) == info) { // NOLINT dataType = TRANS_UINT64_T; } else if (typeid(float*) == info) { // NOLINT dataType = TRANS_FLOAT; } else if (typeid(double*) == info) { // NOLINT dataType = TRANS_DOUBLE; } else { LOG(FATAL) << "not supported"; } return dataType; } protected: /// for a > 0, b > 0: /// return the smallest x s.t. b*x >= a static int divup(int a, int b) { return (a + b - 1) / b; } int calcClientId(int i, int serviceNum) { return (i + FLAGS_trainer_id * numPorts_) % serviceNum; } /// start threads in sendThreads_ and recvThreads_ void startThreads(); /// finish threads in sendThreads_ and recvThreads_ void finishThreads(); template void prepareData(int clientId, SendDataType type, DataUpdateMode updateMode, DataType* datas, size_t size, SendJob* sendJob) { sendJob->parallelDataRequests.resize(serviceNum_); sendJob->parallelInputIovs.resize(serviceNum_); for (int i = 0; i < serviceNum_; ++i) { auto& request = sendJob->parallelDataRequests[i]; request.set_update_mode(updateMode); request.set_type(type); request.set_client_id(clientId); request.set_server_id(i); } /// split datas which need send to Server into serviceNum_ pieces if (!datas) { CHECK(!size) << "ownSize should be zero since datas is nullptr"; } size_t baseSize = size / serviceNum_; size_t dataOffset = 0; for (int i = 0; i < serviceNum_; ++i) { auto& request = sendJob->parallelDataRequests[i]; DataBlock* block = request.add_blocks(); size_t ownSize = size_t(i) < size % serviceNum_ ? baseSize + 1 : baseSize; size_t realSize = datas ? std::max(ownSize, size_t(1)) : 0; block->set_total_size(realSize * sizeof(DataType)); block->set_data_size(sizeof(DataType)); // TODO(yuyang18): The getTransDtype can be rewritten as template method // to reduce runtime overhead. block->set_data_type(getTransDtype(typeid(DataType*))); // NOLINT if (datas) { sendJob->parallelInputIovs[i].push_back( {datas + dataOffset, realSize * sizeof(DataType)}); } dataOffset += ownSize; } CHECK_EQ(dataOffset, size); } /** * @brief send data to all data servers * * @note each trainer sends all its data to all data servers * it's for broadcast data synchronization, such as features * synchronization in metric learning. */ template void sendData(int clientId, SendDataType type, DataUpdateMode updateMode, DataType* datas, size_t size) { SendJobPtr sendJob = std::make_shared(); prepareData(clientId, type, updateMode, datas, size, sendJob.get()); for (int i = 0; i < threadNum_; ++i) { sendJobQueue_[i]->enqueue(sendJob); } } /** * @brief recv data from all data servers * * @note synchronize all recv threads */ void recvData(); /// send request, and recv responses template void multiCall(const char* funcName, const ProtoIn& request, std::vector* responses) { responses->resize(clients_.size()); size_t numClients = clients_.size(); for (size_t i = 0; i < numClients; ++i) { clients_[i].send(funcName, request); } for (size_t i = 0; i < numClients; ++i) { clients_[i].recv(&(*responses)[i]); } } /** * @brief synchronize all trainers and pservers * * @note used to ensure that data of all trainers have been received */ void synchronize(SyncObject syncObjectId = SYNC_DEFAULT); /** * @brief use multithread to separately send data * * @note each thread should read its own JobQueue to handle requests * each thread should calcClientId() to retrieve connections * managed by himself. * send and recv are implemented in child class. */ virtual void send(int threadId) = 0; /** * @brief use multithread to separately receive data * * @note almost same as send() */ virtual void recv(int threadId) = 0; protected: bool stopping_; /// nodes * ports that means the number of real pservers int serviceNum_; /** * threads num for managing all services. Normally the * number of pservers are relatively less than several * hundreds so that using thread-based parallelization * can benifit traffic performance and pserver's sgd * optimization performance. */ int threadNum_; /// the connection manager at client end std::vector clients_; /// send threads for parallelization std::vector sendThreads_; /// recv threads for parallelization std::vector recvThreads_; std::unique_ptr recvSyncBarrier_; // TODO(yanfei): // current pserver's will return value until all parameters' // optimization are finished so that recv are not overlapped // in reality. More robust implimentation should be to pipeline // all send/recv action based on parameter unit level, and // it will benifits deep and larger model training in future, // especially local node compution power surpasses inter-connection // such as GPU cluster, even with BOX GPU cluster. // queue for buffering send request /** * send/recv queue cooperates with each other to accomplish * overlapping communication with forwardBackward action. */ std::vector> sendJobQueue_; /// queue for buffering recv request std::vector> recvJobQueue_; /// specific for dserver SendJob sendJob_; /// port num for each node int numPorts_; /// if set, overlapped optimization is disabled bool separateSendAndRecv_; std::vector recvDataMems_; }; } // namespace paddle