graph_brpc_server.h 6.1 KB
Newer Older
S
seemingwang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"

#include <memory>
#include <vector>
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/service/server.h"
S
seemingwang 已提交
25 26
#include "paddle/fluid/distributed/table/common_graph_table.h"
#include "paddle/fluid/distributed/table/table.h"
S
seemingwang 已提交
27 28 29 30 31 32 33 34
namespace paddle {
namespace distributed {
class GraphBrpcServer : public PSServer {
 public:
  GraphBrpcServer() {}
  virtual ~GraphBrpcServer() {}
  PsBaseService *get_service() { return _service.get(); }
  virtual uint64_t start(const std::string &ip, uint32_t port);
S
seemingwang 已提交
35 36
  virtual int32_t build_peer2peer_connection(int rank);
  virtual brpc::Channel *get_cmd_channel(size_t server_index);
S
seemingwang 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
  virtual int32_t stop() {
    std::unique_lock<std::mutex> lock(mutex_);
    if (stoped_) return 0;
    stoped_ = true;
    // cv_.notify_all();
    _server.Stop(1000);
    _server.Join();
    return 0;
  }
  virtual int32_t port();

  std::condition_variable *export_cv() { return &cv_; }

 private:
  virtual int32_t initialize();
  mutable std::mutex mutex_;
  std::condition_variable cv_;
  bool stoped_ = false;
S
seemingwang 已提交
55
  int rank;
S
seemingwang 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
  brpc::Server _server;
  std::shared_ptr<PsBaseService> _service;
  std::vector<std::shared_ptr<brpc::Channel>> _pserver_channels;
};

class GraphBrpcService;

typedef int32_t (GraphBrpcService::*serviceFunc)(
    Table *table, const PsRequestMessage &request, PsResponseMessage &response,
    brpc::Controller *cntl);

class GraphBrpcService : public PsBaseService {
 public:
  virtual int32_t initialize() override;

  virtual void service(::google::protobuf::RpcController *controller,
                       const PsRequestMessage *request,
                       PsResponseMessage *response,
                       ::google::protobuf::Closure *done) override;

 protected:
  std::unordered_map<int32_t, serviceFunc> _service_handler_map;
  int32_t initialize_shard_info();
  int32_t pull_graph_list(Table *table, const PsRequestMessage &request,
                          PsResponseMessage &response, brpc::Controller *cntl);
81 82 83 84
  int32_t graph_random_sample_neighbors(Table *table,
                                        const PsRequestMessage &request,
                                        PsResponseMessage &response,
                                        brpc::Controller *cntl);
S
seemingwang 已提交
85 86 87 88
  int32_t graph_random_sample_nodes(Table *table,
                                    const PsRequestMessage &request,
                                    PsResponseMessage &response,
                                    brpc::Controller *cntl);
S
seemingwang 已提交
89

S
seemingwang 已提交
90 91 92
  int32_t graph_get_node_feat(Table *table, const PsRequestMessage &request,
                              PsResponseMessage &response,
                              brpc::Controller *cntl);
S
seemingwang 已提交
93 94 95
  int32_t graph_set_node_feat(Table *table, const PsRequestMessage &request,
                              PsResponseMessage &response,
                              brpc::Controller *cntl);
96 97 98 99 100 101 102
  int32_t clear_nodes(Table *table, const PsRequestMessage &request,
                      PsResponseMessage &response, brpc::Controller *cntl);
  int32_t add_graph_node(Table *table, const PsRequestMessage &request,
                         PsResponseMessage &response, brpc::Controller *cntl);
  int32_t remove_graph_node(Table *table, const PsRequestMessage &request,
                            PsResponseMessage &response,
                            brpc::Controller *cntl);
S
seemingwang 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
  int32_t barrier(Table *table, const PsRequestMessage &request,
                  PsResponseMessage &response, brpc::Controller *cntl);
  int32_t load_one_table(Table *table, const PsRequestMessage &request,
                         PsResponseMessage &response, brpc::Controller *cntl);
  int32_t load_all_table(Table *table, const PsRequestMessage &request,
                         PsResponseMessage &response, brpc::Controller *cntl);
  int32_t stop_server(Table *table, const PsRequestMessage &request,
                      PsResponseMessage &response, brpc::Controller *cntl);
  int32_t start_profiler(Table *table, const PsRequestMessage &request,
                         PsResponseMessage &response, brpc::Controller *cntl);
  int32_t stop_profiler(Table *table, const PsRequestMessage &request,
                        PsResponseMessage &response, brpc::Controller *cntl);

  int32_t print_table_stat(Table *table, const PsRequestMessage &request,
                           PsResponseMessage &response, brpc::Controller *cntl);

119 120 121 122 123 124 125 126 127
  int32_t sample_neighbors_across_multi_servers(Table *table,
                                                const PsRequestMessage &request,
                                                PsResponseMessage &response,
                                                brpc::Controller *cntl);

  int32_t use_neighbors_sample_cache(Table *table,
                                     const PsRequestMessage &request,
                                     PsResponseMessage &response,
                                     brpc::Controller *cntl);
S
seemingwang 已提交
128

129 130 131 132
  int32_t load_graph_split_config(Table *table, const PsRequestMessage &request,
                                  PsResponseMessage &response,
                                  brpc::Controller *cntl);

S
seemingwang 已提交
133 134 135 136 137 138
 private:
  bool _is_initialize_shard_info;
  std::mutex _initialize_shard_mutex;
  std::unordered_map<int32_t, serviceHandlerFunc> _msg_handler_map;
  std::vector<float> _ori_values;
  const int sample_nodes_ranges = 23;
S
seemingwang 已提交
139 140
  size_t server_size;
  std::shared_ptr<::ThreadPool> task_pool;
S
seemingwang 已提交
141 142 143 144
};

}  // namespace distributed
}  // namespace paddle