stream_safe_cuda_alloc_test.cu 6.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#endif

#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif

#include <thread>  // NOLINT
#include <vector>

#include "gtest/gtest.h"
#include "paddle/fluid/memory/malloc.h"
29
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

namespace paddle {
namespace memory {

__global__ void add_kernel(int *x, int n) {
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  for (int i = tid; i < n; i += blockDim.x * gridDim.x) {
    atomicAdd(x + i, tid);
  }
}

class StreamSafeCUDAAllocTest : public ::testing::Test {
 protected:
  void SetUp() override {
    place_ = platform::CUDAPlace();
    stream_num_ = 64;
    grid_num_ = 1;
    block_num_ = 64;
    data_num_ = 64;
    default_stream = nullptr;

    streams_.reserve(stream_num_);
    streams_.emplace_back(default_stream);
    for (size_t i = 1; i < stream_num_; ++i) {
      gpuStream_t stream;
#ifdef PADDLE_WITH_CUDA
56
      PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&stream));
57
#else
58
      PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&stream));
59 60 61 62 63 64 65 66 67
#endif
      streams_.emplace_back(stream);
    }

    for (size_t i = 0; i < stream_num_; ++i) {
      size_t allocation_size = data_num_ * sizeof(int);
      std::shared_ptr<Allocation> allocation =
          AllocShared(place_, allocation_size, streams_[i]);
#ifdef PADDLE_WITH_CUDA
68
      PADDLE_ENFORCE_GPU_SUCCESS(
69 70
          cudaMemset(allocation->ptr(), 0, allocation->size()));
#else
71
      PADDLE_ENFORCE_GPU_SUCCESS(
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
          hipMemset(allocation->ptr(), 0, allocation->size()));
#endif
      allocations_.emplace_back(allocation);
    }
  }

  void SingleStreamRun(size_t idx) {
    for (size_t i = 0; i < stream_num_; ++i) {
      int *x = reinterpret_cast<int *>(allocations_[i]->ptr());
      add_kernel<<<grid_num_, block_num_, 0, streams_[idx]>>>(x, data_num_);
      if (i != idx) {
        RecordStream(allocations_[i].get(), streams_[idx]);
      }
    }
  }

  void MultiStreamRun() {
    for (int i = 0; i < stream_num_; ++i) {
      SingleStreamRun(i);
    }
    allocations_.clear();  // fast_gc
  }

  void MultiThreadMUltiStreamRun() {
    std::vector<std::thread> threads;
    for (size_t i = 0; i < stream_num_; ++i) {
      threads.push_back(
          std::thread(&StreamSafeCUDAAllocTest::SingleStreamRun, this, i));
    }
    for (size_t i = 0; i < stream_num_; ++i) {
      threads[i].join();
    }
    allocations_.clear();  // fast_gc
  }

  void CheckResult() {
    auto host_x = std::unique_ptr<int[]>(new int[data_num_]);
    size_t thread_num = grid_num_ * block_num_;
    for (int i = 0; i < stream_num_; ++i) {
// tricky code, the allocations are still accessible even though
// allocations_.clear() has been called
#ifdef PADDLE_WITH_CUDA
114
      PADDLE_ENFORCE_GPU_SUCCESS(
115 116 117
          cudaMemcpy(host_x.get(), allocations_[i]->ptr(),
                     data_num_ * sizeof(int), cudaMemcpyDeviceToHost));
#else
118 119 120
      PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpy(host_x.get(), allocations_[i]->ptr(),
                                           data_num_ * sizeof(int),
                                           hipMemcpyDeviceToHost));
121 122 123 124 125 126 127 128 129
#endif
      for (int j = 0; j < data_num_; ++j) {
        EXPECT_TRUE(host_x[j] == (j % thread_num) * stream_num_);
      }
    }
  }

  void TearDown() override {
#ifdef PADDLE_WITH_CUDA
130
    PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
131
#else
132
    PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
133 134 135 136 137 138 139
#endif
    for (gpuStream_t stream : streams_) {
      Release(place_, stream);
    }

    for (size_t i = 1; i < stream_num_; ++i) {
#ifdef PADDLE_WITH_CUDA
140
      PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(streams_[i]));
141
#else
142
      PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(streams_[i]));
143 144 145 146
#endif
    }

    uint64_t cuda_malloc_size =
147
        platform::RecordedGpuMallocSize(place_.GetDeviceId());
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
    ASSERT_EQ(cuda_malloc_size, 0) << "Found " << cuda_malloc_size
                                   << " bytes memory that not released yet,"
                                   << " there may be a memory leak problem";
  }

  size_t stream_num_;
  size_t grid_num_;
  size_t block_num_;
  size_t data_num_;
  platform::CUDAPlace place_;
  gpuStream_t default_stream;
  std::vector<gpuStream_t> streams_;
  std::vector<std::shared_ptr<Allocation>> allocations_;
};

TEST_F(StreamSafeCUDAAllocTest, CUDAMutilStreamTest) {
  MultiStreamRun();
  CheckResult();
}

TEST_F(StreamSafeCUDAAllocTest, CUDAMutilThreadMutilStreamTest) {
  MultiThreadMUltiStreamRun();
  CheckResult();
}

TEST(StreamSafeCUDAAllocInterfaceTest, AllocInterfaceTest) {
  platform::CUDAPlace place = platform::CUDAPlace();
  size_t alloc_size = 256;

  std::shared_ptr<Allocation> allocation_implicit_stream =
      AllocShared(place, alloc_size);
  EXPECT_GE(allocation_implicit_stream->size(), alloc_size);

  void *address = allocation_implicit_stream->ptr();
  allocation_implicit_stream.reset();

  gpuStream_t default_stream = nullptr;
  allocation::AllocationPtr allocation_unique =
      Alloc(place, alloc_size, default_stream);
  EXPECT_GE(allocation_unique->size(), alloc_size);
  EXPECT_EQ(allocation_unique->ptr(), address);
}

TEST(StreamSafeCUDAAllocRetryTest, RetryTest) {
  platform::CUDAPlace place = platform::CUDAPlace();
  gpuStream_t stream1, stream2;
#ifdef PADDLE_WITH_CUDA
195 196
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&stream1));
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&stream2));
197
#else
198 199
  PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&stream1));
  PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&stream2));
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
#endif
  size_t available_size = platform::GpuAvailableMemToAlloc();
  // alloc_size < available_size < 2 * alloc_size
  size_t alloc_size = available_size / 4 * 3;

  std::shared_ptr<Allocation> allocation1 =
      AllocShared(place, alloc_size, stream1);
  std::shared_ptr<Allocation> allocation2;

  std::thread th([&allocation2, &place, &stream2, alloc_size]() {
    std::this_thread::sleep_for(std::chrono::seconds(1));
    allocation2 = AllocShared(place, alloc_size, stream2);
  });
  allocation1.reset();  // free but not release
  th.join();
  EXPECT_GE(allocation2->size(), alloc_size);
  allocation2.reset();

#ifdef PADDLE_WITH_CUDA
219
  PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
220
#else
221
  PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
222 223 224 225 226 227 228 229
#endif

  Release(place, stream1);
  Release(place, stream2);
}

}  // namespace memory
}  // namespace paddle