heter_comm.h 9.9 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
T
Thunderbrook 已提交
16
#include <thread>
T
Thunderbrook 已提交
17
#include <vector>
18

Y
yaoxuefeng 已提交
19 20
#include "cub/cub.cuh"
#include "cub/util_allocator.cuh"
21
#if defined(PADDLE_WITH_CUDA)
T
Thunderbrook 已提交
22
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
T
Thunderbrook 已提交
23
#include "paddle/fluid/platform/cuda_device_guard.h"
24
#include "paddle/fluid/platform/dynload/nccl.h"
Y
yaoxuefeng 已提交
25
#include "paddle/fluid/platform/timer.h"
T
Thunderbrook 已提交
26
#include "thrust/pair.h"
27
#elif defined(PADDLE_WITH_XPU_KP)
28
// #include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
29
#include <xpu/runtime.h>
30

31 32 33 34 35 36 37 38 39
#include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#endif

#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/place.h"
T
Thunderbrook 已提交
40

T
Thunderbrook 已提交
41
#ifdef PADDLE_WITH_HETERPS
T
Thunderbrook 已提交
42 43 44 45

namespace paddle {
namespace framework {

Y
yaoxuefeng 已提交
46 47 48
#define TYPEALIGN(ALIGNVAL, LEN) \
  (((uint64_t)(LEN) + ((ALIGNVAL)-1)) & ~((uint64_t)((ALIGNVAL)-1)))

T
Thunderbrook 已提交
49 50 51 52 53 54 55 56
template <typename KeyType, typename ValType, typename GradType>
class HeterComm {
 public:
  HeterComm(size_t capacity, std::shared_ptr<HeterPsResource> resource);
  virtual ~HeterComm();
  HeterComm(const HeterComm&) = delete;
  HeterComm& operator=(const HeterComm&) = delete;

57 58 59 60 61 62 63 64 65 66
  void split_input_to_shard(KeyType* d_keys,
                            int* d_idx_ptr,
                            size_t len,
                            int* left,
                            int* right,
                            int gpu_num);
  void merge_grad(int gpu_num,
                  KeyType* d_keys,
                  GradType* d_grads,
                  size_t len,
67
                  int& uniq_len);  // NOLINT
68 69 70 71 72
  void dynamic_merge_grad(int gpu_num,
                          KeyType* d_keys,
                          GradType* d_grads,
                          size_t len,
                          int& uniq_len);
T
Thunderbrook 已提交
73
  void pull_sparse(int num, KeyType* d_keys, ValType* d_vals, size_t len);
74 75 76 77 78 79 80 81 82 83 84 85 86
  void build_ps(int num,
                KeyType* h_keys,
                ValType* h_vals,
                size_t len,
                size_t chunk_size,
                int stream_num);
  void build_ps(int num,
                KeyType* h_keys,
                char* pool,
                size_t len,
                size_t feature_value_size,
                size_t chunk_size,
                int stream_num);
T
Thunderbrook 已提交
87 88 89 90
  void dump();
  void show_one_table(int gpu_num);
  int get_index_by_devid(int devid);

91
#if defined(PADDLE_WITH_CUDA)
T
Thunderbrook 已提交
92
  template <typename Sgd>
93 94 95 96
  void push_sparse(int num,
                   KeyType* d_keys,
                   GradType* d_grads,
                   size_t len,
97
                   Sgd& sgd);  // NOLINT
98 99 100 101
#elif defined(PADDLE_WITH_XPU_KP)
  void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len);
#endif

102 103 104
  void set_sparse_sgd(const OptimizerConfig& optimizer_config);
  void set_embedx_sgd(const OptimizerConfig& optimizer_config);

105
  int log2i(int x);
T
Thunderbrook 已提交
106

107
  template <typename DstPlace, typename SrcPlace, typename StreamType>
108 109 110 111 112 113
  void memory_copy(DstPlace dst_place,
                   void* dst,
                   SrcPlace src_place,
                   const void* src,
                   size_t count,
                   StreamType stream = 0);
114 115

#if defined(PADDLE_WITH_CUDA)
116
  template <typename Sgd>
117 118 119 120 121
  void push_sparse_multi_node(int num,
                              KeyType* d_keys,
                              GradType* d_grads,
                              size_t len,
                              Sgd& sgd);  // NOLINT
122 123

  template <typename Sgd>
124 125 126 127
  void update_one_table(int num,
                        KeyType* d_keys,
                        GradType* d_grads,
                        size_t len,
128
                        Sgd& sgd);  // NOLINT
129

130 131 132
  int gather_one_node_grad(int num,
                           KeyType* d_keys,
                           GradType* d_grads,
133 134
                           int len);

135 136 137
  int gather_multi_node_grad(int num,
                             KeyType* d_keys,
                             GradType* d_grads,
138 139 140 141 142 143 144 145 146
                             int len);

  void set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
                              const std::vector<ncclComm_t>& inter_comms,
                              int comm_size) {
    nccl_inner_comms_ = inner_comms;
    nccl_inter_comms_ = inter_comms;
    node_size_ = comm_size;
  }
Y
yaoxuefeng 已提交
147 148 149 150 151

  void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) {
    multi_mf_dim_ = multi_mf_dim;
    max_mf_dim_ = max_mf_dim;
  }
152
#endif
153

154 155 156 157
  bool need_transfer(int send_id, int receive_id) {
    return ((send_id / 4 != receive_id / 4) && (send_id + 4) % 8 != receive_id);
  }

T
Thunderbrook 已提交
158 159
  // void dump_to_cpu(int index);

160 161
  int get_transfer_devid(int send_id) { return (send_id + 4) % 8; }

162 163
  void end_pass();

164
  struct Node {
165 166
    ppStream in_stream;
    ppStream out_stream;
167 168 169
    char* key_storage;
    char* val_storage;
    int sync;
Y
yaoxuefeng 已提交
170 171
    size_t key_bytes_len;
    size_t val_bytes_len;
172
    int dev_num;
173 174 175 176 177 178
  };

  struct Path {
    std::vector<Node> nodes_;
  };

179 180 181 182 183 184
  struct CopyTask {
    Path* path;
    int step;
    CopyTask(Path* path_, int step_) : path(path_), step(step_) {}
  };

185 186 187 188 189 190 191
  struct LocalStorage {
    LocalStorage() {}
    void init(int size, int dev_id) {
      place_ = platform::CUDAPlace(dev_id);
      alloc(size, true);
    }

192
    void alloc(size_t size, bool force = false) {
193 194 195
      if (force || size > all_keys_mem->size()) {
        all_keys_mem.reset();
        all_grads_mem.reset();
196 197
        all_keys_mem = memory::Alloc(place_, size * sizeof(KeyType));
        all_grads_mem = memory::Alloc(place_, size * sizeof(GradType));
198 199 200 201 202 203
        all_keys = reinterpret_cast<KeyType*>(all_keys_mem->ptr());
        all_grads = reinterpret_cast<GradType*>(all_grads_mem->ptr());
      }
      if (force || size > local_keys_mem->size()) {
        local_keys_mem.reset();
        local_grads_mem.reset();
204 205
        local_keys_mem = memory::Alloc(place_, size * sizeof(KeyType));
        local_grads_mem = memory::Alloc(place_, size * sizeof(GradType));
206 207 208 209 210
        local_keys = reinterpret_cast<KeyType*>(local_keys_mem->ptr());
        local_grads = reinterpret_cast<GradType*>(local_grads_mem->ptr());
      }
    }

211
#if defined(PADDLE_WITH_CUDA)
212
    platform::CUDAPlace place_;
F
Fan Zhang 已提交
213

214 215 216
#elif defined(PADDLE_WITH_XPU_KP)
    platform::XPUPlace place_;
#endif
217 218
    std::shared_ptr<memory::Allocation> all_keys_mem;
    std::shared_ptr<memory::Allocation> all_grads_mem;
F
Fan Zhang 已提交
219

220 221 222 223 224 225 226 227 228
    KeyType* all_keys;
    GradType* all_grads;

    std::shared_ptr<memory::Allocation> local_keys_mem;
    std::shared_ptr<memory::Allocation> local_grads_mem;
    KeyType* local_keys;
    GradType* local_grads;
  };

229
  void init_path();
T
Thunderbrook 已提交
230

231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
  template <typename StreamType>
  void sync_stream(const StreamType& stream) {
#if defined(PADDLE_WITH_CUDA)
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
#elif defined(PADDLE_WITH_XPU_KP)
    PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(stream));
#endif
  }

  template <typename StreamType>
  void create_stream(StreamType* stream) {
#if defined(PADDLE_WITH_CUDA)
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(stream));
#elif defined(PADDLE_WITH_XPU_KP)
    PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(stream));
#endif
  }

  template <typename StreamType>
  void destroy_stream(StreamType stream) {
#if defined(PADDLE_WITH_CUDA)
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream));
#elif defined(PADDLE_WITH_XPU_KP)
    PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_destroy(stream));
#endif
  }

T
Thunderbrook 已提交
258 259
  void create_storage(int start_index, int end_index, int keylen, int vallen);
  void destroy_storage(int start_index, int end_index);
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
  void walk_to_dest(int start_index,
                    int gpu_num,
                    int* h_left,
                    int* h_right,
                    KeyType* src_key,
                    GradType* src_val);
  void walk_to_dest(int start_index,
                    int gpu_num,
                    int* h_left,
                    int* h_right,
                    KeyType* src_key,
                    char* src_val,
                    size_t val_size);
  void walk_to_src(int start_index,
                   int gpu_num,
                   int* h_left,
                   int* h_right,
277
                   ValType* src_val);
278 279 280 281 282 283
  void walk_to_src(int start_index,
                   int gpu_num,
                   int* h_left,
                   int* h_right,
                   char* src_val,
                   size_t val_size);
T
Thunderbrook 已提交
284

S
seemingwang 已提交
285
 protected:
T
Thunderbrook 已提交
286
  using Table = HashTable<KeyType, ValType>;
Y
yaoxuefeng 已提交
287
  using PtrTable = HashTable<KeyType, ValType*>;
T
Thunderbrook 已提交
288
  std::vector<Table*> tables_;
Y
yaoxuefeng 已提交
289
  std::vector<PtrTable*> ptr_tables_;
T
Thunderbrook 已提交
290
  std::shared_ptr<HeterPsResource> resource_;
291
  std::vector<std::vector<Path>> path_;
S
seemingwang 已提交
292 293
  float load_factor_{0.75};
  int block_size_{256};
S
seemingwang 已提交
294
  std::unique_ptr<HeterCommKernel> heter_comm_kernel_;
S
seemingwang 已提交
295 296

 private:
S
seemingwang 已提交
297
  int topo_aware_{0};
298
  std::vector<LocalStorage> storage_;
Y
yaoxuefeng 已提交
299
  DynamicGradMerger merger_;
300
  int feanum_{1800 * 2048};
T
Thunderbrook 已提交
301
  int multi_node_{0};
302 303 304
  int node_size_;

#if defined(PADDLE_WITH_CUDA)
305 306
  std::vector<ncclComm_t> nccl_inner_comms_;
  std::vector<ncclComm_t> nccl_inter_comms_;
Y
yaoxuefeng 已提交
307 308
  int multi_mf_dim_{8};
  int max_mf_dim_ = 8;
T
Thunderbrook 已提交
309
  std::vector<std::shared_ptr<cub::CachingDeviceAllocator>> allocators_;
310
#endif
T
Thunderbrook 已提交
311 312 313 314
};

}  // end namespace framework
}  // end namespace paddle
F
Fan Zhang 已提交
315

T
Thunderbrook 已提交
316
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h"
F
Fan Zhang 已提交
317

T
Thunderbrook 已提交
318
#endif