server.h 6.4 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <future>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
23

T
tangwei12 已提交
24 25 26
#include "butil/endpoint.h"
#include "google/protobuf/service.h"
#include "paddle/fluid/distributed/common/registerer.h"
27 28
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
29
#include "paddle/fluid/distributed/the_one_ps.pb.h"
T
tangwei12 已提交
30
#include "paddle/fluid/framework/channel.h"
31 32 33
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
34
#include "paddle/phi/core/macros.h"
35

36 37 38 39 40 41 42 43 44 45 46
namespace google {
namespace protobuf {
class RpcController;
}  // namespace protobuf
}  // namespace google
namespace paddle {
namespace distributed {
class PSEnvironment;
}  // namespace distributed
}  // namespace paddle

47 48 49 50 51 52 53
namespace paddle {
namespace framework {
class Executor;
class ProgramDesc;
class Scope;
}  // namespace framework
}  // namespace paddle
T
tangwei12 已提交
54 55 56 57 58

namespace paddle {
namespace distributed {

class Table;
59

T
tangwei12 已提交
60 61
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
T
tangwei12 已提交
62 63 64 65 66 67 68 69

class PSServer {
 public:
  PSServer() {}
  virtual ~PSServer() {}
  PSServer(PSServer &&) = delete;
  PSServer(const PSServer &) = delete;

Z
zhaocaibei123 已提交
70
  virtual int32_t Configure(
71
      const PSParameter &config,
72
      PSEnvironment &env,  // NOLINT
73
      size_t server_rank,
T
Thunderbrook 已提交
74
      const std::vector<framework::ProgramDesc> &server_sub_program = {});
T
tangwei12 已提交
75

Z
zhaocaibei123 已提交
76 77
  virtual uint64_t Start(const std::string &ip, uint32_t port) = 0;
  virtual int32_t Stop() = 0;
T
tangwei12 已提交
78

Z
zhaocaibei123 已提交
79
  inline size_t Rank() const { return _rank; }
T
tangwei12 已提交
80

Z
zhaocaibei123 已提交
81
  inline PSEnvironment *Environment() { return _environment; }
T
tangwei12 已提交
82

Z
zhaocaibei123 已提交
83 84
  inline const ServerParameter *Config() const { return &_config; }
  inline Table *GetTable(size_t table_id) {
T
tangwei12 已提交
85 86 87 88 89 90 91
    auto itr = _table_map.find(table_id);
    if (itr != _table_map.end()) {
      return itr->second.get();
    }
    return NULL;
  }

Z
zhaocaibei123 已提交
92
  inline std::unordered_map<uint32_t, std::shared_ptr<Table>> *GetTable() {
T
tangwei12 已提交
93 94 95
    return &_table_map;
  }

Z
zhaocaibei123 已提交
96 97 98 99
  // for cache
  virtual int32_t StartS2S() { return 0; }

  virtual ::std::future<int32_t> SendPServer2PServerMsg(
100 101 102
      int msg_type UNUSED,
      int to_pserver_id UNUSED,
      const std::string &msg UNUSED) {
Z
zhaocaibei123 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115
    LOG(FATAL) << "NotImplementError: PSServer::send_pserver2pserver_msg";
    std::promise<int32_t> promise;
    std::future<int> fut = promise.get_future();
    promise.set_value(-1);
    return fut;
  }

  typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
  virtual int RegistePServer2PServerMsgHandler(int msg_type,
                                               MsgHandlerFunc handler) {
    _msg_handler_map[msg_type] = handler;
    return 0;
  }
116 117
  virtual int HandlePServer2PServerMsg(int msg_type,
                                       int from_pserver_id,
Z
zhaocaibei123 已提交
118 119 120 121 122 123 124 125 126 127 128 129
                                       const std::string &msg) {
    auto itr = _msg_handler_map.find(msg_type);
    if (itr == _msg_handler_map.end()) {
      if (msg_type == 101) {
        return ReceiveFromPServer(msg_type, from_pserver_id, msg);
      } else {
        LOG(WARNING) << "unknown pserver2pserver_msg type:" << msg_type;
        return -1;
      }
    }
    return itr->second(msg_type, from_pserver_id, msg);
  }
130 131 132
  virtual int32_t ReceiveFromPServer(int msg_type UNUSED,
                                     int pserver_id UNUSED,
                                     const std::string &msg UNUSED) {
Z
zhaocaibei123 已提交
133 134 135 136 137 138
    LOG(FATAL) << "NotImplementError::PSServer::ReceiveFromPServer";
    return -1;
  }

  paddle::framework::Channel<std::pair<uint64_t, std::string>> _shuffled_ins;

T
tangwei12 已提交
139
 protected:
Z
zhaocaibei123 已提交
140
  virtual int32_t Initialize() = 0;
T
tangwei12 已提交
141 142 143 144 145 146

 protected:
  size_t _rank;
  ServerParameter _config;
  PSEnvironment *_environment;
  std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
Z
zhaocaibei123 已提交
147
  std::unordered_map<int32_t, MsgHandlerFunc> _msg_handler_map;
148 149 150 151

 protected:
  std::shared_ptr<framework::Scope> scope_;
  platform::Place place_ = platform::CPUPlace();
T
tangwei12 已提交
152 153
};

T
tangwei12 已提交
154
REGISTER_PSCORE_REGISTERER(PSServer);
T
tangwei12 已提交
155 156 157 158 159

typedef std::function<void(void *)> PServerCallBack;

class PServerClosure : public google::protobuf::Closure {
 public:
160
  explicit PServerClosure(PServerCallBack callback) : _callback(callback) {}
T
tangwei12 已提交
161 162 163 164 165 166
  virtual ~PServerClosure() {}
  virtual void set_promise_value(int value) {
    for (auto &promise : _promises) {
      promise->set_value(value);
    }
  }
167
  void add_promise(const std::shared_ptr<std::promise<int32_t>> &promise) {
T
tangwei12 已提交
168 169 170 171 172 173 174 175 176 177 178 179
    _promises.push_back(promise);
  }

 protected:
  PServerCallBack _callback;
  std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
};

class PsBaseService : public PsService {
 public:
  PsBaseService() : _rank(0), _server(NULL), _config(NULL) {}
  virtual ~PsBaseService() {}
Z
zhaocaibei123 已提交
180 181
  virtual size_t GetRank() { return _rank; }
  virtual int32_t Configure(PSServer *server) {
T
tangwei12 已提交
182
    _server = server;
Z
zhaocaibei123 已提交
183 184
    _rank = _server->Rank();
    _config = _server->Config();
T
tangwei12 已提交
185 186
    return 0;
  }
187 188 189 190
  void service(::google::protobuf::RpcController *controller,
               const PsRequestMessage *request,
               PsResponseMessage *response,
               ::google::protobuf::Closure *done) override = 0;
T
tangwei12 已提交
191

192
  virtual void set_response_code(PsResponseMessage &response,  // NOLINT
193
                                 int err_code,
T
tangwei12 已提交
194 195 196 197 198 199
                                 const char *err_msg) {
    response.set_err_msg(err_msg);
    response.set_err_code(err_code);
    LOG(WARNING) << "Resonse err_code:" << err_code << " msg:" << err_msg;
  }

Z
zhaocaibei123 已提交
200 201
  virtual int32_t Initialize() = 0;
  PSServer *GetServer() { return _server; }
T
tangwei12 已提交
202 203 204 205 206 207

 protected:
  size_t _rank;
  PSServer *_server;
  const ServerParameter *_config;
};
T
tangwei12 已提交
208
REGISTER_PSCORE_REGISTERER(PsBaseService);
T
tangwei12 已提交
209 210 211

class PSServerFactory {
 public:
Z
zhaocaibei123 已提交
212
  static PSServer *Create(const PSParameter &config);
T
tangwei12 已提交
213 214 215
};
}  // namespace distributed
}  // namespace paddle