pinned_memory_test.cu 6.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
C
chengduoZH 已提交
14 15
#include <gtest/gtest.h>
#include <unordered_map>
16 17 18 19 20 21 22 23 24

#include "paddle/fluid/memory/detail/memory_block.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"

#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h"

C
chengduoZH 已提交
25 26
// This unit test is an example comparing the performance between using pinned
// memory and not. In general, using pinned memory will be faster.
27 28 29 30 31 32 33 34 35
template <typename T>
__global__ void Kernel(T* output, int dim) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid < dim) {
    output[tid] = output[tid] * output[tid] / 100;
  }
}

template <typename Place>
C
chengduoZH 已提交
36
float test_pinned_memory() {
37 38 39 40 41 42 43
  Place cpu_place;
  paddle::platform::CUDAPlace cuda_place;

  const int data_size = 4096;
  const int iteration = 10;

  // create event start and end
44
  gpuEvent_t start_e, stop_e, copying_e;
45
  float elapsedTime = 0;
46 47 48 49 50 51

#ifdef PADDLE_WITH_HIP
  hipEventCreate(&start_e);
  hipEventCreate(&stop_e);
  hipEventCreate(&copying_e);
#else
52 53 54
  cudaEventCreate(&start_e);
  cudaEventCreate(&stop_e);
  cudaEventCreate(&copying_e);
55
#endif
56 57

  // create computation stream, data copying stream
58 59 60 61 62
  gpuStream_t computation_stream, copying_stream;
#ifdef PADDLE_WITH_HIP
  hipStreamCreate(&computation_stream);
  hipStreamCreate(&copying_stream);
#else
63 64
  cudaStreamCreate(&computation_stream);
  cudaStreamCreate(&copying_stream);
65
#endif
66 67

  // create record event, pinned memory, gpu memory
68
  std::vector<gpuEvent_t> record_event(iteration);
69 70 71 72 73 74
  std::vector<float*> input_pinned_mem(iteration);
  std::vector<float*> gpu_mem(iteration);
  std::vector<float*> output_pinned_mem(iteration);

  // initial data
  for (int j = 0; j < iteration; ++j) {
75 76 77 78
#ifdef PADDLE_WITH_HIP
    hipEventCreateWithFlags(&record_event[j], hipEventDisableTiming);
    hipEventCreate(&(record_event[j]));
#else
79 80
    cudaEventCreateWithFlags(&record_event[j], cudaEventDisableTiming);
    cudaEventCreate(&(record_event[j]));
81
#endif
82 83 84 85 86 87 88 89 90 91 92 93
    input_pinned_mem[j] = static_cast<float*>(
        paddle::memory::Alloc(cpu_place, data_size * sizeof(float)));
    output_pinned_mem[j] = static_cast<float*>(
        paddle::memory::Alloc(cpu_place, data_size * sizeof(float)));
    gpu_mem[j] = static_cast<float*>(
        paddle::memory::Alloc(cuda_place, data_size * sizeof(float)));

    for (int k = 0; k < data_size; ++k) {
      input_pinned_mem[j][k] = k;
    }
  }

94 95 96
#ifdef PADDLE_WITH_HIP
  hipEventRecord(start_e, computation_stream);
#else
97
  cudaEventRecord(start_e, computation_stream);
98
#endif
99 100 101 102 103 104 105 106 107 108 109 110 111

  // computation
  for (int m = 0; m < 30; ++m) {
    for (int i = 0; i < iteration; ++i) {
      // cpu -> GPU on computation stream.
      // note: this operation is async for pinned memory.
      paddle::memory::Copy(cuda_place, gpu_mem[i], cpu_place,
                           input_pinned_mem[i], data_size * sizeof(float),
                           computation_stream);

      // call kernel on computation stream.
      Kernel<<<4, 1024, 0, computation_stream>>>(gpu_mem[i], data_size);

112 113 114 115 116 117 118 119
#ifdef PADDLE_WITH_HIP
      // record event_computation on computation stream
      hipEventRecord(record_event[i], computation_stream);

      // wait event_computation on copy stream.
      // note: this operation is async.
      hipStreamWaitEvent(copying_stream, record_event[i], 0);
#else
120 121 122 123 124 125
      // record event_computation on computation stream
      cudaEventRecord(record_event[i], computation_stream);

      // wait event_computation on copy stream.
      // note: this operation is async.
      cudaStreamWaitEvent(copying_stream, record_event[i], 0);
126
#endif
127 128 129 130 131 132 133 134
      // copy data GPU->CPU, on copy stream.
      // note: this operation is async for pinned memory.
      paddle::memory::Copy(cpu_place, output_pinned_mem[i], cuda_place,
                           gpu_mem[i], data_size * sizeof(float),
                           copying_stream);
    }
  }

135 136 137 138 139 140 141 142 143 144
#ifdef PADDLE_WITH_HIP
  hipEventRecord(copying_e, copying_stream);
  hipStreamWaitEvent(computation_stream, copying_e, 0);

  hipEventRecord(stop_e, computation_stream);

  hipEventSynchronize(start_e);
  hipEventSynchronize(stop_e);
  hipEventElapsedTime(&elapsedTime, start_e, stop_e);
#else
145 146 147 148 149 150 151 152
  cudaEventRecord(copying_e, copying_stream);
  cudaStreamWaitEvent(computation_stream, copying_e, 0);

  cudaEventRecord(stop_e, computation_stream);

  cudaEventSynchronize(start_e);
  cudaEventSynchronize(stop_e);
  cudaEventElapsedTime(&elapsedTime, start_e, stop_e);
153
#endif
154

C
chengduoZH 已提交
155 156
  // std::cout << cpu_place << " "
  //          << "time consume:" << elapsedTime / 30 << std::endl;
157 158 159 160 161 162 163 164 165

  for (int l = 0; l < iteration; ++l) {
    for (int k = 0; k < data_size; ++k) {
      float temp = input_pinned_mem[l][k];
      temp = temp * temp / 100;
      EXPECT_FLOAT_EQ(temp, output_pinned_mem[l][k]);
    }
  }

166 167 168 169 170 171
// destroy resource
#ifdef PADDLE_WITH_HIP
  hipEventDestroy(copying_e);
  hipEventDestroy(start_e);
  hipEventDestroy(stop_e);
#else
172 173 174
  cudaEventDestroy(copying_e);
  cudaEventDestroy(start_e);
  cudaEventDestroy(stop_e);
175
#endif
176
  for (int j = 0; j < 10; ++j) {
177 178 179
#ifdef PADDLE_WITH_HIP
    hipEventDestroy((record_event[j]));
#else
180
    cudaEventDestroy((record_event[j]));
181
#endif
182 183 184 185
    paddle::memory::Free(cpu_place, input_pinned_mem[j]);
    paddle::memory::Free(cpu_place, output_pinned_mem[j]);
    paddle::memory::Free(cuda_place, gpu_mem[j]);
  }
C
chengduoZH 已提交
186
  return elapsedTime / 30;
187 188
}

C
chengduoZH 已提交
189 190 191 192 193 194
TEST(CPUANDCUDAPinned, CPUAllocatorAndCUDAPinnedAllocator) {
  // Generally speaking, operation on pinned_memory is faster than that on
  // unpinned-memory, but if this unit test fails frequently, please close this
  // test for the time being.
  float time1 = test_pinned_memory<paddle::platform::CPUPlace>();
  float time2 = test_pinned_memory<paddle::platform::CUDAPinnedPlace>();
C
chengduoZH 已提交
195
  EXPECT_GT(time1, time2);
196
}