BaseClient.h 10.0 KB
Newer Older
Z
zhangjinchao01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/pserver/ProtoServer.h"
#include "paddle/math/Matrix.h"
#include "paddle/utils/Queue.h"
#include "paddle/utils/TypeDefs.h"
#include "ParameterService.pb.h"

namespace paddle {

/**
 * it manages all connections to pservers.
 * it exists two modes to manage connections to all pservers. Firstly, one
 * connection owns two threads that separately manage to send and receive
 * data. Secondly, each thread uses one connection for all activation in it.
 * the first solution arms with sendThreads_/recvThreads_ and sendJobQueue_/
 * recvJobQueue_. the second solution use some shared thread pool to manage
 * connections.
 * In addition to pserver, metric learning also uses network to exchange
 * features within multi-machines, so this class just abstracts some basic
 * threads and queue buffer creation for them
 */
class BaseClient {
protected:
  typedef std::unique_ptr<std::thread> ThreadPtr;
  typedef std::vector<std::vector<iovec>> InputIovs;
  typedef std::vector<SendParameterRequest> SendRequest;
  typedef std::vector<SendDataRequest> SendDataRequestVec;

  // TODO(yanfei):
  // refine data structure to unify parameter and features communication
  struct SendJob {
    /// store parameters related blocks data
    InputIovs parallelInputIovs;
    /// store protobuf request
    SendRequest parallelRequests;
    /// store data, such as features for metric learning
    SendDataRequestVec parallelDataRequests;
  };

public:
  explicit BaseClient(bool separate = false, int numPorts = FLAGS_ports_num);

  virtual ~BaseClient();

  typedef std::shared_ptr<SendJob> SendJobPtr;
  typedef Queue<SendJobPtr> SendQueue;

  /// send data to server, support only synchronize
  template <class DataType>
  void putData(int clientId, SendDataType type, DataType* datas, size_t size,
               DataUpdateMode mode) {
    synchronize(SYNC_DATA);
    sendData(clientId, type, mode, datas, size);
    recvData();
    synchronize(SYNC_DATA);
  }

  template <class DataType>
  void putOwnData(int clientId, SendDataType type, DataType* datas,
                  size_t size) {
    putData(clientId, type, datas, size, DATA_UPDATE_MODE_SET_OWN);
  }

  template <class DataType>
  void getAllData(int clientId, SendDataType type, DataType* datas,
                  size_t size) {
    sendData(clientId, type, DATA_UPDATE_MODE_GET_ALL,
             reinterpret_cast<DataType*>(NULL), 0);
    recvData();
    size_t dataOffset = 0;
    for (auto& recvMem : recvDataMems_) {
      CHECK_LE(dataOffset, size);
      size_t memSize = std::min(recvMem.get()->getSize(),
                                sizeof(DataType) * (size - dataOffset));
      CHECK_EQ(memSize % sizeof(DataType), size_t(0));
      memcpy(datas + dataOffset, recvMem.get()->getBuf(), memSize);
      dataOffset += memSize / sizeof(DataType);
    }
    CHECK_EQ(dataOffset, size);
  }

  /**
   * Reduces values on all clients.
   * This reduce just support SUM.
   * The results are saved in recvBuf of rootId client
   */
  template <class DataType>
  void reduce(DataType* sendBuf, DataType* recvBuf, size_t size, int clientId,
              int rootId) {
    putOwnData(clientId, DATA_REDUCE_SUM, sendBuf, size);
    if (rootId == clientId) {
      getAllData(clientId, DATA_REDUCE_SUM, recvBuf, size);
    }
  }

  /**
   * return trans data type according to the input type
   */
  virtual TransDataType getTransDtype(const std::type_info& info) {
    TransDataType dataType;
    if (typeid(int*) == info) {  // NOLINT
      dataType = TRANS_INT32;
    } else if (typeid(uint32_t*) == info) {  // NOLINT
      dataType = TRANS_UINT32_T;
    } else if (typeid(int64_t*) == info) {  // NOLINT
      dataType = TRANS_INT64_T;
    } else if (typeid(uint64_t*) == info) {  // NOLINT
      dataType = TRANS_UINT64_T;
    } else if (typeid(float*) == info) {  // NOLINT
      dataType = TRANS_FLOAT;
    } else if (typeid(double*) == info) {  // NOLINT
      dataType = TRANS_DOUBLE;
    } else {
      LOG(FATAL) << "not supported";
    }
    return dataType;
  }

protected:
  /// for a > 0, b > 0:
  /// return the smallest x s.t. b*x >= a
  static int divup(int a, int b) { return (a + b - 1) / b; }

  int calcClientId(int i, int serviceNum) {
    return (i + FLAGS_trainer_id * numPorts_) % serviceNum;
  }

  /// start threads in sendThreads_ and recvThreads_
  void startThreads();

  /// finish threads in sendThreads_ and recvThreads_
  void finishThreads();

  template <class DataType>
  void prepareData(int clientId, SendDataType type, DataUpdateMode updateMode,
                   DataType* datas, size_t size, SendJob* sendJob) {
    sendJob->parallelDataRequests.resize(serviceNum_);
    sendJob->parallelInputIovs.resize(serviceNum_);
    for (int i = 0; i < serviceNum_; ++i) {
      auto& request = sendJob->parallelDataRequests[i];
      request.set_update_mode(updateMode);
      request.set_type(type);
      request.set_client_id(clientId);
      request.set_server_id(i);
    }

    /// split datas which need send to Server into serviceNum_ pieces
    if (!datas) {
      CHECK(!size) << "ownSize should be zero since datas is nullptr";
    }
    size_t baseSize = size / serviceNum_;
    size_t dataOffset = 0;
    for (int i = 0; i < serviceNum_; ++i) {
      auto& request = sendJob->parallelDataRequests[i];
      DataBlock* block = request.add_blocks();
      size_t ownSize = size_t(i) < size % serviceNum_ ? baseSize + 1 : baseSize;
      size_t realSize = datas ? std::max(ownSize, size_t(1)) : 0;
      block->set_total_size(realSize * sizeof(DataType));
      block->set_data_size(sizeof(DataType));
      // TODO(yuyang18): The getTransDtype can be rewritten as template method
      //                 to reduce runtime overhead.
      block->set_data_type(getTransDtype(typeid(DataType*)));  // NOLINT
      if (datas) {
        sendJob->parallelInputIovs[i].push_back(
            {datas + dataOffset, realSize * sizeof(DataType)});
      }
      dataOffset += ownSize;
    }
    CHECK_EQ(dataOffset, size);
  }

  /**
   * @brief send data to all data servers
   *
   * @note  each trainer sends all its data to all data servers
   *        it's for broadcast data synchronization, such as features
   *        synchronization in metric learning.
   */
  template <class DataType>
  void sendData(int clientId, SendDataType type, DataUpdateMode updateMode,
                DataType* datas, size_t size) {
    SendJobPtr sendJob = std::make_shared<SendJob>();
    prepareData(clientId, type, updateMode, datas, size, sendJob.get());
    for (int i = 0; i < threadNum_; ++i) {
      sendJobQueue_[i]->enqueue(sendJob);
    }
  }

  /**
   * @brief recv data from all data servers
   *
   * @note  synchronize all recv threads
   */
  void recvData();

  /// send request, and recv responses
  template <typename ProtoIn, typename ProtoOut>
  void multiCall(const char* funcName, const ProtoIn& request,
                 std::vector<ProtoOut>* responses) {
    responses->resize(clients_.size());
    size_t numClients = clients_.size();
    for (size_t i = 0; i < numClients; ++i) {
      clients_[i].send(funcName, request);
    }
    for (size_t i = 0; i < numClients; ++i) {
      clients_[i].recv(&(*responses)[i]);
    }
  }

  /**
   * @brief synchronize all trainers and pservers
   *
   * @note  used to ensure that data of all trainers have been received
   */
  void synchronize(SyncObject syncObjectId = SYNC_DEFAULT);

  /**
   * @brief use multithread to separately send data
   *
   * @note  each thread should read its own JobQueue to handle requests
   *        each thread should calcClientId() to retrieve connections
   *        managed by himself.
   *        send and recv are implemented in child class.
   */
  virtual void send(int threadId) = 0;

  /**
   * @brief use multithread to separately receive data
   *
   * @note  almost same as send()
   */
  virtual void recv(int threadId) = 0;

protected:
  bool stopping_;
  /// nodes * ports that means the number of real pservers
  int serviceNum_;
  /**
   * threads num for managing all services. Normally the
   * number of pservers are relatively less than several
   * hundreds so that using thread-based parallelization
   * can benifit traffic performance and pserver's sgd
   * optimization performance.
   */
  int threadNum_;
  /// the connection manager at client end
  std::vector<ProtoClient> clients_;
  /// send threads for parallelization
  std::vector<ThreadPtr> sendThreads_;
  /// recv threads for parallelization
  std::vector<ThreadPtr> recvThreads_;
  std::unique_ptr<ThreadBarrier> recvSyncBarrier_;

  // TODO(yanfei):
  // current pserver's will return value until all parameters'
  // optimization are finished so that recv are not overlapped
  // in reality. More robust implimentation should be to pipeline
  // all send/recv action based on parameter unit level, and
  // it will benifits deep and larger model training in future,
  // especially local node compution power surpasses inter-connection
  // such as GPU cluster, even with BOX GPU cluster.
  // queue for buffering send request
  /**
   * send/recv queue cooperates with each other to accomplish
   * overlapping communication with forwardBackward action.
   */
  std::vector<std::unique_ptr<SendQueue>> sendJobQueue_;
  /// queue for buffering recv request
  std::vector<std::unique_ptr<SendQueue>> recvJobQueue_;
  /// specific for dserver
  SendJob sendJob_;
  /// port num for each node
  int numPorts_;
  /// if set, overlapped optimization is disabled
  bool separateSendAndRecv_;
  std::vector<CpuMemHandlePtr> recvDataMems_;
};
}  // namespace paddle