server.h 6.3 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <future>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
23

T
tangwei12 已提交
24 25 26
#include "butil/endpoint.h"
#include "google/protobuf/service.h"
#include "paddle/fluid/distributed/common/registerer.h"
27 28
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
29
#include "paddle/fluid/distributed/the_one_ps.pb.h"
T
tangwei12 已提交
30
#include "paddle/fluid/framework/channel.h"
31 32 33 34
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"

35 36 37 38 39 40 41 42 43 44 45
namespace google {
namespace protobuf {
class RpcController;
}  // namespace protobuf
}  // namespace google
namespace paddle {
namespace distributed {
class PSEnvironment;
}  // namespace distributed
}  // namespace paddle

46 47 48 49 50 51 52
namespace paddle {
namespace framework {
class Executor;
class ProgramDesc;
class Scope;
}  // namespace framework
}  // namespace paddle
T
tangwei12 已提交
53 54 55 56 57

namespace paddle {
namespace distributed {

class Table;
58

T
tangwei12 已提交
59 60
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
T
tangwei12 已提交
61 62 63 64 65 66 67 68

class PSServer {
 public:
  PSServer() {}
  virtual ~PSServer() {}
  PSServer(PSServer &&) = delete;
  PSServer(const PSServer &) = delete;

Z
zhaocaibei123 已提交
69
  virtual int32_t Configure(
70 71 72
      const PSParameter &config,
      PSEnvironment &env,
      size_t server_rank,
T
Thunderbrook 已提交
73
      const std::vector<framework::ProgramDesc> &server_sub_program = {});
T
tangwei12 已提交
74

Z
zhaocaibei123 已提交
75 76
  virtual uint64_t Start(const std::string &ip, uint32_t port) = 0;
  virtual int32_t Stop() = 0;
T
tangwei12 已提交
77

Z
zhaocaibei123 已提交
78
  inline size_t Rank() const { return _rank; }
T
tangwei12 已提交
79

Z
zhaocaibei123 已提交
80
  inline PSEnvironment *Environment() { return _environment; }
T
tangwei12 已提交
81

Z
zhaocaibei123 已提交
82 83
  inline const ServerParameter *Config() const { return &_config; }
  inline Table *GetTable(size_t table_id) {
T
tangwei12 已提交
84 85 86 87 88 89 90
    auto itr = _table_map.find(table_id);
    if (itr != _table_map.end()) {
      return itr->second.get();
    }
    return NULL;
  }

Z
zhaocaibei123 已提交
91
  inline std::unordered_map<uint32_t, std::shared_ptr<Table>> *GetTable() {
T
tangwei12 已提交
92 93 94
    return &_table_map;
  }

Z
zhaocaibei123 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
  // for cache
  virtual int32_t StartS2S() { return 0; }

  virtual ::std::future<int32_t> SendPServer2PServerMsg(
      int msg_type, int to_pserver_id, const std::string &msg) {
    LOG(FATAL) << "NotImplementError: PSServer::send_pserver2pserver_msg";
    std::promise<int32_t> promise;
    std::future<int> fut = promise.get_future();
    promise.set_value(-1);
    return fut;
  }

  typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
  virtual int RegistePServer2PServerMsgHandler(int msg_type,
                                               MsgHandlerFunc handler) {
    _msg_handler_map[msg_type] = handler;
    return 0;
  }
113 114
  virtual int HandlePServer2PServerMsg(int msg_type,
                                       int from_pserver_id,
Z
zhaocaibei123 已提交
115 116 117 118 119 120 121 122 123 124 125 126
                                       const std::string &msg) {
    auto itr = _msg_handler_map.find(msg_type);
    if (itr == _msg_handler_map.end()) {
      if (msg_type == 101) {
        return ReceiveFromPServer(msg_type, from_pserver_id, msg);
      } else {
        LOG(WARNING) << "unknown pserver2pserver_msg type:" << msg_type;
        return -1;
      }
    }
    return itr->second(msg_type, from_pserver_id, msg);
  }
127 128
  virtual int32_t ReceiveFromPServer(int msg_type,
                                     int pserver_id,
Z
zhaocaibei123 已提交
129 130 131 132 133 134 135
                                     const std::string &msg) {
    LOG(FATAL) << "NotImplementError::PSServer::ReceiveFromPServer";
    return -1;
  }

  paddle::framework::Channel<std::pair<uint64_t, std::string>> _shuffled_ins;

T
tangwei12 已提交
136
 protected:
Z
zhaocaibei123 已提交
137
  virtual int32_t Initialize() = 0;
T
tangwei12 已提交
138 139 140 141 142 143

 protected:
  size_t _rank;
  ServerParameter _config;
  PSEnvironment *_environment;
  std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
Z
zhaocaibei123 已提交
144
  std::unordered_map<int32_t, MsgHandlerFunc> _msg_handler_map;
145 146 147 148

 protected:
  std::shared_ptr<framework::Scope> scope_;
  platform::Place place_ = platform::CPUPlace();
T
tangwei12 已提交
149 150
};

T
tangwei12 已提交
151
REGISTER_PSCORE_REGISTERER(PSServer);
T
tangwei12 已提交
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176

typedef std::function<void(void *)> PServerCallBack;

class PServerClosure : public google::protobuf::Closure {
 public:
  PServerClosure(PServerCallBack callback) : _callback(callback) {}
  virtual ~PServerClosure() {}
  virtual void set_promise_value(int value) {
    for (auto &promise : _promises) {
      promise->set_value(value);
    }
  }
  void add_promise(std::shared_ptr<std::promise<int32_t>> &promise) {
    _promises.push_back(promise);
  }

 protected:
  PServerCallBack _callback;
  std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
};

class PsBaseService : public PsService {
 public:
  PsBaseService() : _rank(0), _server(NULL), _config(NULL) {}
  virtual ~PsBaseService() {}
Z
zhaocaibei123 已提交
177 178
  virtual size_t GetRank() { return _rank; }
  virtual int32_t Configure(PSServer *server) {
T
tangwei12 已提交
179
    _server = server;
Z
zhaocaibei123 已提交
180 181
    _rank = _server->Rank();
    _config = _server->Config();
T
tangwei12 已提交
182 183 184
    return 0;
  }
  virtual void service(::google::protobuf::RpcController *controller,
T
tangwei12 已提交
185 186
                       const PsRequestMessage *request,
                       PsResponseMessage *response,
T
tangwei12 已提交
187 188
                       ::google::protobuf::Closure *done) override = 0;

189 190
  virtual void set_response_code(PsResponseMessage &response,
                                 int err_code,
T
tangwei12 已提交
191 192 193 194 195 196
                                 const char *err_msg) {
    response.set_err_msg(err_msg);
    response.set_err_code(err_code);
    LOG(WARNING) << "Resonse err_code:" << err_code << " msg:" << err_msg;
  }

Z
zhaocaibei123 已提交
197 198
  virtual int32_t Initialize() = 0;
  PSServer *GetServer() { return _server; }
T
tangwei12 已提交
199 200 201 202 203 204

 protected:
  size_t _rank;
  PSServer *_server;
  const ServerParameter *_config;
};
T
tangwei12 已提交
205
REGISTER_PSCORE_REGISTERER(PsBaseService);
T
tangwei12 已提交
206 207 208

class PSServerFactory {
 public:
Z
zhaocaibei123 已提交
209
  static PSServer *Create(const PSParameter &config);
T
tangwei12 已提交
210 211 212
};
}  // namespace distributed
}  // namespace paddle