env.h 8.0 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <arpa/inet.h>
#include <glog/logging.h>
#include <netinet/in.h>
#include <stdio.h>
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
26
#include "gflags/gflags.h"
T
tangwei12 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41

namespace paddle {
namespace distributed {

struct PSHost {
  std::string ip;
  uint32_t port;
  uint32_t rank;

  PSHost() = default;
  PSHost(const std::string ip, uint32_t port, uint32_t rank)
      : ip(ip), port(port), rank(rank) {}

  // |---ip---|---port---|--rank--|
  // |-32bit--|--20bit---|--12bit-|
T
tangwei12 已提交
42

Z
zhaocaibei123 已提交
43
  uint64_t SerializeToUint64() {
T
tangwei12 已提交
44 45 46 47 48 49 50 51
    uint64_t host_label = 0;
    host_label = inet_addr(ip.c_str());
    host_label = host_label << 32;
    host_label += (port << 12);
    host_label += rank;
    return host_label;
  }

Z
zhaocaibei123 已提交
52
  void ParseFromUint64(uint64_t host_label) {
T
tangwei12 已提交
53 54 55 56 57
    static uint64_t rank_label_mask = (1L << 12) - 1;
    static uint64_t port_label_mask = (1L << 20) - 1;
    rank = host_label & rank_label_mask;
    port = (host_label >> 12) & port_label_mask;
    uint32_t ip_addr = (host_label >> 32);
58
    ip = inet_ntoa(*(in_addr *)&ip_addr);  // NOLINT
T
tangwei12 已提交
59 60
  }

Z
zhaocaibei123 已提交
61
  std::string ToString() {
T
tangwei12 已提交
62 63 64 65
    std::stringstream s;
    s << "host: " << ip;
    s << " port: " << port;
    s << " rank: " << rank;
Z
zhaocaibei123 已提交
66
    s << " uint: " << SerializeToUint64();
T
tangwei12 已提交
67 68 69 70
    return s.str();
  }

  // for open source parameter server
Z
zhaocaibei123 已提交
71
  std::string SerializeToString() {
T
tangwei12 已提交
72 73 74 75 76 77 78
    std::stringstream s;
    s << ip << ":";
    s << port << ":";
    s << rank;
    return s.str();
  }

Z
zhaocaibei123 已提交
79
  void ParseFromString(std::string endpoint) {
T
tangwei12 已提交
80
    std::vector<std::string> endpoint_info;
Z
zhaocaibei123 已提交
81
    StringSplit(endpoint, ':', &endpoint_info);
T
tangwei12 已提交
82 83 84 85 86
    ip = endpoint_info[0];
    port = std::stoi(endpoint_info[1]);
    rank = std::stoi(endpoint_info[2]);
  }

Z
zhaocaibei123 已提交
87 88
  void StringSplit(const std::string &str, char sep,
                   std::vector<std::string> *pieces, bool ignore_null = true) {
T
tangwei12 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
    pieces->clear();
    if (str.empty()) {
      if (!ignore_null) {
        pieces->push_back(str);
      }
      return;
    }
    size_t pos = 0;
    size_t next = str.find(sep, pos);
    while (next != std::string::npos) {
      pieces->push_back(str.substr(pos, next - pos));
      pos = next + 1;
      next = str.find(sep, pos);
    }
    if (!str.substr(pos).empty()) {
      pieces->push_back(str.substr(pos));
    }
  }
};

class PSEnvironment {
 public:
111
  explicit PSEnvironment() {}  // NOLINT
T
tangwei12 已提交
112 113
  virtual ~PSEnvironment() {}

Z
zhaocaibei123 已提交
114
  virtual int32_t SetPsServers(uint64_t *host_sign_list, int node_num) {
T
tangwei12 已提交
115 116
    return 0;
  }
Z
zhaocaibei123 已提交
117
  virtual int32_t SetPsServers(
T
tangwei12 已提交
118 119 120 121
      const std::vector<std::string> *host_endpoint_list, int node_num) {
    return 0;
  }

Z
zhaocaibei123 已提交
122
  virtual int32_t SetPsClients(uint64_t *host_sign_list, int node_num) {
T
tangwei12 已提交
123 124 125
    return 0;
  }

Z
zhaocaibei123 已提交
126
  virtual int32_t SetPsClients(std::string *host_endpoint_list, int node_num) {
T
tangwei12 已提交
127 128
    return 0;
  }
Z
zhaocaibei123 已提交
129 130 131 132 133
  virtual uint64_t GetLocalHostSign() { return 0; }
  virtual std::vector<PSHost> GetPsServers() const { return _ps_server_list; }
  virtual int32_t RegistePsServer(const std::string &ip, uint32_t port,
                                  int32_t rank) {
    return RegistePsHost(ip, port, rank, _ps_server_list, _ps_server_sign_set);
T
tangwei12 已提交
134 135
  }

Z
zhaocaibei123 已提交
136 137 138 139
  virtual std::vector<PSHost> GetPsClients() const { return _ps_client_list; }
  virtual int32_t RegistePsClient(const std::string &ip, uint32_t port,
                                  int32_t rank) {
    return RegistePsHost(ip, port, rank, _ps_client_list, _ps_client_sign_set);
T
tangwei12 已提交
140 141
  }

Z
zhaocaibei123 已提交
142
  virtual std::vector<uint64_t> GetClientInfo() {
T
tangwei12 已提交
143
    std::vector<uint64_t> client_info;
144
    for (auto &i : _ps_client_list) {
Z
zhaocaibei123 已提交
145
      client_info.push_back(i.SerializeToUint64());
T
tangwei12 已提交
146 147 148 149
    }
    return client_info;
  }

Z
zhaocaibei123 已提交
150
  virtual std::vector<std::string> GetClientInfo(bool use_string_endpoint) {
T
tangwei12 已提交
151 152 153
    if (use_string_endpoint) {
      std::vector<std::string> client_info;
      for (auto &i : _ps_client_list) {
Z
zhaocaibei123 已提交
154
        client_info.push_back(i.SerializeToString());
T
tangwei12 已提交
155 156 157 158 159 160
      }
      return client_info;
    }
    return {};
  }

Z
zhaocaibei123 已提交
161
  virtual void SetTrainers(int trainers) { trainers_ = trainers; }
T
tangwei12 已提交
162

Z
zhaocaibei123 已提交
163
  virtual int GetTrainers() { return trainers_; }
T
tangwei12 已提交
164

T
tangwei12 已提交
165
 protected:
166
  //注册一个host //  NOLINT
Z
zhaocaibei123 已提交
167
  virtual int32_t RegistePsHost(
168 169 170
      const std::string &ip, uint32_t port, int32_t rank,
      std::vector<PSHost> &host_list,            // NOLINT
      std::unordered_set<uint64_t> &sign_set) {  // NOLINT
T
tangwei12 已提交
171 172 173 174
    PSHost host;
    host.ip = ip;
    host.port = port;
    host.rank = rank;
T
tangwei12 已提交
175 176

    if (sign_set.count(rank) == 0) {
T
tangwei12 已提交
177 178 179
      host_list.push_back(host);
      sign_set.insert(rank);
    }
T
tangwei12 已提交
180

T
tangwei12 已提交
181 182 183
    return 0;
  }

T
tangwei12 已提交
184 185
  int trainers_ = 0;

T
tangwei12 已提交
186 187 188 189 190 191 192 193 194
  std::vector<PSHost> _ps_client_list;
  std::unordered_set<uint64_t> _ps_client_sign_set;  // for unique filter

  std::vector<PSHost> _ps_server_list;
  std::unordered_set<uint64_t> _ps_server_sign_set;  // for unique filter
};

class PaddlePSEnvironment : public PSEnvironment {
 public:
195
  explicit PaddlePSEnvironment() {}  // NOLINT
T
tangwei12 已提交
196 197
  virtual ~PaddlePSEnvironment() {}

Z
zhaocaibei123 已提交
198
  virtual int32_t SetPsServers(uint64_t *host_sign_list, int node_num) {
T
tangwei12 已提交
199 200 201 202 203
    _ps_server_list.clear();
    _ps_server_sign_set.clear();
    for (int i = 0; i < node_num; ++i) {
      if (host_sign_list[i] > 0) {
        PSHost host;
Z
zhaocaibei123 已提交
204
        host.ParseFromUint64(host_sign_list[i]);
T
tangwei12 已提交
205
        _ps_server_list.push_back(host);
Z
zhaocaibei123 已提交
206
        _ps_server_sign_set.insert(host.SerializeToUint64());
T
tangwei12 已提交
207 208 209 210 211 212 213 214
      }
    }
    std::sort(
        _ps_server_list.begin(), _ps_server_list.end(),
        [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
    return 0;
  }

Z
zhaocaibei123 已提交
215 216
  virtual int32_t SetPsServers(const std::vector<std::string> *host_sign_list,
                               int node_num) {
T
tangwei12 已提交
217 218 219 220 221
    _ps_server_list.clear();
    _ps_server_sign_set.clear();
    for (int i = 0; i < node_num; ++i) {
      if (host_sign_list->at(i) != "") {
        PSHost host;
Z
zhaocaibei123 已提交
222
        host.ParseFromString(host_sign_list->at(i));
T
tangwei12 已提交
223 224 225 226 227 228 229 230 231 232
        _ps_server_list.push_back(host);
        _ps_server_sign_set.insert(host.rank);
      }
    }
    std::sort(
        _ps_server_list.begin(), _ps_server_list.end(),
        [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
    return 0;
  }

Z
zhaocaibei123 已提交
233
  virtual int32_t SetPsClients(uint64_t *host_sign_list, int node_num) {
T
tangwei12 已提交
234 235 236 237 238
    _ps_client_list.clear();
    _ps_client_sign_set.clear();
    for (int i = 0; i < node_num; ++i) {
      if (host_sign_list[i] > 0) {
        PSHost host;
Z
zhaocaibei123 已提交
239
        host.ParseFromUint64(host_sign_list[i]);
T
tangwei12 已提交
240
        _ps_client_list.push_back(host);
Z
zhaocaibei123 已提交
241
        _ps_client_sign_set.insert(host.SerializeToUint64());
T
tangwei12 已提交
242 243 244 245 246 247 248 249
      }
    }
    std::sort(
        _ps_client_list.begin(), _ps_client_list.end(),
        [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
    return 0;
  }

Z
zhaocaibei123 已提交
250 251
  virtual int32_t SetPsClients(const std::vector<std::string> *host_sign_list,
                               int node_num) {
T
tangwei12 已提交
252 253 254 255 256
    _ps_client_list.clear();
    _ps_client_sign_set.clear();
    for (int i = 0; i < node_num; ++i) {
      if (host_sign_list->at(i) != "") {
        PSHost host;
Z
zhaocaibei123 已提交
257
        host.ParseFromString(host_sign_list->at(i));
T
tangwei12 已提交
258 259 260 261 262 263 264
        _ps_client_list.push_back(host);
        _ps_client_sign_set.insert(host.rank);
      }
    }
    std::sort(
        _ps_client_list.begin(), _ps_client_list.end(),
        [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
265
    VLOG(1) << "env.set_ps_clients done\n";
T
tangwei12 已提交
266 267 268
    return 0;
  }

Z
zhaocaibei123 已提交
269
  virtual uint64_t GetLocalHostSign() {
T
tangwei12 已提交
270
    if (_ps_client_list.size() > 0) {
Z
zhaocaibei123 已提交
271
      return _ps_client_list[0].SerializeToUint64();
T
tangwei12 已提交
272 273 274 275 276 277 278 279
    } else {
      return 0;
    }
  }
};

}  // namespace distributed
}  // namespace paddle