heter_comm.h 27.0 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16
#include <memory>
T
Thunderbrook 已提交
17
#include <vector>
Y
yaoxuefeng 已提交
18 19
#include "cub/cub.cuh"
#include "cub/util_allocator.cuh"
20
#if defined(PADDLE_WITH_CUDA)
T
Thunderbrook 已提交
21
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
T
Thunderbrook 已提交
22
#include "paddle/fluid/platform/cuda_device_guard.h"
23
#include "paddle/fluid/platform/dynload/nccl.h"
Y
yaoxuefeng 已提交
24
#include "paddle/fluid/platform/timer.h"
T
Thunderbrook 已提交
25
#include "thrust/pair.h"
26 27
#elif defined(PADDLE_WITH_XPU_KP)
#include <xpu/runtime.h>
28

29 30 31
#include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#endif

L
lxsbupt 已提交
32
#include "paddle/fluid/framework/barrier.h"
33 34 35 36 37 38
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/place.h"
T
Thunderbrook 已提交
39

T
Thunderbrook 已提交
40
#ifdef PADDLE_WITH_HETERPS
T
Thunderbrook 已提交
41 42 43 44

namespace paddle {
namespace framework {

Y
yaoxuefeng 已提交
45 46 47
#define TYPEALIGN(ALIGNVAL, LEN) \
  (((uint64_t)(LEN) + ((ALIGNVAL)-1)) & ~((uint64_t)((ALIGNVAL)-1)))

D
danleifeng 已提交
48 49 50
template <typename KeyType,
          typename ValType,
          typename GradType,
D
danleifeng 已提交
51
          typename GPUAccessor>
T
Thunderbrook 已提交
52
class HeterComm {
L
lxsbupt 已提交
53 54 55 56 57
  using HeterCommType = HeterComm<KeyType, ValType, GradType, GPUAccessor>;
  static const int COPY_KEY = 0x01;
  static const int COPY_VAL = 0x02;
  static const int COPY_ALL = COPY_KEY | COPY_VAL;

T
Thunderbrook 已提交
58 59
 public:
  HeterComm(size_t capacity, std::shared_ptr<HeterPsResource> resource);
D
danleifeng 已提交
60 61
  HeterComm(size_t capacity,
            std::shared_ptr<HeterPsResource> resource,
62
            GPUAccessor& gpu_accessor);  // NOLINT
T
Thunderbrook 已提交
63 64 65
  virtual ~HeterComm();
  HeterComm(const HeterComm&) = delete;
  HeterComm& operator=(const HeterComm&) = delete;
L
lxsbupt 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
  // reset table
  void reset_table(const int dev_id,
                   size_t capacity,
                   const OptimizerConfig& sgd_config,
                   const OptimizerConfig& embedx_config,
                   bool infer_mode);
  void set_mode(bool infer_mode);
  template <typename StreamType>
  size_t merge_keys(const int gpu_num,
                    const KeyType* d_keys,
                    const size_t& len,
                    KeyType* d_sorted_keys,
                    KeyType* d_merged_keys,
                    uint32_t* d_restore_idx,
                    StreamType stream);
D
danleifeng 已提交
81 82 83 84
  void dynamic_merge_grad(int gpu_num,
                          KeyType* d_keys,
                          float* d_grads,
                          size_t len,
L
lxsbupt 已提交
85 86
                          int& uniq_len,        // NOLINT
                          size_t& segment_len,  // NOLINT
D
danleifeng 已提交
87 88 89 90 91 92 93 94
                          bool enable_segment_merge_grad);
  void segment_merge_grad(int gpu_num,
                          KeyType* d_keys,
                          float* d_grads,
                          const uint32_t* d_index,
                          size_t len,
                          const uint32_t* d_fea_num_info,
                          size_t uniq_len,
L
lxsbupt 已提交
95
                          size_t& segment_len);  // NOLINT
D
danleifeng 已提交
96 97 98 99 100 101 102
  void build_ps(int num,
                KeyType* h_keys,
                ValType* h_vals,
                size_t len,
                size_t chunk_size,
                int stream_num,
                int offset = -1);
103 104 105 106 107 108 109 110 111 112
  void split_input_to_shard(KeyType* d_keys,
                            int* d_idx_ptr,
                            size_t len,
                            int* left,
                            int* right,
                            int gpu_num);
  void merge_grad(int gpu_num,
                  KeyType* d_keys,
                  GradType* d_grads,
                  size_t len,
113
                  int& uniq_len);  // NOLINT
L
lxsbupt 已提交
114 115 116 117 118
  void dynamic_merge_grad(int gpu_num,
                          KeyType* d_keys,
                          float* d_grads,
                          size_t len,
                          int& uniq_len);  // NOLINT
D
danleifeng 已提交
119
  void pull_sparse(int num, KeyType* d_keys, float* d_vals, size_t len);
120 121 122 123 124 125 126
  void build_ps(int num,
                KeyType* h_keys,
                char* pool,
                size_t len,
                size_t feature_value_size,
                size_t chunk_size,
                int stream_num);
T
Thunderbrook 已提交
127 128
  void dump();
  void show_one_table(int gpu_num);
D
danleifeng 已提交
129
  void show_table_collisions();
T
Thunderbrook 已提交
130 131
  int get_index_by_devid(int devid);

132
#if defined(PADDLE_WITH_CUDA)
T
Thunderbrook 已提交
133
  template <typename Sgd>
134 135
  void push_sparse(int num,
                   KeyType* d_keys,
D
danleifeng 已提交
136
                   float* d_grads,
137
                   size_t len,
138
                   Sgd& sgd);  // NOLINT
139 140 141 142
#elif defined(PADDLE_WITH_XPU_KP)
  void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len);
#endif

143 144 145
  void set_sparse_sgd(const OptimizerConfig& optimizer_config);
  void set_embedx_sgd(const OptimizerConfig& optimizer_config);

146
  int log2i(int x);
T
Thunderbrook 已提交
147

148
  template <typename DstPlace, typename SrcPlace, typename StreamType>
149 150 151 152 153 154
  void memory_copy(DstPlace dst_place,
                   void* dst,
                   SrcPlace src_place,
                   const void* src,
                   size_t count,
                   StreamType stream = 0);
155 156

#if defined(PADDLE_WITH_CUDA)
157
  template <typename Sgd>
158 159 160 161 162
  void push_sparse_multi_node(int num,
                              KeyType* d_keys,
                              GradType* d_grads,
                              size_t len,
                              Sgd& sgd);  // NOLINT
163 164

  template <typename Sgd>
165 166 167 168
  void update_one_table(int num,
                        KeyType* d_keys,
                        GradType* d_grads,
                        size_t len,
169
                        Sgd& sgd);  // NOLINT
170

171 172 173
  int gather_one_node_grad(int num,
                           KeyType* d_keys,
                           GradType* d_grads,
174 175
                           int len);

176 177 178
  int gather_multi_node_grad(int num,
                             KeyType* d_keys,
                             GradType* d_grads,
179 180 181 182
                             int len);

  void set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
                              const std::vector<ncclComm_t>& inter_comms,
L
lxsbupt 已提交
183 184
                              int comm_size,
                              int rank_id) {
185 186 187
    nccl_inner_comms_ = inner_comms;
    nccl_inter_comms_ = inter_comms;
    node_size_ = comm_size;
L
lxsbupt 已提交
188
    rank_id_ = rank_id;
189
  }
Y
yaoxuefeng 已提交
190 191 192 193 194

  void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) {
    multi_mf_dim_ = multi_mf_dim;
    max_mf_dim_ = max_mf_dim;
  }
D
danleifeng 已提交
195

196
#endif
197

198
  bool need_transfer(int send_id, int receive_id) {
L
lxsbupt 已提交
199 200
    return ((send_id / 4 != receive_id / 4) &&
            (send_id + 4) % device_num_ != receive_id);
201 202
  }

T
Thunderbrook 已提交
203 204
  // void dump_to_cpu(int index);

L
lxsbupt 已提交
205
  int get_transfer_devid(int send_id) { return (send_id + 4) % device_num_; }
206

207
  void end_pass();
D
danleifeng 已提交
208 209 210 211 212 213 214 215 216 217 218
#if defined(PADDLE_WITH_CUDA)
  // dedup
  int dedup_keys_and_fillidx(const int gpu_id,
                             const int total_fea_num,
                             const KeyType* d_keys,   // input
                             KeyType* d_merged_keys,  // output
                             KeyType* d_sorted_keys,
                             uint32_t* d_restore_idx,
                             uint32_t* d_sorted_idx,
                             uint32_t* d_offset,
                             uint32_t* d_merged_cnts,
L
lxsbupt 已提交
219 220
                             bool filter_zero,
                             cudaStream_t stream = 0);
D
danleifeng 已提交
221
#endif
L
lxsbupt 已提交
222 223 224 225 226 227 228 229
  template <typename T, typename StreamType>
  void split_idx_to_shard(KeyType* d_keys,
                          T* d_idx_ptr,
                          size_t len,
                          T* left,
                          T* right,
                          int gpu_num,
                          StreamType stream);
230

231
  struct Node {
232 233
    ppStream in_stream;
    ppStream out_stream;
234 235 236
    char* key_storage;
    char* val_storage;
    int sync;
Y
yaoxuefeng 已提交
237 238
    size_t key_bytes_len;
    size_t val_bytes_len;
239
    int dev_num;
240 241 242 243 244 245
  };

  struct Path {
    std::vector<Node> nodes_;
  };

246 247 248 249 250
  struct CopyTask {
    Path* path;
    int step;
    CopyTask(Path* path_, int step_) : path(path_), step(step_) {}
  };
L
lxsbupt 已提交
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
  // inner card
  struct InnerResource {
    uint32_t* d_idx = nullptr;
    size_t* h_part_sizes = nullptr;
    std::vector<size_t> h_offsets;
    uint32_t* d_offset_ptr = nullptr;

    KeyType* d_keys_parted = nullptr;
    char* d_vals_parted = nullptr;
    std::vector<KeyType*> d_remote_keys;
    std::vector<char*> d_remote_vals;
    KeyType* d_trans_keys = nullptr;
    char* d_trans_vals = nullptr;

    // resize vector
    void resize(const int num_gpu) {
      h_offsets.resize(num_gpu);
      d_remote_keys.resize(num_gpu);
      d_remote_vals.resize(num_gpu);
    }
  };
  // Resource for partition shard Key by nodes
  struct ShardResource {
    uint32_t* d_local_idx_parted = nullptr;  // uint32_t for multisplit
    std::vector<size_t> h_local_part_sizes;
    std::vector<size_t> h_local_part_offsets;
    std::vector<size_t> h_remote_part_sizes;
    std::vector<size_t> h_remote_part_offsets;
    uint32_t* d_node_size_ptr = nullptr;
    std::vector<uint32_t> h_push_fea_sizes;
    // shard part
    void resize_part_size(const int node_size) {
      if (h_local_part_sizes.size() >= static_cast<size_t>(node_size)) {
        return;
      }
      h_local_part_sizes.resize(node_size);
      h_local_part_offsets.resize(node_size + 1);
      h_remote_part_sizes.resize(node_size);
      h_remote_part_offsets.resize(node_size + 1);
      h_push_fea_sizes.resize(node_size * node_size);
    }
  };
  // pull parition shard key by devices
  struct PullResource {
    size_t h_recv_fea_num = 0;
    uint32_t* d_restore_keys_idx = nullptr;
  };
298

299
  struct LocalStorage {
L
lxsbupt 已提交
300
    LocalStorage() { sem_wait = std::make_unique<Semaphore>(); }
301
    void init(int device_num, int dev_id, phi::Stream stream) {
302
      place_ = platform::CUDAPlace(dev_id);
L
lxsbupt 已提交
303 304
      h_recv_offsets.resize(device_num);
      h_fea_sizes.resize(device_num);
305
      stream_ = stream;
306
    }
L
lxsbupt 已提交
307 308 309 310 311 312
    template <typename T>
    T* alloc_cache(const size_t& len,
                   std::shared_ptr<memory::Allocation>& alloc,  // NOLINT
                   bool need_copy = false) {
      size_t need_mem = len * sizeof(T);
      if (alloc.get() == nullptr) {
313
        alloc = memory::Alloc(place_, need_mem, stream_);
L
lxsbupt 已提交
314 315 316
      } else if (need_mem > alloc->size()) {
        if (need_copy) {
          std::shared_ptr<memory::Allocation> tmp =
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
              memory::Alloc(place_, need_mem, stream_);
#if defined(PADDLE_WITH_CUDA)
          PADDLE_ENFORCE_GPU_SUCCESS(
              cudaMemcpyAsync(tmp->ptr(),  // output
                              alloc->ptr(),
                              alloc->size(),
                              cudaMemcpyDeviceToDevice,
                              reinterpret_cast<cudaStream_t>(stream_.id())));
#else
          memory::Copy(place_,
                       tmp->ptr(),
                       place_,
                       alloc->ptr(),
                       alloc->size(),
                       reinterpret_cast<void*>(stream_.id()));
#endif
L
lxsbupt 已提交
333 334 335 336
          alloc.reset();
          alloc = tmp;
        } else {
          alloc.reset();
337
          alloc = memory::Alloc(place_, need_mem, stream_);
L
lxsbupt 已提交
338
        }
339
      }
L
lxsbupt 已提交
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357
      return reinterpret_cast<T*>(alloc->ptr());
    }
    void alloc(const size_t& len,
               const size_t& value_bytes = sizeof(GradType),
               const int copy_mode = 0) {
      all_keys =
          alloc_cache<KeyType>(len, all_keys_mem, (copy_mode & COPY_KEY));
      all_grads = alloc_cache<char>(
          len * value_bytes, all_grads_mem, (copy_mode & COPY_VAL));
      local_keys =
          alloc_cache<KeyType>(len, local_keys_mem, (copy_mode & COPY_KEY));
      local_grads = alloc_cache<char>(
          len * value_bytes, local_grads_mem, (copy_mode & COPY_VAL));
      d_merged_keys = all_keys;
      d_merged_push_keys = local_keys;
      d_merged_vals = all_grads;
      d_merged_push_vals = local_grads;
    }
358 359 360 361 362
    void check(const size_t& len,
               const size_t& value_bytes = sizeof(GradType)) {
      CHECK_GE(all_keys_mem->size(), len);
      CHECK_GE(all_grads_mem->size(), len * value_bytes);
    }
L
lxsbupt 已提交
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385
    void init_pull(const size_t& len) {
      pull_res.h_recv_fea_num = len;
      pull_res.d_restore_keys_idx = alloc_cache<uint32_t>(len, local_pull_idx);
    }
    void init_shard(const size_t& len, const size_t& node_size) {
      shard_res.d_local_idx_parted =
          alloc_cache<uint32_t>(len, local_shard_idx);
      shard_res.d_node_size_ptr =
          alloc_cache<uint32_t>(node_size * node_size, d_node_size_buf);
      shard_res.resize_part_size(node_size);
    }
    void init_inner(const size_t& len, const int& device_num) {
      inner_res.d_idx = alloc_cache<uint32_t>(len, local_inner_idx);
      inner_res.d_offset_ptr =
          alloc_cache<uint32_t>(device_num * 2, inner_offset);
      inner_res.resize(device_num);
    }
    void init_trans(const size_t& fea_num, const size_t& value_bytes) {
      d_merged_trans_keys = alloc_cache<KeyType>(fea_num * 2, trans_keys_buff);
      d_merged_push_trans_keys = &d_merged_trans_keys[fea_num];
      d_merged_trans_vals =
          alloc_cache<char>(fea_num * 2 * value_bytes, trans_vals_buff);
      d_merged_push_trans_vals = &d_merged_trans_vals[fea_num * value_bytes];
386 387
    }

388
#if defined(PADDLE_WITH_CUDA)
389
    platform::CUDAPlace place_;
F
Fan Zhang 已提交
390

391 392 393
#elif defined(PADDLE_WITH_XPU_KP)
    platform::XPUPlace place_;
#endif
394
    phi::Stream stream_;
L
lxsbupt 已提交
395 396
    std::shared_ptr<memory::Allocation> all_keys_mem = nullptr;
    std::shared_ptr<memory::Allocation> all_grads_mem = nullptr;
F
Fan Zhang 已提交
397

398
    KeyType* all_keys;
L
lxsbupt 已提交
399
    char* all_grads;
400

L
lxsbupt 已提交
401 402
    std::shared_ptr<memory::Allocation> local_keys_mem = nullptr;
    std::shared_ptr<memory::Allocation> local_grads_mem = nullptr;
403
    KeyType* local_keys;
L
lxsbupt 已提交
404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440
    char* local_grads;

    // all2all
    std::shared_ptr<memory::Allocation> local_inner_idx = nullptr;
    std::shared_ptr<memory::Allocation> local_pull_idx = nullptr;
    std::shared_ptr<memory::Allocation> local_shard_idx = nullptr;
    std::shared_ptr<memory::Allocation> inner_offset = nullptr;
    std::shared_ptr<memory::Allocation> d_node_size_buf = nullptr;

    InnerResource inner_res;
    ShardResource shard_res;
    PullResource pull_res;

    KeyType* d_merged_keys = nullptr;
    char* d_merged_vals = nullptr;
    KeyType* d_merged_push_keys = nullptr;
    char* d_merged_push_vals = nullptr;
    std::vector<size_t> h_recv_offsets;
    std::vector<size_t> h_fea_sizes;
    // inner trans comm and stream buffer
    size_t h_trans_size;
    size_t h_trans_offset;

    // node trans comm and stream buffer
    std::unique_ptr<Semaphore> sem_wait;
    std::shared_ptr<memory::Allocation> trans_keys_buff = nullptr;
    std::shared_ptr<memory::Allocation> trans_vals_buff = nullptr;
    KeyType* d_merged_trans_keys = nullptr;
    char* d_merged_trans_vals = nullptr;
    KeyType* d_merged_push_trans_keys = nullptr;
    char* d_merged_push_trans_vals = nullptr;

    platform::Timer all2all_span_;
    platform::Timer inner_span_;
    platform::Timer inner_barrier_;
    platform::Timer node_span_;
    platform::Timer node_barrier_;
441 442
  };

443
  void init_path();
T
Thunderbrook 已提交
444

445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471
  template <typename StreamType>
  void sync_stream(const StreamType& stream) {
#if defined(PADDLE_WITH_CUDA)
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
#elif defined(PADDLE_WITH_XPU_KP)
    PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(stream));
#endif
  }

  template <typename StreamType>
  void create_stream(StreamType* stream) {
#if defined(PADDLE_WITH_CUDA)
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(stream));
#elif defined(PADDLE_WITH_XPU_KP)
    PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(stream));
#endif
  }

  template <typename StreamType>
  void destroy_stream(StreamType stream) {
#if defined(PADDLE_WITH_CUDA)
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream));
#elif defined(PADDLE_WITH_XPU_KP)
    PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_destroy(stream));
#endif
  }

D
danleifeng 已提交
472 473 474 475
  void create_storage(int start_index,
                      int end_index,
                      size_t keylen,
                      size_t vallen);
L
lxsbupt 已提交
476 477 478 479 480
  void create_tmp_storage(void*& dest,  // NOLINT
                          int start_index,
                          int end_index,
                          size_t vallen);
  void destroy_tmp_storage(void*& p, int start_index, int end_index);  // NOLINT
T
Thunderbrook 已提交
481
  void destroy_storage(int start_index, int end_index);
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498
  void walk_to_dest(int start_index,
                    int gpu_num,
                    int* h_left,
                    int* h_right,
                    KeyType* src_key,
                    GradType* src_val);
  void walk_to_dest(int start_index,
                    int gpu_num,
                    int* h_left,
                    int* h_right,
                    KeyType* src_key,
                    char* src_val,
                    size_t val_size);
  void walk_to_src(int start_index,
                   int gpu_num,
                   int* h_left,
                   int* h_right,
499
                   ValType* src_val);
500 501 502 503 504 505
  void walk_to_src(int start_index,
                   int gpu_num,
                   int* h_left,
                   int* h_right,
                   char* src_val,
                   size_t val_size);
T
Thunderbrook 已提交
506

S
seemingwang 已提交
507
 protected:
L
lxsbupt 已提交
508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659
  void pull_merge_sparse(const int gpu_id,
                         KeyType* d_keys,
                         float* d_vals,
                         size_t len);
  void pull_normal_sparse(const int gpu_id,
                          KeyType* d_keys,
                          float* d_vals,
                          size_t len);
  void pull_one_table(const int gpu_id,
                      KeyType* d_keys,
                      float* d_vals,
                      const size_t& len,
                      const cudaStream_t& stream);

  // node all2all pull
  void pull_sparse_all2all(const int& gpu_id,
                           KeyType* d_keys,
                           float* d_vals,
                           const size_t& len);

  template <typename Sgd>
  void push_normal_sparse(int num,
                          KeyType* d_keys,
                          float* d_grads,
                          size_t len,
                          Sgd& sgd);  // NOLINT

  void shard_inner_keys(const size_t& total_fea_num,
                        const KeyType* d_keys,
                        const int& gpu_id,
                        const int& gpu_num,
                        InnerResource* res,
                        const cudaStream_t& stream);
  void gather_inner_keys_p2p(const size_t& total_fea_num,
                             const KeyType* d_keys,
                             InnerResource& res,  // NOLINT
                             const int& gpu_id,
                             const int& gpu_num,
                             const int& trans_id,
                             const cudaStream_t& stream);
  size_t gather_inter_keys_by_copy(const int& gpu_id,
                                   const size_t& fea_size,
                                   const KeyType* d_keys,
                                   const cudaStream_t& stream);
  void partition_shard_keys(const int& gpu_id,
                            const size_t& total_fea_num,
                            const KeyType* d_keys,
                            uint32_t* d_idx_parted,
                            KeyType* d_keys_parted,
                            size_t* h_part_sizes,
                            const int& shard_num,
                            const cudaStream_t& stream);
  size_t send_data_by_all2all(const int& gpu_id,
                              const int& nccl_node_size,
                              const int& nccl_rank_id,
                              const int& value_bytes,
                              const size_t* h_send_part_sizes,
                              const size_t* h_send_part_offsets,
                              const size_t* h_recv_part_sizes,
                              const size_t* h_recv_part_offsets,
                              const char* d_send_buff,
                              char* d_rev_buff,
                              const cudaStream_t& stream);
  size_t gather_sparse_keys_by_all2all(const int& gpu_id,
                                       const size_t& fea_size,
                                       const KeyType* d_in_keys,
                                       const cudaStream_t& stream);
  void scatter_sparse_vals_by_all2all(const int& gpu_id,
                                      const size_t& fea_size,
                                      const char* d_in_vals,
                                      void* d_out_vals,
                                      const size_t& value_bytes,
                                      void* d_tmp_vals,
                                      const cudaStream_t& stream);
  void scatter_inner_vals_p2p(const size_t& total_fea_num,
                              void* d_out_vals,
                              InnerResource& res,  // NOLINT
                              const int& gpu_id,
                              const int& gpu_num,
                              const int& trans_id,
                              const size_t& value_bytes,
                              const cudaStream_t& stream);
  void scatter_inter_vals_by_copy(const int& gpu_id,
                                  const size_t& fea_size,
                                  const char* d_in_vals,
                                  void* d_out_vals,
                                  const size_t& value_bytes,
                                  const cudaStream_t& stream);
  void gather_inner_data_p2p(const size_t& total_fea_num,
                             const KeyType* d_keys,
                             const void* d_vals,
                             InnerResource& res,  // NOLINT
                             const int& gpu_id,
                             const int& gpu_num,
                             const int& trans_id,
                             const size_t& value_bytes,
                             const cudaStream_t& stream);
  template <typename Sgd>
  void push_sparse_all2all(const int& gpu_id,
                           KeyType* d_keys,
                           float* d_grads,
                           const size_t& len,
                           Sgd& sgd);  // NOLINT
  size_t merge_grad(const int& gpu_id,
                    const size_t& len,
                    const KeyType* d_in_keys,
                    KeyType* d_out_keys,
                    const void* d_in_grads,
                    void* d_out_grads,
                    const cudaStream_t& stream);
  size_t gather_inter_gradient_by_copy(const int& gpu_id,
                                       const size_t& push_size,
                                       KeyType* d_keys,
                                       void* d_push_vals,
                                       const size_t& value_bytes,
                                       const cudaStream_t& stream);
  size_t gather_sparse_gradient_by_all2all(const int& gpu_id,
                                           const size_t& push_size,
                                           const KeyType* d_keys,
                                           const char* d_push_vals,
                                           const size_t& value_bytes,
                                           KeyType* d_out_keys,
                                           KeyType* d_tmp_keys,
                                           char* d_out_vals,
                                           char* d_tmp_vals,
                                           const cudaStream_t& stream);
  size_t send_keys_by_all2all_trans(const int& gpu_id,
                                    const int& rank_id,
                                    const int& node_size,
                                    const size_t& fea_size,
                                    const KeyType* d_in_keys,
                                    KeyType* d_out_keys,
                                    const cudaStream_t& stream);
  size_t send_vals_by_all2all_trans(const int& gpu_id,
                                    const int& rank_id,
                                    const int& node_size,
                                    const char* d_in_vals,
                                    char* d_out_vals,
                                    const size_t& value_bytes,
                                    const cudaStream_t& stream);
  size_t send_gradient_by_all2all_trans(const int& gpu_id,
                                        const int& rank_id,
                                        const int& node_size,
                                        const size_t& fea_size,
                                        const KeyType* d_keys,
                                        const char* d_push_vals,
                                        const size_t& value_bytes,
                                        KeyType* d_out_keys,
                                        char* d_out_vals,
                                        const cudaStream_t& stream);
  // debug time
  void print_debug_time(const int& gpu_id, bool force = false);
660 661 662 663 664 665 666 667 668 669 670 671 672
  // alloc temp memory
  template <typename T, typename TPlace, typename StreamType>
  T* AllocCache(std::shared_ptr<memory::Allocation>* alloc,
                const TPlace& place,
                const size_t& byte_len,
                const StreamType& stream) {
    if (alloc->get() == nullptr || byte_len > (*alloc)->size()) {
      alloc->reset();
      auto id = phi::Stream(reinterpret_cast<phi::StreamId>(stream));
      *alloc = memory::Alloc(place, byte_len, id);
    }
    return reinterpret_cast<T*>((*alloc)->ptr());
  }
D
danleifeng 已提交
673

T
Thunderbrook 已提交
674
  using Table = HashTable<KeyType, ValType>;
D
danleifeng 已提交
675
  using PtrTable = HashTable<KeyType, float*>;
T
Thunderbrook 已提交
676
  std::vector<Table*> tables_;
Y
yaoxuefeng 已提交
677
  std::vector<PtrTable*> ptr_tables_;
T
Thunderbrook 已提交
678
  std::shared_ptr<HeterPsResource> resource_;
679
  std::vector<std::vector<Path>> path_;
S
seemingwang 已提交
680 681
  float load_factor_{0.75};
  int block_size_{256};
S
seemingwang 已提交
682
  std::unique_ptr<HeterCommKernel> heter_comm_kernel_;
S
seemingwang 已提交
683

D
danleifeng 已提交
684 685
  GPUAccessor gpu_accessor_;

L
lxsbupt 已提交
686
 protected:
S
seemingwang 已提交
687
  int topo_aware_{0};
688
  std::vector<LocalStorage> storage_;
Y
yaoxuefeng 已提交
689
  DynamicGradMerger merger_;
L
lxsbupt 已提交
690
  int device_num_ = 8;
T
Thunderbrook 已提交
691
  int multi_node_{0};
L
lxsbupt 已提交
692 693 694 695 696 697 698 699 700 701 702 703
  int rank_id_ = 0;
  int node_size_ = 1;
  // inner sync barrier
  Barrier barrier_;
  size_t val_type_size_;
  size_t pull_type_size_;
  size_t grad_type_size_;
  size_t max_type_size_;
  bool enable_gpu_direct_access_ = false;
  // set compress bound
  float max_value_bound_ = 10.0;
  float max_grad_bound_ = 10.0;
704 705

#if defined(PADDLE_WITH_CUDA)
L
lxsbupt 已提交
706
  GpuRDMAChecker* rdma_checker_ = nullptr;
707 708
  std::vector<ncclComm_t> nccl_inner_comms_;
  std::vector<ncclComm_t> nccl_inter_comms_;
Y
yaoxuefeng 已提交
709 710
  int multi_mf_dim_{8};
  int max_mf_dim_ = 8;
T
Thunderbrook 已提交
711
  std::vector<std::shared_ptr<cub::CachingDeviceAllocator>> allocators_;
712
#endif
L
lxsbupt 已提交
713
  int64_t start_time_ = 0;
T
Thunderbrook 已提交
714 715 716 717
};

}  // end namespace framework
}  // end namespace paddle
F
Fan Zhang 已提交
718

T
Thunderbrook 已提交
719
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h"
F
Fan Zhang 已提交
720

T
Thunderbrook 已提交
721
#endif