cuda_managed_memory_test.cu 4.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif

#include "gtest/gtest.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/place.h"

namespace paddle {
namespace memory {

__global__ void write_kernel(int* data, uint64_t n, uint64_t step) {
  int thread_num = gridDim.x * blockDim.x;
  int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
  for (uint64_t i = thread_id; i * step < n; i += thread_num) {
    *(data + i * step) = 1;
  }
}

__global__ void sum_kernel(int* data, uint64_t n, uint64_t step, int* sum) {
  int thread_num = gridDim.x * blockDim.x;
  int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
  for (uint64_t i = thread_id; i * step < n; i += thread_num) {
    atomicAdd(sum, *(data + i * step));
  }
}

TEST(ManagedMemoryTest, H2DTest) {
  if (!platform::IsGPUManagedMemorySupported(0)) {
    return;
  }

  uint64_t n_data = 1024;
  uint64_t step = 1;
  allocation::AllocationPtr allocation =
      Alloc(platform::CUDAPlace(0), n_data * sizeof(int));
  int* data = static_cast<int*>(allocation->ptr());

  memset(data, 0, n_data * sizeof(int));          // located on host memory
  write_kernel<<<1, 1024>>>(data, n_data, step);  // trans to device memory

#ifdef PADDLE_WITH_CUDA
  PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#else
  PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif

  int sum = 0;
  for (uint64_t i = 0; i < n_data; ++i) {
    sum += *(data + i);
  }
  EXPECT_EQ(sum, n_data / step);
  allocation = nullptr;
}

TEST(ManagedMemoryTest, D2HTest) {
  if (!platform::IsGPUManagedMemorySupported(0)) {
    return;
  }

  uint64_t n_data = 1024;
  uint64_t step = 1;
  AllocationPtr allocation =
      Alloc(platform::CUDAPlace(0), n_data * sizeof(int));
  int* data = static_cast<int*>(allocation->ptr());

  write_kernel<<<1, 1024>>>(data, n_data, step);  // located on device memory

#ifdef PADDLE_WITH_CUDA
  PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#else
  PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif

  memset(data, 0, n_data * sizeof(int));  // trans to host memory

  int sum = 0;
  for (uint64_t i = 0; i < n_data; ++i) {
    sum += *(data + i);
  }
  EXPECT_EQ(sum, 0);
}

TEST(ManagedMemoryTest, OversubscribeGPUMemoryTest) {
  if (!platform::IsGPUManagedMemoryOversubscriptionSupported(0)) {
    return;
  }

  uint64_t available_mem = platform::GpuAvailableMemToAlloc();
  uint64_t n_data = available_mem * 2 / sizeof(int) +
                    1;  // requires more than 2 * available_mem bytes
110
  uint64_t step = std::max(n_data / 1024, static_cast<uint64_t>(1));
111 112 113 114 115 116 117
  AllocationPtr data_allocation =
      Alloc(platform::CUDAPlace(0), n_data * sizeof(int));
  AllocationPtr sum_allocation = Alloc(platform::CUDAPlace(0), sizeof(int));
  int* data = static_cast<int*>(data_allocation->ptr());
  int* sum = static_cast<int*>(sum_allocation->ptr());
  (*sum) = 0;

118 119
  write_kernel<<<1, 1024>>>(data, n_data, step);
  sum_kernel<<<1, 1024>>>(data, n_data, step, sum);
120 121 122 123 124 125 126 127 128 129 130

#ifdef PADDLE_WITH_CUDA
  PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#else
  PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif

  EXPECT_EQ(*sum, (n_data + step - 1) / step);
}

TEST(ManagedMemoryTest, OOMExceptionTest) {
131 132 133
  if (!platform::IsGPUManagedMemorySupported(0)) {
    return;
  }
134 135 136 137 138 139
  EXPECT_THROW(Alloc(platform::CUDAPlace(0), size_t(1) << 60),
               memory::allocation::BadAlloc);
}

}  // namespace memory
}  // namespace paddle