rpc_server.h 4.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH)
#pragma once
#include <glog/logging.h>
#include <atomic>
#include <functional>
#include <mutex>
#include <thread>
#include <vector>
#include "paddle/fluid/framework/archive.h"

namespace paddle {
namespace framework {
class GlooWrapper;
}
namespace distributed {
namespace simple {
using BinaryArchive = paddle::framework::BinaryArchive;

class RpcService;
class RpcRequest;

struct RpcMessageHead {
  RpcService *service;
  RpcRequest *request;
  int client_id;
  int server_id;
  enum { REQUEST, RESPONSE } message_type;
  int consumer_id;
};

typedef std::function<void(const RpcMessageHead &, BinaryArchive &)>
    RpcCallback;  // NOLINT

class RpcService {
 public:
  RpcService() {}
  explicit RpcService(RpcCallback callback);
  ~RpcService();
  RpcService *remote_pointer(int rank) { return _remote_ptrs[rank]; }
  RpcCallback &callback() { return _callback; }
  void increase_request() { ++_request_counter; }
  void decrease_request() { --_request_counter; }

 protected:
  std::vector<RpcService *> _remote_ptrs;
  RpcCallback _callback;
  std::atomic<int> _request_counter{0};
};

class RpcRequest {
 public:
  explicit RpcRequest(RpcCallback callback) : _callback(std::move(callback)) {}
  RpcCallback &callback() { return _callback; }

 protected:
  RpcCallback _callback;
};

class RpcServer {
 public:
  RpcServer();
  virtual ~RpcServer();

 public:
  void set_connection_num(int n);
  void set_thread_num(int n);
  void set_connection_idle_timeout_sec(int timeout_sec) {
    _connection_idle_timeout_sec = timeout_sec;
  }
  void set_max_retry(int retry_cnt) { _max_retry = retry_cnt; }
  void set_connect_timeout_ms(int timeout_ms) {
    _connect_timeout_ms = timeout_ms;
  }
  void set_connection_type(const std::string &conn_type) {
    _connection_type = conn_type;
  }
  void set_client_timeout_ms(int timeout_ms) {
    _client_timeout_ms = timeout_ms;
  }

 public:
  virtual void initialize() = 0;
  virtual void finalize() = 0;
  virtual void send_request(int server_id,
                            void *service_,
                            const size_t n,
                            BinaryArchive *oars,
                            RpcCallback callback) = 0;
  virtual void send_response(RpcMessageHead head,
                             const size_t n,
                             BinaryArchive *oars) = 0;
  virtual void send_request_ex(int server_id,
                               int consumer_id,
                               void *service_,
                               const size_t n,
                               BinaryArchive *oars,
                               RpcCallback callback) = 0;

 public:
  virtual void *add_service(RpcCallback callback, bool simplex = true);
  virtual void remove_service(void *service);

 public:
  void send_request_wrapper(int server_id,
                            void *service,
                            BinaryArchive &oar,  // NOLINT
                            RpcCallback callback) {
    send_request(server_id, service, 1, &oar, std::move(callback));
  }
  void send_request_consumer(int server_id,
                             int consumer_id,
                             void *service,
                             BinaryArchive &oar,  // NOLINT
                             RpcCallback callback) {
    send_request_ex(
        server_id, consumer_id, service, 1, &oar, std::move(callback));
  }
  void send_response(RpcMessageHead head, BinaryArchive &oar) {  // NOLINT
    send_response(head, 1, &oar);
  }

 protected:
  int _conn_num = 1;
  int _thread_num = 10;
  std::vector<uint32_t> _ips;
  paddle::framework::GlooWrapper *_gloo = NULL;
  // configure for rpc
  int _connection_idle_timeout_sec = 3600;
  int _max_retry = 1000;
  int _connect_timeout_ms = -1;
  std::string _connection_type = "pooled";
  int _client_timeout_ms = -1;
};

extern RpcServer &global_rpc_server();
}  // namespace simple
}  // namespace distributed
}  // namespace paddle
#endif