malloc_test.cu 6.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#ifdef PADDLE_WITH_CUDA
16 17
#include <cuda.h>
#include <cuda_runtime.h>
18 19 20 21 22 23
#endif

#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif

24 25 26 27
#include <thread>  // NOLINT
#include <vector>

#include "gtest/gtest.h"
W
Wilber 已提交
28
#include "paddle/fluid/memory/allocation/allocator_facade.h"
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device_context.h"

namespace paddle {
namespace memory {

const int NUM_STREAMS = 8;
const int N = 2;
const float DELTA = 1e-1;

using CudaDevCtxVec = std::vector<std::unique_ptr<platform::CUDADeviceContext>>;

__global__ void kernel(float *x, int n) {
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  for (int i = tid; i < n; i += blockDim.x * gridDim.x) {
    x[i] = 3.14159 * i;
  }
}

void CheckKernelOutput(float *x, int n) {
  auto host_x = std::unique_ptr<float[]>(new float[n]);
  for (int i = 0; i < n; ++i) {
51 52 53 54
#ifdef PADDLE_WITH_HIP
    EXPECT_TRUE(hipSuccess == hipMemcpy(host_x.get(), x, n * sizeof(float),
                                        hipMemcpyDeviceToHost));
#else
55 56
    EXPECT_TRUE(cudaSuccess == cudaMemcpy(host_x.get(), x, n * sizeof(float),
                                          cudaMemcpyDeviceToHost));
57
#endif
58 59 60 61 62 63 64 65 66 67 68
    EXPECT_GE(host_x[i] + DELTA, 3.14159f * i);
    EXPECT_LE(host_x[i] - DELTA, 3.14159f * i);
  }
}

void MultiStreamCompute(float **data, float **second_data,
                        const platform::CUDADeviceContext &ctx) {
  // multi-streams
  AllocationPtr allocation_ptr = Alloc(ctx, N * sizeof(float));
  EXPECT_GE(allocation_ptr->size(), N * sizeof(float));
  *data = reinterpret_cast<float *>(allocation_ptr->ptr());
69 70 71
#ifdef PADDLE_WITH_HIP
  hipLaunchKernelGGL((kernel), dim3(1), dim3(64), 0, ctx.stream(), *data, N);
#else
72
  kernel<<<1, 64, 0, ctx.stream()>>>(*data, N);
73
#endif
74 75 76 77 78

  // allocate and compute on same stream again
  allocation_ptr = Alloc(ctx, N * sizeof(float));
  EXPECT_GE(allocation_ptr->size(), N * sizeof(float));
  *second_data = reinterpret_cast<float *>(allocation_ptr->ptr());
79 80 81 82
#ifdef PADDLE_WITH_HIP
  hipLaunchKernelGGL((kernel), dim3(1), dim3(64), 0, ctx.stream(), *second_data,
                     N);
#else
83
  kernel<<<1, 64, 0, ctx.stream()>>>(*second_data, N);
84
#endif
85 86 87 88
}

TEST(Malloc, CUDADeviceContextMultiStream) {
  auto place = platform::CUDAPlace(0);
L
Leo Chen 已提交
89
  platform::SetDeviceId(0);
90 91 92 93 94 95 96 97 98 99

  AllocationPtr main_stream_alloc_ptr = Alloc(place, N * sizeof(float));
  EXPECT_GE(main_stream_alloc_ptr->size(), N * sizeof(float));
  float *main_stream_data =
      reinterpret_cast<float *>(main_stream_alloc_ptr->ptr());

  float *data[NUM_STREAMS];
  float *second_data[NUM_STREAMS];
  CudaDevCtxVec dev_ctx;

100 101 102 103
// default stream
#ifdef PADDLE_WITH_HIP
  hipLaunchKernelGGL((kernel), dim3(1), dim3(64), 0, 0, main_stream_data, N);
#else
104
  kernel<<<1, 64>>>(main_stream_data, N);
105
#endif
106 107 108
  main_stream_alloc_ptr.reset();

  for (int i = 0; i < NUM_STREAMS; ++i) {
W
Wilber 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
    auto ctx = std::unique_ptr<platform::CUDADeviceContext>(
        new platform::CUDADeviceContext(place));
    ctx->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
                          .GetAllocator(place, ctx->stream())
                          .get());
    ctx->SetHostAllocator(
        paddle::memory::allocation::AllocatorFacade::Instance()
            .GetAllocator(paddle::platform::CPUPlace())
            .get());
    ctx->SetZeroAllocator(
        paddle::memory::allocation::AllocatorFacade::Instance()
            .GetZeroAllocator(place)
            .get());
    ctx->PartialInitWithAllocator();
    dev_ctx.emplace_back(std::move(ctx));
124 125 126
    MultiStreamCompute(&data[i], &second_data[i], *dev_ctx[i]);
  }

127 128 129
#ifdef PADDLE_WITH_HIP
  EXPECT_TRUE(hipSuccess == hipDeviceSynchronize());
#else
130
  EXPECT_TRUE(cudaSuccess == cudaDeviceSynchronize());
131
#endif
132 133 134 135 136 137 138 139
  for (int i = 0; i < NUM_STREAMS; ++i) {
    CheckKernelOutput(data[i], N);
    CheckKernelOutput(second_data[i], N);
  }
}

TEST(Malloc, CUDADeviceContextMultiThreadMultiStream) {
  auto place = platform::CUDAPlace(0);
L
Leo Chen 已提交
140
  platform::SetDeviceId(0);
141 142 143 144 145 146 147 148 149 150 151

  AllocationPtr main_stream_alloc_ptr = Alloc(place, N * sizeof(float));
  EXPECT_GE(main_stream_alloc_ptr->size(), N * sizeof(float));
  float *main_stream_data =
      reinterpret_cast<float *>(main_stream_alloc_ptr->ptr());

  float *data[NUM_STREAMS];
  float *second_data[NUM_STREAMS];
  CudaDevCtxVec dev_ctx;
  std::vector<std::thread> threads;

152 153 154 155
// default stream
#ifdef PADDLE_WITH_HIP
  hipLaunchKernelGGL((kernel), dim3(1), dim3(64), 0, 0, main_stream_data, N);
#else
156
  kernel<<<1, 64>>>(main_stream_data, N);
157
#endif
158 159 160
  main_stream_alloc_ptr.reset();

  for (int i = 0; i < NUM_STREAMS; ++i) {
W
Wilber 已提交
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
    auto ctx = std::unique_ptr<platform::CUDADeviceContext>(
        new platform::CUDADeviceContext(place));
    ctx->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
                          .GetAllocator(place, ctx->stream())
                          .get());
    ctx->SetHostAllocator(
        paddle::memory::allocation::AllocatorFacade::Instance()
            .GetAllocator(paddle::platform::CPUPlace())
            .get());
    ctx->SetZeroAllocator(
        paddle::memory::allocation::AllocatorFacade::Instance()
            .GetZeroAllocator(place)
            .get());
    ctx->PartialInitWithAllocator();
    dev_ctx.emplace_back(std::move(ctx));
176 177 178 179 180 181 182
    threads.push_back(std::thread(MultiStreamCompute, &data[i], &second_data[i],
                                  std::cref(*dev_ctx[i])));
  }

  for (int i = 0; i < NUM_STREAMS; ++i) {
    threads[i].join();
  }
183 184 185
#ifdef PADDLE_WITH_HIP
  EXPECT_TRUE(hipSuccess == hipDeviceSynchronize());
#else
186
  EXPECT_TRUE(cudaSuccess == cudaDeviceSynchronize());
187
#endif
188 189 190 191 192 193 194 195 196 197 198 199 200
  for (int i = 0; i < NUM_STREAMS; ++i) {
    CheckKernelOutput(data[i], N);
    CheckKernelOutput(second_data[i], N);
  }
}

TEST(Malloc, AllocZero) {
  auto place = platform::CUDAPlace(0);
  AllocationPtr allocation_ptr = Alloc(place, 0);
  EXPECT_GE(allocation_ptr->size(), 0);
}
}  // namespace memory
}  // namespace paddle