graph_gpu_ps_table.h 5.4 KB
Newer Older
S
seemingwang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16
#include <thrust/host_vector.h>
17
#include <chrono>
S
seemingwang 已提交
18
#include "heter_comm.h"
19 20
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
S
seemingwang 已提交
21
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h"
S
seemingwang 已提交
22 23 24 25 26 27
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_HETERPS
namespace paddle {
namespace framework {
class GpuPsGraphTable : public HeterComm<int64_t, int, int> {
 public:
28
  GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource, int topo_aware)
S
seemingwang 已提交
29 30
      : HeterComm<int64_t, int, int>(1, resource) {
    load_factor_ = 0.25;
31
    rw_lock.reset(new pthread_rwlock_t());
S
seemingwang 已提交
32
    gpu_num = resource_->total_device();
33
    cpu_table_status = -1;
34
    if (topo_aware) {
S
seemingwang 已提交
35
      int total_gpu = resource_->total_device();
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
      std::map<int, int> device_map;
      for (int i = 0; i < total_gpu; i++) {
        device_map[resource_->dev_id(i)] = i;
        VLOG(1) << " device " << resource_->dev_id(i) << " is stored on " << i;
      }
      path_.clear();
      path_.resize(total_gpu);
      VLOG(1) << "topo aware overide";
      for (int i = 0; i < total_gpu; ++i) {
        path_[i].resize(total_gpu);
        for (int j = 0; j < total_gpu; ++j) {
          auto &nodes = path_[i][j].nodes_;
          nodes.clear();
          int from = resource_->dev_id(i);
          int to = resource_->dev_id(j);
          int transfer_id = i;
          if (need_transfer(from, to) &&
              (device_map.find((from + 4) % 8) != device_map.end() ||
               device_map.find((to + 4) % 8) != device_map.end())) {
            transfer_id = (device_map.find((from + 4) % 8) != device_map.end())
                              ? ((from + 4) % 8)
                              : ((to + 4) % 8);
            transfer_id = device_map[transfer_id];
            nodes.push_back(Node());
            Node &node = nodes.back();
            node.in_stream = resource_->comm_stream(i, transfer_id);
            node.out_stream = resource_->comm_stream(transfer_id, i);
            node.key_storage = NULL;
            node.val_storage = NULL;
            node.sync = 0;
S
seemingwang 已提交
66
            node.dev_num = transfer_id;
67 68 69 70 71 72 73 74
          }
          nodes.push_back(Node());
          Node &node = nodes.back();
          node.in_stream = resource_->comm_stream(i, transfer_id);
          node.out_stream = resource_->comm_stream(transfer_id, i);
          node.key_storage = NULL;
          node.val_storage = NULL;
          node.sync = 0;
S
seemingwang 已提交
75
          node.dev_num = j;
76 77 78
        }
      }
    }
79 80
  }
  ~GpuPsGraphTable() {
81 82 83
    // if (cpu_table_status != -1) {
    //   end_graph_sampling();
    // }
S
seemingwang 已提交
84 85
  }
  void build_graph_from_cpu(std::vector<GpuPsCommGraph> &cpu_node_list);
86 87 88 89 90 91 92 93 94
  NodeQueryResult graph_node_sample(int gpu_id, int sample_size);
  NeighborSampleResult graph_neighbor_sample_v3(NeighborSampleQuery q,
                                                bool cpu_switch);
  NeighborSampleResult graph_neighbor_sample(int gpu_id, int64_t *key,
                                             int sample_size, int len);
  NeighborSampleResult graph_neighbor_sample_v2(int gpu_id, int64_t *key,
                                                int sample_size, int len,
                                                bool cpu_query_switch);
  NodeQueryResult query_node_list(int gpu_id, int start, int query_size);
S
seemingwang 已提交
95
  void clear_graph_info();
96 97 98 99 100 101 102 103 104 105 106 107
  void move_neighbor_sample_result_to_source_gpu(int gpu_id, int gpu_num,
                                                 int sample_size, int *h_left,
                                                 int *h_right,
                                                 int64_t *src_sample_res,
                                                 int *actual_sample_size);
  // void move_neighbor_sample_result_to_source_gpu(
  //     int gpu_id, int gpu_num, int *h_left, int *h_right,
  //     int64_t *src_sample_res, thrust::host_vector<int> &total_sample_size);
  // void move_neighbor_sample_size_to_source_gpu(int gpu_id, int gpu_num,
  //                                              int *h_left, int *h_right,
  //                                              int *actual_sample_size,
  //                                              int *total_sample_size);
108
  int init_cpu_table(const paddle::distributed::GraphParameter &graph);
109 110 111 112 113
  // int load(const std::string &path, const std::string &param);
  // virtual int32_t end_graph_sampling() {
  //   return cpu_graph_table->end_graph_sampling();
  // }
  int gpu_num;
S
seemingwang 已提交
114
  std::vector<GpuPsCommGraph> gpu_graph_list;
115 116 117
  std::vector<int *> sample_status;
  const int parallel_sample_size = 1;
  const int dim_y = 256;
118 119 120 121 122
  std::shared_ptr<paddle::distributed::GraphTable> cpu_graph_table;
  std::shared_ptr<pthread_rwlock_t> rw_lock;
  mutable std::mutex mutex_;
  std::condition_variable cv_;
  int cpu_table_status;
S
seemingwang 已提交
123 124 125 126 127
};
}
};
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h"
#endif