system_allocator.h 3.1 KB
Newer Older
L
liaogang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17 18 19
#include <stddef.h>    // for size_t
#include <sys/mman.h>  // for mlock and munlock
#include <cstdlib>     // for malloc and free
L
liaogang 已提交
20

21
#ifndef PADDLE_ONLY_CPU
L
liaogang 已提交
22
#include <thrust/system/cuda/error.h>
23 24
#include <thrust/system_error.h>
#endif  // PADDLE_ONLY_CPU
L
liaogang 已提交
25 26 27 28 29

namespace paddle {
namespace memory {
namespace detail {

30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
class SystemAllocator {
 public:
  virtual void* Alloc(size_t size) = 0;
  virtual void* Free(void* p) = 0;
};

// CPUAllocator<lock_memory=true> calls mlock, which returns pinned
// and locked memory as staging areas for data exchange between host
// and device.  Allocates too much would reduce the amount of memory
// available to the system for paging.  So, by default, we should use
// CPUAllocator<staging=false>.
template <bool lock_memory>
class CPUAllocator : public SystemAllocator {
 public:
  virtual void* Alloc(size_t size) {
    void* p = std::malloc(size);
    if (p != nullptr && lock_memory) {
      mlock(p, size);
    }
    return p;
  }

  virtual void Free(void* p, size_t size) {
    if (p != nullptr && lock_memory) {
      munlock(p, size);
    }
    std::free(p);
  }
};

#ifndef PADDLE_ONLY_CPU  // The following code are for CUDA.

namespace {
L
liaogang 已提交
63 64 65 66 67
inline void throw_on_error(cudaError_t e, const char* message) {
  if (e) {
    throw thrust::system_error(e, thrust::cuda_category(), message);
  }
}
68
}  // namespace
L
liaogang 已提交
69 70 71 72 73 74 75 76

// GPUAllocator<staging=true> calls cudaHostMalloc, which returns
// pinned and locked memory as staging areas for data exchange
// between host and device.  Allocates too much would reduce the
// amount of memory available to the system for paging.  So, by
// default, we should use GPUAllocator<staging=false>.
template <bool staging>
class GPUAllocator {
77
 public:
L
liaogang 已提交
78 79
  void* Alloc(size_t size) {
    void* p = 0;
80 81
    cudaError_t result =
        staging ? cudaMallocHost(&p, size) : cudaMalloc(&p, size);
L
liaogang 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94
    if (result == cudaSuccess) {
      return p;
    }
    // clear last error
    cudaGetLastError();
    return nullptr;
  }

  void Free(void* p, size_t size) {
    // Purposefully allow cudaErrorCudartUnloading, because
    // that is returned if you ever call cudaFree after the
    // driver has already shutdown. This happens only if the
    // process is terminating, in which case we don't care if
95 96
    // cudaFree succeeds.
    auto err = staging ? cudaFreeHost(p) : cudaFree(p);
L
liaogang 已提交
97
    if (err != cudaErrorCudartUnloading) {
98
      throw_on_error(err, "cudaFree failed");
L
liaogang 已提交
99 100 101 102
    }
  }
};

103
#endif  // PADDLE_ONLY_CPU
L
liaogang 已提交
104 105 106 107

}  // namespace detail
}  // namespace memory
}  // namespace paddle