system_allocator.h 3.6 KB
Newer Older
L
liaogang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17 18 19
#include <stddef.h>    // for size_t
#include <sys/mman.h>  // for mlock and munlock
#include <cstdlib>     // for malloc and free
L
liaogang 已提交
20

21
#ifndef PADDLE_ONLY_CPU
L
liaogang 已提交
22
#include <thrust/system/cuda/error.h>
23 24
#include <thrust/system_error.h>
#endif  // PADDLE_ONLY_CPU
L
liaogang 已提交
25

Y
Yi Wang 已提交
26 27
#include "paddle/platform/assert.h"

L
liaogang 已提交
28 29 30 31
namespace paddle {
namespace memory {
namespace detail {

Y
Yi Wang 已提交
32
class CPUDeleter {
33
 public:
Y
Yi Wang 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
  CPUDeleter(void* ptr, size_t size, bool locked)
      : ptr_(ptr), size_(size), locked_(locked) {}

  void* Ptr() { return ptr_; }

  void operator()(void* ptr) {
    PADDLE_ASSERT(ptr == ptr_);
    if (ptr_ != nullptr && locked_) {
      munlock(ptr_, size_);
    }
    std::free(ptr_);
  }

 private:
  void* ptr_;
  size_t size_;
  bool locked_;
51 52 53 54 55 56 57 58
};

// CPUAllocator<lock_memory=true> calls mlock, which returns pinned
// and locked memory as staging areas for data exchange between host
// and device.  Allocates too much would reduce the amount of memory
// available to the system for paging.  So, by default, we should use
// CPUAllocator<staging=false>.
template <bool lock_memory>
Y
Yi Wang 已提交
59
class CPUAllocator {
60
 public:
Y
Yi Wang 已提交
61
  static CPUDeleter Alloc(size_t size) {
62 63 64 65
    void* p = std::malloc(size);
    if (p != nullptr && lock_memory) {
      mlock(p, size);
    }
Y
Yi Wang 已提交
66
    return CPUDeleter(p, size, lock_memory);
67 68 69 70 71 72
  }
};

#ifndef PADDLE_ONLY_CPU  // The following code are for CUDA.

namespace {
L
liaogang 已提交
73 74 75 76 77
inline void throw_on_error(cudaError_t e, const char* message) {
  if (e) {
    throw thrust::system_error(e, thrust::cuda_category(), message);
  }
}
78
}  // namespace
L
liaogang 已提交
79

Y
Yi Wang 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
class GPUDeleter {
 public:
  GPUDeleter(void* ptr, size_t size, bool staging)
      : ptr_(ptr), size_(size), staging_(staging) {}

  void* Ptr() { return ptr_; }

  void operator()(void* ptr) {
    PADDLE_ASSERT(ptr == ptr_);
    // Purposefully allow cudaErrorCudartUnloading, because
    // that is returned if you ever call cudaFree after the
    // driver has already shutdown. This happens only if the
    // process is terminating, in which case we don't care if
    // cudaFree succeeds.
    cudaError_t err = staging_ ? cudaFreeHost(ptr) : cudaFree(ptr);
    if (err != cudaErrorCudartUnloading) {
      throw_on_error(err, "cudaFree{Host} failed");
    }
  }

 private:
  void* ptr_;
  size_t size_;
  bool staging_;
};

L
liaogang 已提交
106 107 108 109 110 111 112
// GPUAllocator<staging=true> calls cudaHostMalloc, which returns
// pinned and locked memory as staging areas for data exchange
// between host and device.  Allocates too much would reduce the
// amount of memory available to the system for paging.  So, by
// default, we should use GPUAllocator<staging=false>.
template <bool staging>
class GPUAllocator {
113
 public:
Y
Yi Wang 已提交
114
  static GPUDeleter Alloc(size_t size) {
L
liaogang 已提交
115
    void* p = 0;
116 117
    cudaError_t result =
        staging ? cudaMallocHost(&p, size) : cudaMalloc(&p, size);
Y
Yi Wang 已提交
118 119
    if (result != cudaSuccess) {
      cudaGetLastError();  // clear error if there is any.
L
liaogang 已提交
120
    }
Y
Yi Wang 已提交
121
    return GPUDeleter(result == cudaSuccess ? p : nullptr, size, staging);
L
liaogang 已提交
122 123 124
  }
};

125
#endif  // PADDLE_ONLY_CPU
L
liaogang 已提交
126 127 128 129

}  // namespace detail
}  // namespace memory
}  // namespace paddle