server.h 4.6 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <future>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "butil/endpoint.h"
#include "google/protobuf/service.h"
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
27 28
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
T
tangwei12 已提交
29
#include "paddle/fluid/framework/channel.h"
30 31 32 33
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"

34 35 36 37 38 39 40 41 42 43 44
namespace google {
namespace protobuf {
class RpcController;
}  // namespace protobuf
}  // namespace google
namespace paddle {
namespace distributed {
class PSEnvironment;
}  // namespace distributed
}  // namespace paddle

45 46 47 48 49 50 51
namespace paddle {
namespace framework {
class Executor;
class ProgramDesc;
class Scope;
}  // namespace framework
}  // namespace paddle
T
tangwei12 已提交
52 53 54 55 56

namespace paddle {
namespace distributed {

class Table;
57

T
tangwei12 已提交
58 59
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
T
tangwei12 已提交
60 61 62 63 64 65 66 67

class PSServer {
 public:
  PSServer() {}
  virtual ~PSServer() {}
  PSServer(PSServer &&) = delete;
  PSServer(const PSServer &) = delete;

Z
zhaocaibei123 已提交
68
  virtual int32_t Configure(
69
      const PSParameter &config, PSEnvironment &env, size_t server_rank,
T
Thunderbrook 已提交
70
      const std::vector<framework::ProgramDesc> &server_sub_program = {});
T
tangwei12 已提交
71

Z
zhaocaibei123 已提交
72 73
  virtual uint64_t Start(const std::string &ip, uint32_t port) = 0;
  virtual int32_t Stop() = 0;
T
tangwei12 已提交
74

Z
zhaocaibei123 已提交
75
  inline size_t Rank() const { return _rank; }
T
tangwei12 已提交
76

Z
zhaocaibei123 已提交
77
  inline PSEnvironment *Environment() { return _environment; }
T
tangwei12 已提交
78

Z
zhaocaibei123 已提交
79 80
  inline const ServerParameter *Config() const { return &_config; }
  inline Table *GetTable(size_t table_id) {
T
tangwei12 已提交
81 82 83 84 85 86 87
    auto itr = _table_map.find(table_id);
    if (itr != _table_map.end()) {
      return itr->second.get();
    }
    return NULL;
  }

Z
zhaocaibei123 已提交
88
  inline std::unordered_map<uint32_t, std::shared_ptr<Table>> *GetTable() {
T
tangwei12 已提交
89 90 91 92
    return &_table_map;
  }

 protected:
Z
zhaocaibei123 已提交
93
  virtual int32_t Initialize() = 0;
T
tangwei12 已提交
94 95 96 97 98 99

 protected:
  size_t _rank;
  ServerParameter _config;
  PSEnvironment *_environment;
  std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
100 101 102 103

 protected:
  std::shared_ptr<framework::Scope> scope_;
  platform::Place place_ = platform::CPUPlace();
T
tangwei12 已提交
104 105
};

T
tangwei12 已提交
106
REGISTER_PSCORE_REGISTERER(PSServer);
T
tangwei12 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131

typedef std::function<void(void *)> PServerCallBack;

class PServerClosure : public google::protobuf::Closure {
 public:
  PServerClosure(PServerCallBack callback) : _callback(callback) {}
  virtual ~PServerClosure() {}
  virtual void set_promise_value(int value) {
    for (auto &promise : _promises) {
      promise->set_value(value);
    }
  }
  void add_promise(std::shared_ptr<std::promise<int32_t>> &promise) {
    _promises.push_back(promise);
  }

 protected:
  PServerCallBack _callback;
  std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
};

class PsBaseService : public PsService {
 public:
  PsBaseService() : _rank(0), _server(NULL), _config(NULL) {}
  virtual ~PsBaseService() {}
Z
zhaocaibei123 已提交
132 133
  virtual size_t GetRank() { return _rank; }
  virtual int32_t Configure(PSServer *server) {
T
tangwei12 已提交
134
    _server = server;
Z
zhaocaibei123 已提交
135 136
    _rank = _server->Rank();
    _config = _server->Config();
T
tangwei12 已提交
137 138 139
    return 0;
  }
  virtual void service(::google::protobuf::RpcController *controller,
T
tangwei12 已提交
140 141
                       const PsRequestMessage *request,
                       PsResponseMessage *response,
T
tangwei12 已提交
142 143 144 145 146 147 148 149 150
                       ::google::protobuf::Closure *done) override = 0;

  virtual void set_response_code(PsResponseMessage &response, int err_code,
                                 const char *err_msg) {
    response.set_err_msg(err_msg);
    response.set_err_code(err_code);
    LOG(WARNING) << "Resonse err_code:" << err_code << " msg:" << err_msg;
  }

Z
zhaocaibei123 已提交
151 152
  virtual int32_t Initialize() = 0;
  PSServer *GetServer() { return _server; }
T
tangwei12 已提交
153 154 155 156 157 158

 protected:
  size_t _rank;
  PSServer *_server;
  const ServerParameter *_config;
};
T
tangwei12 已提交
159
REGISTER_PSCORE_REGISTERER(PsBaseService);
T
tangwei12 已提交
160 161 162

class PSServerFactory {
 public:
Z
zhaocaibei123 已提交
163
  static PSServer *Create(const PSParameter &config);
T
tangwei12 已提交
164 165 166
};
}  // namespace distributed
}  // namespace paddle