system_allocator.cc 6.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
D
dzhwinter 已提交
14
#define GLOG_NO_ABBREVIATED_SEVERITIES
15

Y
Yi Wang 已提交
16
#include "paddle/fluid/memory/detail/system_allocator.h"
17

D
dzhwinter 已提交
18 19 20 21
#ifdef _WIN32
#include <malloc.h>
#include <windows.h>  // VirtualLock/VirtualUnlock
#else
22
#include <sys/mman.h>  // for mlock and munlock
D
dzhwinter 已提交
23 24 25
#endif
#include <stdlib.h>   // for malloc and free
#include <algorithm>  // for std::max
26 27

#include "gflags/gflags.h"
Y
Yi Wang 已提交
28 29 30 31
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
32

S
sneaxiy 已提交
33
DECLARE_bool(use_pinned_memory);
34
DECLARE_double(fraction_of_gpu_memory_to_use);
Z
zhhsplendid 已提交
35 36 37
DECLARE_uint64(gpu_init_memory_in_mb);
DECLARE_uint64(gpu_reallocate_memory_in_mb);

38 39 40 41
namespace paddle {
namespace memory {
namespace detail {

D
dzhwinter 已提交
42
void* AlignedMalloc(size_t size) {
G
gongweibao 已提交
43
  void* p = nullptr;
D
dzhwinter 已提交
44
  size_t alignment = 32ul;
T
tensor-tang 已提交
45
#ifdef PADDLE_WITH_MKLDNN
46 47
  // refer to https://github.com/01org/mkl-dnn/blob/master/include/mkldnn.hpp
  // memory alignment
D
dzhwinter 已提交
48 49 50 51
  alignment = 4096ul;
#endif
#ifdef _WIN32
  p = _aligned_malloc(size, alignment);
52
#else
D
dzhwinter 已提交
53
  PADDLE_ENFORCE_EQ(posix_memalign(&p, alignment, size), 0, "Alloc %ld error!",
G
gongweibao 已提交
54
                    size);
55 56
#endif
  PADDLE_ENFORCE(p, "Fail to allocate CPU memory: size = %d .", size);
D
dzhwinter 已提交
57 58 59 60 61 62 63 64 65 66 67 68
  return p;
}

void* CPUAllocator::Alloc(size_t* index, size_t size) {
  // According to http://www.cplusplus.com/reference/cstdlib/malloc/,
  // malloc might not return nullptr if size is zero, but the returned
  // pointer shall not be dereferenced -- so we make it nullptr.
  if (size <= 0) return nullptr;

  *index = 0;  // unlock memory

  void* p = AlignedMalloc(size);
69 70 71

  if (p != nullptr) {
    if (FLAGS_use_pinned_memory) {
Y
Yi Wang 已提交
72
      *index = 1;
D
dzhwinter 已提交
73 74 75
#ifdef _WIN32
      VirtualLock(p, size);
#else
76
      mlock(p, size);  // lock memory
D
dzhwinter 已提交
77
#endif
78
    }
79
  }
80

81 82 83
  return p;
}

L
liaogang 已提交
84
void CPUAllocator::Free(void* p, size_t size, size_t index) {
85
  if (p != nullptr && index == 1) {
D
dzhwinter 已提交
86 87 88
#ifdef _WIN32
    VirtualUnlock(p, size);
#else
89
    munlock(p, size);
D
dzhwinter 已提交
90
#endif
91
  }
P
peizhilin 已提交
92 93 94
#ifdef _WIN32
  _aligned_free(p);
#else
95
  free(p);
P
peizhilin 已提交
96
#endif
97 98
}

L
liaogang 已提交
99
bool CPUAllocator::UseGpu() const { return false; }
L
liaogang 已提交
100

101
#ifdef PADDLE_WITH_CUDA
102

Y
Yi Wang 已提交
103
void* GPUAllocator::Alloc(size_t* index, size_t size) {
104 105
  // CUDA documentation doesn't explain if cudaMalloc returns nullptr
  // if size is 0.  We just make sure it does.
L
liaogang 已提交
106
  if (size <= 0) return nullptr;
107
  void* p;
Y
Yu Yang 已提交
108 109 110 111 112 113
  int prev_id;
  cudaGetDevice(&prev_id);
  if (prev_id != gpu_id_) {
    cudaSetDevice(gpu_id_);
  }

114
  cudaError_t result = cudaMalloc(&p, size);
Y
Yu Yang 已提交
115 116 117 118 119

  if (prev_id != gpu_id_) {
    cudaSetDevice(prev_id);
  }

L
liaogang 已提交
120
  if (result == cudaSuccess) {
Y
Yi Wang 已提交
121
    *index = 0;
122
    gpu_alloc_size_ += size;
L
liaogang 已提交
123
    return p;
124
  } else {
Z
zhhsplendid 已提交
125 126 127 128 129 130 131 132 133 134 135 136
    LOG(WARNING) << "Cannot malloc " << size / 1024.0 / 1024.0
                 << " MB GPU memory. Please shrink "
                    "FLAGS_fraction_of_gpu_memory_to_use or "
                    "FLAGS_gpu_init_memory_in_mb or "
                    "FLAGS_gpu_reallocate_memory_in_mb"
                    "environment variable to a lower value. "
                 << "Current FLAGS_fraction_of_gpu_memory_to_use value is "
                 << FLAGS_fraction_of_gpu_memory_to_use
                 << ". Current FLAGS_gpu_init_memory_in_mb value is "
                 << FLAGS_gpu_init_memory_in_mb
                 << ". Current FLAGS_gpu_reallocate_memory_in_mb value is "
                 << FLAGS_gpu_reallocate_memory_in_mb;
137
    return nullptr;
L
liaogang 已提交
138
  }
139 140
}

L
liaogang 已提交
141
void GPUAllocator::Free(void* p, size_t size, size_t index) {
142 143 144 145 146 147 148 149 150 151 152 153
  cudaError_t err;

  if (index == 0) {
    PADDLE_ASSERT(gpu_alloc_size_ >= size);
    gpu_alloc_size_ -= size;
    err = cudaFree(p);
  } else {
    PADDLE_ASSERT(fallback_alloc_size_ >= size);
    fallback_alloc_size_ -= size;
    err = cudaFreeHost(p);
  }

154 155 156 157 158 159
  // Purposefully allow cudaErrorCudartUnloading, because
  // that is returned if you ever call cudaFree after the
  // driver has already shutdown. This happens only if the
  // process is terminating, in which case we don't care if
  // cudaFree succeeds.
  if (err != cudaErrorCudartUnloading) {
L
liaogang 已提交
160
    PADDLE_ENFORCE(err, "cudaFree{Host} failed in GPUAllocator::Free.");
161 162 163
  }
}

L
liaogang 已提交
164
bool GPUAllocator::UseGpu() const { return true; }
L
liaogang 已提交
165

C
chengduoZH 已提交
166 167
// PINNED memory allows direct DMA transfers by the GPU to and from system
// memory. It’s locked to a physical address.
Y
Yi Wang 已提交
168
void* CUDAPinnedAllocator::Alloc(size_t* index, size_t size) {
C
chengduoZH 已提交
169
  if (size <= 0) return nullptr;
C
chengduoZH 已提交
170

171
  // NOTE: here, we use CUDAPinnedMaxAllocSize as the maximum memory size
C
chengduoZH 已提交
172
  // of host pinned allocation. Allocates too much would reduce
C
chengduoZH 已提交
173
  // the amount of memory available to the underlying system for paging.
C
chengduoZH 已提交
174
  size_t usable =
175
      paddle::platform::CUDAPinnedMaxAllocSize() - cuda_pinnd_alloc_size_;
C
chengduoZH 已提交
176

C
chengduoZH 已提交
177 178 179 180 181 182
  if (size > usable) {
    LOG(WARNING) << "Cannot malloc " << size / 1024.0 / 1024.0
                 << " MB pinned memory."
                 << ", available " << usable / 1024.0 / 1024.0 << " MB";
    return nullptr;
  }
C
chengduoZH 已提交
183

C
chengduoZH 已提交
184
  void* p;
C
chengduoZH 已提交
185
  // PINNED memory is visible to all CUDA contexts.
D
Dun Liang 已提交
186
  cudaError_t result = cudaHostAlloc(&p, size, cudaHostAllocPortable);
C
chengduoZH 已提交
187

C
chengduoZH 已提交
188
  if (result == cudaSuccess) {
Y
Yi Wang 已提交
189
    *index = 1;  // PINNED memory
C
chengduoZH 已提交
190
    cuda_pinnd_alloc_size_ += size;
C
chengduoZH 已提交
191
    return p;
C
chengduoZH 已提交
192
  } else {
D
Dun Liang 已提交
193
    LOG(WARNING) << "cudaHostAlloc failed.";
C
chengduoZH 已提交
194
    return nullptr;
C
chengduoZH 已提交
195 196 197 198 199 200 201 202 203
  }

  return nullptr;
}

void CUDAPinnedAllocator::Free(void* p, size_t size, size_t index) {
  cudaError_t err;
  PADDLE_ASSERT(index == 1);

C
chengduoZH 已提交
204 205
  PADDLE_ASSERT(cuda_pinnd_alloc_size_ >= size);
  cuda_pinnd_alloc_size_ -= size;
C
chengduoZH 已提交
206 207 208
  err = cudaFreeHost(p);

  // Purposefully allow cudaErrorCudartUnloading, because
C
chengduoZH 已提交
209
  // that is returned if you ever call cudaFreeHost after the
C
chengduoZH 已提交
210 211
  // driver has already shutdown. This happens only if the
  // process is terminating, in which case we don't care if
C
chengduoZH 已提交
212
  // cudaFreeHost succeeds.
C
chengduoZH 已提交
213 214 215 216 217
  if (err != cudaErrorCudartUnloading) {
    PADDLE_ENFORCE(err, "cudaFreeHost failed in GPUPinnedAllocator::Free.");
  }
}

C
chengduoZH 已提交
218
bool CUDAPinnedAllocator::UseGpu() const { return false; }
C
chengduoZH 已提交
219

L
Luo Tao 已提交
220
#endif
221 222 223 224

}  // namespace detail
}  // namespace memory
}  // namespace paddle