malloc_test.cu 7.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17 18 19 20 21 22 23
#include <thread>  // NOLINT
#include <vector>

#include "gtest/gtest.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/stream.h"

24
#ifdef PADDLE_WITH_CUDA
25 26
#include <cuda.h>
#include <cuda_runtime.h>
27 28 29 30 31 32
#endif

#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif

33 34 35 36 37 38 39
namespace paddle {
namespace memory {

const int NUM_STREAMS = 8;
const int N = 2;
const float DELTA = 1e-1;

L
Leo Chen 已提交
40
using CudaDevCtxVec = std::vector<std::unique_ptr<phi::GPUContext>>;
41 42 43 44 45 46 47 48 49 50 51

__global__ void kernel(float *x, int n) {
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  for (int i = tid; i < n; i += blockDim.x * gridDim.x) {
    x[i] = 3.14159 * i;
  }
}

void CheckKernelOutput(float *x, int n) {
  auto host_x = std::unique_ptr<float[]>(new float[n]);
  for (int i = 0; i < n; ++i) {
52
#ifdef PADDLE_WITH_HIP
53 54 55
    EXPECT_TRUE(
        hipSuccess ==
        hipMemcpy(host_x.get(), x, n * sizeof(float), hipMemcpyDeviceToHost));
56
#else
57 58 59
    EXPECT_TRUE(
        cudaSuccess ==
        cudaMemcpy(host_x.get(), x, n * sizeof(float), cudaMemcpyDeviceToHost));
60
#endif
61 62 63 64 65
    EXPECT_GE(host_x[i] + DELTA, 3.14159f * i);
    EXPECT_LE(host_x[i] - DELTA, 3.14159f * i);
  }
}

66 67
void MultiStreamCompute(float **data,
                        float **second_data,
L
Leo Chen 已提交
68
                        const phi::GPUContext &ctx) {
69
  // multi-streams
70 71 72 73
  AllocationPtr allocation_ptr =
      Alloc(ctx.GetPlace(),
            N * sizeof(float),
            phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
74 75
  EXPECT_GE(allocation_ptr->size(), N * sizeof(float));
  *data = reinterpret_cast<float *>(allocation_ptr->ptr());
76 77 78
#ifdef PADDLE_WITH_HIP
  hipLaunchKernelGGL((kernel), dim3(1), dim3(64), 0, ctx.stream(), *data, N);
#else
79
  kernel<<<1, 64, 0, ctx.stream()>>>(*data, N);
80
#endif
81 82

  // allocate and compute on same stream again
83 84 85 86
  allocation_ptr =
      Alloc(ctx.GetPlace(),
            N * sizeof(float),
            phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
87 88
  EXPECT_GE(allocation_ptr->size(), N * sizeof(float));
  *second_data = reinterpret_cast<float *>(allocation_ptr->ptr());
89
#ifdef PADDLE_WITH_HIP
90 91
  hipLaunchKernelGGL(
      (kernel), dim3(1), dim3(64), 0, ctx.stream(), *second_data, N);
92
#else
93
  kernel<<<1, 64, 0, ctx.stream()>>>(*second_data, N);
94
#endif
95 96
}

L
Leo Chen 已提交
97
TEST(Malloc, GPUContextMultiStream) {
98
  auto place = platform::CUDAPlace(0);
L
Leo Chen 已提交
99
  platform::SetDeviceId(0);
100 101 102 103 104 105 106 107 108 109

  AllocationPtr main_stream_alloc_ptr = Alloc(place, N * sizeof(float));
  EXPECT_GE(main_stream_alloc_ptr->size(), N * sizeof(float));
  float *main_stream_data =
      reinterpret_cast<float *>(main_stream_alloc_ptr->ptr());

  float *data[NUM_STREAMS];
  float *second_data[NUM_STREAMS];
  CudaDevCtxVec dev_ctx;

110 111 112 113
// default stream
#ifdef PADDLE_WITH_HIP
  hipLaunchKernelGGL((kernel), dim3(1), dim3(64), 0, 0, main_stream_data, N);
#else
114
  kernel<<<1, 64>>>(main_stream_data, N);
115
#endif
116 117 118
  main_stream_alloc_ptr.reset();

  for (int i = 0; i < NUM_STREAMS; ++i) {
L
Leo Chen 已提交
119
    auto ctx = std::unique_ptr<phi::GPUContext>(new phi::GPUContext(place));
W
Wilber 已提交
120 121 122 123 124 125 126 127 128 129 130
    ctx->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
                          .GetAllocator(place, ctx->stream())
                          .get());
    ctx->SetHostAllocator(
        paddle::memory::allocation::AllocatorFacade::Instance()
            .GetAllocator(paddle::platform::CPUPlace())
            .get());
    ctx->SetZeroAllocator(
        paddle::memory::allocation::AllocatorFacade::Instance()
            .GetZeroAllocator(place)
            .get());
W
wanghuancoder 已提交
131 132 133 134
    ctx->SetPinnedAllocator(
        paddle::memory::allocation::AllocatorFacade::Instance()
            .GetAllocator(paddle::platform::CUDAPinnedPlace())
            .get());
W
Wilber 已提交
135 136
    ctx->PartialInitWithAllocator();
    dev_ctx.emplace_back(std::move(ctx));
137 138 139
    MultiStreamCompute(&data[i], &second_data[i], *dev_ctx[i]);
  }

140 141 142
#ifdef PADDLE_WITH_HIP
  EXPECT_TRUE(hipSuccess == hipDeviceSynchronize());
#else
143
  EXPECT_TRUE(cudaSuccess == cudaDeviceSynchronize());
144
#endif
145 146 147 148 149 150
  for (int i = 0; i < NUM_STREAMS; ++i) {
    CheckKernelOutput(data[i], N);
    CheckKernelOutput(second_data[i], N);
  }
}

L
Leo Chen 已提交
151
TEST(Malloc, GPUContextMultiThreadMultiStream) {
152
  auto place = platform::CUDAPlace(0);
L
Leo Chen 已提交
153
  platform::SetDeviceId(0);
154 155 156 157 158 159 160 161 162 163 164

  AllocationPtr main_stream_alloc_ptr = Alloc(place, N * sizeof(float));
  EXPECT_GE(main_stream_alloc_ptr->size(), N * sizeof(float));
  float *main_stream_data =
      reinterpret_cast<float *>(main_stream_alloc_ptr->ptr());

  float *data[NUM_STREAMS];
  float *second_data[NUM_STREAMS];
  CudaDevCtxVec dev_ctx;
  std::vector<std::thread> threads;

165 166 167 168
// default stream
#ifdef PADDLE_WITH_HIP
  hipLaunchKernelGGL((kernel), dim3(1), dim3(64), 0, 0, main_stream_data, N);
#else
169
  kernel<<<1, 64>>>(main_stream_data, N);
170
#endif
171 172 173
  main_stream_alloc_ptr.reset();

  for (int i = 0; i < NUM_STREAMS; ++i) {
L
Leo Chen 已提交
174
    auto ctx = std::unique_ptr<phi::GPUContext>(new phi::GPUContext(place));
W
Wilber 已提交
175 176 177 178 179 180 181 182 183 184 185
    ctx->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
                          .GetAllocator(place, ctx->stream())
                          .get());
    ctx->SetHostAllocator(
        paddle::memory::allocation::AllocatorFacade::Instance()
            .GetAllocator(paddle::platform::CPUPlace())
            .get());
    ctx->SetZeroAllocator(
        paddle::memory::allocation::AllocatorFacade::Instance()
            .GetZeroAllocator(place)
            .get());
W
wanghuancoder 已提交
186 187 188 189
    ctx->SetPinnedAllocator(
        paddle::memory::allocation::AllocatorFacade::Instance()
            .GetAllocator(paddle::platform::CUDAPinnedPlace())
            .get());
W
Wilber 已提交
190 191
    ctx->PartialInitWithAllocator();
    dev_ctx.emplace_back(std::move(ctx));
192 193
    threads.push_back(std::thread(
        MultiStreamCompute, &data[i], &second_data[i], std::cref(*dev_ctx[i])));
194 195 196 197 198
  }

  for (int i = 0; i < NUM_STREAMS; ++i) {
    threads[i].join();
  }
199 200 201
#ifdef PADDLE_WITH_HIP
  EXPECT_TRUE(hipSuccess == hipDeviceSynchronize());
#else
202
  EXPECT_TRUE(cudaSuccess == cudaDeviceSynchronize());
203
#endif
204 205 206 207 208 209 210 211 212 213 214
  for (int i = 0; i < NUM_STREAMS; ++i) {
    CheckKernelOutput(data[i], N);
    CheckKernelOutput(second_data[i], N);
  }
}

TEST(Malloc, AllocZero) {
  auto place = platform::CUDAPlace(0);
  AllocationPtr allocation_ptr = Alloc(place, 0);
  EXPECT_GE(allocation_ptr->size(), 0);
}
215 216 217 218 219 220 221

TEST(Malloc, AllocWithStream) {
  size_t size = 1024;
  AllocationPtr allocation = Alloc(platform::CUDAPlace(), size, phi::Stream(0));
  EXPECT_EQ(allocation->size(), 1024);
}

222 223
}  // namespace memory
}  // namespace paddle