stream_safe_cuda_alloc_test.cu 11.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#endif

#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif

#include <thread>  // NOLINT
#include <vector>

#include "gtest/gtest.h"
28
#include "paddle/fluid/memory/allocation/allocator_facade.h"
29
#include "paddle/fluid/memory/malloc.h"
30
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
31
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
32
#include "paddle/fluid/platform/device_context.h"
33 34 35 36 37

namespace paddle {
namespace memory {

__global__ void add_kernel(int *x, int n) {
38 39 40 41
  int thread_num = gridDim.x * blockDim.x;
  int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
  for (int i = thread_id; i < n; i += thread_num) {
    atomicAdd(x + i, thread_id);
42 43 44
  }
}

45 46 47 48 49 50 51 52
void CheckMemLeak(const platform::CUDAPlace &place) {
  uint64_t cuda_malloc_size =
      platform::RecordedGpuMallocSize(place.GetDeviceId());
  ASSERT_EQ(cuda_malloc_size, 0) << "Found " << cuda_malloc_size
                                 << " bytes memory that not released yet,"
                                 << " there may be a memory leak problem";
}

53 54 55 56 57 58
class StreamSafeCUDAAllocTest : public ::testing::Test {
 protected:
  void SetUp() override {
    place_ = platform::CUDAPlace();
    stream_num_ = 64;
    grid_num_ = 1;
59 60 61
    block_num_ = 32;
    data_num_ = 131072;
    workspace_size_ = data_num_ * sizeof(int);
62

63 64
    // alloc workspace for each stream
    for (size_t i = 0; i < stream_num_; ++i) {
65 66
      gpuStream_t stream;
#ifdef PADDLE_WITH_CUDA
67
      PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&stream));
68
#else
69
      PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&stream));
70 71 72
#endif

      std::shared_ptr<Allocation> allocation =
73
          AllocShared(place_, workspace_size_, stream);
74
#ifdef PADDLE_WITH_CUDA
75
      PADDLE_ENFORCE_GPU_SUCCESS(
76 77
          cudaMemset(allocation->ptr(), 0, allocation->size()));
#else
78
      PADDLE_ENFORCE_GPU_SUCCESS(
79 80
          hipMemset(allocation->ptr(), 0, allocation->size()));
#endif
81 82 83

      streams_.emplace_back(stream);
      workspaces_.emplace_back(allocation);
84
    }
85 86

    result_ = AllocShared(place_, stream_num_ * workspace_size_);
87 88 89
  }

  void SingleStreamRun(size_t idx) {
90 91
    // for all stream i,
    // stream idx lauch a kernel to add (j % thread_num) to workspaces_[i][j]
92
    for (size_t i = 0; i < stream_num_; ++i) {
93
      int *x = reinterpret_cast<int *>(workspaces_[i]->ptr());
94
      add_kernel<<<grid_num_, block_num_, 0, streams_[idx]>>>(x, data_num_);
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
      RecordStream(workspaces_[i], streams_[idx]);
    }
  }

  void CopyResultAsync() {
    for (size_t i = 0; i < stream_num_; ++i) {
#ifdef PADDLE_WITH_CUDA
      PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(
          reinterpret_cast<int *>(result_->ptr()) + i * data_num_,
          workspaces_[i]->ptr(), workspace_size_, cudaMemcpyDeviceToDevice));
#else
      PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(
          reinterpret_cast<int *>(result_->ptr()) + i * data_num_,
          workspaces_[i]->ptr(), workspace_size_, hipMemcpyDeviceToDevice));
#endif
110 111 112 113
    }
  }

  void MultiStreamRun() {
114
    for (size_t i = 0; i < stream_num_; ++i) {
115 116
      SingleStreamRun(i);
    }
117 118 119
    CopyResultAsync();
    workspaces_.clear();  // fast_gc
    cudaDeviceSynchronize();
120 121 122 123 124 125 126 127 128 129 130
  }

  void MultiThreadMUltiStreamRun() {
    std::vector<std::thread> threads;
    for (size_t i = 0; i < stream_num_; ++i) {
      threads.push_back(
          std::thread(&StreamSafeCUDAAllocTest::SingleStreamRun, this, i));
    }
    for (size_t i = 0; i < stream_num_; ++i) {
      threads[i].join();
    }
131 132 133
    CopyResultAsync();
    workspaces_.clear();  // fast_gc
    cudaDeviceSynchronize();
134 135 136
  }

  void CheckResult() {
137
    auto result_host = std::unique_ptr<int[]>(new int[result_->size()]);
138
#ifdef PADDLE_WITH_CUDA
139 140 141
    PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpy(result_host.get(), result_->ptr(),
                                          result_->size(),
                                          cudaMemcpyDeviceToHost));
142
#else
143 144 145
    PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpy(result_host.get(), result_->ptr(),
                                         result_->size(),
                                         hipMemcpyDeviceToHost));
146
#endif
147 148 149 150 151
    size_t thread_num = grid_num_ * block_num_;
    for (size_t i = 0; i < stream_num_; ++i) {
      for (size_t j = 0; j < data_num_; ++j) {
        EXPECT_TRUE(result_host[i * stream_num_ + j] ==
                    (j % thread_num) * stream_num_);
152 153
      }
    }
154
    result_.reset();
155 156 157 158
  }

  void TearDown() override {
#ifdef PADDLE_WITH_CUDA
159
    PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
160
#else
161
    PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
162 163 164 165 166 167 168
#endif
    for (gpuStream_t stream : streams_) {
      Release(place_, stream);
    }

    for (size_t i = 1; i < stream_num_; ++i) {
#ifdef PADDLE_WITH_CUDA
169
      PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(streams_[i]));
170
#else
171
      PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(streams_[i]));
172 173 174
#endif
    }

175
    CheckMemLeak(place_);
176 177 178 179 180 181
  }

  size_t stream_num_;
  size_t grid_num_;
  size_t block_num_;
  size_t data_num_;
182
  size_t workspace_size_;
183 184
  platform::CUDAPlace place_;
  std::vector<gpuStream_t> streams_;
185 186
  std::vector<std::shared_ptr<Allocation>> workspaces_;
  std::shared_ptr<Allocation> result_;
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
};

TEST_F(StreamSafeCUDAAllocTest, CUDAMutilStreamTest) {
  MultiStreamRun();
  CheckResult();
}

TEST_F(StreamSafeCUDAAllocTest, CUDAMutilThreadMutilStreamTest) {
  MultiThreadMUltiStreamRun();
  CheckResult();
}

TEST(StreamSafeCUDAAllocInterfaceTest, AllocInterfaceTest) {
  platform::CUDAPlace place = platform::CUDAPlace();
  size_t alloc_size = 256;

  std::shared_ptr<Allocation> allocation_implicit_stream =
      AllocShared(place, alloc_size);
  EXPECT_GE(allocation_implicit_stream->size(), alloc_size);

  void *address = allocation_implicit_stream->ptr();
  allocation_implicit_stream.reset();

210 211 212 213
  gpuStream_t default_stream =
      dynamic_cast<platform::CUDADeviceContext *>(
          paddle::platform::DeviceContextPool::Instance().Get(place))
          ->stream();
214 215 216 217
  allocation::AllocationPtr allocation_unique =
      Alloc(place, alloc_size, default_stream);
  EXPECT_GE(allocation_unique->size(), alloc_size);
  EXPECT_EQ(allocation_unique->ptr(), address);
218 219 220 221
  allocation_unique.reset();

  Release(place);
  CheckMemLeak(place);
222 223
}

224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
TEST(StreamSafeCUDAAllocInterfaceTest, GetAllocatorInterfaceTest) {
  platform::CUDAPlace place = platform::CUDAPlace();
  auto &instance = allocation::AllocatorFacade::Instance();
  const std::shared_ptr<Allocator> &allocator = instance.GetAllocator(place);

  size_t alloc_size = 256;
  std::shared_ptr<Allocation> allocation_from_allocator =
      allocator->Allocate(alloc_size);
  EXPECT_GE(allocation_from_allocator->size(), alloc_size);
  void *address = allocation_from_allocator->ptr();
  allocation_from_allocator.reset();

  std::shared_ptr<Allocation> allocation_implicit_stream =
      AllocShared(place, alloc_size);
  EXPECT_GE(allocation_implicit_stream->size(), alloc_size);
  EXPECT_EQ(allocation_implicit_stream->ptr(), address);
  allocation_implicit_stream.reset();

  Release(place);
  CheckMemLeak(place);
}

246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
TEST(StreamSafeCUDAAllocInterfaceTest, ZeroSizeRecordStreamTest) {
  platform::CUDAPlace place = platform::CUDAPlace();
  std::shared_ptr<Allocation> zero_size_allocation = AllocShared(place, 0);
  EXPECT_EQ(zero_size_allocation->ptr(), nullptr);

  gpuStream_t stream;
#ifdef PADDLE_WITH_CUDA
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&stream));
#else
  PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&stream));
#endif

  EXPECT_NO_THROW(RecordStream(zero_size_allocation, stream));

#ifdef PADDLE_WITH_CUDA
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream));
#else
  PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream));
#endif
}

267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
TEST(StreamSafeCUDAAllocInterfaceTest, GetStreamInterfaceTest) {
  platform::CUDAPlace place = platform::CUDAPlace();
  size_t alloc_size = 256;

  gpuStream_t default_stream =
      dynamic_cast<platform::CUDADeviceContext *>(
          paddle::platform::DeviceContextPool::Instance().Get(place))
          ->stream();
  std::shared_ptr<Allocation> allocation_implicit_stream =
      AllocShared(place, alloc_size);
  EXPECT_EQ(GetStream(allocation_implicit_stream), default_stream);

  gpuStream_t new_stream;
#ifdef PADDLE_WITH_CUDA
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&new_stream));
#else
  PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&new_stream));
#endif

  std::shared_ptr<Allocation> allocation_new_stream =
      AllocShared(place, alloc_size, new_stream);
  EXPECT_EQ(GetStream(allocation_new_stream), new_stream);

#ifdef PADDLE_WITH_CUDA
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(new_stream));
#else
  PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(new_stream));
#endif

  allocation_implicit_stream.reset();
  allocation_new_stream.reset();
  Release(place);
  CheckMemLeak(place);
}

302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
#ifdef PADDLE_WITH_CUDA
TEST(StreamSafeCUDAAllocInterfaceTest, CUDAGraphExceptionTest) {
  platform::CUDAPlace place = platform::CUDAPlace();
  size_t alloc_size = 1;
  std::shared_ptr<Allocation> allocation = AllocShared(place, alloc_size);

  platform::BeginCUDAGraphCapture(place, cudaStreamCaptureModeGlobal);
  EXPECT_THROW(AllocShared(place, alloc_size), paddle::platform::EnforceNotMet);
  EXPECT_THROW(Alloc(place, alloc_size), paddle::platform::EnforceNotMet);
  EXPECT_THROW(Release(place), paddle::platform::EnforceNotMet);
  EXPECT_THROW(allocation::AllocatorFacade::Instance().GetAllocator(place),
               paddle::platform::EnforceNotMet);
  EXPECT_THROW(AllocShared(place, alloc_size, nullptr),
               paddle::platform::EnforceNotMet);
  EXPECT_THROW(Alloc(place, alloc_size, nullptr),
               paddle::platform::EnforceNotMet);
  EXPECT_THROW(Release(place, nullptr), paddle::platform::EnforceNotMet);
319
  EXPECT_THROW(RecordStream(allocation, nullptr),
320
               paddle::platform::EnforceNotMet);
321
  EXPECT_THROW(GetStream(allocation), paddle::platform::EnforceNotMet);
322 323 324 325 326 327 328 329
  platform::EndCUDAGraphCapture();

  allocation.reset();
  Release(place);
  CheckMemLeak(place);
}
#endif

330 331 332 333
TEST(StreamSafeCUDAAllocRetryTest, RetryTest) {
  platform::CUDAPlace place = platform::CUDAPlace();
  gpuStream_t stream1, stream2;
#ifdef PADDLE_WITH_CUDA
334 335
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&stream1));
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&stream2));
336
#else
337 338
  PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&stream1));
  PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&stream2));
339 340
#endif
  size_t available_size = platform::GpuAvailableMemToAlloc();
341 342
  // alloc_size < available_size < 2 * alloc_size,
  // so the second alloc will fail and retry
343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
  size_t alloc_size = available_size / 4 * 3;

  std::shared_ptr<Allocation> allocation1 =
      AllocShared(place, alloc_size, stream1);
  std::shared_ptr<Allocation> allocation2;

  std::thread th([&allocation2, &place, &stream2, alloc_size]() {
    std::this_thread::sleep_for(std::chrono::seconds(1));
    allocation2 = AllocShared(place, alloc_size, stream2);
  });
  allocation1.reset();  // free but not release
  th.join();
  EXPECT_GE(allocation2->size(), alloc_size);
  allocation2.reset();

#ifdef PADDLE_WITH_CUDA
359
  PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
360
#else
361
  PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
362 363 364 365
#endif

  Release(place, stream1);
  Release(place, stream2);
366
  CheckMemLeak(place);
367 368 369 370
}

}  // namespace memory
}  // namespace paddle