data_feed.cu 112.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)

#include "paddle/fluid/framework/data_feed.h"
D
danleifeng 已提交
21 22 23 24 25
#include <thrust/device_ptr.h>
#include <thrust/random.h>
#include <thrust/shuffle.h>
#include <sstream>
#include "cub/cub.cuh"
P
pangengzheng 已提交
26
#if defined(PADDLE_WITH_PSCORE) && defined(PADDLE_WITH_GPU_GRAPH)
D
danleifeng 已提交
27
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
L
lxsbupt 已提交
28
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h"
D
danleifeng 已提交
29
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h"
P
pangengzheng 已提交
30
#endif
L
lxsbupt 已提交
31 32
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
33
#include "paddle/fluid/framework/io/fs.h"
L
lxsbupt 已提交
34 35
#include "paddle/phi/kernels/gpu/graph_reindex_funcs.h"
#include "paddle/phi/kernels/graph_reindex_kernel.h"
D
danleifeng 已提交
36 37

DECLARE_bool(enable_opt_get_features);
L
lxsbupt 已提交
38 39 40
DECLARE_bool(graph_metapath_split_opt);
DECLARE_int32(gpugraph_storage_mode);
DECLARE_double(gpugraph_hbm_table_load_factor);
41 42 43 44 45 46 47 48

namespace paddle {
namespace framework {

#define CUDA_KERNEL_LOOP(i, n)                                 \
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
       i += blockDim.x * gridDim.x)

L
lxsbupt 已提交
49 50 51 52 53
#define DEBUG_STATE(state)                                             \
  VLOG(2) << "left: " << state->left << " right: " << state->right     \
          << " central_word: " << state->central_word                  \
          << " step: " << state->step << " cursor: " << state->cursor  \
          << " len: " << state->len << " row_num: " << state->row_num; \
54 55 56 57 58 59
// CUDA: use 512 threads per block
const int CUDA_NUM_THREADS = 512;
// CUDA: number of blocks for threads.
inline int GET_BLOCKS(const int N) {
  return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
L
lxsbupt 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276

template <typename T>
__global__ void fill_idx(T *idx, size_t len) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    idx[i] = i;
  }
}

/**
 * @brief sort cub
 */
template <typename K, typename V>
void cub_sort_pairs(int len,
                    const K *in_keys,
                    K *out_keys,
                    const V *in_vals,
                    V *out_vals,
                    cudaStream_t stream,
                    std::shared_ptr<phi::Allocation> &d_buf_,  // NOLINT
                    const paddle::platform::Place &place_) {
  size_t temp_storage_bytes = 0;
  CUDA_CHECK(cub::DeviceRadixSort::SortPairs(NULL,
                                             temp_storage_bytes,
                                             in_keys,
                                             out_keys,
                                             in_vals,
                                             out_vals,
                                             len,
                                             0,
                                             8 * sizeof(K),
                                             stream,
                                             false));
  if (d_buf_ == NULL || d_buf_->size() < temp_storage_bytes) {
    d_buf_ = memory::AllocShared(
        place_,
        temp_storage_bytes,
        phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
  }
  CUDA_CHECK(cub::DeviceRadixSort::SortPairs(d_buf_->ptr(),
                                             temp_storage_bytes,
                                             in_keys,
                                             out_keys,
                                             in_vals,
                                             out_vals,
                                             len,
                                             0,
                                             8 * sizeof(K),
                                             stream,
                                             false));
}

/**
 * @Brief cub run length encode
 */
template <typename K, typename V, typename TNum>
void cub_runlength_encode(int N,
                          const K *in_keys,
                          K *out_keys,
                          V *out_sizes,
                          TNum *d_out_len,
                          cudaStream_t stream,
                          std::shared_ptr<phi::Allocation> &d_buf_,  // NOLINT
                          const paddle::platform::Place &place_) {
  size_t temp_storage_bytes = 0;
  CUDA_CHECK(cub::DeviceRunLengthEncode::Encode(NULL,
                                                temp_storage_bytes,
                                                in_keys,
                                                out_keys,
                                                out_sizes,
                                                d_out_len,
                                                N,
                                                stream));
  if (d_buf_ == NULL || d_buf_->size() < temp_storage_bytes) {
    d_buf_ = memory::AllocShared(
        place_,
        temp_storage_bytes,
        phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
  }
  CUDA_CHECK(cub::DeviceRunLengthEncode::Encode(d_buf_->ptr(),
                                                temp_storage_bytes,
                                                in_keys,
                                                out_keys,
                                                out_sizes,
                                                d_out_len,
                                                N,
                                                stream));
}

/**
 * @brief exclusive sum
 */
template <typename K>
void cub_exclusivesum(int N,
                      const K *in,
                      K *out,
                      cudaStream_t stream,
                      std::shared_ptr<phi::Allocation> &d_buf_,  // NOLINT
                      const paddle::platform::Place &place_) {
  size_t temp_storage_bytes = 0;
  CUDA_CHECK(cub::DeviceScan::ExclusiveSum(
      NULL, temp_storage_bytes, in, out, N, stream));
  if (d_buf_ == NULL || d_buf_->size() < temp_storage_bytes) {
    d_buf_ = memory::AllocShared(
        place_,
        temp_storage_bytes,
        phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
  }
  CUDA_CHECK(cub::DeviceScan::ExclusiveSum(
      d_buf_->ptr(), temp_storage_bytes, in, out, N, stream));
}

template <typename T>
__global__ void kernel_fill_restore_idx(size_t N,
                                        const T *d_sorted_idx,
                                        const T *d_offset,
                                        const T *d_merged_cnts,
                                        T *d_restore_idx) {
  CUDA_KERNEL_LOOP(i, N) {
    const T &off = d_offset[i];
    const T &num = d_merged_cnts[i];
    for (size_t k = 0; k < num; k++) {
      d_restore_idx[d_sorted_idx[off + k]] = i;
    }
  }
}

template <typename T>
__global__ void kernel_fill_restore_idx_by_search(size_t N,
                                                  const T *d_sorted_idx,
                                                  size_t merge_num,
                                                  const T *d_offset,
                                                  T *d_restore_idx) {
  CUDA_KERNEL_LOOP(i, N) {
    if (i < d_offset[1]) {
      d_restore_idx[d_sorted_idx[i]] = 0;
      continue;
    }
    int high = merge_num - 1;
    int low = 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < d_offset[mid + 1]) {
        high = mid;
      } else {
        low = mid + 1;
      }
    }
    d_restore_idx[d_sorted_idx[i]] = low;
  }
}

// For unique node and inverse id.
int dedup_keys_and_fillidx(int total_nodes_num,
                           const uint64_t *d_keys,
                           uint64_t *d_merged_keys,  // input
                           uint64_t *d_sorted_keys,  // output
                           uint32_t *d_restore_idx,  // inverse
                           uint32_t *d_sorted_idx,
                           uint32_t *d_offset,
                           uint32_t *d_merged_cnts,
                           cudaStream_t stream,
                           std::shared_ptr<phi::Allocation> &d_buf_,  // NOLINT
                           const paddle::platform::Place &place_) {
  int merged_size = 0;  // Final num
  auto d_index_in =
      memory::Alloc(place_,
                    sizeof(uint32_t) * (total_nodes_num + 1),
                    phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
  uint32_t *d_index_in_ptr = reinterpret_cast<uint32_t *>(d_index_in->ptr());
  int *d_merged_size =
      reinterpret_cast<int *>(&d_index_in_ptr[total_nodes_num]);
  fill_idx<<<GET_BLOCKS(total_nodes_num), CUDA_NUM_THREADS, 0, stream>>>(
      d_index_in_ptr, total_nodes_num);
  cub_sort_pairs(total_nodes_num,
                 d_keys,
                 d_sorted_keys,
                 d_index_in_ptr,
                 d_sorted_idx,
                 stream,
                 d_buf_,
                 place_);
  cub_runlength_encode(total_nodes_num,
                       d_sorted_keys,
                       d_merged_keys,
                       d_merged_cnts,
                       d_merged_size,
                       stream,
                       d_buf_,
                       place_);
  CUDA_CHECK(cudaMemcpyAsync(&merged_size,
                             d_merged_size,
                             sizeof(int),
                             cudaMemcpyDeviceToHost,
                             stream));
  CUDA_CHECK(cudaStreamSynchronize(stream));
  cub_exclusivesum(
      merged_size, d_merged_cnts, d_offset, stream, d_buf_, place_);

  if (total_nodes_num < merged_size * 2) {
    kernel_fill_restore_idx<<<GET_BLOCKS(merged_size),
                              CUDA_NUM_THREADS,
                              0,
                              stream>>>(
        merged_size, d_sorted_idx, d_offset, d_merged_cnts, d_restore_idx);
  } else {
    // used mid search fill idx when high dedup rate
    kernel_fill_restore_idx_by_search<<<GET_BLOCKS(total_nodes_num),
                                        CUDA_NUM_THREADS,
                                        0,
                                        stream>>>(
        total_nodes_num, d_sorted_idx, merged_size, d_offset, d_restore_idx);
  }
  CUDA_CHECK(cudaStreamSynchronize(stream));
  return merged_size;
}

277
// fill slot values
278 279 280 281 282 283 284 285
__global__ void FillSlotValueOffsetKernel(const int ins_num,
                                          const int used_slot_num,
                                          size_t *slot_value_offsets,
                                          const int *uint64_offsets,
                                          const int uint64_slot_size,
                                          const int *float_offsets,
                                          const int float_slot_size,
                                          const UsedSlotGpuType *used_slots) {
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
  int col_num = ins_num + 1;
  int uint64_cols = uint64_slot_size + 1;
  int float_cols = float_slot_size + 1;

  CUDA_KERNEL_LOOP(slot_idx, used_slot_num) {
    int value_off = slot_idx * col_num;
    slot_value_offsets[value_off] = 0;

    auto &info = used_slots[slot_idx];
    if (info.is_uint64_value) {
      for (int k = 0; k < ins_num; ++k) {
        int pos = k * uint64_cols + info.slot_value_idx;
        int num = uint64_offsets[pos + 1] - uint64_offsets[pos];
        PADDLE_ENFORCE(num >= 0, "The number of slot size must be ge 0.");
        slot_value_offsets[value_off + k + 1] =
            slot_value_offsets[value_off + k] + num;
      }
    } else {
      for (int k = 0; k < ins_num; ++k) {
        int pos = k * float_cols + info.slot_value_idx;
        int num = float_offsets[pos + 1] - float_offsets[pos];
        PADDLE_ENFORCE(num >= 0, "The number of slot size must be ge 0.");
        slot_value_offsets[value_off + k + 1] =
            slot_value_offsets[value_off + k] + num;
      }
    }
  }
}

void SlotRecordInMemoryDataFeed::FillSlotValueOffset(
316 317 318 319 320 321 322
    const int ins_num,
    const int used_slot_num,
    size_t *slot_value_offsets,
    const int *uint64_offsets,
    const int uint64_slot_size,
    const int *float_offsets,
    const int float_slot_size,
323 324
    const UsedSlotGpuType *used_slots) {
  auto stream =
L
Leo Chen 已提交
325
      dynamic_cast<phi::GPUContext *>(
326 327
          paddle::platform::DeviceContextPool::Instance().Get(this->place_))
          ->stream();
328 329 330 331 332 333 334 335 336 337 338
  FillSlotValueOffsetKernel<<<GET_BLOCKS(used_slot_num),
                              CUDA_NUM_THREADS,
                              0,
                              stream>>>(ins_num,
                                        used_slot_num,
                                        slot_value_offsets,
                                        uint64_offsets,
                                        uint64_slot_size,
                                        float_offsets,
                                        float_slot_size,
                                        used_slots);
339 340 341
  cudaStreamSynchronize(stream);
}

342 343 344 345 346 347 348 349 350 351 352 353 354
__global__ void CopyForTensorKernel(const int used_slot_num,
                                    const int ins_num,
                                    void **dest,
                                    const size_t *slot_value_offsets,
                                    const uint64_t *uint64_feas,
                                    const int *uint64_offsets,
                                    const int *uint64_ins_lens,
                                    const int uint64_slot_size,
                                    const float *float_feas,
                                    const int *float_offsets,
                                    const int *float_ins_lens,
                                    const int float_slot_size,
                                    const UsedSlotGpuType *used_slots) {
355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
  int col_num = ins_num + 1;
  int uint64_cols = uint64_slot_size + 1;
  int float_cols = float_slot_size + 1;

  CUDA_KERNEL_LOOP(i, ins_num * used_slot_num) {
    int slot_idx = i / ins_num;
    int ins_idx = i % ins_num;

    uint32_t value_offset = slot_value_offsets[slot_idx * col_num + ins_idx];
    auto &info = used_slots[slot_idx];
    if (info.is_uint64_value) {
      uint64_t *up = reinterpret_cast<uint64_t *>(dest[slot_idx]);
      int index = info.slot_value_idx + uint64_cols * ins_idx;
      int old_off = uint64_offsets[index];
      int num = uint64_offsets[index + 1] - old_off;
      PADDLE_ENFORCE(num >= 0, "The number of slot size must be ge 0.");
      int uint64_value_offset = uint64_ins_lens[ins_idx];
      for (int k = 0; k < num; ++k) {
        up[k + value_offset] = uint64_feas[k + old_off + uint64_value_offset];
      }
    } else {
      float *fp = reinterpret_cast<float *>(dest[slot_idx]);
      int index = info.slot_value_idx + float_cols * ins_idx;
      int old_off = float_offsets[index];
      int num = float_offsets[index + 1] - old_off;
      PADDLE_ENFORCE(num >= 0, "The number of slot size must be ge 0.");
      int float_value_offset = float_ins_lens[ins_idx];
      for (int k = 0; k < num; ++k) {
        fp[k + value_offset] = float_feas[k + old_off + float_value_offset];
      }
    }
  }
}

void SlotRecordInMemoryDataFeed::CopyForTensor(
390 391 392 393 394 395 396 397 398 399 400 401 402
    const int ins_num,
    const int used_slot_num,
    void **dest,
    const size_t *slot_value_offsets,
    const uint64_t *uint64_feas,
    const int *uint64_offsets,
    const int *uint64_ins_lens,
    const int uint64_slot_size,
    const float *float_feas,
    const int *float_offsets,
    const int *float_ins_lens,
    const int float_slot_size,
    const UsedSlotGpuType *used_slots) {
403
  auto stream =
L
Leo Chen 已提交
404
      dynamic_cast<phi::GPUContext *>(
405 406 407
          paddle::platform::DeviceContextPool::Instance().Get(this->place_))
          ->stream();

408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423
  CopyForTensorKernel<<<GET_BLOCKS(used_slot_num * ins_num),
                        CUDA_NUM_THREADS,
                        0,
                        stream>>>(used_slot_num,
                                  ins_num,
                                  dest,
                                  slot_value_offsets,
                                  uint64_feas,
                                  uint64_offsets,
                                  uint64_ins_lens,
                                  uint64_slot_size,
                                  float_feas,
                                  float_offsets,
                                  float_ins_lens,
                                  float_slot_size,
                                  used_slots);
424 425 426
  cudaStreamSynchronize(stream);
}

D
danleifeng 已提交
427 428 429 430 431 432 433 434 435 436 437 438 439
__global__ void GraphFillCVMKernel(int64_t *tensor, int len) {
  CUDA_KERNEL_LOOP(idx, len) { tensor[idx] = 1; }
}

__global__ void CopyDuplicateKeys(int64_t *dist_tensor,
                                  uint64_t *src_tensor,
                                  int len) {
  CUDA_KERNEL_LOOP(idx, len) {
    dist_tensor[idx * 2] = src_tensor[idx];
    dist_tensor[idx * 2 + 1] = src_tensor[idx];
  }
}

P
pangengzheng 已提交
440
#if defined(PADDLE_WITH_PSCORE) && defined(PADDLE_WITH_GPU_GRAPH)
D
danleifeng 已提交
441 442
int GraphDataGenerator::AcquireInstance(BufState *state) {
  if (state->GetNextStep()) {
L
lxsbupt 已提交
443
    DEBUG_STATE(state);
D
danleifeng 已提交
444 445
    return state->len;
  } else if (state->GetNextCentrolWord()) {
L
lxsbupt 已提交
446
    DEBUG_STATE(state);
D
danleifeng 已提交
447 448
    return state->len;
  } else if (state->GetNextBatch()) {
L
lxsbupt 已提交
449
    DEBUG_STATE(state);
D
danleifeng 已提交
450 451 452 453 454 455 456 457
    return state->len;
  }
  return 0;
}

__global__ void GraphFillIdKernel(uint64_t *id_tensor,
                                  int *fill_ins_num,
                                  uint64_t *walk,
458
                                  uint8_t *walk_ntype,
D
danleifeng 已提交
459 460 461 462
                                  int *row,
                                  int central_word,
                                  int step,
                                  int len,
463 464 465
                                  int col_num,
                                  uint8_t *excluded_train_pair,
                                  int excluded_train_pair_len) {
D
danleifeng 已提交
466 467 468
  __shared__ uint64_t local_key[CUDA_NUM_THREADS * 2];
  __shared__ int local_num;
  __shared__ int global_num;
469
  bool need_filter = false;
D
danleifeng 已提交
470 471 472 473 474 475 476 477 478 479 480 481

  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (threadIdx.x == 0) {
    local_num = 0;
  }
  __syncthreads();
  // int dst = idx * 2;
  // id_tensor[dst] = walk[src];
  // id_tensor[dst + 1] = walk[src + step];
  if (idx < len) {
    int src = row[idx] * col_num + central_word;
    if (walk[src] != 0 && walk[src + step] != 0) {
482 483 484 485 486 487 488 489 490 491 492 493 494
      for (int i = 0; i < excluded_train_pair_len; i += 2) {
        if (walk_ntype[src] == excluded_train_pair[i] &&
            walk_ntype[src + step] == excluded_train_pair[i + 1]) {
          // filter this pair
          need_filter = true;
          break;
        }
      }
      if (!need_filter) {
        size_t dst = atomicAdd(&local_num, 1);
        local_key[dst * 2] = walk[src];
        local_key[dst * 2 + 1] = walk[src + step];
      }
D
danleifeng 已提交
495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515
    }
  }

  __syncthreads();

  if (threadIdx.x == 0) {
    global_num = atomicAdd(fill_ins_num, local_num);
  }
  __syncthreads();

  if (threadIdx.x < local_num) {
    id_tensor[global_num * 2 + 2 * threadIdx.x] = local_key[2 * threadIdx.x];
    id_tensor[global_num * 2 + 2 * threadIdx.x + 1] =
        local_key[2 * threadIdx.x + 1];
  }
}

__global__ void GraphFillSlotKernel(uint64_t *id_tensor,
                                    uint64_t *feature_buf,
                                    int len,
                                    int total_ins,
L
lxsbupt 已提交
516 517 518 519 520
                                    int slot_num,
                                    int *slot_feature_num_map,
                                    int fea_num_per_node,
                                    int *actual_slot_id_map,
                                    int *fea_offset_map) {
D
danleifeng 已提交
521
  CUDA_KERNEL_LOOP(idx, len) {
L
lxsbupt 已提交
522
    int fea_idx = idx / total_ins;
D
danleifeng 已提交
523
    int ins_idx = idx % total_ins;
L
lxsbupt 已提交
524 525 526 527 528
    int actual_slot_id = actual_slot_id_map[fea_idx];
    int fea_offset = fea_offset_map[fea_idx];
    reinterpret_cast<uint64_t *>(id_tensor[actual_slot_id])
        [ins_idx * slot_feature_num_map[actual_slot_id] + fea_offset] =
            feature_buf[ins_idx * fea_num_per_node + fea_idx];
D
danleifeng 已提交
529 530 531 532 533
  }
}

__global__ void GraphFillSlotLodKernelOpt(uint64_t *id_tensor,
                                          int len,
L
lxsbupt 已提交
534 535
                                          int total_ins,
                                          int *slot_feature_num_map) {
D
danleifeng 已提交
536 537 538
  CUDA_KERNEL_LOOP(idx, len) {
    int slot_idx = idx / total_ins;
    int ins_idx = idx % total_ins;
L
lxsbupt 已提交
539 540
    (reinterpret_cast<uint64_t *>(id_tensor[slot_idx]))[ins_idx] =
        ins_idx * slot_feature_num_map[slot_idx];
D
danleifeng 已提交
541 542 543 544 545 546 547
  }
}

__global__ void GraphFillSlotLodKernel(int64_t *id_tensor, int len) {
  CUDA_KERNEL_LOOP(idx, len) { id_tensor[idx] = idx; }
}

L
lxsbupt 已提交
548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565
// fill sage neighbor results
__global__ void FillActualNeighbors(int64_t *vals,
                                    int64_t *actual_vals,
                                    int64_t *actual_vals_dst,
                                    int *actual_sample_size,
                                    int *cumsum_actual_sample_size,
                                    int sample_size,
                                    int len,
                                    int mod) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    int offset1 = cumsum_actual_sample_size[i];
    int offset2 = sample_size * i;
    int dst_id = i % mod;
    for (int j = 0; j < actual_sample_size[i]; j++) {
      actual_vals[offset1 + j] = vals[offset2 + j];
      actual_vals_dst[offset1 + j] = dst_id;
    }
D
danleifeng 已提交
566
  }
L
lxsbupt 已提交
567
}
D
danleifeng 已提交
568

L
lxsbupt 已提交
569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597
int GraphDataGenerator::FillIdShowClkTensor(int total_instance,
                                            bool gpu_graph_training,
                                            size_t cursor) {
  id_tensor_ptr_ =
      feed_vec_[0]->mutable_data<int64_t>({total_instance, 1}, this->place_);
  show_tensor_ptr_ =
      feed_vec_[1]->mutable_data<int64_t>({total_instance}, this->place_);
  clk_tensor_ptr_ =
      feed_vec_[2]->mutable_data<int64_t>({total_instance}, this->place_);
  if (gpu_graph_training) {
    uint64_t *ins_cursor, *ins_buf;
    ins_buf = reinterpret_cast<uint64_t *>(d_ins_buf_->ptr());
    ins_cursor = ins_buf + ins_buf_pair_len_ * 2 - total_instance;
    cudaMemcpyAsync(id_tensor_ptr_,
                    ins_cursor,
                    sizeof(uint64_t) * total_instance,
                    cudaMemcpyDeviceToDevice,
                    train_stream_);
  } else {
    uint64_t *d_type_keys =
        reinterpret_cast<uint64_t *>(d_device_keys_[cursor]->ptr());
    d_type_keys += infer_node_start_;
    infer_node_start_ += total_instance / 2;
    CopyDuplicateKeys<<<GET_BLOCKS(total_instance / 2),
                        CUDA_NUM_THREADS,
                        0,
                        train_stream_>>>(
        id_tensor_ptr_, d_type_keys, total_instance / 2);
  }
D
danleifeng 已提交
598

L
lxsbupt 已提交
599 600 601 602 603 604 605 606 607 608
  GraphFillCVMKernel<<<GET_BLOCKS(total_instance),
                       CUDA_NUM_THREADS,
                       0,
                       train_stream_>>>(show_tensor_ptr_, total_instance);
  GraphFillCVMKernel<<<GET_BLOCKS(total_instance),
                       CUDA_NUM_THREADS,
                       0,
                       train_stream_>>>(clk_tensor_ptr_, total_instance);
  return 0;
}
D
danleifeng 已提交
609

L
lxsbupt 已提交
610 611 612 613 614 615 616 617 618 619 620 621
int GraphDataGenerator::FillGraphIdShowClkTensor(int uniq_instance,
                                                 int total_instance,
                                                 int index) {
  id_tensor_ptr_ =
      feed_vec_[0]->mutable_data<int64_t>({uniq_instance, 1}, this->place_);
  show_tensor_ptr_ =
      feed_vec_[1]->mutable_data<int64_t>({uniq_instance}, this->place_);
  clk_tensor_ptr_ =
      feed_vec_[2]->mutable_data<int64_t>({uniq_instance}, this->place_);
  int index_offset = 3 + slot_num_ * 2 + 5 * samples_.size();
  index_tensor_ptr_ = feed_vec_[index_offset]->mutable_data<int>(
      {total_instance}, this->place_);
622 623 624 625
  if (get_degree_) {
    degree_tensor_ptr_ = feed_vec_[index_offset + 1]->mutable_data<int>(
        {uniq_instance * edge_to_id_len_}, this->place_);
  }
L
lxsbupt 已提交
626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691

  int len_samples = samples_.size();
  int *num_nodes_tensor_ptr_[len_samples];
  int *next_num_nodes_tensor_ptr_[len_samples];
  int64_t *edges_src_tensor_ptr_[len_samples];
  int64_t *edges_dst_tensor_ptr_[len_samples];
  int *edges_split_tensor_ptr_[len_samples];

  std::vector<std::vector<int>> edges_split_num_for_graph =
      edges_split_num_vec_[index];
  std::vector<std::shared_ptr<phi::Allocation>> graph_edges =
      graph_edges_vec_[index];
  for (int i = 0; i < len_samples; i++) {
    int offset = 3 + 2 * slot_num_ + 5 * i;
    std::vector<int> edges_split_num = edges_split_num_for_graph[i];

    int neighbor_len = edges_split_num[edge_to_id_len_ + 2];
    num_nodes_tensor_ptr_[i] =
        feed_vec_[offset]->mutable_data<int>({1}, this->place_);
    next_num_nodes_tensor_ptr_[i] =
        feed_vec_[offset + 1]->mutable_data<int>({1}, this->place_);
    edges_src_tensor_ptr_[i] = feed_vec_[offset + 2]->mutable_data<int64_t>(
        {neighbor_len, 1}, this->place_);
    edges_dst_tensor_ptr_[i] = feed_vec_[offset + 3]->mutable_data<int64_t>(
        {neighbor_len, 1}, this->place_);
    edges_split_tensor_ptr_[i] = feed_vec_[offset + 4]->mutable_data<int>(
        {edge_to_id_len_}, this->place_);

    // [edges_split_num, next_num_nodes, num_nodes, neighbor_len]
    cudaMemcpyAsync(next_num_nodes_tensor_ptr_[i],
                    edges_split_num.data() + edge_to_id_len_,
                    sizeof(int),
                    cudaMemcpyHostToDevice,
                    train_stream_);
    cudaMemcpyAsync(num_nodes_tensor_ptr_[i],
                    edges_split_num.data() + edge_to_id_len_ + 1,
                    sizeof(int),
                    cudaMemcpyHostToDevice,
                    train_stream_);
    cudaMemcpyAsync(edges_split_tensor_ptr_[i],
                    edges_split_num.data(),
                    sizeof(int) * edge_to_id_len_,
                    cudaMemcpyHostToDevice,
                    train_stream_);
    cudaMemcpyAsync(edges_src_tensor_ptr_[i],
                    graph_edges[i * 2]->ptr(),
                    sizeof(int64_t) * neighbor_len,
                    cudaMemcpyDeviceToDevice,
                    train_stream_);
    cudaMemcpyAsync(edges_dst_tensor_ptr_[i],
                    graph_edges[i * 2 + 1]->ptr(),
                    sizeof(int64_t) * neighbor_len,
                    cudaMemcpyDeviceToDevice,
                    train_stream_);
  }

  cudaMemcpyAsync(id_tensor_ptr_,
                  final_sage_nodes_vec_[index]->ptr(),
                  sizeof(int64_t) * uniq_instance,
                  cudaMemcpyDeviceToDevice,
                  train_stream_);
  cudaMemcpyAsync(index_tensor_ptr_,
                  inverse_vec_[index]->ptr(),
                  sizeof(int) * total_instance,
                  cudaMemcpyDeviceToDevice,
                  train_stream_);
692 693 694 695 696 697 698
  if (get_degree_) {
    cudaMemcpyAsync(degree_tensor_ptr_,
                    node_degree_vec_[index]->ptr(),
                    sizeof(int) * uniq_instance * edge_to_id_len_,
                    cudaMemcpyDeviceToDevice,
                    train_stream_);
  }
L
lxsbupt 已提交
699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721
  GraphFillCVMKernel<<<GET_BLOCKS(uniq_instance),
                       CUDA_NUM_THREADS,
                       0,
                       train_stream_>>>(show_tensor_ptr_, uniq_instance);
  GraphFillCVMKernel<<<GET_BLOCKS(uniq_instance),
                       CUDA_NUM_THREADS,
                       0,
                       train_stream_>>>(clk_tensor_ptr_, uniq_instance);
  return 0;
}

int GraphDataGenerator::FillGraphSlotFeature(
    int total_instance,
    bool gpu_graph_training,
    std::shared_ptr<phi::Allocation> final_sage_nodes) {
  uint64_t *ins_cursor, *ins_buf;
  if (gpu_graph_training) {
    ins_buf = reinterpret_cast<uint64_t *>(d_ins_buf_->ptr());
    ins_cursor = ins_buf + ins_buf_pair_len_ * 2 - total_instance;
  } else {
    id_tensor_ptr_ =
        feed_vec_[0]->mutable_data<int64_t>({total_instance, 1}, this->place_);
    ins_cursor = reinterpret_cast<uint64_t *>(id_tensor_ptr_);
D
danleifeng 已提交
722 723
  }

L
lxsbupt 已提交
724 725 726 727 728 729 730 731 732 733
  if (!sage_mode_) {
    return FillSlotFeature(ins_cursor, total_instance);
  } else {
    uint64_t *sage_nodes_ptr =
        reinterpret_cast<uint64_t *>(final_sage_nodes->ptr());
    return FillSlotFeature(sage_nodes_ptr, total_instance);
  }
}

int GraphDataGenerator::MakeInsPair(cudaStream_t stream) {
D
danleifeng 已提交
734
  uint64_t *walk = reinterpret_cast<uint64_t *>(d_walk_->ptr());
735 736 737 738 739 740 741
  uint8_t *walk_ntype = NULL;
  uint8_t *excluded_train_pair = NULL;
  if (excluded_train_pair_len_ > 0) {
    walk_ntype = reinterpret_cast<uint8_t *>(d_walk_ntype_->ptr());
    excluded_train_pair =
        reinterpret_cast<uint8_t *>(d_excluded_train_pair_->ptr());
  }
D
danleifeng 已提交
742 743 744
  uint64_t *ins_buf = reinterpret_cast<uint64_t *>(d_ins_buf_->ptr());
  int *random_row = reinterpret_cast<int *>(d_random_row_->ptr());
  int *d_pair_num = reinterpret_cast<int *>(d_pair_num_->ptr());
L
lxsbupt 已提交
745
  cudaMemsetAsync(d_pair_num, 0, sizeof(int), stream);
D
danleifeng 已提交
746
  int len = buf_state_.len;
L
lxsbupt 已提交
747 748
  // make pair
  GraphFillIdKernel<<<GET_BLOCKS(len), CUDA_NUM_THREADS, 0, stream>>>(
D
danleifeng 已提交
749 750 751
      ins_buf + ins_buf_pair_len_ * 2,
      d_pair_num,
      walk,
752
      walk_ntype,
D
danleifeng 已提交
753 754 755 756
      random_row + buf_state_.cursor,
      buf_state_.central_word,
      window_step_[buf_state_.step],
      len,
757 758 759
      walk_len_,
      excluded_train_pair,
      excluded_train_pair_len_);
D
danleifeng 已提交
760 761
  int h_pair_num;
  cudaMemcpyAsync(
L
lxsbupt 已提交
762 763
      &h_pair_num, d_pair_num, sizeof(int), cudaMemcpyDeviceToHost, stream);
  cudaStreamSynchronize(stream);
D
danleifeng 已提交
764 765 766
  ins_buf_pair_len_ += h_pair_num;

  if (debug_mode_) {
767
    uint64_t h_ins_buf[ins_buf_pair_len_ * 2];  // NOLINT
D
danleifeng 已提交
768 769 770 771 772 773
    cudaMemcpy(h_ins_buf,
               ins_buf,
               2 * ins_buf_pair_len_ * sizeof(uint64_t),
               cudaMemcpyDeviceToHost);
    VLOG(2) << "h_pair_num = " << h_pair_num
            << ", ins_buf_pair_len = " << ins_buf_pair_len_;
774 775 776
    for (int xx = 0; xx < ins_buf_pair_len_; xx++) {
      VLOG(2) << "h_ins_buf: " << h_ins_buf[xx * 2] << ", "
              << h_ins_buf[xx * 2 + 1];
D
danleifeng 已提交
777 778 779 780 781
    }
  }
  return ins_buf_pair_len_;
}

L
lxsbupt 已提交
782 783 784 785 786 787 788 789 790 791 792 793 794 795 796
int GraphDataGenerator::FillInsBuf(cudaStream_t stream) {
  if (ins_buf_pair_len_ >= batch_size_) {
    return batch_size_;
  }
  int total_instance = AcquireInstance(&buf_state_);

  VLOG(2) << "total_ins: " << total_instance;
  buf_state_.Debug();

  if (total_instance == 0) {
    return -1;
  }
  return MakeInsPair(stream);
}

D
danleifeng 已提交
797 798 799 800 801
int GraphDataGenerator::GenerateBatch() {
  int total_instance = 0;
  platform::CUDADeviceGuard guard(gpuid_);
  int res = 0;
  if (!gpu_graph_training_) {
802
    // infer
L
lxsbupt 已提交
803 804 805 806
    if (!sage_mode_) {
      total_instance = (infer_node_start_ + batch_size_ <= infer_node_end_)
                           ? batch_size_
                           : infer_node_end_ - infer_node_start_;
D
danleifeng 已提交
807 808 809
      VLOG(1) << "in graph_data generator:batch_size = " << batch_size_
              << " instance = " << total_instance;
      total_instance *= 2;
L
lxsbupt 已提交
810 811
      if (total_instance == 0) {
        return 0;
D
danleifeng 已提交
812
      }
L
lxsbupt 已提交
813 814 815 816 817 818 819 820
      FillIdShowClkTensor(total_instance, gpu_graph_training_, cursor_);
    } else {
      if (sage_batch_count_ == sage_batch_num_) {
        return 0;
      }
      FillGraphIdShowClkTensor(uniq_instance_vec_[sage_batch_count_],
                               total_instance_vec_[sage_batch_count_],
                               sage_batch_count_);
D
danleifeng 已提交
821 822
    }
  } else {
823
    // train
L
lxsbupt 已提交
824 825 826 827 828 829 830 831
    if (!sage_mode_) {
      while (ins_buf_pair_len_ < batch_size_) {
        res = FillInsBuf(train_stream_);
        if (res == -1) {
          if (ins_buf_pair_len_ == 0) {
            return 0;
          } else {
            break;
D
danleifeng 已提交
832 833 834
          }
        }
      }
L
lxsbupt 已提交
835 836 837 838 839 840
      total_instance =
          ins_buf_pair_len_ < batch_size_ ? ins_buf_pair_len_ : batch_size_;
      total_instance *= 2;
      VLOG(2) << "total_instance: " << total_instance
              << ", ins_buf_pair_len = " << ins_buf_pair_len_;
      FillIdShowClkTensor(total_instance, gpu_graph_training_);
D
danleifeng 已提交
841
    } else {
L
lxsbupt 已提交
842 843
      if (sage_batch_count_ == sage_batch_num_) {
        return 0;
D
danleifeng 已提交
844
      }
L
lxsbupt 已提交
845 846 847
      FillGraphIdShowClkTensor(uniq_instance_vec_[sage_batch_count_],
                               total_instance_vec_[sage_batch_count_],
                               sage_batch_count_);
D
danleifeng 已提交
848 849 850
    }
  }

L
lxsbupt 已提交
851 852 853 854 855 856 857 858 859
  if (slot_num_ > 0) {
    if (!sage_mode_) {
      FillGraphSlotFeature(total_instance, gpu_graph_training_);
    } else {
      FillGraphSlotFeature(uniq_instance_vec_[sage_batch_count_],
                           gpu_graph_training_,
                           final_sage_nodes_vec_[sage_batch_count_]);
    }
  }
D
danleifeng 已提交
860 861
  offset_.clear();
  offset_.push_back(0);
L
lxsbupt 已提交
862 863 864 865 866 867
  if (!sage_mode_) {
    offset_.push_back(total_instance);
  } else {
    offset_.push_back(uniq_instance_vec_[sage_batch_count_]);
    sage_batch_count_ += 1;
  }
D
danleifeng 已提交
868 869 870 871 872 873 874 875
  LoD lod{offset_};
  feed_vec_[0]->set_lod(lod);
  if (slot_num_ > 0) {
    for (int i = 0; i < slot_num_; ++i) {
      feed_vec_[3 + 2 * i]->set_lod(lod);
    }
  }

L
lxsbupt 已提交
876
  cudaStreamSynchronize(train_stream_);
D
danleifeng 已提交
877
  if (!gpu_graph_training_) return 1;
L
lxsbupt 已提交
878 879
  if (!sage_mode_) {
    ins_buf_pair_len_ -= total_instance / 2;
D
danleifeng 已提交
880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902
  }
  return 1;
}

__global__ void GraphFillSampleKeysKernel(uint64_t *neighbors,
                                          uint64_t *sample_keys,
                                          int *prefix_sum,
                                          int *sampleidx2row,
                                          int *tmp_sampleidx2row,
                                          int *actual_sample_size,
                                          int cur_degree,
                                          int len) {
  CUDA_KERNEL_LOOP(idx, len) {
    for (int k = 0; k < actual_sample_size[idx]; k++) {
      size_t offset = prefix_sum[idx] + k;
      sample_keys[offset] = neighbors[idx * cur_degree + k];
      tmp_sampleidx2row[offset] = sampleidx2row[idx] + k;
    }
  }
}

__global__ void GraphDoWalkKernel(uint64_t *neighbors,
                                  uint64_t *walk,
903
                                  uint8_t *walk_ntype,
D
danleifeng 已提交
904 905 906 907 908 909 910
                                  int *d_prefix_sum,
                                  int *actual_sample_size,
                                  int cur_degree,
                                  int step,
                                  int len,
                                  int *id_cnt,
                                  int *sampleidx2row,
911 912
                                  int col_size,
                                  uint8_t edge_dst_id) {
D
danleifeng 已提交
913 914 915 916 917 918 919 920
  CUDA_KERNEL_LOOP(i, len) {
    for (int k = 0; k < actual_sample_size[i]; k++) {
      // int idx = sampleidx2row[i];
      size_t row = sampleidx2row[k + d_prefix_sum[i]];
      // size_t row = idx * cur_degree + k;
      size_t col = step;
      size_t offset = (row * col_size + col);
      walk[offset] = neighbors[i * cur_degree + k];
921 922 923
      if (walk_ntype != NULL) {
        walk_ntype[offset] = edge_dst_id;
      }
D
danleifeng 已提交
924 925 926 927 928 929 930 931
    }
  }
}

// Fill keys to the first column of walk
__global__ void GraphFillFirstStepKernel(int *prefix_sum,
                                         int *sampleidx2row,
                                         uint64_t *walk,
932
                                         uint8_t *walk_ntype,
D
danleifeng 已提交
933
                                         uint64_t *keys,
934 935
                                         uint8_t edge_src_id,
                                         uint8_t edge_dst_id,
D
danleifeng 已提交
936 937 938 939 940 941 942 943 944 945 946 947 948 949 950
                                         int len,
                                         int walk_degree,
                                         int col_size,
                                         int *actual_sample_size,
                                         uint64_t *neighbors,
                                         uint64_t *sample_keys) {
  CUDA_KERNEL_LOOP(idx, len) {
    for (int k = 0; k < actual_sample_size[idx]; k++) {
      size_t row = prefix_sum[idx] + k;
      sample_keys[row] = neighbors[idx * walk_degree + k];
      sampleidx2row[row] = row;

      size_t offset = col_size * row;
      walk[offset] = keys[idx];
      walk[offset + 1] = neighbors[idx * walk_degree + k];
951 952 953 954
      if (walk_ntype != NULL) {
        walk_ntype[offset] = edge_src_id;
        walk_ntype[offset + 1] = edge_dst_id;
      }
D
danleifeng 已提交
955 956 957 958
    }
  }
}

L
lxsbupt 已提交
959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075
__global__ void get_each_ins_info(uint8_t *slot_list,
                                  uint32_t *slot_size_list,
                                  uint32_t *slot_size_prefix,
                                  uint32_t *each_ins_slot_num,
                                  uint32_t *each_ins_slot_num_inner_prefix,
                                  size_t key_num,
                                  int slot_num) {
  const size_t i = blockIdx.x * blockDim.y + threadIdx.y;
  if (i < key_num) {
    uint32_t slot_index = slot_size_prefix[i];
    size_t each_ins_slot_index = i * slot_num;
    for (int j = 0; j < slot_size_list[i]; j++) {
      each_ins_slot_num[each_ins_slot_index + slot_list[slot_index + j]] += 1;
    }
    each_ins_slot_num_inner_prefix[each_ins_slot_index] = 0;
    for (int j = 1; j < slot_num; j++) {
      each_ins_slot_num_inner_prefix[each_ins_slot_index + j] =
          each_ins_slot_num[each_ins_slot_index + j - 1] +
          each_ins_slot_num_inner_prefix[each_ins_slot_index + j - 1];
    }
  }
}

__global__ void fill_slot_num(uint32_t *d_each_ins_slot_num_ptr,
                              uint64_t **d_ins_slot_num_vector_ptr,
                              size_t key_num,
                              int slot_num) {
  const size_t i = blockIdx.x * blockDim.y + threadIdx.y;
  if (i < key_num) {
    size_t d_each_index = i * slot_num;
    for (int j = 0; j < slot_num; j++) {
      d_ins_slot_num_vector_ptr[j][i] =
          d_each_ins_slot_num_ptr[d_each_index + j];
    }
  }
}

__global__ void fill_slot_tensor(uint64_t *feature_list,
                                 uint32_t *feature_size_prefixsum,
                                 uint32_t *each_ins_slot_num_inner_prefix,
                                 uint64_t *ins_slot_num,
                                 int64_t *slot_lod_tensor,
                                 int64_t *slot_tensor,
                                 int slot,
                                 int slot_num,
                                 size_t node_num) {
  const size_t i = blockIdx.x * blockDim.y + threadIdx.y;
  if (i < node_num) {
    size_t dst_index = slot_lod_tensor[i];
    size_t src_index = feature_size_prefixsum[i] +
                       each_ins_slot_num_inner_prefix[slot_num * i + slot];
    for (uint64_t j = 0; j < ins_slot_num[i]; j++) {
      slot_tensor[dst_index + j] = feature_list[src_index + j];
    }
  }
}

__global__ void GetUniqueFeaNum(uint64_t *d_in,
                                uint64_t *unique_num,
                                size_t len) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  __shared__ uint64_t local_num;
  if (threadIdx.x == 0) {
    local_num = 0;
  }
  __syncthreads();

  if (i < len - 1) {
    if (d_in[i] != d_in[i + 1]) {
      atomicAdd(&local_num, 1);
    }
  }
  if (i == len - 1) {
    atomicAdd(&local_num, 1);
  }

  __syncthreads();
  if (threadIdx.x == 0) {
    atomicAdd(unique_num, local_num);
  }
}

__global__ void UniqueFeature(uint64_t *d_in,
                              uint64_t *d_out,
                              uint64_t *unique_num,
                              size_t len) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  __shared__ uint64_t local_key[CUDA_NUM_THREADS];
  __shared__ uint64_t local_num;
  __shared__ uint64_t global_num;
  if (threadIdx.x == 0) {
    local_num = 0;
  }
  __syncthreads();

  if (i < len - 1) {
    if (d_in[i] != d_in[i + 1]) {
      size_t dst = atomicAdd(&local_num, 1);
      local_key[dst] = d_in[i];
    }
  }
  if (i == len - 1) {
    size_t dst = atomicAdd(&local_num, 1);
    local_key[dst] = d_in[i];
  }

  __syncthreads();

  if (threadIdx.x == 0) {
    global_num = atomicAdd(unique_num, local_num);
  }
  __syncthreads();

  if (threadIdx.x < local_num) {
    d_out[global_num + threadIdx.x] = local_key[threadIdx.x];
  }
}
D
danleifeng 已提交
1076 1077
// Fill sample_res to the stepth column of walk
void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids,
1078
                                     int etype_id,
D
danleifeng 已提交
1079
                                     uint64_t *walk,
1080
                                     uint8_t *walk_ntype,
D
danleifeng 已提交
1081 1082 1083 1084 1085
                                     int len,
                                     NeighborSampleResult &sample_res,
                                     int cur_degree,
                                     int step,
                                     int *len_per_row) {
1086 1087 1088 1089
  auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
  uint64_t node_id = gpu_graph_ptr->edge_to_node_map_[etype_id];
  uint8_t edge_src_id = node_id >> 32;
  uint8_t edge_dst_id = node_id;
D
danleifeng 已提交
1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104
  size_t temp_storage_bytes = 0;
  int *d_actual_sample_size = sample_res.actual_sample_size;
  uint64_t *d_neighbors = sample_res.val;
  int *d_prefix_sum = reinterpret_cast<int *>(d_prefix_sum_->ptr());
  uint64_t *d_sample_keys = reinterpret_cast<uint64_t *>(d_sample_keys_->ptr());
  int *d_sampleidx2row =
      reinterpret_cast<int *>(d_sampleidx2rows_[cur_sampleidx2row_]->ptr());
  int *d_tmp_sampleidx2row =
      reinterpret_cast<int *>(d_sampleidx2rows_[1 - cur_sampleidx2row_]->ptr());

  CUDA_CHECK(cub::DeviceScan::InclusiveSum(NULL,
                                           temp_storage_bytes,
                                           d_actual_sample_size,
                                           d_prefix_sum + 1,
                                           len,
L
lxsbupt 已提交
1105 1106 1107 1108 1109
                                           sample_stream_));
  auto d_temp_storage = memory::Alloc(
      place_,
      temp_storage_bytes,
      phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
D
danleifeng 已提交
1110 1111 1112 1113 1114 1115

  CUDA_CHECK(cub::DeviceScan::InclusiveSum(d_temp_storage->ptr(),
                                           temp_storage_bytes,
                                           d_actual_sample_size,
                                           d_prefix_sum + 1,
                                           len,
L
lxsbupt 已提交
1116
                                           sample_stream_));
D
danleifeng 已提交
1117

L
lxsbupt 已提交
1118
  cudaStreamSynchronize(sample_stream_);
D
danleifeng 已提交
1119 1120

  if (step == 1) {
L
lxsbupt 已提交
1121 1122 1123 1124 1125 1126
    GraphFillFirstStepKernel<<<GET_BLOCKS(len),
                               CUDA_NUM_THREADS,
                               0,
                               sample_stream_>>>(d_prefix_sum,
                                                 d_tmp_sampleidx2row,
                                                 walk,
1127
                                                 walk_ntype,
L
lxsbupt 已提交
1128
                                                 d_start_ids,
1129 1130
                                                 edge_src_id,
                                                 edge_dst_id,
L
lxsbupt 已提交
1131 1132 1133 1134 1135 1136
                                                 len,
                                                 walk_degree_,
                                                 walk_len_,
                                                 d_actual_sample_size,
                                                 d_neighbors,
                                                 d_sample_keys);
D
danleifeng 已提交
1137 1138 1139 1140 1141

  } else {
    GraphFillSampleKeysKernel<<<GET_BLOCKS(len),
                                CUDA_NUM_THREADS,
                                0,
L
lxsbupt 已提交
1142 1143 1144 1145 1146 1147 1148 1149 1150 1151
                                sample_stream_>>>(d_neighbors,
                                                  d_sample_keys,
                                                  d_prefix_sum,
                                                  d_sampleidx2row,
                                                  d_tmp_sampleidx2row,
                                                  d_actual_sample_size,
                                                  cur_degree,
                                                  len);

    GraphDoWalkKernel<<<GET_BLOCKS(len), CUDA_NUM_THREADS, 0, sample_stream_>>>(
D
danleifeng 已提交
1152 1153
        d_neighbors,
        walk,
1154
        walk_ntype,
D
danleifeng 已提交
1155 1156 1157 1158 1159 1160 1161
        d_prefix_sum,
        d_actual_sample_size,
        cur_degree,
        step,
        len,
        len_per_row,
        d_tmp_sampleidx2row,
1162 1163
        walk_len_,
        edge_dst_id);
D
danleifeng 已提交
1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188
  }
  if (debug_mode_) {
    size_t once_max_sample_keynum = walk_degree_ * once_sample_startid_len_;
    int *h_prefix_sum = new int[len + 1];
    int *h_actual_size = new int[len];
    int *h_offset2idx = new int[once_max_sample_keynum];
    cudaMemcpy(h_offset2idx,
               d_tmp_sampleidx2row,
               once_max_sample_keynum * sizeof(int),
               cudaMemcpyDeviceToHost);

    cudaMemcpy(h_prefix_sum,
               d_prefix_sum,
               (len + 1) * sizeof(int),
               cudaMemcpyDeviceToHost);
    for (int xx = 0; xx < once_max_sample_keynum; xx++) {
      VLOG(2) << "h_offset2idx[" << xx << "]: " << h_offset2idx[xx];
    }
    for (int xx = 0; xx < len + 1; xx++) {
      VLOG(2) << "h_prefix_sum[" << xx << "]: " << h_prefix_sum[xx];
    }
    delete[] h_prefix_sum;
    delete[] h_actual_size;
    delete[] h_offset2idx;
  }
L
lxsbupt 已提交
1189
  cudaStreamSynchronize(sample_stream_);
D
danleifeng 已提交
1190 1191 1192
  cur_sampleidx2row_ = 1 - cur_sampleidx2row_;
}

L
lxsbupt 已提交
1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453
int GraphDataGenerator::FillSlotFeature(uint64_t *d_walk, size_t key_num) {
  platform::CUDADeviceGuard guard(gpuid_);
  auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
  std::shared_ptr<phi::Allocation> d_feature_list;
  std::shared_ptr<phi::Allocation> d_slot_list;

  if (sage_mode_) {
    size_t temp_storage_bytes = (key_num + 1) * sizeof(uint32_t);
    if (d_feature_size_list_buf_ == NULL ||
        d_feature_size_list_buf_->size() < temp_storage_bytes) {
      d_feature_size_list_buf_ =
          memory::AllocShared(this->place_, temp_storage_bytes);
    }
    if (d_feature_size_prefixsum_buf_ == NULL ||
        d_feature_size_prefixsum_buf_->size() < temp_storage_bytes) {
      d_feature_size_prefixsum_buf_ =
          memory::AllocShared(this->place_, temp_storage_bytes);
    }
  }

  uint32_t *d_feature_size_list_ptr =
      reinterpret_cast<uint32_t *>(d_feature_size_list_buf_->ptr());
  uint32_t *d_feature_size_prefixsum_ptr =
      reinterpret_cast<uint32_t *>(d_feature_size_prefixsum_buf_->ptr());
  int fea_num =
      gpu_graph_ptr->get_feature_info_of_nodes(gpuid_,
                                               d_walk,
                                               key_num,
                                               d_feature_size_list_ptr,
                                               d_feature_size_prefixsum_ptr,
                                               d_feature_list,
                                               d_slot_list);
  int64_t *slot_tensor_ptr_[slot_num_];
  int64_t *slot_lod_tensor_ptr_[slot_num_];
  if (fea_num == 0) {
    int64_t default_lod = 1;
    for (int i = 0; i < slot_num_; ++i) {
      slot_lod_tensor_ptr_[i] = feed_vec_[3 + 2 * i + 1]->mutable_data<int64_t>(
          {(long)key_num + 1}, this->place_);  // NOLINT
      slot_tensor_ptr_[i] =
          feed_vec_[3 + 2 * i]->mutable_data<int64_t>({1, 1}, this->place_);
      CUDA_CHECK(cudaMemsetAsync(
          slot_tensor_ptr_[i], 0, sizeof(int64_t), train_stream_));
      CUDA_CHECK(cudaMemsetAsync(slot_lod_tensor_ptr_[i],
                                 0,
                                 sizeof(int64_t) * key_num,
                                 train_stream_));
      CUDA_CHECK(cudaMemcpyAsync(
          reinterpret_cast<char *>(slot_lod_tensor_ptr_[i] + key_num),
          &default_lod,
          sizeof(int64_t),
          cudaMemcpyHostToDevice,
          train_stream_));
    }
    CUDA_CHECK(cudaStreamSynchronize(train_stream_));
    return 0;
  }

  uint64_t *d_feature_list_ptr =
      reinterpret_cast<uint64_t *>(d_feature_list->ptr());
  uint8_t *d_slot_list_ptr = reinterpret_cast<uint8_t *>(d_slot_list->ptr());

  std::shared_ptr<phi::Allocation> d_each_ins_slot_num_inner_prefix =
      memory::AllocShared(place_, (slot_num_ * key_num) * sizeof(uint32_t));
  std::shared_ptr<phi::Allocation> d_each_ins_slot_num =
      memory::AllocShared(place_, (slot_num_ * key_num) * sizeof(uint32_t));
  uint32_t *d_each_ins_slot_num_ptr =
      reinterpret_cast<uint32_t *>(d_each_ins_slot_num->ptr());
  uint32_t *d_each_ins_slot_num_inner_prefix_ptr =
      reinterpret_cast<uint32_t *>(d_each_ins_slot_num_inner_prefix->ptr());
  CUDA_CHECK(cudaMemsetAsync(d_each_ins_slot_num_ptr,
                             0,
                             slot_num_ * key_num * sizeof(uint32_t),
                             train_stream_));

  dim3 grid((key_num - 1) / 256 + 1);
  dim3 block(1, 256);

  get_each_ins_info<<<grid, block, 0, train_stream_>>>(
      d_slot_list_ptr,
      d_feature_size_list_ptr,
      d_feature_size_prefixsum_ptr,
      d_each_ins_slot_num_ptr,
      d_each_ins_slot_num_inner_prefix_ptr,
      key_num,
      slot_num_);

  std::vector<std::shared_ptr<phi::Allocation>> ins_slot_num(slot_num_,
                                                             nullptr);
  std::vector<uint64_t *> ins_slot_num_vecotr(slot_num_, NULL);
  std::shared_ptr<phi::Allocation> d_ins_slot_num_vector =
      memory::AllocShared(place_, (slot_num_) * sizeof(uint64_t *));
  uint64_t **d_ins_slot_num_vector_ptr =
      reinterpret_cast<uint64_t **>(d_ins_slot_num_vector->ptr());
  for (int i = 0; i < slot_num_; i++) {
    ins_slot_num[i] = memory::AllocShared(place_, key_num * sizeof(uint64_t));
    ins_slot_num_vecotr[i] =
        reinterpret_cast<uint64_t *>(ins_slot_num[i]->ptr());
  }
  CUDA_CHECK(
      cudaMemcpyAsync(reinterpret_cast<char *>(d_ins_slot_num_vector_ptr),
                      ins_slot_num_vecotr.data(),
                      sizeof(uint64_t *) * slot_num_,
                      cudaMemcpyHostToDevice,
                      train_stream_));
  fill_slot_num<<<grid, block, 0, train_stream_>>>(
      d_each_ins_slot_num_ptr, d_ins_slot_num_vector_ptr, key_num, slot_num_);
  CUDA_CHECK(cudaStreamSynchronize(train_stream_));

  for (int i = 0; i < slot_num_; ++i) {
    slot_lod_tensor_ptr_[i] = feed_vec_[3 + 2 * i + 1]->mutable_data<int64_t>(
        {(long)key_num + 1}, this->place_);  // NOLINT
  }
  size_t temp_storage_bytes = 0;
  CUDA_CHECK(cub::DeviceScan::InclusiveSum(NULL,
                                           temp_storage_bytes,
                                           ins_slot_num_vecotr[0],
                                           slot_lod_tensor_ptr_[0] + 1,
                                           key_num,
                                           train_stream_));
  CUDA_CHECK(cudaStreamSynchronize(train_stream_));
  auto d_temp_storage = memory::Alloc(
      this->place_,
      temp_storage_bytes,
      phi::Stream(reinterpret_cast<phi::StreamId>(train_stream_)));
  std::vector<int64_t> each_slot_fea_num(slot_num_, 0);
  for (int i = 0; i < slot_num_; ++i) {
    CUDA_CHECK(cudaMemsetAsync(
        slot_lod_tensor_ptr_[i], 0, sizeof(uint64_t), train_stream_));
    CUDA_CHECK(cub::DeviceScan::InclusiveSum(d_temp_storage->ptr(),
                                             temp_storage_bytes,
                                             ins_slot_num_vecotr[i],
                                             slot_lod_tensor_ptr_[i] + 1,
                                             key_num,
                                             train_stream_));
    CUDA_CHECK(cudaMemcpyAsync(&each_slot_fea_num[i],
                               slot_lod_tensor_ptr_[i] + key_num,
                               sizeof(uint64_t),
                               cudaMemcpyDeviceToHost,
                               train_stream_));
  }
  CUDA_CHECK(cudaStreamSynchronize(train_stream_));
  for (int i = 0; i < slot_num_; ++i) {
    slot_tensor_ptr_[i] = feed_vec_[3 + 2 * i]->mutable_data<int64_t>(
        {each_slot_fea_num[i], 1}, this->place_);
  }
  int64_t default_lod = 1;
  for (int i = 0; i < slot_num_; ++i) {
    fill_slot_tensor<<<grid, block, 0, train_stream_>>>(
        d_feature_list_ptr,
        d_feature_size_prefixsum_ptr,
        d_each_ins_slot_num_inner_prefix_ptr,
        ins_slot_num_vecotr[i],
        slot_lod_tensor_ptr_[i],
        slot_tensor_ptr_[i],
        i,
        slot_num_,
        key_num);
    // trick for empty tensor
    if (each_slot_fea_num[i] == 0) {
      slot_tensor_ptr_[i] =
          feed_vec_[3 + 2 * i]->mutable_data<int64_t>({1, 1}, this->place_);
      CUDA_CHECK(cudaMemsetAsync(
          slot_tensor_ptr_[i], 0, sizeof(uint64_t), train_stream_));
      CUDA_CHECK(cudaMemcpyAsync(
          reinterpret_cast<char *>(slot_lod_tensor_ptr_[i] + key_num),
          &default_lod,
          sizeof(int64_t),
          cudaMemcpyHostToDevice,
          train_stream_));
    }
  }
  CUDA_CHECK(cudaStreamSynchronize(train_stream_));

  if (debug_mode_) {
    std::vector<uint32_t> h_feature_size_list(key_num, 0);
    std::vector<uint32_t> h_feature_size_list_prefixsum(key_num, 0);
    std::vector<uint64_t> node_list(key_num, 0);
    std::vector<uint64_t> h_feature_list(fea_num, 0);
    std::vector<uint8_t> h_slot_list(fea_num, 0);

    CUDA_CHECK(
        cudaMemcpyAsync(reinterpret_cast<char *>(h_feature_size_list.data()),
                        d_feature_size_list_ptr,
                        sizeof(uint32_t) * key_num,
                        cudaMemcpyDeviceToHost,
                        train_stream_));
    CUDA_CHECK(cudaMemcpyAsync(
        reinterpret_cast<char *>(h_feature_size_list_prefixsum.data()),
        d_feature_size_prefixsum_ptr,
        sizeof(uint32_t) * key_num,
        cudaMemcpyDeviceToHost,
        train_stream_));
    CUDA_CHECK(cudaMemcpyAsync(reinterpret_cast<char *>(node_list.data()),
                               d_walk,
                               sizeof(uint64_t) * key_num,
                               cudaMemcpyDeviceToHost,
                               train_stream_));

    CUDA_CHECK(cudaMemcpyAsync(reinterpret_cast<char *>(h_feature_list.data()),
                               d_feature_list_ptr,
                               sizeof(uint64_t) * fea_num,
                               cudaMemcpyDeviceToHost,
                               train_stream_));
    CUDA_CHECK(cudaMemcpyAsync(reinterpret_cast<char *>(h_slot_list.data()),
                               d_slot_list_ptr,
                               sizeof(uint8_t) * fea_num,
                               cudaMemcpyDeviceToHost,
                               train_stream_));

    CUDA_CHECK(cudaStreamSynchronize(train_stream_));
    for (size_t i = 0; i < key_num; i++) {
      std::stringstream ss;
      ss << "node_id: " << node_list[i]
         << " fea_num: " << h_feature_size_list[i] << " offset "
         << h_feature_size_list_prefixsum[i] << " slot: ";
      for (uint32_t j = 0; j < h_feature_size_list[i]; j++) {
        ss << int(h_slot_list[h_feature_size_list_prefixsum[i] + j]) << " : "
           << h_feature_list[h_feature_size_list_prefixsum[i] + j] << "  ";
      }
      VLOG(0) << ss.str();
    }
    VLOG(0) << "all fea_num is " << fea_num << " calc fea_num is "
            << h_feature_size_list[key_num - 1] +
                   h_feature_size_list_prefixsum[key_num - 1];
    for (int i = 0; i < slot_num_; ++i) {
      std::vector<int64_t> h_slot_lod_tensor(key_num + 1, 0);
      CUDA_CHECK(
          cudaMemcpyAsync(reinterpret_cast<char *>(h_slot_lod_tensor.data()),
                          slot_lod_tensor_ptr_[i],
                          sizeof(int64_t) * (key_num + 1),
                          cudaMemcpyDeviceToHost,
                          train_stream_));
      CUDA_CHECK(cudaStreamSynchronize(train_stream_));
      std::stringstream ss_lod;
      std::stringstream ss_tensor;
      ss_lod << " slot " << i << " lod is [";
      for (size_t j = 0; j < key_num + 1; j++) {
        ss_lod << h_slot_lod_tensor[j] << ",";
      }
      ss_lod << "]";
      std::vector<int64_t> h_slot_tensor(h_slot_lod_tensor[key_num], 0);
      CUDA_CHECK(cudaMemcpyAsync(reinterpret_cast<char *>(h_slot_tensor.data()),
                                 slot_tensor_ptr_[i],
                                 sizeof(int64_t) * h_slot_lod_tensor[key_num],
                                 cudaMemcpyDeviceToHost,
                                 train_stream_));
      CUDA_CHECK(cudaStreamSynchronize(train_stream_));

      ss_tensor << " tensor is [ ";
      for (size_t j = 0; j < h_slot_lod_tensor[key_num]; j++) {
        ss_tensor << h_slot_tensor[j] << ",";
      }
      ss_tensor << "]";
      VLOG(0) << ss_lod.str() << "  " << ss_tensor.str();
    }
  }

  return 0;
}

D
danleifeng 已提交
1454 1455 1456 1457 1458 1459 1460
int GraphDataGenerator::FillFeatureBuf(uint64_t *d_walk,
                                       uint64_t *d_feature,
                                       size_t key_num) {
  platform::CUDADeviceGuard guard(gpuid_);

  auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
  int ret = gpu_graph_ptr->get_feature_of_nodes(
L
lxsbupt 已提交
1461 1462 1463 1464 1465 1466 1467
      gpuid_,
      d_walk,
      d_feature,
      key_num,
      slot_num_,
      reinterpret_cast<int *>(d_slot_feature_num_map_->ptr()),
      fea_num_per_node_);
D
danleifeng 已提交
1468 1469 1470 1471 1472 1473 1474 1475 1476
  return ret;
}

int GraphDataGenerator::FillFeatureBuf(
    std::shared_ptr<phi::Allocation> d_walk,
    std::shared_ptr<phi::Allocation> d_feature) {
  platform::CUDADeviceGuard guard(gpuid_);

  auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
1477 1478
  int ret = gpu_graph_ptr->get_feature_of_nodes(
      gpuid_,
L
lxsbupt 已提交
1479 1480
      reinterpret_cast<uint64_t *>(d_walk->ptr()),
      reinterpret_cast<uint64_t *>(d_feature->ptr()),
1481
      buf_size_,
L
lxsbupt 已提交
1482 1483 1484
      slot_num_,
      reinterpret_cast<int *>(d_slot_feature_num_map_->ptr()),
      fea_num_per_node_);
D
danleifeng 已提交
1485 1486 1487
  return ret;
}

L
lxsbupt 已提交
1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861
// 对于deepwalk模式,尝试插入table,0表示插入成功,1表示插入失败;
// 对于sage模式,尝试插入table,table数量不够则清空table重新插入,返回值无影响。
int GraphDataGenerator::InsertTable(
    const uint64_t *d_keys,
    uint64_t len,
    std::shared_ptr<phi::Allocation> d_uniq_node_num) {
  // Used under NOT WHOLE_HBM.
  uint64_t h_uniq_node_num = 0;
  uint64_t *d_uniq_node_num_ptr =
      reinterpret_cast<uint64_t *>(d_uniq_node_num->ptr());
  cudaMemcpyAsync(&h_uniq_node_num,
                  d_uniq_node_num_ptr,
                  sizeof(uint64_t),
                  cudaMemcpyDeviceToHost,
                  sample_stream_);
  cudaStreamSynchronize(sample_stream_);

  if (gpu_graph_training_) {
    VLOG(2) << "table capacity: " << train_table_cap_ << ", " << h_uniq_node_num
            << " used";
    if (h_uniq_node_num + len >= train_table_cap_) {
      if (!sage_mode_) {
        return 1;
      } else {
        // Copy unique nodes first.
        uint64_t copy_len = CopyUniqueNodes();
        copy_unique_len_ += copy_len;
        table_->clear(sample_stream_);
        cudaMemsetAsync(
            d_uniq_node_num_ptr, 0, sizeof(uint64_t), sample_stream_);
      }
    }
  } else {
    // used only for sage_mode.
    if (h_uniq_node_num + len >= infer_table_cap_) {
      uint64_t copy_len = CopyUniqueNodes();
      copy_unique_len_ += copy_len;
      table_->clear(sample_stream_);
      cudaMemsetAsync(d_uniq_node_num_ptr, 0, sizeof(uint64_t), sample_stream_);
    }
  }

  table_->insert(d_keys, len, d_uniq_node_num_ptr, sample_stream_);
  CUDA_CHECK(cudaStreamSynchronize(sample_stream_));
  return 0;
}

std::vector<std::shared_ptr<phi::Allocation>>
GraphDataGenerator::SampleNeighbors(int64_t *uniq_nodes,
                                    int len,
                                    int sample_size,
                                    std::vector<int> &edges_split_num,
                                    int64_t *neighbor_len) {
  auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
  auto sample_res = gpu_graph_ptr->graph_neighbor_sample_all_edge_type(
      gpuid_,
      edge_to_id_len_,
      reinterpret_cast<uint64_t *>(uniq_nodes),
      sample_size,
      len,
      edge_type_graph_);

  int *all_sample_count_ptr =
      reinterpret_cast<int *>(sample_res.actual_sample_size_mem->ptr());

  auto cumsum_actual_sample_size = memory::Alloc(
      place_,
      (len * edge_to_id_len_ + 1) * sizeof(int),
      phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
  int *cumsum_actual_sample_size_ptr =
      reinterpret_cast<int *>(cumsum_actual_sample_size->ptr());
  cudaMemsetAsync(cumsum_actual_sample_size_ptr,
                  0,
                  (len * edge_to_id_len_ + 1) * sizeof(int),
                  sample_stream_);

  size_t temp_storage_bytes = 0;
  CUDA_CHECK(cub::DeviceScan::InclusiveSum(NULL,
                                           temp_storage_bytes,
                                           all_sample_count_ptr,
                                           cumsum_actual_sample_size_ptr + 1,
                                           len * edge_to_id_len_,
                                           sample_stream_));
  auto d_temp_storage = memory::Alloc(
      place_,
      temp_storage_bytes,
      phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
  CUDA_CHECK(cub::DeviceScan::InclusiveSum(d_temp_storage->ptr(),
                                           temp_storage_bytes,
                                           all_sample_count_ptr,
                                           cumsum_actual_sample_size_ptr + 1,
                                           len * edge_to_id_len_,
                                           sample_stream_));
  cudaStreamSynchronize(sample_stream_);

  edges_split_num.resize(edge_to_id_len_);
  for (int i = 0; i < edge_to_id_len_; i++) {
    cudaMemcpyAsync(edges_split_num.data() + i,
                    cumsum_actual_sample_size_ptr + (i + 1) * len,
                    sizeof(int),
                    cudaMemcpyDeviceToHost,
                    sample_stream_);
  }

  CUDA_CHECK(cudaStreamSynchronize(sample_stream_));

  int all_sample_size = edges_split_num[edge_to_id_len_ - 1];
  auto final_sample_val = memory::AllocShared(
      place_,
      all_sample_size * sizeof(int64_t),
      phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
  auto final_sample_val_dst = memory::AllocShared(
      place_,
      all_sample_size * sizeof(int64_t),
      phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
  int64_t *final_sample_val_ptr =
      reinterpret_cast<int64_t *>(final_sample_val->ptr());
  int64_t *final_sample_val_dst_ptr =
      reinterpret_cast<int64_t *>(final_sample_val_dst->ptr());
  int64_t *all_sample_val_ptr =
      reinterpret_cast<int64_t *>(sample_res.val_mem->ptr());
  FillActualNeighbors<<<GET_BLOCKS(len * edge_to_id_len_),
                        CUDA_NUM_THREADS,
                        0,
                        sample_stream_>>>(all_sample_val_ptr,
                                          final_sample_val_ptr,
                                          final_sample_val_dst_ptr,
                                          all_sample_count_ptr,
                                          cumsum_actual_sample_size_ptr,
                                          sample_size,
                                          len * edge_to_id_len_,
                                          len);
  *neighbor_len = all_sample_size;
  cudaStreamSynchronize(sample_stream_);

  std::vector<std::shared_ptr<phi::Allocation>> sample_results;
  sample_results.emplace_back(final_sample_val);
  sample_results.emplace_back(final_sample_val_dst);
  return sample_results;
}

std::shared_ptr<phi::Allocation> GraphDataGenerator::FillReindexHashTable(
    int64_t *input,
    int num_input,
    int64_t len_hashtable,
    int64_t *keys,
    int *values,
    int *key_index,
    int *final_nodes_len) {
  phi::BuildHashTable<int64_t>
      <<<GET_BLOCKS(num_input), CUDA_NUM_THREADS, 0, sample_stream_>>>(
          input, num_input, len_hashtable, keys, key_index);

  // Get item index count.
  auto item_count = memory::Alloc(
      place_,
      (num_input + 1) * sizeof(int),
      phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
  int *item_count_ptr = reinterpret_cast<int *>(item_count->ptr());
  cudaMemsetAsync(
      item_count_ptr, 0, sizeof(int) * (num_input + 1), sample_stream_);
  phi::GetItemIndexCount<int64_t>
      <<<GET_BLOCKS(num_input), CUDA_NUM_THREADS, 0, sample_stream_>>>(
          input, item_count_ptr, num_input, len_hashtable, keys, key_index);

  size_t temp_storage_bytes = 0;
  cub::DeviceScan::ExclusiveSum(NULL,
                                temp_storage_bytes,
                                item_count_ptr,
                                item_count_ptr,
                                num_input + 1,
                                sample_stream_);
  auto d_temp_storage = memory::Alloc(
      place_,
      temp_storage_bytes,
      phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
  cub::DeviceScan::ExclusiveSum(d_temp_storage->ptr(),
                                temp_storage_bytes,
                                item_count_ptr,
                                item_count_ptr,
                                num_input + 1,
                                sample_stream_);

  int total_unique_items = 0;
  cudaMemcpyAsync(&total_unique_items,
                  item_count_ptr + num_input,
                  sizeof(int),
                  cudaMemcpyDeviceToHost,
                  sample_stream_);
  cudaStreamSynchronize(sample_stream_);

  auto unique_items = memory::AllocShared(
      place_,
      total_unique_items * sizeof(int64_t),
      phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
  int64_t *unique_items_ptr = reinterpret_cast<int64_t *>(unique_items->ptr());
  *final_nodes_len = total_unique_items;

  // Get unique items
  phi::FillUniqueItems<int64_t>
      <<<GET_BLOCKS(num_input), CUDA_NUM_THREADS, 0, sample_stream_>>>(
          input,
          num_input,
          len_hashtable,
          unique_items_ptr,
          item_count_ptr,
          keys,
          values,
          key_index);
  cudaStreamSynchronize(sample_stream_);
  return unique_items;
}

std::shared_ptr<phi::Allocation> GraphDataGenerator::GetReindexResult(
    int64_t *reindex_src_data,
    int64_t *center_nodes,
    int *final_nodes_len,
    int node_len,
    int64_t neighbor_len) {
  // Reset reindex table
  int64_t *d_reindex_table_key_ptr =
      reinterpret_cast<int64_t *>(d_reindex_table_key_->ptr());
  int *d_reindex_table_value_ptr =
      reinterpret_cast<int *>(d_reindex_table_value_->ptr());
  int *d_reindex_table_index_ptr =
      reinterpret_cast<int *>(d_reindex_table_index_->ptr());

  // Fill table with -1.
  cudaMemsetAsync(d_reindex_table_key_ptr,
                  -1,
                  reindex_table_size_ * sizeof(int64_t),
                  sample_stream_);
  cudaMemsetAsync(d_reindex_table_value_ptr,
                  -1,
                  reindex_table_size_ * sizeof(int),
                  sample_stream_);
  cudaMemsetAsync(d_reindex_table_index_ptr,
                  -1,
                  reindex_table_size_ * sizeof(int),
                  sample_stream_);

  auto all_nodes = memory::AllocShared(
      place_,
      (node_len + neighbor_len) * sizeof(int64_t),
      phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
  int64_t *all_nodes_data = reinterpret_cast<int64_t *>(all_nodes->ptr());

  cudaMemcpyAsync(all_nodes_data,
                  center_nodes,
                  sizeof(int64_t) * node_len,
                  cudaMemcpyDeviceToDevice,
                  sample_stream_);
  cudaMemcpyAsync(all_nodes_data + node_len,
                  reindex_src_data,
                  sizeof(int64_t) * neighbor_len,
                  cudaMemcpyDeviceToDevice,
                  sample_stream_);

  cudaStreamSynchronize(sample_stream_);

  auto final_nodes = FillReindexHashTable(all_nodes_data,
                                          node_len + neighbor_len,
                                          reindex_table_size_,
                                          d_reindex_table_key_ptr,
                                          d_reindex_table_value_ptr,
                                          d_reindex_table_index_ptr,
                                          final_nodes_len);

  phi::ReindexSrcOutput<int64_t>
      <<<GET_BLOCKS(neighbor_len), CUDA_NUM_THREADS, 0, sample_stream_>>>(
          reindex_src_data,
          neighbor_len,
          reindex_table_size_,
          d_reindex_table_key_ptr,
          d_reindex_table_value_ptr);
  return final_nodes;
}

std::shared_ptr<phi::Allocation> GraphDataGenerator::GenerateSampleGraph(
    uint64_t *node_ids,
    int len,
    int *final_len,
    std::shared_ptr<phi::Allocation> &inverse) {
  VLOG(2) << "Get Unique Nodes";

  auto uniq_nodes = memory::Alloc(
      place_,
      len * sizeof(uint64_t),
      phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
  int *inverse_ptr = reinterpret_cast<int *>(inverse->ptr());
  int64_t *uniq_nodes_data = reinterpret_cast<int64_t *>(uniq_nodes->ptr());
  int uniq_len = dedup_keys_and_fillidx(
      len,
      node_ids,
      reinterpret_cast<uint64_t *>(uniq_nodes_data),
      reinterpret_cast<uint64_t *>(d_sorted_keys_->ptr()),
      reinterpret_cast<uint32_t *>(inverse_ptr),
      reinterpret_cast<uint32_t *>(d_sorted_idx_->ptr()),
      reinterpret_cast<uint32_t *>(d_offset_->ptr()),
      reinterpret_cast<uint32_t *>(d_merged_cnts_->ptr()),
      sample_stream_,
      d_buf_,
      place_);
  int len_samples = samples_.size();

  VLOG(2) << "Sample Neighbors and Reindex";
  std::vector<int> edges_split_num;
  std::vector<std::shared_ptr<phi::Allocation>> final_nodes_vec;
  std::vector<std::shared_ptr<phi::Allocation>> graph_edges;
  std::vector<std::vector<int>> edges_split_num_for_graph;
  std::vector<int> final_nodes_len_vec;

  for (int i = 0; i < len_samples; i++) {
    edges_split_num.clear();
    std::shared_ptr<phi::Allocation> neighbors, reindex_dst;
    int64_t neighbors_len = 0;
    if (i == 0) {
      auto sample_results = SampleNeighbors(uniq_nodes_data,
                                            uniq_len,
                                            samples_[i],
                                            edges_split_num,
                                            &neighbors_len);
      neighbors = sample_results[0];
      reindex_dst = sample_results[1];
      edges_split_num.push_back(uniq_len);
    } else {
      int64_t *final_nodes_data =
          reinterpret_cast<int64_t *>(final_nodes_vec[i - 1]->ptr());
      auto sample_results = SampleNeighbors(final_nodes_data,
                                            final_nodes_len_vec[i - 1],
                                            samples_[i],
                                            edges_split_num,
                                            &neighbors_len);
      neighbors = sample_results[0];
      reindex_dst = sample_results[1];
      edges_split_num.push_back(final_nodes_len_vec[i - 1]);
    }

    int64_t *reindex_src_data = reinterpret_cast<int64_t *>(neighbors->ptr());
    int final_nodes_len = 0;
    if (i == 0) {
      auto tmp_final_nodes = GetReindexResult(reindex_src_data,
                                              uniq_nodes_data,
                                              &final_nodes_len,
                                              uniq_len,
                                              neighbors_len);
      final_nodes_vec.emplace_back(tmp_final_nodes);
      final_nodes_len_vec.emplace_back(final_nodes_len);
    } else {
      int64_t *final_nodes_data =
          reinterpret_cast<int64_t *>(final_nodes_vec[i - 1]->ptr());
      auto tmp_final_nodes = GetReindexResult(reindex_src_data,
                                              final_nodes_data,
                                              &final_nodes_len,
                                              final_nodes_len_vec[i - 1],
                                              neighbors_len);
      final_nodes_vec.emplace_back(tmp_final_nodes);
      final_nodes_len_vec.emplace_back(final_nodes_len);
    }
    edges_split_num.emplace_back(
        final_nodes_len_vec[i]);  // [edges_split_num, next_num_nodes,
                                  // num_nodes]
    edges_split_num.emplace_back(neighbors_len);
    graph_edges.emplace_back(neighbors);
    graph_edges.emplace_back(reindex_dst);
    edges_split_num_for_graph.emplace_back(edges_split_num);
  }
  graph_edges_vec_.emplace_back(graph_edges);
  edges_split_num_vec_.emplace_back(edges_split_num_for_graph);

  *final_len = final_nodes_len_vec[len_samples - 1];
  return final_nodes_vec[len_samples - 1];
}

1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877
std::shared_ptr<phi::Allocation> GraphDataGenerator::GetNodeDegree(
    uint64_t *node_ids, int len) {
  auto node_degree = memory::AllocShared(
      place_,
      len * edge_to_id_len_ * sizeof(int),
      phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
  auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
  auto edge_to_id = gpu_graph_ptr->edge_to_id;
  for (auto &iter : edge_to_id) {
    int edge_idx = iter.second;
    gpu_graph_ptr->get_node_degree(
        gpuid_, edge_idx, node_ids, len, node_degree);
  }
  return node_degree;
}

L
lxsbupt 已提交
1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925
uint64_t GraphDataGenerator::CopyUniqueNodes() {
  if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) {
    uint64_t h_uniq_node_num = 0;
    uint64_t *d_uniq_node_num =
        reinterpret_cast<uint64_t *>(d_uniq_node_num_->ptr());
    cudaMemcpyAsync(&h_uniq_node_num,
                    d_uniq_node_num,
                    sizeof(uint64_t),
                    cudaMemcpyDeviceToHost,
                    sample_stream_);
    cudaStreamSynchronize(sample_stream_);
    auto d_uniq_node = memory::AllocShared(
        place_,
        h_uniq_node_num * sizeof(uint64_t),
        phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
    uint64_t *d_uniq_node_ptr =
        reinterpret_cast<uint64_t *>(d_uniq_node->ptr());

    auto d_node_cursor = memory::AllocShared(
        place_,
        sizeof(uint64_t),
        phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));

    uint64_t *d_node_cursor_ptr =
        reinterpret_cast<uint64_t *>(d_node_cursor->ptr());
    cudaMemsetAsync(d_node_cursor_ptr, 0, sizeof(uint64_t), sample_stream_);
    // uint64_t unused_key = std::numeric_limits<uint64_t>::max();
    table_->get_keys(d_uniq_node_ptr, d_node_cursor_ptr, sample_stream_);

    cudaStreamSynchronize(sample_stream_);

    host_vec_.resize(h_uniq_node_num + copy_unique_len_);
    cudaMemcpyAsync(host_vec_.data() + copy_unique_len_,
                    d_uniq_node_ptr,
                    sizeof(uint64_t) * h_uniq_node_num,
                    cudaMemcpyDeviceToHost,
                    sample_stream_);
    cudaStreamSynchronize(sample_stream_);
    return h_uniq_node_num;
  }
  return 0;
}

void GraphDataGenerator::DoWalkandSage() {
  int device_id = place_.GetDeviceId();
  debug_gpu_memory_info(device_id, "DoWalkandSage start");
  platform::CUDADeviceGuard guard(gpuid_);
  if (gpu_graph_training_) {
1926
    // train
L
lxsbupt 已提交
1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967
    bool train_flag;
    if (FLAGS_graph_metapath_split_opt) {
      train_flag = FillWalkBufMultiPath();
    } else {
      train_flag = FillWalkBuf();
    }

    if (sage_mode_) {
      sage_batch_num_ = 0;
      if (train_flag) {
        int total_instance = 0, uniq_instance = 0;
        bool ins_pair_flag = true;
        uint64_t *ins_buf, *ins_cursor;
        while (ins_pair_flag) {
          int res = 0;
          while (ins_buf_pair_len_ < batch_size_) {
            res = FillInsBuf(sample_stream_);
            if (res == -1) {
              if (ins_buf_pair_len_ == 0) {
                ins_pair_flag = false;
              }
              break;
            }
          }

          if (!ins_pair_flag) {
            break;
          }

          total_instance =
              ins_buf_pair_len_ < batch_size_ ? ins_buf_pair_len_ : batch_size_;
          total_instance *= 2;

          ins_buf = reinterpret_cast<uint64_t *>(d_ins_buf_->ptr());
          ins_cursor = ins_buf + ins_buf_pair_len_ * 2 - total_instance;
          auto inverse = memory::AllocShared(
              place_,
              total_instance * sizeof(int),
              phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
          auto final_sage_nodes = GenerateSampleGraph(
              ins_cursor, total_instance, &uniq_instance, inverse);
1968 1969 1970 1971 1972 1973 1974 1975
          uint64_t *final_sage_nodes_ptr =
              reinterpret_cast<uint64_t *>(final_sage_nodes->ptr());
          if (get_degree_) {
            auto node_degrees =
                GetNodeDegree(final_sage_nodes_ptr, uniq_instance);
            node_degree_vec_.emplace_back(node_degrees);
          }
          cudaStreamSynchronize(sample_stream_);
L
lxsbupt 已提交
1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988
          if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) {
            uint64_t *final_sage_nodes_ptr =
                reinterpret_cast<uint64_t *>(final_sage_nodes->ptr());
            InsertTable(final_sage_nodes_ptr, uniq_instance, d_uniq_node_num_);
          }
          final_sage_nodes_vec_.emplace_back(final_sage_nodes);
          inverse_vec_.emplace_back(inverse);
          uniq_instance_vec_.emplace_back(uniq_instance);
          total_instance_vec_.emplace_back(total_instance);
          ins_buf_pair_len_ -= total_instance / 2;
          sage_batch_num_ += 1;
        }
        uint64_t h_uniq_node_num = CopyUniqueNodes();
1989
        VLOG(1) << "train sage_batch_num: " << sage_batch_num_;
L
lxsbupt 已提交
1990 1991 1992
      }
    }
  } else {
1993
    // infer
L
lxsbupt 已提交
1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025
    bool infer_flag = FillInferBuf();
    if (sage_mode_) {
      sage_batch_num_ = 0;
      if (infer_flag) {
        int total_instance = 0, uniq_instance = 0;
        total_instance = (infer_node_start_ + batch_size_ <= infer_node_end_)
                             ? batch_size_
                             : infer_node_end_ - infer_node_start_;
        total_instance *= 2;
        while (total_instance != 0) {
          uint64_t *d_type_keys =
              reinterpret_cast<uint64_t *>(d_device_keys_[cursor_]->ptr());
          d_type_keys += infer_node_start_;
          infer_node_start_ += total_instance / 2;
          auto node_buf = memory::AllocShared(
              place_,
              total_instance * sizeof(uint64_t),
              phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
          int64_t *node_buf_ptr = reinterpret_cast<int64_t *>(node_buf->ptr());
          CopyDuplicateKeys<<<GET_BLOCKS(total_instance / 2),
                              CUDA_NUM_THREADS,
                              0,
                              sample_stream_>>>(
              node_buf_ptr, d_type_keys, total_instance / 2);
          uint64_t *node_buf_ptr_ =
              reinterpret_cast<uint64_t *>(node_buf->ptr());
          auto inverse = memory::AllocShared(
              place_,
              total_instance * sizeof(int),
              phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
          auto final_sage_nodes = GenerateSampleGraph(
              node_buf_ptr_, total_instance, &uniq_instance, inverse);
2026 2027 2028 2029 2030 2031 2032
          uint64_t *final_sage_nodes_ptr =
              reinterpret_cast<uint64_t *>(final_sage_nodes->ptr());
          if (get_degree_) {
            auto node_degrees =
                GetNodeDegree(final_sage_nodes_ptr, uniq_instance);
            node_degree_vec_.emplace_back(node_degrees);
          }
L
lxsbupt 已提交
2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051
          cudaStreamSynchronize(sample_stream_);
          if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) {
            uint64_t *final_sage_nodes_ptr =
                reinterpret_cast<uint64_t *>(final_sage_nodes->ptr());
            InsertTable(final_sage_nodes_ptr, uniq_instance, d_uniq_node_num_);
          }
          final_sage_nodes_vec_.emplace_back(final_sage_nodes);
          inverse_vec_.emplace_back(inverse);
          uniq_instance_vec_.emplace_back(uniq_instance);
          total_instance_vec_.emplace_back(total_instance);
          sage_batch_num_ += 1;

          total_instance = (infer_node_start_ + batch_size_ <= infer_node_end_)
                               ? batch_size_
                               : infer_node_end_ - infer_node_start_;
          total_instance *= 2;
        }

        uint64_t h_uniq_node_num = CopyUniqueNodes();
2052
        VLOG(1) << "infer sage_batch_num: " << sage_batch_num_;
L
lxsbupt 已提交
2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092
      }
    }
  }
  debug_gpu_memory_info(device_id, "DoWalkandSage end");
}

void GraphDataGenerator::clear_gpu_mem() {
  d_len_per_row_.reset();
  d_sample_keys_.reset();
  d_prefix_sum_.reset();
  for (size_t i = 0; i < d_sampleidx2rows_.size(); i++) {
    d_sampleidx2rows_[i].reset();
  }
  delete table_;
  if (sage_mode_) {
    d_reindex_table_key_.reset();
    d_reindex_table_value_.reset();
    d_reindex_table_index_.reset();
    d_sorted_keys_.reset();
    d_sorted_idx_.reset();
    d_offset_.reset();
    d_merged_cnts_.reset();
  }
}

int GraphDataGenerator::FillInferBuf() {
  platform::CUDADeviceGuard guard(gpuid_);
  auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
  auto &global_infer_node_type_start =
      gpu_graph_ptr->global_infer_node_type_start_[gpuid_];
  auto &infer_cursor = gpu_graph_ptr->infer_cursor_[thread_id_];
  total_row_ = 0;
  if (infer_cursor < h_device_keys_len_.size()) {
    if (global_infer_node_type_start[infer_cursor] >=
        h_device_keys_len_[infer_cursor]) {
      infer_cursor++;
      if (infer_cursor >= h_device_keys_len_.size()) {
        return 0;
      }
    }
2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109
    if (!infer_node_type_index_set_.empty()) {
      while (infer_cursor < h_device_keys_len_.size()) {
        if (infer_node_type_index_set_.find(infer_cursor) ==
            infer_node_type_index_set_.end()) {
          VLOG(2) << "Skip cursor[" << infer_cursor << "]";
          infer_cursor++;
          continue;
        } else {
          VLOG(2) << "Not skip cursor[" << infer_cursor << "]";
          break;
        }
      }
      if (infer_cursor >= h_device_keys_len_.size()) {
        return 0;
      }
    }

L
lxsbupt 已提交
2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150
    size_t device_key_size = h_device_keys_len_[infer_cursor];
    total_row_ =
        (global_infer_node_type_start[infer_cursor] + infer_table_cap_ <=
         device_key_size)
            ? infer_table_cap_
            : device_key_size - global_infer_node_type_start[infer_cursor];

    uint64_t *d_type_keys =
        reinterpret_cast<uint64_t *>(d_device_keys_[infer_cursor]->ptr());
    if (!sage_mode_) {
      host_vec_.resize(total_row_);
      cudaMemcpyAsync(host_vec_.data(),
                      d_type_keys + global_infer_node_type_start[infer_cursor],
                      sizeof(uint64_t) * total_row_,
                      cudaMemcpyDeviceToHost,
                      sample_stream_);
      cudaStreamSynchronize(sample_stream_);
    }
    VLOG(1) << "cursor: " << infer_cursor
            << " start: " << global_infer_node_type_start[infer_cursor]
            << " num: " << total_row_;
    infer_node_start_ = global_infer_node_type_start[infer_cursor];
    global_infer_node_type_start[infer_cursor] += total_row_;
    infer_node_end_ = global_infer_node_type_start[infer_cursor];
    cursor_ = infer_cursor;
  }
  return 1;
}

void GraphDataGenerator::ClearSampleState() {
  auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
  auto &finish_node_type = gpu_graph_ptr->finish_node_type_[gpuid_];
  auto &node_type_start = gpu_graph_ptr->node_type_start_[gpuid_];
  finish_node_type.clear();
  for (auto iter = node_type_start.begin(); iter != node_type_start.end();
       iter++) {
    iter->second = 0;
  }
}

int GraphDataGenerator::FillWalkBuf() {
D
danleifeng 已提交
2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167
  platform::CUDADeviceGuard guard(gpuid_);
  size_t once_max_sample_keynum = walk_degree_ * once_sample_startid_len_;
  ////////
  uint64_t *h_walk;
  uint64_t *h_sample_keys;
  int *h_offset2idx;
  int *h_len_per_row;
  uint64_t *h_prefix_sum;
  if (debug_mode_) {
    h_walk = new uint64_t[buf_size_];
    h_sample_keys = new uint64_t[once_max_sample_keynum];
    h_offset2idx = new int[once_max_sample_keynum];
    h_len_per_row = new int[once_max_sample_keynum];
    h_prefix_sum = new uint64_t[once_max_sample_keynum + 1];
  }
  ///////
  auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
L
lxsbupt 已提交
2168
  uint64_t *walk = reinterpret_cast<uint64_t *>(d_walk_->ptr());
D
danleifeng 已提交
2169 2170
  int *len_per_row = reinterpret_cast<int *>(d_len_per_row_->ptr());
  uint64_t *d_sample_keys = reinterpret_cast<uint64_t *>(d_sample_keys_->ptr());
L
lxsbupt 已提交
2171
  cudaMemsetAsync(walk, 0, buf_size_ * sizeof(uint64_t), sample_stream_);
2172 2173 2174 2175 2176
  uint8_t *walk_ntype = NULL;
  if (excluded_train_pair_len_ > 0) {
    walk_ntype = reinterpret_cast<uint8_t *>(d_walk_ntype_->ptr());
    cudaMemsetAsync(walk_ntype, 0, buf_size_ * sizeof(uint8_t), sample_stream_);
  }
L
lxsbupt 已提交
2177 2178 2179
  // cudaMemsetAsync(
  //     len_per_row, 0, once_max_sample_keynum * sizeof(int), sample_stream_);
  int sample_times = 0;
D
danleifeng 已提交
2180
  int i = 0;
L
lxsbupt 已提交
2181 2182 2183 2184 2185 2186 2187 2188 2189 2190
  total_row_ = 0;

  // 获取全局采样状态
  auto &first_node_type = gpu_graph_ptr->first_node_type_;
  auto &meta_path = gpu_graph_ptr->meta_path_;
  auto &node_type_start = gpu_graph_ptr->node_type_start_[gpuid_];
  auto &finish_node_type = gpu_graph_ptr->finish_node_type_[gpuid_];
  auto &type_to_index = gpu_graph_ptr->get_graph_type_to_index();
  auto &cursor = gpu_graph_ptr->cursor_[thread_id_];
  size_t node_type_len = first_node_type.size();
D
danleifeng 已提交
2191 2192
  int remain_size =
      buf_size_ - walk_degree_ * once_sample_startid_len_ * walk_len_;
L
lxsbupt 已提交
2193
  int total_samples = 0;
D
danleifeng 已提交
2194 2195

  while (i <= remain_size) {
L
lxsbupt 已提交
2196 2197 2198 2199 2200 2201
    int cur_node_idx = cursor % node_type_len;
    int node_type = first_node_type[cur_node_idx];
    auto &path = meta_path[cur_node_idx];
    size_t start = node_type_start[node_type];
    VLOG(2) << "cur_node_idx = " << cur_node_idx
            << " meta_path.size = " << meta_path.size();
D
danleifeng 已提交
2202
    // auto node_query_result = gpu_graph_ptr->query_node_list(
L
lxsbupt 已提交
2203
    //     gpuid_, node_type, start, once_sample_startid_len_);
D
danleifeng 已提交
2204 2205 2206

    // int tmp_len = node_query_result.actual_sample_size;
    VLOG(2) << "choose start type: " << node_type;
L
lxsbupt 已提交
2207 2208
    int type_index = type_to_index[node_type];
    size_t device_key_size = h_device_keys_len_[type_index];
D
danleifeng 已提交
2209 2210 2211 2212 2213 2214 2215
    VLOG(2) << "type: " << node_type << " size: " << device_key_size
            << " start: " << start;
    uint64_t *d_type_keys =
        reinterpret_cast<uint64_t *>(d_device_keys_[type_index]->ptr());
    int tmp_len = start + once_sample_startid_len_ > device_key_size
                      ? device_key_size - start
                      : once_sample_startid_len_;
L
lxsbupt 已提交
2216
    bool update = true;
D
danleifeng 已提交
2217
    if (tmp_len == 0) {
L
lxsbupt 已提交
2218 2219 2220 2221
      finish_node_type.insert(node_type);
      if (finish_node_type.size() == node_type_start.size()) {
        cursor = 0;
        epoch_finish_ = true;
D
danleifeng 已提交
2222 2223
        break;
      }
L
lxsbupt 已提交
2224
      cursor += 1;
D
danleifeng 已提交
2225 2226
      continue;
    }
L
lxsbupt 已提交
2227 2228

    VLOG(2) << "gpuid = " << gpuid_ << " path[0] = " << path[0];
D
danleifeng 已提交
2229
    uint64_t *cur_walk = walk + i;
2230 2231 2232 2233
    uint8_t *cur_walk_ntype = NULL;
    if (excluded_train_pair_len_ > 0) {
      cur_walk_ntype = walk_ntype + i;
    }
D
danleifeng 已提交
2234 2235 2236 2237 2238 2239 2240

    NeighborSampleQuery q;
    q.initialize(gpuid_,
                 path[0],
                 (uint64_t)(d_type_keys + start),
                 walk_degree_,
                 tmp_len);
L
lxsbupt 已提交
2241
    auto sample_res = gpu_graph_ptr->graph_neighbor_sample_v3(q, false, true);
D
danleifeng 已提交
2242 2243 2244 2245

    int step = 1;
    VLOG(2) << "sample edge type: " << path[0] << " step: " << 1;
    jump_rows_ = sample_res.total_sample_size;
L
lxsbupt 已提交
2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272
    total_samples += sample_res.total_sample_size;
    VLOG(2) << "i = " << i << " start = " << start << " tmp_len = " << tmp_len
            << " cursor = " << node_type << " cur_node_idx = " << cur_node_idx
            << " jump row: " << jump_rows_;
    VLOG(2) << "jump_row: " << jump_rows_;
    if (jump_rows_ == 0) {
      node_type_start[node_type] = tmp_len + start;
      cursor += 1;
      continue;
    }

    if (!sage_mode_) {
      if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) {
        if (InsertTable(d_type_keys + start, tmp_len, d_uniq_node_num_) != 0) {
          VLOG(2) << "in step 0, insert key stage, table is full";
          update = false;
          break;
        }
        if (InsertTable(sample_res.actual_val,
                        sample_res.total_sample_size,
                        d_uniq_node_num_) != 0) {
          VLOG(2) << "in step 0, insert sample res stage, table is full";
          update = false;
          break;
        }
      }
    }
D
danleifeng 已提交
2273
    FillOneStep(d_type_keys + start,
2274
                path[0],
D
danleifeng 已提交
2275
                cur_walk,
2276
                cur_walk_ntype,
D
danleifeng 已提交
2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289
                tmp_len,
                sample_res,
                walk_degree_,
                step,
                len_per_row);
    /////////
    if (debug_mode_) {
      cudaMemcpy(
          h_walk, walk, buf_size_ * sizeof(uint64_t), cudaMemcpyDeviceToHost);
      for (int xx = 0; xx < buf_size_; xx++) {
        VLOG(2) << "h_walk[" << xx << "]: " << h_walk[xx];
      }
    }
L
lxsbupt 已提交
2290 2291 2292 2293

    VLOG(2) << "sample, step=" << step << " sample_keys=" << tmp_len
            << " sample_res_len=" << sample_res.total_sample_size;

D
danleifeng 已提交
2294 2295 2296 2297 2298
    /////////
    step++;
    size_t path_len = path.size();
    for (; step < walk_len_; step++) {
      if (sample_res.total_sample_size == 0) {
L
lxsbupt 已提交
2299
        VLOG(2) << "sample finish, step=" << step;
D
danleifeng 已提交
2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311
        break;
      }
      auto sample_key_mem = sample_res.actual_val_mem;
      uint64_t *sample_keys_ptr =
          reinterpret_cast<uint64_t *>(sample_key_mem->ptr());
      int edge_type_id = path[(step - 1) % path_len];
      VLOG(2) << "sample edge type: " << edge_type_id << " step: " << step;
      q.initialize(gpuid_,
                   edge_type_id,
                   (uint64_t)sample_keys_ptr,
                   1,
                   sample_res.total_sample_size);
L
lxsbupt 已提交
2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326
      int sample_key_len = sample_res.total_sample_size;
      sample_res = gpu_graph_ptr->graph_neighbor_sample_v3(q, false, true);
      total_samples += sample_res.total_sample_size;
      if (!sage_mode_) {
        if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) {
          if (InsertTable(sample_res.actual_val,
                          sample_res.total_sample_size,
                          d_uniq_node_num_) != 0) {
            VLOG(2) << "in step: " << step << ", table is full";
            update = false;
            break;
          }
        }
      }
      FillOneStep(d_type_keys + start,
2327
                  edge_type_id,
L
lxsbupt 已提交
2328
                  cur_walk,
2329
                  cur_walk_ntype,
L
lxsbupt 已提交
2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341
                  sample_key_len,
                  sample_res,
                  1,
                  step,
                  len_per_row);
      if (debug_mode_) {
        cudaMemcpy(
            h_walk, walk, buf_size_ * sizeof(uint64_t), cudaMemcpyDeviceToHost);
        for (int xx = 0; xx < buf_size_; xx++) {
          VLOG(2) << "h_walk[" << xx << "]: " << h_walk[xx];
        }
      }
D
danleifeng 已提交
2342

L
lxsbupt 已提交
2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391
      VLOG(2) << "sample, step=" << step << " sample_keys=" << sample_key_len
              << " sample_res_len=" << sample_res.total_sample_size;
    }
    // 此时更新全局采样状态
    if (update == true) {
      node_type_start[node_type] = tmp_len + start;
      i += jump_rows_ * walk_len_;
      total_row_ += jump_rows_;
      cursor += 1;
      sample_times++;
    } else {
      VLOG(2) << "table is full, not update stat!";
      break;
    }
  }
  buf_state_.Reset(total_row_);
  int *d_random_row = reinterpret_cast<int *>(d_random_row_->ptr());

  thrust::random::default_random_engine engine(shuffle_seed_);
  const auto &exec_policy = thrust::cuda::par.on(sample_stream_);
  thrust::counting_iterator<int> cnt_iter(0);
  thrust::shuffle_copy(exec_policy,
                       cnt_iter,
                       cnt_iter + total_row_,
                       thrust::device_pointer_cast(d_random_row),
                       engine);

  cudaStreamSynchronize(sample_stream_);
  shuffle_seed_ = engine();

  if (debug_mode_) {
    int *h_random_row = new int[total_row_ + 10];
    cudaMemcpy(h_random_row,
               d_random_row,
               total_row_ * sizeof(int),
               cudaMemcpyDeviceToHost);
    for (int xx = 0; xx < total_row_; xx++) {
      VLOG(2) << "h_random_row[" << xx << "]: " << h_random_row[xx];
    }
    delete[] h_random_row;
    delete[] h_walk;
    delete[] h_sample_keys;
    delete[] h_offset2idx;
    delete[] h_len_per_row;
    delete[] h_prefix_sum;
  }

  if (!sage_mode_) {
    uint64_t h_uniq_node_num = CopyUniqueNodes();
2392
    VLOG(1) << "sample_times:" << sample_times << ", d_walk_size:" << buf_size_
L
lxsbupt 已提交
2393 2394 2395
            << ", d_walk_offset:" << i << ", total_rows:" << total_row_
            << ", total_samples:" << total_samples;
  } else {
2396
    VLOG(1) << "sample_times:" << sample_times << ", d_walk_size:" << buf_size_
L
lxsbupt 已提交
2397 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421
            << ", d_walk_offset:" << i << ", total_rows:" << total_row_
            << ", total_samples:" << total_samples;
  }
  return total_row_ != 0;
}

int GraphDataGenerator::FillWalkBufMultiPath() {
  platform::CUDADeviceGuard guard(gpuid_);
  size_t once_max_sample_keynum = walk_degree_ * once_sample_startid_len_;
  ////////
  uint64_t *h_walk;
  uint64_t *h_sample_keys;
  int *h_offset2idx;
  int *h_len_per_row;
  uint64_t *h_prefix_sum;
  if (debug_mode_) {
    h_walk = new uint64_t[buf_size_];
    h_sample_keys = new uint64_t[once_max_sample_keynum];
    h_offset2idx = new int[once_max_sample_keynum];
    h_len_per_row = new int[once_max_sample_keynum];
    h_prefix_sum = new uint64_t[once_max_sample_keynum + 1];
  }
  ///////
  auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
  uint64_t *walk = reinterpret_cast<uint64_t *>(d_walk_->ptr());
2422 2423 2424 2425
  uint8_t *walk_ntype = NULL;
  if (excluded_train_pair_len_ > 0) {
    walk_ntype = reinterpret_cast<uint8_t *>(d_walk_ntype_->ptr());
  }
L
lxsbupt 已提交
2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440 2441 2442 2443
  int *len_per_row = reinterpret_cast<int *>(d_len_per_row_->ptr());
  uint64_t *d_sample_keys = reinterpret_cast<uint64_t *>(d_sample_keys_->ptr());
  cudaMemsetAsync(walk, 0, buf_size_ * sizeof(uint64_t), sample_stream_);
  int sample_times = 0;
  int i = 0;
  total_row_ = 0;

  // 获取全局采样状态
  auto &first_node_type = gpu_graph_ptr->first_node_type_;
  auto &cur_metapath = gpu_graph_ptr->cur_metapath_;
  auto &meta_path = gpu_graph_ptr->meta_path_;
  auto &path = gpu_graph_ptr->cur_parse_metapath_;
  auto &cur_metapath_start = gpu_graph_ptr->cur_metapath_start_[gpuid_];
  auto &finish_node_type = gpu_graph_ptr->finish_node_type_[gpuid_];
  auto &type_to_index = gpu_graph_ptr->get_graph_type_to_index();
  size_t node_type_len = first_node_type.size();
  std::string first_node =
      paddle::string::split_string<std::string>(cur_metapath, "2")[0];
2444
  auto it = gpu_graph_ptr->node_to_id.find(first_node);
L
lxsbupt 已提交
2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462 2463 2464 2465 2466 2467
  auto node_type = it->second;

  int remain_size =
      buf_size_ - walk_degree_ * once_sample_startid_len_ * walk_len_;
  int total_samples = 0;

  while (i <= remain_size) {
    size_t start = cur_metapath_start;
    size_t device_key_size = h_train_metapath_keys_len_;
    VLOG(2) << "type: " << node_type << " size: " << device_key_size
            << " start: " << start;
    uint64_t *d_type_keys =
        reinterpret_cast<uint64_t *>(d_train_metapath_keys_->ptr());
    int tmp_len = start + once_sample_startid_len_ > device_key_size
                      ? device_key_size - start
                      : once_sample_startid_len_;
    bool update = true;
    if (tmp_len == 0) {
      break;
    }

    VLOG(2) << "gpuid = " << gpuid_ << " path[0] = " << path[0];
    uint64_t *cur_walk = walk + i;
2468 2469 2470 2471
    uint8_t *cur_walk_ntype = NULL;
    if (excluded_train_pair_len_ > 0) {
      cur_walk_ntype = walk_ntype + i;
    }
L
lxsbupt 已提交
2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509

    NeighborSampleQuery q;
    q.initialize(gpuid_,
                 path[0],
                 (uint64_t)(d_type_keys + start),
                 walk_degree_,
                 tmp_len);
    auto sample_res = gpu_graph_ptr->graph_neighbor_sample_v3(q, false, true);

    int step = 1;
    VLOG(2) << "sample edge type: " << path[0] << " step: " << 1;
    jump_rows_ = sample_res.total_sample_size;
    total_samples += sample_res.total_sample_size;
    VLOG(2) << "i = " << i << " start = " << start << " tmp_len = " << tmp_len
            << "jump row: " << jump_rows_;
    if (jump_rows_ == 0) {
      cur_metapath_start = tmp_len + start;
      continue;
    }

    if (!sage_mode_) {
      if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) {
        if (InsertTable(d_type_keys + start, tmp_len, d_uniq_node_num_) != 0) {
          VLOG(2) << "in step 0, insert key stage, table is full";
          update = false;
          break;
        }
        if (InsertTable(sample_res.actual_val,
                        sample_res.total_sample_size,
                        d_uniq_node_num_) != 0) {
          VLOG(2) << "in step 0, insert sample res stage, table is full";
          update = false;
          break;
        }
      }
    }

    FillOneStep(d_type_keys + start,
2510
                path[0],
L
lxsbupt 已提交
2511
                cur_walk,
2512
                cur_walk_ntype,
L
lxsbupt 已提交
2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 2548 2549 2550 2551 2552 2553 2554 2555 2556 2557 2558 2559 2560 2561
                tmp_len,
                sample_res,
                walk_degree_,
                step,
                len_per_row);
    /////////
    if (debug_mode_) {
      cudaMemcpy(
          h_walk, walk, buf_size_ * sizeof(uint64_t), cudaMemcpyDeviceToHost);
      for (int xx = 0; xx < buf_size_; xx++) {
        VLOG(2) << "h_walk[" << xx << "]: " << h_walk[xx];
      }
    }

    VLOG(2) << "sample, step=" << step << " sample_keys=" << tmp_len
            << " sample_res_len=" << sample_res.total_sample_size;

    /////////
    step++;
    size_t path_len = path.size();
    for (; step < walk_len_; step++) {
      if (sample_res.total_sample_size == 0) {
        VLOG(2) << "sample finish, step=" << step;
        break;
      }
      auto sample_key_mem = sample_res.actual_val_mem;
      uint64_t *sample_keys_ptr =
          reinterpret_cast<uint64_t *>(sample_key_mem->ptr());
      int edge_type_id = path[(step - 1) % path_len];
      VLOG(2) << "sample edge type: " << edge_type_id << " step: " << step;
      q.initialize(gpuid_,
                   edge_type_id,
                   (uint64_t)sample_keys_ptr,
                   1,
                   sample_res.total_sample_size);
      int sample_key_len = sample_res.total_sample_size;
      sample_res = gpu_graph_ptr->graph_neighbor_sample_v3(q, false, true);
      total_samples += sample_res.total_sample_size;
      if (!sage_mode_) {
        if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) {
          if (InsertTable(sample_res.actual_val,
                          sample_res.total_sample_size,
                          d_uniq_node_num_) != 0) {
            VLOG(2) << "in step: " << step << ", table is full";
            update = false;
            break;
          }
        }
      }
D
danleifeng 已提交
2562
      FillOneStep(d_type_keys + start,
2563
                  edge_type_id,
D
danleifeng 已提交
2564
                  cur_walk,
2565
                  cur_walk_ntype,
L
lxsbupt 已提交
2566
                  sample_key_len,
D
danleifeng 已提交
2567 2568 2569 2570 2571 2572 2573 2574 2575 2576 2577
                  sample_res,
                  1,
                  step,
                  len_per_row);
      if (debug_mode_) {
        cudaMemcpy(
            h_walk, walk, buf_size_ * sizeof(uint64_t), cudaMemcpyDeviceToHost);
        for (int xx = 0; xx < buf_size_; xx++) {
          VLOG(2) << "h_walk[" << xx << "]: " << h_walk[xx];
        }
      }
L
lxsbupt 已提交
2578 2579 2580 2581 2582 2583 2584 2585 2586 2587 2588 2589 2590

      VLOG(2) << "sample, step=" << step << " sample_keys=" << sample_key_len
              << " sample_res_len=" << sample_res.total_sample_size;
    }
    // 此时更新全局采样状态
    if (update == true) {
      cur_metapath_start = tmp_len + start;
      i += jump_rows_ * walk_len_;
      total_row_ += jump_rows_;
      sample_times++;
    } else {
      VLOG(2) << "table is full, not update stat!";
      break;
D
danleifeng 已提交
2591 2592
    }
  }
L
lxsbupt 已提交
2593
  buf_state_.Reset(total_row_);
D
danleifeng 已提交
2594 2595 2596
  int *d_random_row = reinterpret_cast<int *>(d_random_row_->ptr());

  thrust::random::default_random_engine engine(shuffle_seed_);
L
lxsbupt 已提交
2597
  const auto &exec_policy = thrust::cuda::par.on(sample_stream_);
D
danleifeng 已提交
2598 2599 2600
  thrust::counting_iterator<int> cnt_iter(0);
  thrust::shuffle_copy(exec_policy,
                       cnt_iter,
L
lxsbupt 已提交
2601
                       cnt_iter + total_row_,
D
danleifeng 已提交
2602 2603 2604
                       thrust::device_pointer_cast(d_random_row),
                       engine);

L
lxsbupt 已提交
2605
  cudaStreamSynchronize(sample_stream_);
D
danleifeng 已提交
2606 2607 2608
  shuffle_seed_ = engine();

  if (debug_mode_) {
L
lxsbupt 已提交
2609
    int *h_random_row = new int[total_row_ + 10];
D
danleifeng 已提交
2610 2611
    cudaMemcpy(h_random_row,
               d_random_row,
L
lxsbupt 已提交
2612
               total_row_ * sizeof(int),
D
danleifeng 已提交
2613
               cudaMemcpyDeviceToHost);
L
lxsbupt 已提交
2614
    for (int xx = 0; xx < total_row_; xx++) {
D
danleifeng 已提交
2615 2616 2617 2618 2619 2620 2621 2622 2623
      VLOG(2) << "h_random_row[" << xx << "]: " << h_random_row[xx];
    }
    delete[] h_random_row;
    delete[] h_walk;
    delete[] h_sample_keys;
    delete[] h_offset2idx;
    delete[] h_len_per_row;
    delete[] h_prefix_sum;
  }
L
lxsbupt 已提交
2624 2625 2626

  if (!sage_mode_) {
    uint64_t h_uniq_node_num = CopyUniqueNodes();
2627
    VLOG(1) << "sample_times:" << sample_times << ", d_walk_size:" << buf_size_
L
lxsbupt 已提交
2628 2629 2630 2631
            << ", d_walk_offset:" << i << ", total_rows:" << total_row_
            << ", h_uniq_node_num:" << h_uniq_node_num
            << ", total_samples:" << total_samples;
  } else {
2632
    VLOG(1) << "sample_times:" << sample_times << ", d_walk_size:" << buf_size_
L
lxsbupt 已提交
2633 2634 2635 2636 2637
            << ", d_walk_offset:" << i << ", total_rows:" << total_row_
            << ", total_samples:" << total_samples;
  }

  return total_row_ != 0;
D
danleifeng 已提交
2638 2639
}

L
lxsbupt 已提交
2640
void GraphDataGenerator::SetFeedVec(std::vector<phi::DenseTensor *> feed_vec) {
D
danleifeng 已提交
2641
  feed_vec_ = feed_vec;
L
lxsbupt 已提交
2642 2643 2644 2645 2646 2647 2648 2649 2650 2651 2652 2653 2654 2655 2656 2657 2658 2659 2660 2661 2662 2663 2664 2665 2666 2667 2668 2669 2670 2671 2672 2673 2674 2675 2676 2677 2678 2679 2680 2681 2682 2683 2684 2685 2686 2687 2688 2689 2690 2691 2692 2693 2694 2695 2696 2697 2698 2699 2700 2701 2702 2703
}

void GraphDataGenerator::AllocResource(
    int thread_id, std::vector<phi::DenseTensor *> feed_vec) {
  auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
  gpuid_ = gpu_graph_ptr->device_id_mapping[thread_id];
  thread_id_ = thread_id;
  place_ = platform::CUDAPlace(gpuid_);
  debug_gpu_memory_info(gpuid_, "AllocResource start");

  platform::CUDADeviceGuard guard(gpuid_);
  if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) {
    if (gpu_graph_training_) {
      table_ = new HashTable<uint64_t, uint64_t>(
          train_table_cap_ / FLAGS_gpugraph_hbm_table_load_factor);
    } else {
      table_ = new HashTable<uint64_t, uint64_t>(
          infer_table_cap_ / FLAGS_gpugraph_hbm_table_load_factor);
    }
  }
  VLOG(1) << "AllocResource gpuid " << gpuid_
          << " feed_vec.size: " << feed_vec.size()
          << " table cap: " << train_table_cap_;
  sample_stream_ = gpu_graph_ptr->get_local_stream(gpuid_);
  train_stream_ = dynamic_cast<phi::GPUContext *>(
                      platform::DeviceContextPool::Instance().Get(place_))
                      ->stream();
  // feed_vec_ = feed_vec;
  if (!sage_mode_) {
    slot_num_ = (feed_vec.size() - 3) / 2;
  } else {
    slot_num_ = (feed_vec.size() - 4 - samples_.size() * 5) / 2;
  }

  // infer_node_type_start_ = std::vector<int>(h_device_keys_.size(), 0);
  // for (size_t i = 0; i < h_device_keys_.size(); i++) {
  //   for (size_t j = 0; j < h_device_keys_[i]->size(); j++) {
  //     VLOG(3) << "h_device_keys_[" << i << "][" << j
  //             << "] = " << (*(h_device_keys_[i]))[j];
  //   }
  //   auto buf = memory::AllocShared(
  //       place_, h_device_keys_[i]->size() * sizeof(uint64_t));
  //   d_device_keys_.push_back(buf);
  //   CUDA_CHECK(cudaMemcpyAsync(buf->ptr(),
  //                              h_device_keys_[i]->data(),
  //                              h_device_keys_[i]->size() * sizeof(uint64_t),
  //                              cudaMemcpyHostToDevice,
  //                              stream_));
  // }
  if (gpu_graph_training_ && FLAGS_graph_metapath_split_opt) {
    d_train_metapath_keys_ =
        gpu_graph_ptr->d_graph_train_total_keys_[thread_id];
    h_train_metapath_keys_len_ =
        gpu_graph_ptr->h_graph_train_keys_len_[thread_id];
    VLOG(2) << "h train metapaths key len: " << h_train_metapath_keys_len_;
  } else {
    auto &d_graph_all_type_keys = gpu_graph_ptr->d_graph_all_type_total_keys_;
    auto &h_graph_all_type_keys_len = gpu_graph_ptr->h_graph_all_type_keys_len_;

    for (size_t i = 0; i < d_graph_all_type_keys.size(); i++) {
      d_device_keys_.push_back(d_graph_all_type_keys[i][thread_id]);
      h_device_keys_len_.push_back(h_graph_all_type_keys_len[i][thread_id]);
D
danleifeng 已提交
2704
    }
L
lxsbupt 已提交
2705 2706 2707
    VLOG(2) << "h_device_keys size: " << h_device_keys_len_.size();
  }

D
danleifeng 已提交
2708
  size_t once_max_sample_keynum = walk_degree_ * once_sample_startid_len_;
L
lxsbupt 已提交
2709 2710 2711 2712
  d_prefix_sum_ = memory::AllocShared(
      place_,
      (once_max_sample_keynum + 1) * sizeof(int),
      phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
D
danleifeng 已提交
2713
  int *d_prefix_sum_ptr = reinterpret_cast<int *>(d_prefix_sum_->ptr());
L
lxsbupt 已提交
2714 2715 2716 2717
  cudaMemsetAsync(d_prefix_sum_ptr,
                  0,
                  (once_max_sample_keynum + 1) * sizeof(int),
                  sample_stream_);
D
danleifeng 已提交
2718 2719
  cursor_ = 0;
  jump_rows_ = 0;
L
lxsbupt 已提交
2720 2721 2722 2723 2724 2725 2726 2727 2728 2729 2730 2731
  d_uniq_node_num_ = memory::AllocShared(
      place_,
      sizeof(uint64_t),
      phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
  cudaMemsetAsync(d_uniq_node_num_->ptr(), 0, sizeof(uint64_t), sample_stream_);

  d_walk_ = memory::AllocShared(
      place_,
      buf_size_ * sizeof(uint64_t),
      phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
  cudaMemsetAsync(
      d_walk_->ptr(), 0, buf_size_ * sizeof(uint64_t), sample_stream_);
2732 2733 2734 2735 2736 2737 2738 2739 2740 2741 2742 2743 2744 2745 2746 2747 2748 2749 2750 2751 2752

  excluded_train_pair_len_ = gpu_graph_ptr->excluded_train_pair_.size();
  if (excluded_train_pair_len_ > 0) {
    d_excluded_train_pair_ = memory::AllocShared(
        place_,
        excluded_train_pair_len_ * sizeof(uint8_t),
        phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
    CUDA_CHECK(cudaMemcpyAsync(d_excluded_train_pair_->ptr(),
                               gpu_graph_ptr->excluded_train_pair_.data(),
                               excluded_train_pair_len_ * sizeof(uint8_t),
                               cudaMemcpyHostToDevice,
                               sample_stream_));

    d_walk_ntype_ = memory::AllocShared(
        place_,
        buf_size_ * sizeof(uint8_t),
        phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
    cudaMemsetAsync(
        d_walk_ntype_->ptr(), 0, buf_size_ * sizeof(uint8_t), sample_stream_);
  }

L
lxsbupt 已提交
2753 2754 2755 2756 2757 2758 2759 2760 2761 2762 2763 2764 2765
  d_sample_keys_ = memory::AllocShared(
      place_,
      once_max_sample_keynum * sizeof(uint64_t),
      phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));

  d_sampleidx2rows_.push_back(memory::AllocShared(
      place_,
      once_max_sample_keynum * sizeof(int),
      phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_))));
  d_sampleidx2rows_.push_back(memory::AllocShared(
      place_,
      once_max_sample_keynum * sizeof(int),
      phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_))));
D
danleifeng 已提交
2766 2767
  cur_sampleidx2row_ = 0;

L
lxsbupt 已提交
2768 2769 2770 2771
  d_len_per_row_ = memory::AllocShared(
      place_,
      once_max_sample_keynum * sizeof(int),
      phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
D
danleifeng 已提交
2772 2773 2774 2775 2776 2777 2778 2779 2780
  for (int i = -window_; i < 0; i++) {
    window_step_.push_back(i);
  }
  for (int i = 0; i < window_; i++) {
    window_step_.push_back(i + 1);
  }
  buf_state_.Init(batch_size_, walk_len_, &window_step_);
  d_random_row_ = memory::AllocShared(
      place_,
L
lxsbupt 已提交
2781 2782
      (once_sample_startid_len_ * walk_degree_ * repeat_time_) * sizeof(int),
      phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
D
danleifeng 已提交
2783 2784 2785
  shuffle_seed_ = 0;

  ins_buf_pair_len_ = 0;
L
lxsbupt 已提交
2786 2787 2788 2789 2790 2791 2792 2793 2794 2795 2796 2797 2798
  if (!sage_mode_) {
    d_ins_buf_ =
        memory::AllocShared(place_, (batch_size_ * 2 * 2) * sizeof(uint64_t));
    d_pair_num_ = memory::AllocShared(place_, sizeof(int));
  } else {
    d_ins_buf_ = memory::AllocShared(
        place_,
        (batch_size_ * 2 * 2) * sizeof(uint64_t),
        phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
    d_pair_num_ = memory::AllocShared(
        place_,
        sizeof(int),
        phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
D
danleifeng 已提交
2799
  }
L
lxsbupt 已提交
2800 2801 2802 2803 2804 2805 2806 2807 2808 2809 2810 2811 2812 2813 2814 2815 2816 2817 2818 2819 2820 2821 2822 2823 2824 2825 2826 2827 2828 2829 2830 2831 2832 2833 2834 2835 2836 2837 2838 2839 2840 2841 2842 2843 2844 2845 2846

  d_slot_tensor_ptr_ =
      memory::AllocShared(place_, slot_num_ * sizeof(uint64_t *));
  d_slot_lod_tensor_ptr_ =
      memory::AllocShared(place_, slot_num_ * sizeof(uint64_t *));

  if (sage_mode_) {
    reindex_table_size_ = batch_size_ * 2;
    // get hashtable size
    for (int i = 0; i < samples_.size(); i++) {
      reindex_table_size_ *= (samples_[i] * edge_to_id_len_ + 1);
    }
    int64_t next_pow2 =
        1 << static_cast<size_t>(1 + std::log2(reindex_table_size_ >> 1));
    reindex_table_size_ = next_pow2 << 1;

    d_reindex_table_key_ = memory::AllocShared(
        place_,
        reindex_table_size_ * sizeof(int64_t),
        phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
    d_reindex_table_value_ = memory::AllocShared(
        place_,
        reindex_table_size_ * sizeof(int),
        phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
    d_reindex_table_index_ = memory::AllocShared(
        place_,
        reindex_table_size_ * sizeof(int),
        phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
    edge_type_graph_ =
        gpu_graph_ptr->get_edge_type_graph(gpuid_, edge_to_id_len_);

    d_sorted_keys_ = memory::AllocShared(
        place_,
        (batch_size_ * 2 * 2) * sizeof(uint64_t),
        phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
    d_sorted_idx_ = memory::AllocShared(
        place_,
        (batch_size_ * 2 * 2) * sizeof(uint32_t),
        phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
    d_offset_ = memory::AllocShared(
        place_,
        (batch_size_ * 2 * 2) * sizeof(uint32_t),
        phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
    d_merged_cnts_ = memory::AllocShared(
        place_,
        (batch_size_ * 2 * 2) * sizeof(uint32_t),
        phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
D
danleifeng 已提交
2847 2848
  }

2849 2850 2851 2852 2853 2854 2855 2856 2857 2858 2859 2860 2861 2862 2863 2864 2865 2866 2867 2868 2869 2870 2871
  // parse infer_node_type
  auto &type_to_index = gpu_graph_ptr->get_graph_type_to_index();
  if (!gpu_graph_training_) {
    auto node_types =
        paddle::string::split_string<std::string>(infer_node_type_, ";");
    auto node_to_id = gpu_graph_ptr->node_to_id;
    for (auto &type : node_types) {
      auto iter = node_to_id.find(type);
      PADDLE_ENFORCE_NE(
          iter,
          node_to_id.end(),
          platform::errors::NotFound("(%s) is not found in node_to_id.", type));
      int node_type = iter->second;
      int type_index = type_to_index[node_type];
      VLOG(2) << "add node[" << type
              << "] into infer_node_type, type_index(cursor)[" << type_index
              << "]";
      infer_node_type_index_set_.insert(type_index);
    }
    VLOG(2) << "infer_node_type_index_set_num: "
            << infer_node_type_index_set_.size();
  }

L
lxsbupt 已提交
2872 2873 2874 2875 2876 2877 2878 2879 2880 2881 2882 2883 2884 2885 2886 2887 2888 2889
  cudaStreamSynchronize(sample_stream_);

  debug_gpu_memory_info(gpuid_, "AllocResource end");
}

void GraphDataGenerator::AllocTrainResource(int thread_id) {
  if (slot_num_ > 0) {
    platform::CUDADeviceGuard guard(gpuid_);
    if (!sage_mode_) {
      d_feature_size_list_buf_ =
          memory::AllocShared(place_, (batch_size_ * 2) * sizeof(uint32_t));
      d_feature_size_prefixsum_buf_ =
          memory::AllocShared(place_, (batch_size_ * 2 + 1) * sizeof(uint32_t));
    } else {
      d_feature_size_list_buf_ = NULL;
      d_feature_size_prefixsum_buf_ = NULL;
    }
  }
D
danleifeng 已提交
2890 2891 2892 2893 2894 2895 2896 2897 2898 2899 2900 2901 2902 2903 2904 2905 2906 2907 2908
}

void GraphDataGenerator::SetConfig(
    const paddle::framework::DataFeedDesc &data_feed_desc) {
  auto graph_config = data_feed_desc.graph_config();
  walk_degree_ = graph_config.walk_degree();
  walk_len_ = graph_config.walk_len();
  window_ = graph_config.window();
  once_sample_startid_len_ = graph_config.once_sample_startid_len();
  debug_mode_ = graph_config.debug_mode();
  gpu_graph_training_ = graph_config.gpu_graph_training();
  if (debug_mode_ || !gpu_graph_training_) {
    batch_size_ = graph_config.batch_size();
  } else {
    batch_size_ = once_sample_startid_len_;
  }
  repeat_time_ = graph_config.sample_times_one_chunk();
  buf_size_ =
      once_sample_startid_len_ * walk_len_ * walk_degree_ * repeat_time_;
L
lxsbupt 已提交
2909 2910
  train_table_cap_ = graph_config.train_table_cap();
  infer_table_cap_ = graph_config.infer_table_cap();
2911
  get_degree_ = graph_config.get_degree();
L
lxsbupt 已提交
2912 2913
  epoch_finish_ = false;
  VLOG(1) << "Confirm GraphConfig, walk_degree : " << walk_degree_
D
danleifeng 已提交
2914 2915 2916
          << ", walk_len : " << walk_len_ << ", window : " << window_
          << ", once_sample_startid_len : " << once_sample_startid_len_
          << ", sample_times_one_chunk : " << repeat_time_
L
lxsbupt 已提交
2917 2918 2919
          << ", batch_size: " << batch_size_
          << ", train_table_cap: " << train_table_cap_
          << ", infer_table_cap: " << infer_table_cap_;
D
danleifeng 已提交
2920 2921
  std::string first_node_type = graph_config.first_node_type();
  std::string meta_path = graph_config.meta_path();
L
lxsbupt 已提交
2922 2923
  sage_mode_ = graph_config.sage_mode();
  std::string str_samples = graph_config.samples();
D
danleifeng 已提交
2924
  auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
L
lxsbupt 已提交
2925
  debug_gpu_memory_info("init_conf start");
2926 2927
  gpu_graph_ptr->init_conf(
      first_node_type, meta_path, graph_config.excluded_train_pair());
L
lxsbupt 已提交
2928 2929
  debug_gpu_memory_info("init_conf end");

D
danleifeng 已提交
2930
  auto edge_to_id = gpu_graph_ptr->edge_to_id;
L
lxsbupt 已提交
2931 2932 2933 2934 2935 2936
  edge_to_id_len_ = edge_to_id.size();
  sage_batch_count_ = 0;
  auto samples = paddle::string::split_string<std::string>(str_samples, ";");
  for (size_t i = 0; i < samples.size(); i++) {
    int sample_size = std::stoi(samples[i]);
    samples_.emplace_back(sample_size);
D
danleifeng 已提交
2937
  }
L
lxsbupt 已提交
2938
  copy_unique_len_ = 0;
2939 2940 2941 2942

  if (!gpu_graph_training_) {
    infer_node_type_ = graph_config.infer_node_type();
  }
2943
}
P
pangengzheng 已提交
2944
#endif
D
danleifeng 已提交
2945

2946 2947 2948 2949 2950 2951 2952 2953 2954 2955 2956 2957 2958 2959 2960 2961 2962 2963 2964 2965 2966 2967 2968 2969 2970 2971 2972 2973 2974 2975 2976 2977 2978 2979 2980 2981
void GraphDataGenerator::DumpWalkPath(std::string dump_path, size_t dump_rate) {
#ifdef _LINUX
  PADDLE_ENFORCE_LT(
      dump_rate,
      10000000,
      platform::errors::InvalidArgument(
          "dump_rate can't be large than 10000000. Please check the dump "
          "rate[1, 10000000]"));
  PADDLE_ENFORCE_GT(dump_rate,
                    1,
                    platform::errors::InvalidArgument(
                        "dump_rate can't be less than 1. Please check "
                        "the dump rate[1, 10000000]"));
  int err_no = 0;
  std::shared_ptr<FILE> fp = fs_open_append_write(dump_path, &err_no, "");
  uint64_t *h_walk = new uint64_t[buf_size_];
  uint64_t *walk = reinterpret_cast<uint64_t *>(d_walk_->ptr());
  cudaMemcpy(
      h_walk, walk, buf_size_ * sizeof(uint64_t), cudaMemcpyDeviceToHost);
  VLOG(1) << "DumpWalkPath all buf_size_:" << buf_size_;
  std::string ss = "";
  size_t write_count = 0;
  for (int xx = 0; xx < buf_size_ / dump_rate; xx += walk_len_) {
    ss = "";
    for (int yy = 0; yy < walk_len_; yy++) {
      ss += std::to_string(h_walk[xx + yy]) + "-";
    }
    write_count = fwrite_unlocked(ss.data(), 1, ss.length(), fp.get());
    if (write_count != ss.length()) {
      VLOG(1) << "dump walk path" << ss << " failed";
    }
    write_count = fwrite_unlocked("\n", 1, 1, fp.get());
  }
#endif
}

2982 2983 2984
}  // namespace framework
}  // namespace paddle
#endif