system_allocator.cc 5.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
D
dzhwinter 已提交
14
#define GLOG_NO_ABBREVIATED_SEVERITIES
15

Y
Yi Wang 已提交
16
#include "paddle/fluid/memory/detail/system_allocator.h"
17

D
dzhwinter 已提交
18 19 20
#ifdef _WIN32
#include <malloc.h>
#else
21
#include <sys/mman.h>  // for mlock and munlock
D
dzhwinter 已提交
22 23 24
#endif
#include <stdlib.h>   // for malloc and free
#include <algorithm>  // for std::max
25 26

#include "gflags/gflags.h"
Y
Yi Wang 已提交
27 28 29 30
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
31

S
sneaxiy 已提交
32
DECLARE_bool(use_pinned_memory);
33
DECLARE_double(fraction_of_gpu_memory_to_use);
34 35 36 37
namespace paddle {
namespace memory {
namespace detail {

D
dzhwinter 已提交
38
void* AlignedMalloc(size_t size) {
G
gongweibao 已提交
39
  void* p = nullptr;
D
dzhwinter 已提交
40
  size_t alignment = 32ul;
T
tensor-tang 已提交
41
#ifdef PADDLE_WITH_MKLDNN
42 43
  // refer to https://github.com/01org/mkl-dnn/blob/master/include/mkldnn.hpp
  // memory alignment
D
dzhwinter 已提交
44 45 46 47
  alignment = 4096ul;
#endif
#ifdef _WIN32
  p = _aligned_malloc(size, alignment);
48
#else
D
dzhwinter 已提交
49
  PADDLE_ENFORCE_EQ(posix_memalign(&p, alignment, size), 0, "Alloc %ld error!",
G
gongweibao 已提交
50
                    size);
51 52
#endif
  PADDLE_ENFORCE(p, "Fail to allocate CPU memory: size = %d .", size);
D
dzhwinter 已提交
53 54 55 56 57 58 59 60 61 62 63 64
  return p;
}

void* CPUAllocator::Alloc(size_t* index, size_t size) {
  // According to http://www.cplusplus.com/reference/cstdlib/malloc/,
  // malloc might not return nullptr if size is zero, but the returned
  // pointer shall not be dereferenced -- so we make it nullptr.
  if (size <= 0) return nullptr;

  *index = 0;  // unlock memory

  void* p = AlignedMalloc(size);
65 66 67

  if (p != nullptr) {
    if (FLAGS_use_pinned_memory) {
Y
Yi Wang 已提交
68
      *index = 1;
D
dzhwinter 已提交
69 70 71
#ifdef _WIN32
      VirtualLock(p, size);
#else
72
      mlock(p, size);  // lock memory
D
dzhwinter 已提交
73
#endif
74
    }
75
  }
76

77 78 79
  return p;
}

L
liaogang 已提交
80
void CPUAllocator::Free(void* p, size_t size, size_t index) {
81
  if (p != nullptr && index == 1) {
D
dzhwinter 已提交
82 83 84
#ifdef _WIN32
    VirtualUnlock(p, size);
#else
85
    munlock(p, size);
D
dzhwinter 已提交
86
#endif
87
  }
P
peizhilin 已提交
88 89 90
#ifdef _WIN32
  _aligned_free(p);
#else
91
  free(p);
P
peizhilin 已提交
92
#endif
93 94
}

L
liaogang 已提交
95
bool CPUAllocator::UseGpu() const { return false; }
L
liaogang 已提交
96

97
#ifdef PADDLE_WITH_CUDA
98

Y
Yi Wang 已提交
99
void* GPUAllocator::Alloc(size_t* index, size_t size) {
100 101
  // CUDA documentation doesn't explain if cudaMalloc returns nullptr
  // if size is 0.  We just make sure it does.
L
liaogang 已提交
102
  if (size <= 0) return nullptr;
103
  void* p;
Y
Yu Yang 已提交
104 105 106 107 108 109
  int prev_id;
  cudaGetDevice(&prev_id);
  if (prev_id != gpu_id_) {
    cudaSetDevice(gpu_id_);
  }

110
  cudaError_t result = cudaMalloc(&p, size);
Y
Yu Yang 已提交
111 112 113 114 115

  if (prev_id != gpu_id_) {
    cudaSetDevice(prev_id);
  }

L
liaogang 已提交
116
  if (result == cudaSuccess) {
Y
Yi Wang 已提交
117
    *index = 0;
118
    gpu_alloc_size_ += size;
L
liaogang 已提交
119
    return p;
120 121 122 123 124 125 126
  } else {
    LOG(WARNING)
        << "Cannot malloc " << size / 1024.0 / 1024.0
        << " MB GPU memory. Please shrink FLAGS_fraction_of_gpu_memory_to_use "
           "environment variable to a lower value. Current value is "
        << FLAGS_fraction_of_gpu_memory_to_use;
    return nullptr;
L
liaogang 已提交
127
  }
128 129
}

L
liaogang 已提交
130
void GPUAllocator::Free(void* p, size_t size, size_t index) {
131 132 133 134 135 136 137 138 139 140 141 142
  cudaError_t err;

  if (index == 0) {
    PADDLE_ASSERT(gpu_alloc_size_ >= size);
    gpu_alloc_size_ -= size;
    err = cudaFree(p);
  } else {
    PADDLE_ASSERT(fallback_alloc_size_ >= size);
    fallback_alloc_size_ -= size;
    err = cudaFreeHost(p);
  }

143 144 145 146 147 148
  // Purposefully allow cudaErrorCudartUnloading, because
  // that is returned if you ever call cudaFree after the
  // driver has already shutdown. This happens only if the
  // process is terminating, in which case we don't care if
  // cudaFree succeeds.
  if (err != cudaErrorCudartUnloading) {
L
liaogang 已提交
149
    PADDLE_ENFORCE(err, "cudaFree{Host} failed in GPUAllocator::Free.");
150 151 152
  }
}

L
liaogang 已提交
153
bool GPUAllocator::UseGpu() const { return true; }
L
liaogang 已提交
154

C
chengduoZH 已提交
155 156
// PINNED memory allows direct DMA transfers by the GPU to and from system
// memory. It’s locked to a physical address.
Y
Yi Wang 已提交
157
void* CUDAPinnedAllocator::Alloc(size_t* index, size_t size) {
C
chengduoZH 已提交
158
  if (size <= 0) return nullptr;
C
chengduoZH 已提交
159

160
  // NOTE: here, we use CUDAPinnedMaxAllocSize as the maximum memory size
C
chengduoZH 已提交
161
  // of host pinned allocation. Allocates too much would reduce
C
chengduoZH 已提交
162
  // the amount of memory available to the underlying system for paging.
C
chengduoZH 已提交
163
  size_t usable =
164
      paddle::platform::CUDAPinnedMaxAllocSize() - cuda_pinnd_alloc_size_;
C
chengduoZH 已提交
165

C
chengduoZH 已提交
166 167 168 169 170 171
  if (size > usable) {
    LOG(WARNING) << "Cannot malloc " << size / 1024.0 / 1024.0
                 << " MB pinned memory."
                 << ", available " << usable / 1024.0 / 1024.0 << " MB";
    return nullptr;
  }
C
chengduoZH 已提交
172

C
chengduoZH 已提交
173
  void* p;
C
chengduoZH 已提交
174
  // PINNED memory is visible to all CUDA contexts.
C
chengduoZH 已提交
175
  cudaError_t result = cudaMallocHost(&p, size);
C
chengduoZH 已提交
176

C
chengduoZH 已提交
177
  if (result == cudaSuccess) {
Y
Yi Wang 已提交
178
    *index = 1;  // PINNED memory
C
chengduoZH 已提交
179
    cuda_pinnd_alloc_size_ += size;
C
chengduoZH 已提交
180
    return p;
C
chengduoZH 已提交
181 182 183
  } else {
    LOG(WARNING) << "cudaMallocHost failed.";
    return nullptr;
C
chengduoZH 已提交
184 185 186 187 188 189 190 191 192
  }

  return nullptr;
}

void CUDAPinnedAllocator::Free(void* p, size_t size, size_t index) {
  cudaError_t err;
  PADDLE_ASSERT(index == 1);

C
chengduoZH 已提交
193 194
  PADDLE_ASSERT(cuda_pinnd_alloc_size_ >= size);
  cuda_pinnd_alloc_size_ -= size;
C
chengduoZH 已提交
195 196 197
  err = cudaFreeHost(p);

  // Purposefully allow cudaErrorCudartUnloading, because
C
chengduoZH 已提交
198
  // that is returned if you ever call cudaFreeHost after the
C
chengduoZH 已提交
199 200
  // driver has already shutdown. This happens only if the
  // process is terminating, in which case we don't care if
C
chengduoZH 已提交
201
  // cudaFreeHost succeeds.
C
chengduoZH 已提交
202 203 204 205 206
  if (err != cudaErrorCudartUnloading) {
    PADDLE_ENFORCE(err, "cudaFreeHost failed in GPUPinnedAllocator::Free.");
  }
}

C
chengduoZH 已提交
207
bool CUDAPinnedAllocator::UseGpu() const { return false; }
C
chengduoZH 已提交
208

L
Luo Tao 已提交
209
#endif
210 211 212 213

}  // namespace detail
}  // namespace memory
}  // namespace paddle