heter_comm.h 11.7 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
T
Thunderbrook 已提交
16
#include <thread>
T
Thunderbrook 已提交
17
#include <vector>
18

Y
yaoxuefeng 已提交
19 20
#include "cub/cub.cuh"
#include "cub/util_allocator.cuh"
21
#if defined(PADDLE_WITH_CUDA)
T
Thunderbrook 已提交
22
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
T
Thunderbrook 已提交
23
#include "paddle/fluid/platform/cuda_device_guard.h"
24
#include "paddle/fluid/platform/dynload/nccl.h"
Y
yaoxuefeng 已提交
25
#include "paddle/fluid/platform/timer.h"
T
Thunderbrook 已提交
26
#include "thrust/pair.h"
27 28
#elif defined(PADDLE_WITH_XPU_KP)
#include <xpu/runtime.h>
29

30 31 32 33 34 35 36 37 38
#include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#endif

#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/place.h"
T
Thunderbrook 已提交
39

T
Thunderbrook 已提交
40
#ifdef PADDLE_WITH_HETERPS
T
Thunderbrook 已提交
41 42 43 44

namespace paddle {
namespace framework {

Y
yaoxuefeng 已提交
45 46 47
#define TYPEALIGN(ALIGNVAL, LEN) \
  (((uint64_t)(LEN) + ((ALIGNVAL)-1)) & ~((uint64_t)((ALIGNVAL)-1)))

D
danleifeng 已提交
48 49 50
template <typename KeyType,
          typename ValType,
          typename GradType,
D
danleifeng 已提交
51
          typename GPUAccessor>
T
Thunderbrook 已提交
52 53 54
class HeterComm {
 public:
  HeterComm(size_t capacity, std::shared_ptr<HeterPsResource> resource);
D
danleifeng 已提交
55 56 57
  HeterComm(size_t capacity,
            std::shared_ptr<HeterPsResource> resource,
            GPUAccessor& gpu_accessor);
T
Thunderbrook 已提交
58 59 60 61
  virtual ~HeterComm();
  HeterComm(const HeterComm&) = delete;
  HeterComm& operator=(const HeterComm&) = delete;

D
danleifeng 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
  void merge_keys(int gpu_num,
                  const KeyType* d_keys,
                  size_t len,
                  KeyType* d_sorted_keys,
                  KeyType* d_merged_keys,
                  uint32_t* d_restore_idx,
                  size_t& uniq_len);
  void dynamic_merge_grad(int gpu_num,
                          KeyType* d_keys,
                          float* d_grads,
                          size_t len,
                          int& uniq_len,
                          size_t& segment_len,
                          bool enable_segment_merge_grad);
  void segment_merge_grad(int gpu_num,
                          KeyType* d_keys,
                          float* d_grads,
                          const uint32_t* d_index,
                          size_t len,
                          const uint32_t* d_fea_num_info,
                          size_t uniq_len,
                          size_t& segment_len);
  void build_ps(int num,
                KeyType* h_keys,
                ValType* h_vals,
                size_t len,
                size_t chunk_size,
                int stream_num,
                int offset = -1);
91 92 93 94 95 96 97 98 99 100
  void split_input_to_shard(KeyType* d_keys,
                            int* d_idx_ptr,
                            size_t len,
                            int* left,
                            int* right,
                            int gpu_num);
  void merge_grad(int gpu_num,
                  KeyType* d_keys,
                  GradType* d_grads,
                  size_t len,
101
                  int& uniq_len);  // NOLINT
D
danleifeng 已提交
102 103 104
  void dynamic_merge_grad(
      int gpu_num, KeyType* d_keys, float* d_grads, size_t len, int& uniq_len);
  void pull_sparse(int num, KeyType* d_keys, float* d_vals, size_t len);
105 106 107 108 109 110 111
  void build_ps(int num,
                KeyType* h_keys,
                char* pool,
                size_t len,
                size_t feature_value_size,
                size_t chunk_size,
                int stream_num);
T
Thunderbrook 已提交
112 113
  void dump();
  void show_one_table(int gpu_num);
D
danleifeng 已提交
114
  void show_table_collisions();
T
Thunderbrook 已提交
115 116
  int get_index_by_devid(int devid);

117
#if defined(PADDLE_WITH_CUDA)
T
Thunderbrook 已提交
118
  template <typename Sgd>
119 120
  void push_sparse(int num,
                   KeyType* d_keys,
D
danleifeng 已提交
121
                   float* d_grads,
122
                   size_t len,
123
                   Sgd& sgd);  // NOLINT
124 125 126 127
#elif defined(PADDLE_WITH_XPU_KP)
  void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len);
#endif

128 129 130
  void set_sparse_sgd(const OptimizerConfig& optimizer_config);
  void set_embedx_sgd(const OptimizerConfig& optimizer_config);

131
  int log2i(int x);
T
Thunderbrook 已提交
132

133
  template <typename DstPlace, typename SrcPlace, typename StreamType>
134 135 136 137 138 139
  void memory_copy(DstPlace dst_place,
                   void* dst,
                   SrcPlace src_place,
                   const void* src,
                   size_t count,
                   StreamType stream = 0);
140 141

#if defined(PADDLE_WITH_CUDA)
142
  template <typename Sgd>
143 144 145 146 147
  void push_sparse_multi_node(int num,
                              KeyType* d_keys,
                              GradType* d_grads,
                              size_t len,
                              Sgd& sgd);  // NOLINT
148 149

  template <typename Sgd>
150 151 152 153
  void update_one_table(int num,
                        KeyType* d_keys,
                        GradType* d_grads,
                        size_t len,
154
                        Sgd& sgd);  // NOLINT
155

156 157 158
  int gather_one_node_grad(int num,
                           KeyType* d_keys,
                           GradType* d_grads,
159 160
                           int len);

161 162 163
  int gather_multi_node_grad(int num,
                             KeyType* d_keys,
                             GradType* d_grads,
164 165 166 167 168 169 170 171 172
                             int len);

  void set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
                              const std::vector<ncclComm_t>& inter_comms,
                              int comm_size) {
    nccl_inner_comms_ = inner_comms;
    nccl_inter_comms_ = inter_comms;
    node_size_ = comm_size;
  }
Y
yaoxuefeng 已提交
173 174 175 176 177

  void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) {
    multi_mf_dim_ = multi_mf_dim;
    max_mf_dim_ = max_mf_dim;
  }
D
danleifeng 已提交
178

179
#endif
180

181 182 183 184
  bool need_transfer(int send_id, int receive_id) {
    return ((send_id / 4 != receive_id / 4) && (send_id + 4) % 8 != receive_id);
  }

T
Thunderbrook 已提交
185 186
  // void dump_to_cpu(int index);

187 188
  int get_transfer_devid(int send_id) { return (send_id + 4) % 8; }

189
  void end_pass();
D
danleifeng 已提交
190 191 192 193 194 195 196 197 198 199 200 201 202
#if defined(PADDLE_WITH_CUDA)
  // dedup
  int dedup_keys_and_fillidx(const int gpu_id,
                             const int total_fea_num,
                             const KeyType* d_keys,   // input
                             KeyType* d_merged_keys,  // output
                             KeyType* d_sorted_keys,
                             uint32_t* d_restore_idx,
                             uint32_t* d_sorted_idx,
                             uint32_t* d_offset,
                             uint32_t* d_merged_cnts,
                             bool filter_zero);
#endif
203

204
  struct Node {
205 206
    ppStream in_stream;
    ppStream out_stream;
207 208 209
    char* key_storage;
    char* val_storage;
    int sync;
Y
yaoxuefeng 已提交
210 211
    size_t key_bytes_len;
    size_t val_bytes_len;
212
    int dev_num;
213 214 215 216 217 218
  };

  struct Path {
    std::vector<Node> nodes_;
  };

219 220 221 222 223 224
  struct CopyTask {
    Path* path;
    int step;
    CopyTask(Path* path_, int step_) : path(path_), step(step_) {}
  };

225 226 227 228 229 230 231
  struct LocalStorage {
    LocalStorage() {}
    void init(int size, int dev_id) {
      place_ = platform::CUDAPlace(dev_id);
      alloc(size, true);
    }

232
    void alloc(size_t size, bool force = false) {
233 234 235
      if (force || size > all_keys_mem->size()) {
        all_keys_mem.reset();
        all_grads_mem.reset();
236 237
        all_keys_mem = memory::Alloc(place_, size * sizeof(KeyType));
        all_grads_mem = memory::Alloc(place_, size * sizeof(GradType));
238 239 240 241 242 243
        all_keys = reinterpret_cast<KeyType*>(all_keys_mem->ptr());
        all_grads = reinterpret_cast<GradType*>(all_grads_mem->ptr());
      }
      if (force || size > local_keys_mem->size()) {
        local_keys_mem.reset();
        local_grads_mem.reset();
244 245
        local_keys_mem = memory::Alloc(place_, size * sizeof(KeyType));
        local_grads_mem = memory::Alloc(place_, size * sizeof(GradType));
246 247 248 249 250
        local_keys = reinterpret_cast<KeyType*>(local_keys_mem->ptr());
        local_grads = reinterpret_cast<GradType*>(local_grads_mem->ptr());
      }
    }

251
#if defined(PADDLE_WITH_CUDA)
252
    platform::CUDAPlace place_;
F
Fan Zhang 已提交
253

254 255 256
#elif defined(PADDLE_WITH_XPU_KP)
    platform::XPUPlace place_;
#endif
257 258
    std::shared_ptr<memory::Allocation> all_keys_mem;
    std::shared_ptr<memory::Allocation> all_grads_mem;
F
Fan Zhang 已提交
259

260 261 262 263 264 265 266 267 268
    KeyType* all_keys;
    GradType* all_grads;

    std::shared_ptr<memory::Allocation> local_keys_mem;
    std::shared_ptr<memory::Allocation> local_grads_mem;
    KeyType* local_keys;
    GradType* local_grads;
  };

269
  void init_path();
T
Thunderbrook 已提交
270

271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
  template <typename StreamType>
  void sync_stream(const StreamType& stream) {
#if defined(PADDLE_WITH_CUDA)
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
#elif defined(PADDLE_WITH_XPU_KP)
    PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(stream));
#endif
  }

  template <typename StreamType>
  void create_stream(StreamType* stream) {
#if defined(PADDLE_WITH_CUDA)
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(stream));
#elif defined(PADDLE_WITH_XPU_KP)
    PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(stream));
#endif
  }

  template <typename StreamType>
  void destroy_stream(StreamType stream) {
#if defined(PADDLE_WITH_CUDA)
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream));
#elif defined(PADDLE_WITH_XPU_KP)
    PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_destroy(stream));
#endif
  }

D
danleifeng 已提交
298 299 300 301
  void create_storage(int start_index,
                      int end_index,
                      size_t keylen,
                      size_t vallen);
T
Thunderbrook 已提交
302
  void destroy_storage(int start_index, int end_index);
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
  void walk_to_dest(int start_index,
                    int gpu_num,
                    int* h_left,
                    int* h_right,
                    KeyType* src_key,
                    GradType* src_val);
  void walk_to_dest(int start_index,
                    int gpu_num,
                    int* h_left,
                    int* h_right,
                    KeyType* src_key,
                    char* src_val,
                    size_t val_size);
  void walk_to_src(int start_index,
                   int gpu_num,
                   int* h_left,
                   int* h_right,
320
                   ValType* src_val);
321 322 323 324 325 326
  void walk_to_src(int start_index,
                   int gpu_num,
                   int* h_left,
                   int* h_right,
                   char* src_val,
                   size_t val_size);
T
Thunderbrook 已提交
327

S
seemingwang 已提交
328
 protected:
D
danleifeng 已提交
329 330 331
  void pull_merge_sparse(int num, KeyType* d_keys, float* d_vals, size_t len);
  void pull_normal_sparse(int num, KeyType* d_keys, float* d_vals, size_t len);

T
Thunderbrook 已提交
332
  using Table = HashTable<KeyType, ValType>;
D
danleifeng 已提交
333
  using PtrTable = HashTable<KeyType, float*>;
T
Thunderbrook 已提交
334
  std::vector<Table*> tables_;
Y
yaoxuefeng 已提交
335
  std::vector<PtrTable*> ptr_tables_;
T
Thunderbrook 已提交
336
  std::shared_ptr<HeterPsResource> resource_;
337
  std::vector<std::vector<Path>> path_;
S
seemingwang 已提交
338 339
  float load_factor_{0.75};
  int block_size_{256};
S
seemingwang 已提交
340
  std::unique_ptr<HeterCommKernel> heter_comm_kernel_;
S
seemingwang 已提交
341

D
danleifeng 已提交
342 343
  GPUAccessor gpu_accessor_;

S
seemingwang 已提交
344
 private:
S
seemingwang 已提交
345
  int topo_aware_{0};
346
  std::vector<LocalStorage> storage_;
Y
yaoxuefeng 已提交
347
  DynamicGradMerger merger_;
348
  int feanum_{1800 * 2048};
T
Thunderbrook 已提交
349
  int multi_node_{0};
350 351 352
  int node_size_;

#if defined(PADDLE_WITH_CUDA)
353 354
  std::vector<ncclComm_t> nccl_inner_comms_;
  std::vector<ncclComm_t> nccl_inter_comms_;
Y
yaoxuefeng 已提交
355 356
  int multi_mf_dim_{8};
  int max_mf_dim_ = 8;
T
Thunderbrook 已提交
357
  std::vector<std::shared_ptr<cub::CachingDeviceAllocator>> allocators_;
358
#endif
T
Thunderbrook 已提交
359 360 361 362
};

}  // end namespace framework
}  // end namespace paddle
F
Fan Zhang 已提交
363

T
Thunderbrook 已提交
364
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h"
F
Fan Zhang 已提交
365

T
Thunderbrook 已提交
366
#endif