heter_comm.h 25.9 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
T
Thunderbrook 已提交
16
#include <thread>
T
Thunderbrook 已提交
17
#include <vector>
18

Y
yaoxuefeng 已提交
19 20
#include "cub/cub.cuh"
#include "cub/util_allocator.cuh"
21
#if defined(PADDLE_WITH_CUDA)
T
Thunderbrook 已提交
22
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
T
Thunderbrook 已提交
23
#include "paddle/fluid/platform/cuda_device_guard.h"
24
#include "paddle/fluid/platform/dynload/nccl.h"
Y
yaoxuefeng 已提交
25
#include "paddle/fluid/platform/timer.h"
T
Thunderbrook 已提交
26
#include "thrust/pair.h"
27 28
#elif defined(PADDLE_WITH_XPU_KP)
#include <xpu/runtime.h>
29

30 31 32
#include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#endif

L
lxsbupt 已提交
33
#include "paddle/fluid/framework/barrier.h"
34 35 36 37 38 39
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/place.h"
T
Thunderbrook 已提交
40

T
Thunderbrook 已提交
41
#ifdef PADDLE_WITH_HETERPS
T
Thunderbrook 已提交
42 43 44 45

namespace paddle {
namespace framework {

Y
yaoxuefeng 已提交
46 47 48
#define TYPEALIGN(ALIGNVAL, LEN) \
  (((uint64_t)(LEN) + ((ALIGNVAL)-1)) & ~((uint64_t)((ALIGNVAL)-1)))

D
danleifeng 已提交
49 50 51
template <typename KeyType,
          typename ValType,
          typename GradType,
D
danleifeng 已提交
52
          typename GPUAccessor>
T
Thunderbrook 已提交
53
class HeterComm {
L
lxsbupt 已提交
54 55 56 57 58
  using HeterCommType = HeterComm<KeyType, ValType, GradType, GPUAccessor>;
  static const int COPY_KEY = 0x01;
  static const int COPY_VAL = 0x02;
  static const int COPY_ALL = COPY_KEY | COPY_VAL;

T
Thunderbrook 已提交
59 60
 public:
  HeterComm(size_t capacity, std::shared_ptr<HeterPsResource> resource);
D
danleifeng 已提交
61 62
  HeterComm(size_t capacity,
            std::shared_ptr<HeterPsResource> resource,
L
lxsbupt 已提交
63
            const GPUAccessor& gpu_accessor);
T
Thunderbrook 已提交
64 65 66
  virtual ~HeterComm();
  HeterComm(const HeterComm&) = delete;
  HeterComm& operator=(const HeterComm&) = delete;
L
lxsbupt 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
  // reset table
  void reset_table(const int dev_id,
                   size_t capacity,
                   const OptimizerConfig& sgd_config,
                   const OptimizerConfig& embedx_config,
                   bool infer_mode);
  void set_mode(bool infer_mode);
  template <typename StreamType>
  size_t merge_keys(const int gpu_num,
                    const KeyType* d_keys,
                    const size_t& len,
                    KeyType* d_sorted_keys,
                    KeyType* d_merged_keys,
                    uint32_t* d_restore_idx,
                    StreamType stream);
D
danleifeng 已提交
82 83 84 85
  void dynamic_merge_grad(int gpu_num,
                          KeyType* d_keys,
                          float* d_grads,
                          size_t len,
L
lxsbupt 已提交
86 87
                          int& uniq_len,        // NOLINT
                          size_t& segment_len,  // NOLINT
D
danleifeng 已提交
88 89 90 91 92 93 94 95
                          bool enable_segment_merge_grad);
  void segment_merge_grad(int gpu_num,
                          KeyType* d_keys,
                          float* d_grads,
                          const uint32_t* d_index,
                          size_t len,
                          const uint32_t* d_fea_num_info,
                          size_t uniq_len,
L
lxsbupt 已提交
96
                          size_t& segment_len);  // NOLINT
D
danleifeng 已提交
97 98 99 100 101 102 103
  void build_ps(int num,
                KeyType* h_keys,
                ValType* h_vals,
                size_t len,
                size_t chunk_size,
                int stream_num,
                int offset = -1);
104 105 106 107 108 109 110 111 112 113
  void split_input_to_shard(KeyType* d_keys,
                            int* d_idx_ptr,
                            size_t len,
                            int* left,
                            int* right,
                            int gpu_num);
  void merge_grad(int gpu_num,
                  KeyType* d_keys,
                  GradType* d_grads,
                  size_t len,
114
                  int& uniq_len);  // NOLINT
L
lxsbupt 已提交
115 116 117 118 119
  void dynamic_merge_grad(int gpu_num,
                          KeyType* d_keys,
                          float* d_grads,
                          size_t len,
                          int& uniq_len);  // NOLINT
D
danleifeng 已提交
120
  void pull_sparse(int num, KeyType* d_keys, float* d_vals, size_t len);
121 122 123 124 125 126 127
  void build_ps(int num,
                KeyType* h_keys,
                char* pool,
                size_t len,
                size_t feature_value_size,
                size_t chunk_size,
                int stream_num);
T
Thunderbrook 已提交
128 129
  void dump();
  void show_one_table(int gpu_num);
D
danleifeng 已提交
130
  void show_table_collisions();
T
Thunderbrook 已提交
131 132
  int get_index_by_devid(int devid);

133
#if defined(PADDLE_WITH_CUDA)
T
Thunderbrook 已提交
134
  template <typename Sgd>
135 136
  void push_sparse(int num,
                   KeyType* d_keys,
D
danleifeng 已提交
137
                   float* d_grads,
138
                   size_t len,
139
                   Sgd& sgd);  // NOLINT
140 141 142 143
#elif defined(PADDLE_WITH_XPU_KP)
  void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len);
#endif

144 145 146
  void set_sparse_sgd(const OptimizerConfig& optimizer_config);
  void set_embedx_sgd(const OptimizerConfig& optimizer_config);

147
  int log2i(int x);
T
Thunderbrook 已提交
148

149
  template <typename DstPlace, typename SrcPlace, typename StreamType>
150 151 152 153 154 155
  void memory_copy(DstPlace dst_place,
                   void* dst,
                   SrcPlace src_place,
                   const void* src,
                   size_t count,
                   StreamType stream = 0);
156 157

#if defined(PADDLE_WITH_CUDA)
158
  template <typename Sgd>
159 160 161 162 163
  void push_sparse_multi_node(int num,
                              KeyType* d_keys,
                              GradType* d_grads,
                              size_t len,
                              Sgd& sgd);  // NOLINT
164 165

  template <typename Sgd>
166 167 168 169
  void update_one_table(int num,
                        KeyType* d_keys,
                        GradType* d_grads,
                        size_t len,
170
                        Sgd& sgd);  // NOLINT
171

172 173 174
  int gather_one_node_grad(int num,
                           KeyType* d_keys,
                           GradType* d_grads,
175 176
                           int len);

177 178 179
  int gather_multi_node_grad(int num,
                             KeyType* d_keys,
                             GradType* d_grads,
180 181 182 183
                             int len);

  void set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
                              const std::vector<ncclComm_t>& inter_comms,
L
lxsbupt 已提交
184 185
                              int comm_size,
                              int rank_id) {
186 187 188
    nccl_inner_comms_ = inner_comms;
    nccl_inter_comms_ = inter_comms;
    node_size_ = comm_size;
L
lxsbupt 已提交
189
    rank_id_ = rank_id;
190
  }
Y
yaoxuefeng 已提交
191 192 193 194 195

  void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) {
    multi_mf_dim_ = multi_mf_dim;
    max_mf_dim_ = max_mf_dim;
  }
D
danleifeng 已提交
196

197
#endif
198

199
  bool need_transfer(int send_id, int receive_id) {
L
lxsbupt 已提交
200 201
    return ((send_id / 4 != receive_id / 4) &&
            (send_id + 4) % device_num_ != receive_id);
202 203
  }

T
Thunderbrook 已提交
204 205
  // void dump_to_cpu(int index);

L
lxsbupt 已提交
206
  int get_transfer_devid(int send_id) { return (send_id + 4) % device_num_; }
207

208
  void end_pass();
D
danleifeng 已提交
209 210 211 212 213 214 215 216 217 218 219
#if defined(PADDLE_WITH_CUDA)
  // dedup
  int dedup_keys_and_fillidx(const int gpu_id,
                             const int total_fea_num,
                             const KeyType* d_keys,   // input
                             KeyType* d_merged_keys,  // output
                             KeyType* d_sorted_keys,
                             uint32_t* d_restore_idx,
                             uint32_t* d_sorted_idx,
                             uint32_t* d_offset,
                             uint32_t* d_merged_cnts,
L
lxsbupt 已提交
220 221
                             bool filter_zero,
                             cudaStream_t stream = 0);
D
danleifeng 已提交
222
#endif
L
lxsbupt 已提交
223 224 225 226 227 228 229 230
  template <typename T, typename StreamType>
  void split_idx_to_shard(KeyType* d_keys,
                          T* d_idx_ptr,
                          size_t len,
                          T* left,
                          T* right,
                          int gpu_num,
                          StreamType stream);
231

232
  struct Node {
233 234
    ppStream in_stream;
    ppStream out_stream;
235 236 237
    char* key_storage;
    char* val_storage;
    int sync;
Y
yaoxuefeng 已提交
238 239
    size_t key_bytes_len;
    size_t val_bytes_len;
240
    int dev_num;
241 242 243 244 245 246
  };

  struct Path {
    std::vector<Node> nodes_;
  };

247 248 249 250 251
  struct CopyTask {
    Path* path;
    int step;
    CopyTask(Path* path_, int step_) : path(path_), step(step_) {}
  };
L
lxsbupt 已提交
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
  // inner card
  struct InnerResource {
    uint32_t* d_idx = nullptr;
    size_t* h_part_sizes = nullptr;
    std::vector<size_t> h_offsets;
    uint32_t* d_offset_ptr = nullptr;

    KeyType* d_keys_parted = nullptr;
    char* d_vals_parted = nullptr;
    std::vector<KeyType*> d_remote_keys;
    std::vector<char*> d_remote_vals;
    KeyType* d_trans_keys = nullptr;
    char* d_trans_vals = nullptr;

    // resize vector
    void resize(const int num_gpu) {
      h_offsets.resize(num_gpu);
      d_remote_keys.resize(num_gpu);
      d_remote_vals.resize(num_gpu);
    }
  };
  // Resource for partition shard Key by nodes
  struct ShardResource {
    uint32_t* d_local_idx_parted = nullptr;  // uint32_t for multisplit
    std::vector<size_t> h_local_part_sizes;
    std::vector<size_t> h_local_part_offsets;
    std::vector<size_t> h_remote_part_sizes;
    std::vector<size_t> h_remote_part_offsets;
    uint32_t* d_node_size_ptr = nullptr;
    std::vector<uint32_t> h_push_fea_sizes;
    // shard part
    void resize_part_size(const int node_size) {
      if (h_local_part_sizes.size() >= static_cast<size_t>(node_size)) {
        return;
      }
      h_local_part_sizes.resize(node_size);
      h_local_part_offsets.resize(node_size + 1);
      h_remote_part_sizes.resize(node_size);
      h_remote_part_offsets.resize(node_size + 1);
      h_push_fea_sizes.resize(node_size * node_size);
    }
  };
  // pull parition shard key by devices
  struct PullResource {
    size_t h_recv_fea_num = 0;
    uint32_t* d_restore_keys_idx = nullptr;
  };
299

300
  struct LocalStorage {
L
lxsbupt 已提交
301 302
    LocalStorage() { sem_wait = std::make_unique<Semaphore>(); }
    void init(int device_num, int dev_id) {
303
      place_ = platform::CUDAPlace(dev_id);
L
lxsbupt 已提交
304 305
      h_recv_offsets.resize(device_num);
      h_fea_sizes.resize(device_num);
306
    }
L
lxsbupt 已提交
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
    template <typename T>
    T* alloc_cache(const size_t& len,
                   std::shared_ptr<memory::Allocation>& alloc,  // NOLINT
                   bool need_copy = false) {
      size_t need_mem = len * sizeof(T);
      if (alloc.get() == nullptr) {
        alloc = memory::Alloc(place_, need_mem);
      } else if (need_mem > alloc->size()) {
        if (need_copy) {
          std::shared_ptr<memory::Allocation> tmp =
              memory::Alloc(place_, need_mem);
          cudaMemcpy(tmp->ptr(),
                     alloc->ptr(),
                     alloc->size(),
                     cudaMemcpyDeviceToDevice);
          alloc.reset();
          alloc = tmp;
        } else {
          alloc.reset();
          alloc = memory::Alloc(place_, need_mem);
        }
328
      }
L
lxsbupt 已提交
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369
      return reinterpret_cast<T*>(alloc->ptr());
    }
    void alloc(const size_t& len,
               const size_t& value_bytes = sizeof(GradType),
               const int copy_mode = 0) {
      all_keys =
          alloc_cache<KeyType>(len, all_keys_mem, (copy_mode & COPY_KEY));
      all_grads = alloc_cache<char>(
          len * value_bytes, all_grads_mem, (copy_mode & COPY_VAL));
      local_keys =
          alloc_cache<KeyType>(len, local_keys_mem, (copy_mode & COPY_KEY));
      local_grads = alloc_cache<char>(
          len * value_bytes, local_grads_mem, (copy_mode & COPY_VAL));
      d_merged_keys = all_keys;
      d_merged_push_keys = local_keys;
      d_merged_vals = all_grads;
      d_merged_push_vals = local_grads;
    }
    void init_pull(const size_t& len) {
      pull_res.h_recv_fea_num = len;
      pull_res.d_restore_keys_idx = alloc_cache<uint32_t>(len, local_pull_idx);
    }
    void init_shard(const size_t& len, const size_t& node_size) {
      shard_res.d_local_idx_parted =
          alloc_cache<uint32_t>(len, local_shard_idx);
      shard_res.d_node_size_ptr =
          alloc_cache<uint32_t>(node_size * node_size, d_node_size_buf);
      shard_res.resize_part_size(node_size);
    }
    void init_inner(const size_t& len, const int& device_num) {
      inner_res.d_idx = alloc_cache<uint32_t>(len, local_inner_idx);
      inner_res.d_offset_ptr =
          alloc_cache<uint32_t>(device_num * 2, inner_offset);
      inner_res.resize(device_num);
    }
    void init_trans(const size_t& fea_num, const size_t& value_bytes) {
      d_merged_trans_keys = alloc_cache<KeyType>(fea_num * 2, trans_keys_buff);
      d_merged_push_trans_keys = &d_merged_trans_keys[fea_num];
      d_merged_trans_vals =
          alloc_cache<char>(fea_num * 2 * value_bytes, trans_vals_buff);
      d_merged_push_trans_vals = &d_merged_trans_vals[fea_num * value_bytes];
370 371
    }

372
#if defined(PADDLE_WITH_CUDA)
373
    platform::CUDAPlace place_;
F
Fan Zhang 已提交
374

375 376 377
#elif defined(PADDLE_WITH_XPU_KP)
    platform::XPUPlace place_;
#endif
L
lxsbupt 已提交
378 379
    std::shared_ptr<memory::Allocation> all_keys_mem = nullptr;
    std::shared_ptr<memory::Allocation> all_grads_mem = nullptr;
F
Fan Zhang 已提交
380

381
    KeyType* all_keys;
L
lxsbupt 已提交
382
    char* all_grads;
383

L
lxsbupt 已提交
384 385
    std::shared_ptr<memory::Allocation> local_keys_mem = nullptr;
    std::shared_ptr<memory::Allocation> local_grads_mem = nullptr;
386
    KeyType* local_keys;
L
lxsbupt 已提交
387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423
    char* local_grads;

    // all2all
    std::shared_ptr<memory::Allocation> local_inner_idx = nullptr;
    std::shared_ptr<memory::Allocation> local_pull_idx = nullptr;
    std::shared_ptr<memory::Allocation> local_shard_idx = nullptr;
    std::shared_ptr<memory::Allocation> inner_offset = nullptr;
    std::shared_ptr<memory::Allocation> d_node_size_buf = nullptr;

    InnerResource inner_res;
    ShardResource shard_res;
    PullResource pull_res;

    KeyType* d_merged_keys = nullptr;
    char* d_merged_vals = nullptr;
    KeyType* d_merged_push_keys = nullptr;
    char* d_merged_push_vals = nullptr;
    std::vector<size_t> h_recv_offsets;
    std::vector<size_t> h_fea_sizes;
    // inner trans comm and stream buffer
    size_t h_trans_size;
    size_t h_trans_offset;

    // node trans comm and stream buffer
    std::unique_ptr<Semaphore> sem_wait;
    std::shared_ptr<memory::Allocation> trans_keys_buff = nullptr;
    std::shared_ptr<memory::Allocation> trans_vals_buff = nullptr;
    KeyType* d_merged_trans_keys = nullptr;
    char* d_merged_trans_vals = nullptr;
    KeyType* d_merged_push_trans_keys = nullptr;
    char* d_merged_push_trans_vals = nullptr;

    platform::Timer all2all_span_;
    platform::Timer inner_span_;
    platform::Timer inner_barrier_;
    platform::Timer node_span_;
    platform::Timer node_barrier_;
424 425
  };

426
  void init_path();
T
Thunderbrook 已提交
427

428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454
  template <typename StreamType>
  void sync_stream(const StreamType& stream) {
#if defined(PADDLE_WITH_CUDA)
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
#elif defined(PADDLE_WITH_XPU_KP)
    PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(stream));
#endif
  }

  template <typename StreamType>
  void create_stream(StreamType* stream) {
#if defined(PADDLE_WITH_CUDA)
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(stream));
#elif defined(PADDLE_WITH_XPU_KP)
    PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(stream));
#endif
  }

  template <typename StreamType>
  void destroy_stream(StreamType stream) {
#if defined(PADDLE_WITH_CUDA)
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream));
#elif defined(PADDLE_WITH_XPU_KP)
    PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_destroy(stream));
#endif
  }

D
danleifeng 已提交
455 456 457 458
  void create_storage(int start_index,
                      int end_index,
                      size_t keylen,
                      size_t vallen);
L
lxsbupt 已提交
459 460 461 462 463
  void create_tmp_storage(void*& dest,  // NOLINT
                          int start_index,
                          int end_index,
                          size_t vallen);
  void destroy_tmp_storage(void*& p, int start_index, int end_index);  // NOLINT
T
Thunderbrook 已提交
464
  void destroy_storage(int start_index, int end_index);
465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481
  void walk_to_dest(int start_index,
                    int gpu_num,
                    int* h_left,
                    int* h_right,
                    KeyType* src_key,
                    GradType* src_val);
  void walk_to_dest(int start_index,
                    int gpu_num,
                    int* h_left,
                    int* h_right,
                    KeyType* src_key,
                    char* src_val,
                    size_t val_size);
  void walk_to_src(int start_index,
                   int gpu_num,
                   int* h_left,
                   int* h_right,
482
                   ValType* src_val);
483 484 485 486 487 488
  void walk_to_src(int start_index,
                   int gpu_num,
                   int* h_left,
                   int* h_right,
                   char* src_val,
                   size_t val_size);
T
Thunderbrook 已提交
489

S
seemingwang 已提交
490
 protected:
L
lxsbupt 已提交
491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644
  void pull_merge_sparse(const int gpu_id,
                         KeyType* d_keys,
                         float* d_vals,
                         size_t len);
  void pull_normal_sparse(const int gpu_id,
                          KeyType* d_keys,
                          float* d_vals,
                          size_t len);
  void pull_one_table(const int gpu_id,
                      KeyType* d_keys,
                      float* d_vals,
                      const size_t& len,
                      const cudaStream_t& stream);

  // node all2all pull
  void pull_sparse_all2all(const int& gpu_id,
                           KeyType* d_keys,
                           float* d_vals,
                           const size_t& len);

  template <typename Sgd>
  void push_normal_sparse(int num,
                          KeyType* d_keys,
                          float* d_grads,
                          size_t len,
                          Sgd& sgd);  // NOLINT

  void shard_inner_keys(const size_t& total_fea_num,
                        const KeyType* d_keys,
                        const int& gpu_id,
                        const int& gpu_num,
                        InnerResource* res,
                        const cudaStream_t& stream);
  void gather_inner_keys_p2p(const size_t& total_fea_num,
                             const KeyType* d_keys,
                             InnerResource& res,  // NOLINT
                             const int& gpu_id,
                             const int& gpu_num,
                             const int& trans_id,
                             const cudaStream_t& stream);
  size_t gather_inter_keys_by_copy(const int& gpu_id,
                                   const size_t& fea_size,
                                   const KeyType* d_keys,
                                   const cudaStream_t& stream);
  void partition_shard_keys(const int& gpu_id,
                            const size_t& total_fea_num,
                            const KeyType* d_keys,
                            uint32_t* d_idx_parted,
                            KeyType* d_keys_parted,
                            size_t* h_part_sizes,
                            const int& shard_num,
                            const cudaStream_t& stream);
  size_t send_data_by_all2all(const int& gpu_id,
                              const int& nccl_node_size,
                              const int& nccl_rank_id,
                              const int& value_bytes,
                              const size_t* h_send_part_sizes,
                              const size_t* h_send_part_offsets,
                              const size_t* h_recv_part_sizes,
                              const size_t* h_recv_part_offsets,
                              const char* d_send_buff,
                              char* d_rev_buff,
                              const cudaStream_t& stream);
  size_t gather_sparse_keys_by_all2all(const int& gpu_id,
                                       const size_t& fea_size,
                                       const KeyType* d_in_keys,
                                       KeyType* d_out_keys,
                                       KeyType* d_tmp_keys,
                                       const cudaStream_t& stream);
  void scatter_sparse_vals_by_all2all(const int& gpu_id,
                                      const size_t& fea_size,
                                      const char* d_in_vals,
                                      void* d_out_vals,
                                      const size_t& value_bytes,
                                      void* d_tmp_vals,
                                      const cudaStream_t& stream);
  void scatter_inner_vals_p2p(const size_t& total_fea_num,
                              void* d_out_vals,
                              InnerResource& res,  // NOLINT
                              const int& gpu_id,
                              const int& gpu_num,
                              const int& trans_id,
                              const size_t& value_bytes,
                              const cudaStream_t& stream);
  void scatter_inter_vals_by_copy(const int& gpu_id,
                                  const size_t& fea_size,
                                  const char* d_in_vals,
                                  void* d_out_vals,
                                  const size_t& value_bytes,
                                  const cudaStream_t& stream);
  void gather_inner_data_p2p(const size_t& total_fea_num,
                             const KeyType* d_keys,
                             const void* d_vals,
                             InnerResource& res,  // NOLINT
                             const int& gpu_id,
                             const int& gpu_num,
                             const int& trans_id,
                             const size_t& value_bytes,
                             const cudaStream_t& stream);
  template <typename Sgd>
  void push_sparse_all2all(const int& gpu_id,
                           KeyType* d_keys,
                           float* d_grads,
                           const size_t& len,
                           Sgd& sgd);  // NOLINT
  size_t merge_grad(const int& gpu_id,
                    const size_t& len,
                    const KeyType* d_in_keys,
                    KeyType* d_out_keys,
                    const void* d_in_grads,
                    void* d_out_grads,
                    const cudaStream_t& stream);
  size_t gather_inter_gradient_by_copy(const int& gpu_id,
                                       const size_t& push_size,
                                       KeyType* d_keys,
                                       void* d_push_vals,
                                       const size_t& value_bytes,
                                       const cudaStream_t& stream);
  size_t gather_sparse_gradient_by_all2all(const int& gpu_id,
                                           const size_t& push_size,
                                           const KeyType* d_keys,
                                           const char* d_push_vals,
                                           const size_t& value_bytes,
                                           KeyType* d_out_keys,
                                           KeyType* d_tmp_keys,
                                           char* d_out_vals,
                                           char* d_tmp_vals,
                                           const cudaStream_t& stream);
  size_t send_keys_by_all2all_trans(const int& gpu_id,
                                    const int& rank_id,
                                    const int& node_size,
                                    const size_t& fea_size,
                                    const KeyType* d_in_keys,
                                    KeyType* d_out_keys,
                                    const cudaStream_t& stream);
  size_t send_vals_by_all2all_trans(const int& gpu_id,
                                    const int& rank_id,
                                    const int& node_size,
                                    const char* d_in_vals,
                                    char* d_out_vals,
                                    const size_t& value_bytes,
                                    const cudaStream_t& stream);
  size_t send_gradient_by_all2all_trans(const int& gpu_id,
                                        const int& rank_id,
                                        const int& node_size,
                                        const size_t& fea_size,
                                        const KeyType* d_keys,
                                        const char* d_push_vals,
                                        const size_t& value_bytes,
                                        KeyType* d_out_keys,
                                        char* d_out_vals,
                                        const cudaStream_t& stream);
  // debug time
  void print_debug_time(const int& gpu_id, bool force = false);
D
danleifeng 已提交
645

T
Thunderbrook 已提交
646
  using Table = HashTable<KeyType, ValType>;
D
danleifeng 已提交
647
  using PtrTable = HashTable<KeyType, float*>;
T
Thunderbrook 已提交
648
  std::vector<Table*> tables_;
Y
yaoxuefeng 已提交
649
  std::vector<PtrTable*> ptr_tables_;
T
Thunderbrook 已提交
650
  std::shared_ptr<HeterPsResource> resource_;
651
  std::vector<std::vector<Path>> path_;
S
seemingwang 已提交
652 653
  float load_factor_{0.75};
  int block_size_{256};
S
seemingwang 已提交
654
  std::unique_ptr<HeterCommKernel> heter_comm_kernel_;
S
seemingwang 已提交
655

D
danleifeng 已提交
656 657
  GPUAccessor gpu_accessor_;

L
lxsbupt 已提交
658
 protected:
S
seemingwang 已提交
659
  int topo_aware_{0};
660
  std::vector<LocalStorage> storage_;
Y
yaoxuefeng 已提交
661
  DynamicGradMerger merger_;
L
lxsbupt 已提交
662
  int device_num_ = 8;
T
Thunderbrook 已提交
663
  int multi_node_{0};
L
lxsbupt 已提交
664 665 666 667 668 669 670 671 672 673 674 675
  int rank_id_ = 0;
  int node_size_ = 1;
  // inner sync barrier
  Barrier barrier_;
  size_t val_type_size_;
  size_t pull_type_size_;
  size_t grad_type_size_;
  size_t max_type_size_;
  bool enable_gpu_direct_access_ = false;
  // set compress bound
  float max_value_bound_ = 10.0;
  float max_grad_bound_ = 10.0;
676 677

#if defined(PADDLE_WITH_CUDA)
L
lxsbupt 已提交
678
  GpuRDMAChecker* rdma_checker_ = nullptr;
679 680
  std::vector<ncclComm_t> nccl_inner_comms_;
  std::vector<ncclComm_t> nccl_inter_comms_;
Y
yaoxuefeng 已提交
681 682
  int multi_mf_dim_{8};
  int max_mf_dim_ = 8;
T
Thunderbrook 已提交
683
  std::vector<std::shared_ptr<cub::CachingDeviceAllocator>> allocators_;
684
#endif
L
lxsbupt 已提交
685
  int64_t start_time_ = 0;
T
Thunderbrook 已提交
686 687 688 689
};

}  // end namespace framework
}  // end namespace paddle
F
Fan Zhang 已提交
690

T
Thunderbrook 已提交
691
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h"
F
Fan Zhang 已提交
692

T
Thunderbrook 已提交
693
#endif