stream_safe_cuda_alloc_test.cu 12.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <thread>  // NOLINT
#include <vector>

#include "gtest/gtest.h"
19
#include "paddle/fluid/memory/allocation/allocator_facade.h"
20
#include "paddle/fluid/memory/memory.h"
21
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
22
#include "paddle/fluid/platform/device_context.h"
23
#include "paddle/phi/core/stream.h"
24

25 26 27 28 29 30 31 32 33 34
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#endif

#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif

35 36 37
namespace paddle {
namespace memory {

38 39
// y += (x + 1)
__global__ void add_kernel(int *x, int *y, int n) {
40 41 42
  int thread_num = gridDim.x * blockDim.x;
  int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
  for (int i = thread_id; i < n; i += thread_num) {
43
    y[i] += x[i] + 1;
44 45 46
  }
}

47 48 49 50 51 52 53 54
void CheckMemLeak(const platform::CUDAPlace &place) {
  uint64_t cuda_malloc_size =
      platform::RecordedGpuMallocSize(place.GetDeviceId());
  ASSERT_EQ(cuda_malloc_size, 0) << "Found " << cuda_malloc_size
                                 << " bytes memory that not released yet,"
                                 << " there may be a memory leak problem";
}

55 56 57 58 59 60 61 62 63 64 65
TEST(StreamSafeCUDAAllocInterfaceTest, AllocInterfaceTest) {
  platform::CUDAPlace place = platform::CUDAPlace();
  size_t alloc_size = 256;

  std::shared_ptr<Allocation> allocation_implicit_stream =
      AllocShared(place, alloc_size);
  EXPECT_GE(allocation_implicit_stream->size(), alloc_size);

  void *address = allocation_implicit_stream->ptr();
  allocation_implicit_stream.reset();

66 67 68 69
  gpuStream_t default_stream =
      dynamic_cast<platform::CUDADeviceContext *>(
          paddle::platform::DeviceContextPool::Instance().Get(place))
          ->stream();
70
  allocation::AllocationPtr allocation_unique =
71 72
      Alloc(place, alloc_size,
            phi::Stream(reinterpret_cast<phi::StreamId>(default_stream)));
73 74
  EXPECT_GE(allocation_unique->size(), alloc_size);
  EXPECT_EQ(allocation_unique->ptr(), address);
75 76 77 78
  allocation_unique.reset();

  Release(place);
  CheckMemLeak(place);
79 80
}

81 82
TEST(StreamSafeCUDAAllocInterfaceTest, GetAllocatorInterfaceTest) {
  platform::CUDAPlace place = platform::CUDAPlace();
83 84 85 86 87 88 89 90
  size_t alloc_size = 256;

  allocation::AllocationPtr allocation_implicit_stream =
      Alloc(place, alloc_size);
  EXPECT_GE(allocation_implicit_stream->size(), alloc_size);
  void *address = allocation_implicit_stream->ptr();
  allocation_implicit_stream.reset();

91 92 93
  auto &instance = allocation::AllocatorFacade::Instance();
  const std::shared_ptr<Allocator> &allocator = instance.GetAllocator(place);

94
  allocation::AllocationPtr allocation_from_allocator =
95 96
      allocator->Allocate(alloc_size);
  EXPECT_GE(allocation_from_allocator->size(), alloc_size);
97
  EXPECT_EQ(allocation_from_allocator->ptr(), address);
98 99 100 101 102 103
  allocation_from_allocator.reset();

  Release(place);
  CheckMemLeak(place);
}

104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
TEST(StreamSafeCUDAAllocInterfaceTest, ZeroSizeRecordStreamTest) {
  platform::CUDAPlace place = platform::CUDAPlace();
  std::shared_ptr<Allocation> zero_size_allocation = AllocShared(place, 0);
  EXPECT_EQ(zero_size_allocation->ptr(), nullptr);

  gpuStream_t stream;
#ifdef PADDLE_WITH_CUDA
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&stream));
#else
  PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&stream));
#endif

  EXPECT_NO_THROW(RecordStream(zero_size_allocation, stream));

#ifdef PADDLE_WITH_CUDA
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream));
#else
  PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream));
#endif
}

125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
TEST(StreamSafeCUDAAllocInterfaceTest, GetStreamInterfaceTest) {
  platform::CUDAPlace place = platform::CUDAPlace();
  size_t alloc_size = 256;

  gpuStream_t default_stream =
      dynamic_cast<platform::CUDADeviceContext *>(
          paddle::platform::DeviceContextPool::Instance().Get(place))
          ->stream();
  std::shared_ptr<Allocation> allocation_implicit_stream =
      AllocShared(place, alloc_size);
  EXPECT_EQ(GetStream(allocation_implicit_stream), default_stream);

  gpuStream_t new_stream;
#ifdef PADDLE_WITH_CUDA
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&new_stream));
#else
  PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&new_stream));
#endif

C
Chen Weihang 已提交
144 145
  std::shared_ptr<Allocation> allocation_new_stream =
      AllocShared(place, alloc_size,
146
                  phi::Stream(reinterpret_cast<phi::StreamId>(new_stream)));
147 148 149 150 151 152 153 154 155 156 157 158 159 160
  EXPECT_EQ(GetStream(allocation_new_stream), new_stream);

#ifdef PADDLE_WITH_CUDA
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(new_stream));
#else
  PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(new_stream));
#endif

  allocation_implicit_stream.reset();
  allocation_new_stream.reset();
  Release(place);
  CheckMemLeak(place);
}

161 162 163 164
TEST(StreamSafeCUDAAllocRetryTest, RetryTest) {
  platform::CUDAPlace place = platform::CUDAPlace();
  gpuStream_t stream1, stream2;
#ifdef PADDLE_WITH_CUDA
165 166
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&stream1));
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&stream2));
167
#else
168 169
  PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&stream1));
  PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&stream2));
170 171
#endif
  size_t available_size = platform::GpuAvailableMemToAlloc();
172 173
  // alloc_size < available_size < 2 * alloc_size,
  // so the second alloc will fail and retry
174 175
  size_t alloc_size = available_size / 4 * 3;

176 177
  allocation::AllocationPtr allocation1 = Alloc(
      place, alloc_size, phi::Stream(reinterpret_cast<phi::StreamId>(stream1)));
178
  allocation::AllocationPtr allocation2;
179 180 181

  std::thread th([&allocation2, &place, &stream2, alloc_size]() {
    std::this_thread::sleep_for(std::chrono::seconds(1));
182 183
    allocation2 = Alloc(place, alloc_size,
                        phi::Stream(reinterpret_cast<phi::StreamId>(stream2)));
184 185 186 187 188 189 190
  });
  allocation1.reset();  // free but not release
  th.join();
  EXPECT_GE(allocation2->size(), alloc_size);
  allocation2.reset();

#ifdef PADDLE_WITH_CUDA
191
  PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
192
#else
193
  PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
194 195 196 197
#endif

  Release(place, stream1);
  Release(place, stream2);
198
  CheckMemLeak(place);
199 200
}

201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396
class StreamSafeCUDAAllocTest : public ::testing::Test {
 protected:
  void SetUp() override {
    place_ = platform::CUDAPlace();
    stream_num_ = 64;
    grid_num_ = 1;
    block_num_ = 32;
    data_num_ = 131072;
    workspace_size_ = data_num_ * sizeof(int);

    for (size_t i = 0; i < stream_num_; ++i) {
      gpuStream_t stream;
#ifdef PADDLE_WITH_CUDA
      PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&stream));
#else
      PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&stream));
#endif

      std::shared_ptr<phi::Allocation> workspace_allocation =
          AllocShared(place_, workspace_size_,
                      phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
      std::shared_ptr<phi::Allocation> result_allocation =
          AllocShared(place_, workspace_size_,
                      phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
      std::shared_ptr<phi::Allocation> host_result_allocation =
          AllocShared(platform::CPUPlace(), workspace_size_);

#ifdef PADDLE_WITH_CUDA
      PADDLE_ENFORCE_GPU_SUCCESS(cudaMemset(workspace_allocation->ptr(), 0,
                                            workspace_allocation->size()));
      PADDLE_ENFORCE_GPU_SUCCESS(
          cudaMemset(result_allocation->ptr(), 0, result_allocation->size()));
#else
      PADDLE_ENFORCE_GPU_SUCCESS(hipMemset(workspace_allocation->ptr(), 0,
                                           workspace_allocation->size()));
      PADDLE_ENFORCE_GPU_SUCCESS(
          hipMemset(result_allocation->ptr(), 0, result_allocation->size()));
#endif

      streams_.emplace_back(stream);
      workspaces_.emplace_back(workspace_allocation);
      results_.emplace_back(result_allocation);
      host_results_.emplace_back(host_result_allocation);
    }
  }

  void SingleStreamRun(size_t idx) {
    int *y = reinterpret_cast<int *>(results_[idx]->ptr());
    int neighbouring_idx = idx > 0 ? idx - 1 : idx;

    add_kernel<<<grid_num_, block_num_, 0, streams_[idx]>>>(
        reinterpret_cast<int *>(workspaces_[idx]->ptr()), y, data_num_);
    add_kernel<<<grid_num_, block_num_, 0, streams_[idx]>>>(
        reinterpret_cast<int *>(workspaces_[neighbouring_idx]->ptr()), y,
        data_num_);
    RecordStream(workspaces_[neighbouring_idx], streams_[idx]);
  }

  void MultiStreamRun() {
    // Must run in reverse order, or the workspace_[i - 1] will be released
    // before streams_[i]'s kernel launch
    for (int i = stream_num_ - 1; i >= 0; --i) {
      SingleStreamRun(i);
      workspaces_[i].reset();  // fast GC
    }
  }

  void MultiThreadMultiStreamRun() {
    std::vector<std::thread> threads;
    for (size_t i = 0; i < stream_num_; ++i) {
      threads.push_back(
          std::thread(&StreamSafeCUDAAllocTest::SingleStreamRun, this, i));
    }
    for (size_t i = 0; i < stream_num_; ++i) {
      threads[i].join();
    }
    workspaces_.clear();
  }

  void CUDAGraphRun() {
    testing_cuda_graph_ = true;
    platform::BeginCUDAGraphCapture(platform::CUDAPlace(),
                                    cudaStreamCaptureModeGlobal);

    std::shared_ptr<Allocation> data_allocation =
        AllocShared(platform::CUDAPlace(), workspace_size_);
    std::shared_ptr<Allocation> result_allocation =
        AllocShared(platform::CUDAPlace(), workspace_size_);

    int *data = static_cast<int *>(data_allocation->ptr());
    int *result = static_cast<int *>(result_allocation->ptr());

    gpuStream_t main_stream = GetStream(data_allocation);
    gpuStream_t other_stream;
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&other_stream));

    add_kernel<<<grid_num_, block_num_, 0, main_stream>>>(data, result,
                                                          data_num_);
    RecordStream(data_allocation, other_stream);

    std::unique_ptr<platform::CUDAGraph> cuda_graph =
        platform::EndCUDAGraphCapture();

    int replay_times = 10;
    for (int i = 0; i < replay_times; ++i) {
      cuda_graph->Replay();
    }

    std::shared_ptr<Allocation> host_result_allocation =
        AllocShared(platform::CPUPlace(), workspace_size_);
    Copy(host_result_allocation->place(), host_result_allocation->ptr(),
         result_allocation->place(), result_allocation->ptr(), workspace_size_,
         main_stream);
    cudaStreamSynchronize(main_stream);

    int *host_result = static_cast<int *>(host_result_allocation->ptr());
    for (int i = 0; i < data_num_; ++i) {
      EXPECT_EQ(host_result[i], replay_times);
    }

    data_allocation.reset();
    result_allocation.reset();
    cuda_graph.release();
    PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(other_stream));
  }

  void CheckResult() {
    for (size_t i = 0; i < stream_num_; ++i) {
      Copy(host_results_[i]->place(), host_results_[i]->ptr(),
           results_[i]->place(), results_[i]->ptr(), workspace_size_,
           streams_[i]);
    }
    cudaDeviceSynchronize();

    size_t thread_num = grid_num_ * block_num_;
    for (size_t i = 0; i < stream_num_; ++i) {
      int *result = static_cast<int *>(host_results_[i]->ptr());
      for (size_t j = 0; j < data_num_; ++j) {
        EXPECT_EQ(result[j], 2);
      }
    }
  }

  void TearDown() override {
    workspaces_.clear();
    results_.clear();
    host_results_.clear();
    for (gpuStream_t stream : streams_) {
      Release(place_, stream);
    }

    for (size_t i = 0; i < stream_num_; ++i) {
#ifdef PADDLE_WITH_CUDA
      PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(streams_[i]));
#else
      PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(streams_[i]));
#endif
    }

    // Memory release for CUDA Graph memory pool is forbidden
    if (!testing_cuda_graph_) {
      CheckMemLeak(place_);
    }
  }

  bool testing_cuda_graph_{0};
  size_t stream_num_;
  size_t grid_num_;
  size_t block_num_;
  size_t data_num_;
  size_t workspace_size_;
  platform::CUDAPlace place_;
  std::vector<gpuStream_t> streams_;
  std::vector<std::shared_ptr<phi::Allocation>> workspaces_;
  std::vector<std::shared_ptr<phi::Allocation>> results_;
  std::vector<std::shared_ptr<phi::Allocation>> host_results_;
};

TEST_F(StreamSafeCUDAAllocTest, CUDAMutilStreamTest) {
  MultiStreamRun();
  CheckResult();
}

TEST_F(StreamSafeCUDAAllocTest, CUDAMutilThreadMutilStreamTest) {
  MultiThreadMultiStreamRun();
  CheckResult();
}

#ifdef PADDLE_WITH_CUDA
TEST_F(StreamSafeCUDAAllocTest, CUDAGraphTest) {
  MultiStreamRun();
  CUDAGraphRun();
  CheckResult();
}
#endif

397 398
}  // namespace memory
}  // namespace paddle