system_allocator.cc 6.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
D
dzhwinter 已提交
14
#define GLOG_NO_ABBREVIATED_SEVERITIES
15

Y
Yi Wang 已提交
16
#include "paddle/fluid/memory/detail/system_allocator.h"
17

D
dzhwinter 已提交
18 19 20 21
#ifdef _WIN32
#include <malloc.h>
#include <windows.h>  // VirtualLock/VirtualUnlock
#else
22
#include <sys/mman.h>  // for mlock and munlock
D
dzhwinter 已提交
23 24 25
#endif
#include <stdlib.h>   // for malloc and free
#include <algorithm>  // for std::max
26 27

#include "gflags/gflags.h"
28
#include "paddle/fluid/memory/allocation/allocator.h"
Y
Yi Wang 已提交
29 30 31
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
32 33 34
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
35

S
sneaxiy 已提交
36
DECLARE_bool(use_pinned_memory);
37
DECLARE_double(fraction_of_gpu_memory_to_use);
38 39
DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
Z
zhhsplendid 已提交
40

41 42 43 44
namespace paddle {
namespace memory {
namespace detail {

D
dzhwinter 已提交
45
void* AlignedMalloc(size_t size) {
G
gongweibao 已提交
46
  void* p = nullptr;
D
dzhwinter 已提交
47
  size_t alignment = 32ul;
T
tensor-tang 已提交
48
#ifdef PADDLE_WITH_MKLDNN
49 50
  // refer to https://github.com/01org/mkl-dnn/blob/master/include/mkldnn.hpp
  // memory alignment
D
dzhwinter 已提交
51 52 53 54
  alignment = 4096ul;
#endif
#ifdef _WIN32
  p = _aligned_malloc(size, alignment);
55
#else
D
dzhwinter 已提交
56
  PADDLE_ENFORCE_EQ(posix_memalign(&p, alignment, size), 0, "Alloc %ld error!",
G
gongweibao 已提交
57
                    size);
58
#endif
59
  PADDLE_ENFORCE_NOT_NULL(p, "Fail to allocate CPU memory: size = %d .", size);
D
dzhwinter 已提交
60 61 62 63 64 65 66 67 68 69 70 71
  return p;
}

void* CPUAllocator::Alloc(size_t* index, size_t size) {
  // According to http://www.cplusplus.com/reference/cstdlib/malloc/,
  // malloc might not return nullptr if size is zero, but the returned
  // pointer shall not be dereferenced -- so we make it nullptr.
  if (size <= 0) return nullptr;

  *index = 0;  // unlock memory

  void* p = AlignedMalloc(size);
72 73 74

  if (p != nullptr) {
    if (FLAGS_use_pinned_memory) {
Y
Yi Wang 已提交
75
      *index = 1;
D
dzhwinter 已提交
76 77 78
#ifdef _WIN32
      VirtualLock(p, size);
#else
79
      mlock(p, size);  // lock memory
D
dzhwinter 已提交
80
#endif
81
    }
82
  }
83

84 85 86
  return p;
}

L
liaogang 已提交
87
void CPUAllocator::Free(void* p, size_t size, size_t index) {
88
  if (p != nullptr && index == 1) {
D
dzhwinter 已提交
89 90 91
#ifdef _WIN32
    VirtualUnlock(p, size);
#else
92
    munlock(p, size);
D
dzhwinter 已提交
93
#endif
94
  }
P
peizhilin 已提交
95 96 97
#ifdef _WIN32
  _aligned_free(p);
#else
98
  free(p);
P
peizhilin 已提交
99
#endif
100 101
}

L
liaogang 已提交
102
bool CPUAllocator::UseGpu() const { return false; }
L
liaogang 已提交
103

104
#ifdef PADDLE_WITH_CUDA
105

Y
Yi Wang 已提交
106
void* GPUAllocator::Alloc(size_t* index, size_t size) {
107 108
  // CUDA documentation doesn't explain if cudaMalloc returns nullptr
  // if size is 0.  We just make sure it does.
L
liaogang 已提交
109
  if (size <= 0) return nullptr;
Y
Yu Yang 已提交
110

111
  paddle::platform::CUDADeviceGuard guard(gpu_id_);
Y
Yu Yang 已提交
112

113 114
  void* p;
  cudaError_t result = cudaMalloc(&p, size);
Y
Yu Yang 已提交
115

L
liaogang 已提交
116
  if (result == cudaSuccess) {
Y
Yi Wang 已提交
117
    *index = 0;
118
    gpu_alloc_size_ += size;
L
liaogang 已提交
119
    return p;
120
  } else {
121
    PADDLE_ENFORCE_NE(cudaGetLastError(), cudaSuccess);
122 123 124 125 126
    PADDLE_THROW_BAD_ALLOC(
        "Cannot malloc " + std::to_string(size / 1024.0 / 1024.0) +
        " MB GPU memory. Please shrink "
        "FLAGS_fraction_of_gpu_memory_to_use or "
        "FLAGS_initial_gpu_memory_in_mb or "
127
        "FLAGS_reallocate_gpu_memory_in_mb "
128 129 130 131 132 133 134
        "environment variable to a lower value. " +
        "Current FLAGS_fraction_of_gpu_memory_to_use value is " +
        std::to_string(FLAGS_fraction_of_gpu_memory_to_use) +
        ". Current FLAGS_initial_gpu_memory_in_mb value is " +
        std::to_string(FLAGS_initial_gpu_memory_in_mb) +
        ". Current FLAGS_reallocate_gpu_memory_in_mb value is " +
        std::to_string(FLAGS_reallocate_gpu_memory_in_mb));
L
liaogang 已提交
135
  }
136 137
}

L
liaogang 已提交
138
void GPUAllocator::Free(void* p, size_t size, size_t index) {
139
  cudaError_t err;
140 141 142 143
  PADDLE_ENFORCE_EQ(index, 0);
  PADDLE_ENFORCE_GE(gpu_alloc_size_, size);
  gpu_alloc_size_ -= size;
  err = cudaFree(p);
144

145 146 147 148 149 150
  // Purposefully allow cudaErrorCudartUnloading, because
  // that is returned if you ever call cudaFree after the
  // driver has already shutdown. This happens only if the
  // process is terminating, in which case we don't care if
  // cudaFree succeeds.
  if (err != cudaErrorCudartUnloading) {
L
liaogang 已提交
151
    PADDLE_ENFORCE(err, "cudaFree{Host} failed in GPUAllocator::Free.");
152 153 154
  }
}

L
liaogang 已提交
155
bool GPUAllocator::UseGpu() const { return true; }
L
liaogang 已提交
156

C
chengduoZH 已提交
157 158
// PINNED memory allows direct DMA transfers by the GPU to and from system
// memory. It’s locked to a physical address.
Y
Yi Wang 已提交
159
void* CUDAPinnedAllocator::Alloc(size_t* index, size_t size) {
C
chengduoZH 已提交
160
  if (size <= 0) return nullptr;
C
chengduoZH 已提交
161

162
  // NOTE: here, we use CUDAPinnedMaxAllocSize as the maximum memory size
C
chengduoZH 已提交
163
  // of host pinned allocation. Allocates too much would reduce
C
chengduoZH 已提交
164
  // the amount of memory available to the underlying system for paging.
C
chengduoZH 已提交
165
  size_t usable =
166
      paddle::platform::CUDAPinnedMaxAllocSize() - cuda_pinnd_alloc_size_;
C
chengduoZH 已提交
167

C
chengduoZH 已提交
168 169 170 171 172 173
  if (size > usable) {
    LOG(WARNING) << "Cannot malloc " << size / 1024.0 / 1024.0
                 << " MB pinned memory."
                 << ", available " << usable / 1024.0 / 1024.0 << " MB";
    return nullptr;
  }
C
chengduoZH 已提交
174

C
chengduoZH 已提交
175
  void* p;
C
chengduoZH 已提交
176
  // PINNED memory is visible to all CUDA contexts.
D
Dun Liang 已提交
177
  cudaError_t result = cudaHostAlloc(&p, size, cudaHostAllocPortable);
C
chengduoZH 已提交
178

C
chengduoZH 已提交
179
  if (result == cudaSuccess) {
Y
Yi Wang 已提交
180
    *index = 1;  // PINNED memory
C
chengduoZH 已提交
181
    cuda_pinnd_alloc_size_ += size;
C
chengduoZH 已提交
182
    return p;
C
chengduoZH 已提交
183
  } else {
D
Dun Liang 已提交
184
    LOG(WARNING) << "cudaHostAlloc failed.";
C
chengduoZH 已提交
185
    return nullptr;
C
chengduoZH 已提交
186 187 188 189 190 191 192
  }

  return nullptr;
}

void CUDAPinnedAllocator::Free(void* p, size_t size, size_t index) {
  cudaError_t err;
193
  PADDLE_ENFORCE_EQ(index, 1);
C
chengduoZH 已提交
194

195
  PADDLE_ENFORCE_GE(cuda_pinnd_alloc_size_, size);
C
chengduoZH 已提交
196
  cuda_pinnd_alloc_size_ -= size;
C
chengduoZH 已提交
197 198 199
  err = cudaFreeHost(p);

  // Purposefully allow cudaErrorCudartUnloading, because
C
chengduoZH 已提交
200
  // that is returned if you ever call cudaFreeHost after the
C
chengduoZH 已提交
201 202
  // driver has already shutdown. This happens only if the
  // process is terminating, in which case we don't care if
C
chengduoZH 已提交
203
  // cudaFreeHost succeeds.
C
chengduoZH 已提交
204 205 206 207 208
  if (err != cudaErrorCudartUnloading) {
    PADDLE_ENFORCE(err, "cudaFreeHost failed in GPUPinnedAllocator::Free.");
  }
}

C
chengduoZH 已提交
209
bool CUDAPinnedAllocator::UseGpu() const { return false; }
C
chengduoZH 已提交
210

L
Luo Tao 已提交
211
#endif
212 213 214 215

}  // namespace detail
}  // namespace memory
}  // namespace paddle