heter_comm.h 6.4 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
T
Thunderbrook 已提交
16
#include <thread>
T
Thunderbrook 已提交
17 18
#include <vector>
#include "cub/cub.cuh"
T
Thunderbrook 已提交
19
#include "cub/util_allocator.cuh"
20 21
#include "hashtable.h"       // NOLINT
#include "heter_resource.h"  // NOLINT
T
Thunderbrook 已提交
22
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
23
#include "paddle/fluid/memory/allocation/allocator.h"
T
Thunderbrook 已提交
24 25
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
26
#include "paddle/fluid/platform/dynload/nccl.h"
T
Thunderbrook 已提交
27 28 29
#include "paddle/fluid/platform/place.h"
#include "thrust/pair.h"

T
Thunderbrook 已提交
30
#ifdef PADDLE_WITH_HETERPS
T
Thunderbrook 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61

namespace paddle {
namespace framework {

struct CustomGradMerger {
  template <typename T>
  CUB_RUNTIME_FUNCTION __forceinline__ __device__ T
  operator()(const T& a, const T& b) const {
    T out;
    out.slot = a.slot;
    out.show = a.show + b.show;
    out.clk = a.clk + b.clk;
    out.lr_g = a.lr_g + b.lr_g;
    for (int i = 0; i < MF_DIM; ++i) {
      out.mf_g[i] = a.mf_g[i] + b.mf_g[i];
    }
    return out;
  }
};

template <typename KeyType, typename ValType, typename GradType>
class HeterComm {
 public:
  HeterComm(size_t capacity, std::shared_ptr<HeterPsResource> resource);
  virtual ~HeterComm();
  HeterComm(const HeterComm&) = delete;
  HeterComm& operator=(const HeterComm&) = delete;

  void split_input_to_shard(KeyType* d_keys, int* d_idx_ptr, size_t len,
                            int* left, int* right, int gpu_num);
  void merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len,
62
                  int& uniq_len);  // NOLINT
T
Thunderbrook 已提交
63 64 65 66 67 68 69 70 71
  void pull_sparse(int num, KeyType* d_keys, ValType* d_vals, size_t len);
  void build_ps(int num, KeyType* h_keys, ValType* h_vals, size_t len,
                size_t chunk_size, int stream_num);
  void dump();
  void show_one_table(int gpu_num);
  int get_index_by_devid(int devid);

  template <typename Sgd>
  void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len,
72
                   Sgd& sgd);  // NOLINT
T
Thunderbrook 已提交
73

74 75
  template <typename Sgd>
  void push_sparse_multi_node(int num, KeyType* d_keys, GradType* d_grads,
76
                              size_t len, Sgd& sgd);  // NOLINT
77 78 79

  template <typename Sgd>
  void update_one_table(int num, KeyType* d_keys, GradType* d_grads, size_t len,
80
                        Sgd& sgd);  // NOLINT
81 82 83 84 85 86 87

  int gather_one_node_grad(int num, KeyType* d_keys, GradType* d_grads,
                           int len);

  int gather_multi_node_grad(int num, KeyType* d_keys, GradType* d_grads,
                             int len);

T
Thunderbrook 已提交
88
  int log2i(int x);
89 90 91 92 93 94 95 96 97

  void set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
                              const std::vector<ncclComm_t>& inter_comms,
                              int comm_size) {
    nccl_inner_comms_ = inner_comms;
    nccl_inter_comms_ = inter_comms;
    node_size_ = comm_size;
  }

98 99 100 101
  bool need_transfer(int send_id, int receive_id) {
    return ((send_id / 4 != receive_id / 4) && (send_id + 4) % 8 != receive_id);
  }

T
Thunderbrook 已提交
102 103 104 105
  // void dump_to_cpu(int index);

  void end_pass();

106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
  int get_transfer_devid(int send_id) { return (send_id + 4) % 8; }

  struct Node {
    cudaStream_t in_stream;
    cudaStream_t out_stream;
    char* key_storage;
    char* val_storage;
    int sync;
    int key_bytes_len;
    int val_bytes_len;
    int gpu_num;
  };

  struct Path {
    std::vector<Node> nodes_;
  };

123 124 125 126 127 128
  struct CopyTask {
    Path* path;
    int step;
    CopyTask(Path* path_, int step_) : path(path_), step(step_) {}
  };

129 130 131 132 133 134 135 136 137 138 139
  struct LocalStorage {
    LocalStorage() {}
    void init(int size, int dev_id) {
      place_ = platform::CUDAPlace(dev_id);
      alloc(size, true);
    }

    void alloc(int size, bool force = false) {
      if (force || size > all_keys_mem->size()) {
        all_keys_mem.reset();
        all_grads_mem.reset();
140 141
        all_keys_mem = memory::Alloc(place_, size * sizeof(KeyType));
        all_grads_mem = memory::Alloc(place_, size * sizeof(GradType));
142 143 144 145 146 147
        all_keys = reinterpret_cast<KeyType*>(all_keys_mem->ptr());
        all_grads = reinterpret_cast<GradType*>(all_grads_mem->ptr());
      }
      if (force || size > local_keys_mem->size()) {
        local_keys_mem.reset();
        local_grads_mem.reset();
148 149
        local_keys_mem = memory::Alloc(place_, size * sizeof(KeyType));
        local_grads_mem = memory::Alloc(place_, size * sizeof(GradType));
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
        local_keys = reinterpret_cast<KeyType*>(local_keys_mem->ptr());
        local_grads = reinterpret_cast<GradType*>(local_grads_mem->ptr());
      }
    }

    platform::CUDAPlace place_;
    std::shared_ptr<memory::Allocation> all_keys_mem;
    std::shared_ptr<memory::Allocation> all_grads_mem;
    KeyType* all_keys;
    GradType* all_grads;

    std::shared_ptr<memory::Allocation> local_keys_mem;
    std::shared_ptr<memory::Allocation> local_grads_mem;
    KeyType* local_keys;
    GradType* local_grads;
  };

167
  void init_path();
T
Thunderbrook 已提交
168 169 170

  void create_storage(int start_index, int end_index, int keylen, int vallen);
  void destroy_storage(int start_index, int end_index);
171 172 173 174
  void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right,
                    KeyType* src_key, GradType* src_val);
  void walk_to_src(int start_index, int gpu_num, int* h_left, int* h_right,
                   ValType* src_val);
T
Thunderbrook 已提交
175 176 177 178 179 180 181 182

 private:
  using Table = HashTable<KeyType, ValType>;
  int block_size_{256};
  float load_factor_{0.75};
  std::vector<Table*> tables_;
  std::shared_ptr<HeterPsResource> resource_;
  CustomGradMerger merger_;
T
Thunderbrook 已提交
183
  int topo_aware_{0};
184
  std::vector<std::vector<Path>> path_;
185 186
  std::vector<LocalStorage> storage_;
  int feanum_{1800 * 2048};
T
Thunderbrook 已提交
187
  int multi_node_{0};
188 189 190
  std::vector<ncclComm_t> nccl_inner_comms_;
  std::vector<ncclComm_t> nccl_inter_comms_;
  int node_size_;
T
Thunderbrook 已提交
191
  std::vector<std::shared_ptr<cub::CachingDeviceAllocator>> allocators_;
T
Thunderbrook 已提交
192 193 194 195
};

}  // end namespace framework
}  // end namespace paddle
T
Thunderbrook 已提交
196
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h"
T
Thunderbrook 已提交
197
#endif