system_allocator.cc 6.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
D
dzhwinter 已提交
14
#define GLOG_NO_ABBREVIATED_SEVERITIES
15

Y
Yi Wang 已提交
16
#include "paddle/fluid/memory/detail/system_allocator.h"
17

D
dzhwinter 已提交
18 19 20 21
#ifdef _WIN32
#include <malloc.h>
#include <windows.h>  // VirtualLock/VirtualUnlock
#else
22
#include <sys/mman.h>  // for mlock and munlock
D
dzhwinter 已提交
23 24 25
#endif
#include <stdlib.h>   // for malloc and free
#include <algorithm>  // for std::max
26 27

#include "gflags/gflags.h"
Y
Yi Wang 已提交
28 29 30
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
31 32 33
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
34

S
sneaxiy 已提交
35
DECLARE_bool(use_pinned_memory);
36
DECLARE_double(fraction_of_gpu_memory_to_use);
37 38
DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
Z
zhhsplendid 已提交
39

40 41 42 43
namespace paddle {
namespace memory {
namespace detail {

D
dzhwinter 已提交
44
void* AlignedMalloc(size_t size) {
G
gongweibao 已提交
45
  void* p = nullptr;
D
dzhwinter 已提交
46
  size_t alignment = 32ul;
T
tensor-tang 已提交
47
#ifdef PADDLE_WITH_MKLDNN
48 49
  // refer to https://github.com/01org/mkl-dnn/blob/master/include/mkldnn.hpp
  // memory alignment
D
dzhwinter 已提交
50 51 52 53
  alignment = 4096ul;
#endif
#ifdef _WIN32
  p = _aligned_malloc(size, alignment);
54
#else
D
dzhwinter 已提交
55
  PADDLE_ENFORCE_EQ(posix_memalign(&p, alignment, size), 0, "Alloc %ld error!",
G
gongweibao 已提交
56
                    size);
57
#endif
58
  PADDLE_ENFORCE_NOT_NULL(p, "Fail to allocate CPU memory: size = %d .", size);
D
dzhwinter 已提交
59 60 61 62 63 64 65 66 67 68 69 70
  return p;
}

void* CPUAllocator::Alloc(size_t* index, size_t size) {
  // According to http://www.cplusplus.com/reference/cstdlib/malloc/,
  // malloc might not return nullptr if size is zero, but the returned
  // pointer shall not be dereferenced -- so we make it nullptr.
  if (size <= 0) return nullptr;

  *index = 0;  // unlock memory

  void* p = AlignedMalloc(size);
71 72 73

  if (p != nullptr) {
    if (FLAGS_use_pinned_memory) {
Y
Yi Wang 已提交
74
      *index = 1;
D
dzhwinter 已提交
75 76 77
#ifdef _WIN32
      VirtualLock(p, size);
#else
78
      mlock(p, size);  // lock memory
D
dzhwinter 已提交
79
#endif
80
    }
81
  }
82

83 84 85
  return p;
}

L
liaogang 已提交
86
void CPUAllocator::Free(void* p, size_t size, size_t index) {
87
  if (p != nullptr && index == 1) {
D
dzhwinter 已提交
88 89 90
#ifdef _WIN32
    VirtualUnlock(p, size);
#else
91
    munlock(p, size);
D
dzhwinter 已提交
92
#endif
93
  }
P
peizhilin 已提交
94 95 96
#ifdef _WIN32
  _aligned_free(p);
#else
97
  free(p);
P
peizhilin 已提交
98
#endif
99 100
}

L
liaogang 已提交
101
bool CPUAllocator::UseGpu() const { return false; }
L
liaogang 已提交
102

103
#ifdef PADDLE_WITH_CUDA
104

Y
Yi Wang 已提交
105
void* GPUAllocator::Alloc(size_t* index, size_t size) {
106 107
  // CUDA documentation doesn't explain if cudaMalloc returns nullptr
  // if size is 0.  We just make sure it does.
L
liaogang 已提交
108
  if (size <= 0) return nullptr;
Y
Yu Yang 已提交
109

110
  paddle::platform::CUDADeviceGuard guard(gpu_id_);
Y
Yu Yang 已提交
111

112 113
  void* p;
  cudaError_t result = cudaMalloc(&p, size);
Y
Yu Yang 已提交
114

L
liaogang 已提交
115
  if (result == cudaSuccess) {
Y
Yi Wang 已提交
116
    *index = 0;
117
    gpu_alloc_size_ += size;
L
liaogang 已提交
118
    return p;
119
  } else {
Z
zhhsplendid 已提交
120 121 122
    LOG(WARNING) << "Cannot malloc " << size / 1024.0 / 1024.0
                 << " MB GPU memory. Please shrink "
                    "FLAGS_fraction_of_gpu_memory_to_use or "
123 124
                    "FLAGS_initial_gpu_memory_in_mb or "
                    "FLAGS_reallocate_gpu_memory_in_mb"
Z
zhhsplendid 已提交
125 126 127
                    "environment variable to a lower value. "
                 << "Current FLAGS_fraction_of_gpu_memory_to_use value is "
                 << FLAGS_fraction_of_gpu_memory_to_use
128 129 130 131
                 << ". Current FLAGS_initial_gpu_memory_in_mb value is "
                 << FLAGS_initial_gpu_memory_in_mb
                 << ". Current FLAGS_reallocate_gpu_memory_in_mb value is "
                 << FLAGS_reallocate_gpu_memory_in_mb;
132
    return nullptr;
L
liaogang 已提交
133
  }
134 135
}

L
liaogang 已提交
136
void GPUAllocator::Free(void* p, size_t size, size_t index) {
137
  cudaError_t err;
138 139 140 141
  PADDLE_ENFORCE_EQ(index, 0);
  PADDLE_ENFORCE_GE(gpu_alloc_size_, size);
  gpu_alloc_size_ -= size;
  err = cudaFree(p);
142

143 144 145 146 147 148
  // Purposefully allow cudaErrorCudartUnloading, because
  // that is returned if you ever call cudaFree after the
  // driver has already shutdown. This happens only if the
  // process is terminating, in which case we don't care if
  // cudaFree succeeds.
  if (err != cudaErrorCudartUnloading) {
L
liaogang 已提交
149
    PADDLE_ENFORCE(err, "cudaFree{Host} failed in GPUAllocator::Free.");
150 151 152
  }
}

L
liaogang 已提交
153
bool GPUAllocator::UseGpu() const { return true; }
L
liaogang 已提交
154

C
chengduoZH 已提交
155 156
// PINNED memory allows direct DMA transfers by the GPU to and from system
// memory. It’s locked to a physical address.
Y
Yi Wang 已提交
157
void* CUDAPinnedAllocator::Alloc(size_t* index, size_t size) {
C
chengduoZH 已提交
158
  if (size <= 0) return nullptr;
C
chengduoZH 已提交
159

160
  // NOTE: here, we use CUDAPinnedMaxAllocSize as the maximum memory size
C
chengduoZH 已提交
161
  // of host pinned allocation. Allocates too much would reduce
C
chengduoZH 已提交
162
  // the amount of memory available to the underlying system for paging.
C
chengduoZH 已提交
163
  size_t usable =
164
      paddle::platform::CUDAPinnedMaxAllocSize() - cuda_pinnd_alloc_size_;
C
chengduoZH 已提交
165

C
chengduoZH 已提交
166 167 168 169 170 171
  if (size > usable) {
    LOG(WARNING) << "Cannot malloc " << size / 1024.0 / 1024.0
                 << " MB pinned memory."
                 << ", available " << usable / 1024.0 / 1024.0 << " MB";
    return nullptr;
  }
C
chengduoZH 已提交
172

C
chengduoZH 已提交
173
  void* p;
C
chengduoZH 已提交
174
  // PINNED memory is visible to all CUDA contexts.
D
Dun Liang 已提交
175
  cudaError_t result = cudaHostAlloc(&p, size, cudaHostAllocPortable);
C
chengduoZH 已提交
176

C
chengduoZH 已提交
177
  if (result == cudaSuccess) {
Y
Yi Wang 已提交
178
    *index = 1;  // PINNED memory
C
chengduoZH 已提交
179
    cuda_pinnd_alloc_size_ += size;
C
chengduoZH 已提交
180
    return p;
C
chengduoZH 已提交
181
  } else {
D
Dun Liang 已提交
182
    LOG(WARNING) << "cudaHostAlloc failed.";
C
chengduoZH 已提交
183
    return nullptr;
C
chengduoZH 已提交
184 185 186 187 188 189 190
  }

  return nullptr;
}

void CUDAPinnedAllocator::Free(void* p, size_t size, size_t index) {
  cudaError_t err;
191
  PADDLE_ENFORCE_EQ(index, 1);
C
chengduoZH 已提交
192

193
  PADDLE_ENFORCE_GE(cuda_pinnd_alloc_size_, size);
C
chengduoZH 已提交
194
  cuda_pinnd_alloc_size_ -= size;
C
chengduoZH 已提交
195 196 197
  err = cudaFreeHost(p);

  // Purposefully allow cudaErrorCudartUnloading, because
C
chengduoZH 已提交
198
  // that is returned if you ever call cudaFreeHost after the
C
chengduoZH 已提交
199 200
  // driver has already shutdown. This happens only if the
  // process is terminating, in which case we don't care if
C
chengduoZH 已提交
201
  // cudaFreeHost succeeds.
C
chengduoZH 已提交
202 203 204 205 206
  if (err != cudaErrorCudartUnloading) {
    PADDLE_ENFORCE(err, "cudaFreeHost failed in GPUPinnedAllocator::Free.");
  }
}

C
chengduoZH 已提交
207
bool CUDAPinnedAllocator::UseGpu() const { return false; }
C
chengduoZH 已提交
208

L
Luo Tao 已提交
209
#endif
210 211 212 213

}  // namespace detail
}  // namespace memory
}  // namespace paddle