未验证 提交 c14305f0 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #9380 from chengduoZH/feature/add_CUDAPinnedPlace

Add CUDAPinnedPlace
...@@ -45,11 +45,10 @@ class Tensor { ...@@ -45,11 +45,10 @@ class Tensor {
friend struct EigenVector; friend struct EigenVector;
public: public:
Tensor() : offset_(0), is_pinned_(false) {} Tensor() : offset_(0) {}
/*! Constructor with place should only be used in pybind. */ /*! Constructor with place should only be used in pybind. */
explicit Tensor(const platform::Place& place) explicit Tensor(const platform::Place& place) : offset_(0) {
: offset_(0), is_pinned_(false) {
holder_->set_place(place); holder_->set_place(place);
} }
...@@ -70,12 +69,11 @@ class Tensor { ...@@ -70,12 +69,11 @@ class Tensor {
* @note If not exist, then allocation. * @note If not exist, then allocation.
*/ */
template <typename T> template <typename T>
inline T* mutable_data(platform::Place place, bool is_pinned = false); inline T* mutable_data(platform::Place place);
inline void* mutable_data(platform::Place place, std::type_index type, inline void* mutable_data(platform::Place place, std::type_index type);
bool is_pinned = false);
inline void* mutable_data(platform::Place place, bool is_pinned = false); inline void* mutable_data(platform::Place place);
/** /**
* @brief Return a pointer to mutable memory block. * @brief Return a pointer to mutable memory block.
...@@ -86,8 +84,7 @@ class Tensor { ...@@ -86,8 +84,7 @@ class Tensor {
* @note If not exist, then allocation. * @note If not exist, then allocation.
*/ */
template <typename T> template <typename T>
inline T* mutable_data(DDim dims, platform::Place place, inline T* mutable_data(DDim dims, platform::Place place);
bool is_pinned = false);
/*! Return the dimensions of the memory block. */ /*! Return the dimensions of the memory block. */
inline const DDim& dims() const; inline const DDim& dims() const;
...@@ -95,9 +92,6 @@ class Tensor { ...@@ -95,9 +92,6 @@ class Tensor {
/*! Return the numel of the memory block. */ /*! Return the numel of the memory block. */
inline int64_t numel() const; inline int64_t numel() const;
/*! Return the numel of the memory block. */
inline bool isPinned() const;
/*! Resize the dimensions of the memory block. */ /*! Resize the dimensions of the memory block. */
inline Tensor& Resize(const DDim& dims); inline Tensor& Resize(const DDim& dims);
...@@ -152,14 +146,12 @@ class Tensor { ...@@ -152,14 +146,12 @@ class Tensor {
template <typename Place> template <typename Place>
struct PlaceholderImpl : public Placeholder { struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(Place place, size_t size, std::type_index type, PlaceholderImpl(Place place, size_t size, std::type_index type)
bool is_pinned = false) : ptr_(static_cast<uint8_t*>(memory::Alloc(place, size)),
: ptr_(static_cast<uint8_t*>(memory::Alloc(place, size, is_pinned)), memory::PODDeleter<uint8_t, Place>(place)),
memory::PODDeleter<uint8_t, Place>(place, is_pinned)),
place_(place), place_(place),
size_(size), size_(size),
type_(type), type_(type) {
is_pinned_(is_pinned) {
PADDLE_ENFORCE_NOT_NULL(ptr_, "Insufficient %s memory to allocation.", PADDLE_ENFORCE_NOT_NULL(ptr_, "Insufficient %s memory to allocation.",
(is_cpu_place(place_) ? "CPU" : "GPU")); (is_cpu_place(place_) ? "CPU" : "GPU"));
} }
...@@ -182,9 +174,6 @@ class Tensor { ...@@ -182,9 +174,6 @@ class Tensor {
/* the current type of memory */ /* the current type of memory */
std::type_index type_; std::type_index type_;
/*! use pinned memory or not. */
bool is_pinned_;
}; };
/*! holds the memory block if allocated. */ /*! holds the memory block if allocated. */
...@@ -219,7 +208,6 @@ class Tensor { ...@@ -219,7 +208,6 @@ class Tensor {
* PlaceHolder::ptr_ and where the tensor data really begins. * PlaceHolder::ptr_ and where the tensor data really begins.
*/ */
size_t offset_; size_t offset_;
bool is_pinned_;
}; };
inline void Tensor::switch_place(platform::Place new_place) { inline void Tensor::switch_place(platform::Place new_place) {
......
...@@ -101,21 +101,19 @@ inline T* Tensor::data() { ...@@ -101,21 +101,19 @@ inline T* Tensor::data() {
} }
template <typename T> template <typename T>
inline T* Tensor::mutable_data(DDim dims, platform::Place place, inline T* Tensor::mutable_data(DDim dims, platform::Place place) {
bool is_pinned) {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
Resize(dims); Resize(dims);
return mutable_data<T>(place, is_pinned); return mutable_data<T>(place);
} }
template <typename T> template <typename T>
inline T* Tensor::mutable_data(platform::Place place, bool is_pinned) { inline T* Tensor::mutable_data(platform::Place place) {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T*>(mutable_data(place, typeid(T), is_pinned)); return reinterpret_cast<T*>(mutable_data(place, typeid(T)));
} }
inline void* Tensor::mutable_data(platform::Place place, std::type_index type, inline void* Tensor::mutable_data(platform::Place place, std::type_index type) {
bool is_pinned) {
if (holder_ != nullptr) { if (holder_ != nullptr) {
holder_->set_type(type); holder_->set_type(type);
} }
...@@ -129,27 +127,26 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type, ...@@ -129,27 +127,26 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type,
holder_->size() < size + offset_) { holder_->size() < size + offset_) {
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
holder_.reset(new PlaceholderImpl<platform::CPUPlace>( holder_.reset(new PlaceholderImpl<platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), size, type, is_pinned)); boost::get<platform::CPUPlace>(place), size, type));
} else if (platform::is_gpu_place(place)) { } else if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA #ifndef PADDLE_WITH_CUDA
PADDLE_THROW("'CUDAPlace' is not supported in CPU only device."); PADDLE_THROW("'CUDAPlace' is not supported in CPU only device.");
} }
#else #else
holder_.reset(new PlaceholderImpl<platform::CUDAPlace>( holder_.reset(new PlaceholderImpl<platform::CUDAPlace>(
boost::get<platform::CUDAPlace>(place), size, type, is_pinned)); boost::get<platform::CUDAPlace>(place), size, type));
} }
#endif #endif
offset_ = 0; offset_ = 0;
is_pinned_ = is_pinned;
} }
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(holder_->ptr()) + return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_); offset_);
} }
inline void* Tensor::mutable_data(platform::Place place, bool is_pinned) { inline void* Tensor::mutable_data(platform::Place place) {
PADDLE_ENFORCE(this->holder_ != nullptr, PADDLE_ENFORCE(this->holder_ != nullptr,
"Cannot invoke mutable data if current hold nothing"); "Cannot invoke mutable data if current hold nothing");
return mutable_data(place, holder_->type(), is_pinned); return mutable_data(place, holder_->type());
} }
inline Tensor& Tensor::ShareDataWith(const Tensor& src) { inline Tensor& Tensor::ShareDataWith(const Tensor& src) {
...@@ -191,8 +188,6 @@ inline const DDim& Tensor::dims() const { return dims_; } ...@@ -191,8 +188,6 @@ inline const DDim& Tensor::dims() const { return dims_; }
inline int64_t Tensor::numel() const { return product(dims_); } inline int64_t Tensor::numel() const { return product(dims_); }
inline bool Tensor::isPinned() const { return is_pinned_; }
inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
Tensor res; Tensor res;
res.ShareDataWith(src); res.ShareDataWith(src);
......
...@@ -148,6 +148,11 @@ struct AnyVisitor : public boost::static_visitor<bool> { ...@@ -148,6 +148,11 @@ struct AnyVisitor : public boost::static_visitor<bool> {
const platform::CPUPlace& cpu) const { const platform::CPUPlace& cpu) const {
return *out.data<bool>(); return *out.data<bool>();
} }
bool GetResult(const framework::Tensor& out,
const platform::CUDAPinnedPlace& cpu) const {
return *out.data<bool>();
}
}; };
template <typename Predicate> template <typename Predicate>
......
...@@ -14,3 +14,7 @@ cc_library(paddle_memory ...@@ -14,3 +14,7 @@ cc_library(paddle_memory
system_allocator) system_allocator)
cc_test(memory_test SRCS memory_test.cc DEPS place paddle_memory) cc_test(memory_test SRCS memory_test.cc DEPS place paddle_memory)
#if (WITH_GPU)
# nv_test(pinned_memory_test SRCS pinned_memory_test.cu DEPS place paddle_memory)
#endif()
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/memory/detail/system_allocator.h" #include "paddle/fluid/memory/detail/system_allocator.h"
#include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
...@@ -134,21 +135,31 @@ bool GPUAllocator::UseGpu() const { return true; } ...@@ -134,21 +135,31 @@ bool GPUAllocator::UseGpu() const { return true; }
// memory. It’s locked to a physical address. // memory. It’s locked to a physical address.
void* CUDAPinnedAllocator::Alloc(size_t& index, size_t size) { void* CUDAPinnedAllocator::Alloc(size_t& index, size_t size) {
if (size <= 0) return nullptr; if (size <= 0) return nullptr;
void* p;
// NOTE: here, we use GpuMaxAllocSize() as the maximum memory size // NOTE: here, we use CUDAPinnedMaxAllocSize as the maximum memory size
// of host pinned allocation. Allocates too much would reduce // of host pinned allocation. Allocates too much would reduce
// the amount of memory available to the underlying system for paging. // the amount of memory available to the underlying system for paging.
size_t usable =
paddle::platform::CUDAPinnedMaxAllocSize() - cuda_pinnd_alloc_size_;
size_t usable = paddle::platform::GpuMaxAllocSize() - fallback_alloc_size_; if (size > usable) {
LOG(WARNING) << "Cannot malloc " << size / 1024.0 / 1024.0
if (size > usable) return nullptr; << " MB pinned memory."
<< ", available " << usable / 1024.0 / 1024.0 << " MB";
return nullptr;
}
void* p;
// PINNED memory is visible to all CUDA contexts. // PINNED memory is visible to all CUDA contexts.
cudaError_t result = cudaMallocHost(&p, size); cudaError_t result = cudaMallocHost(&p, size);
if (result == cudaSuccess) { if (result == cudaSuccess) {
index = 1; index = 1; // PINNED memory
fallback_alloc_size_ += size; cuda_pinnd_alloc_size_ += size;
return p; return p;
} else {
LOG(WARNING) << "cudaMallocHost failed.";
return nullptr;
} }
return nullptr; return nullptr;
...@@ -158,8 +169,8 @@ void CUDAPinnedAllocator::Free(void* p, size_t size, size_t index) { ...@@ -158,8 +169,8 @@ void CUDAPinnedAllocator::Free(void* p, size_t size, size_t index) {
cudaError_t err; cudaError_t err;
PADDLE_ASSERT(index == 1); PADDLE_ASSERT(index == 1);
PADDLE_ASSERT(fallback_alloc_size_ >= size); PADDLE_ASSERT(cuda_pinnd_alloc_size_ >= size);
fallback_alloc_size_ -= size; cuda_pinnd_alloc_size_ -= size;
err = cudaFreeHost(p); err = cudaFreeHost(p);
// Purposefully allow cudaErrorCudartUnloading, because // Purposefully allow cudaErrorCudartUnloading, because
...@@ -172,7 +183,7 @@ void CUDAPinnedAllocator::Free(void* p, size_t size, size_t index) { ...@@ -172,7 +183,7 @@ void CUDAPinnedAllocator::Free(void* p, size_t size, size_t index) {
} }
} }
bool CUDAPinnedAllocator::UseGpu() const { return true; } bool CUDAPinnedAllocator::UseGpu() const { return false; }
#endif #endif
......
...@@ -21,8 +21,9 @@ namespace memory { ...@@ -21,8 +21,9 @@ namespace memory {
namespace detail { namespace detail {
/** /**
* \brief SystemAllocator is the parent class of CPUAllocator and GPUAllocator. * \brief SystemAllocator is the parent class of CPUAllocator,
* A BuddyAllocator object uses a SystemAllocator* pointing to the * CUDAPinnedAllocator and GPUAllocator. A BuddyAllocator
* object uses a SystemAllocator* pointing to the
* underlying system allocator. * underlying system allocator.
*/ */
class SystemAllocator { class SystemAllocator {
...@@ -62,9 +63,7 @@ class CUDAPinnedAllocator : public SystemAllocator { ...@@ -62,9 +63,7 @@ class CUDAPinnedAllocator : public SystemAllocator {
virtual bool UseGpu() const; virtual bool UseGpu() const;
private: private:
size_t gpu_alloc_size_ = size_t cuda_pinnd_alloc_size_ = 0;
0; // TODO(zcd): how to define the upper limit of CUDAPinnedMemory?
size_t fallback_alloc_size_ = 0;
}; };
#endif #endif
......
...@@ -56,6 +56,45 @@ void Copy<platform::CUDAPlace, platform::CUDAPlace>( ...@@ -56,6 +56,45 @@ void Copy<platform::CUDAPlace, platform::CUDAPlace>(
} }
} }
template <>
void Copy<platform::CPUPlace, platform::CUDAPinnedPlace>(
platform::CPUPlace dst_place, void* dst,
platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
std::memcpy(dst, src, num);
}
template <>
void Copy<platform::CUDAPinnedPlace, platform::CPUPlace>(
platform::CUDAPinnedPlace dst_place, void* dst,
platform::CPUPlace src_place, const void* src, size_t num) {
std::memcpy(dst, src, num);
}
template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPinnedPlace>(
platform::CUDAPinnedPlace dst_place, void* dst,
platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
std::memcpy(dst, src, num);
}
template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>(
platform::CUDAPinnedPlace dst_place, void* dst,
platform::CUDAPlace src_place, const void* src, size_t num,
cudaStream_t stream) {
platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
}
template <>
void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>(
platform::CUDAPlace dst_place, void* dst,
platform::CUDAPinnedPlace src_place, const void* src, size_t num,
cudaStream_t stream) {
platform::SetDeviceId(dst_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
}
#endif #endif
} // namespace memory } // namespace memory
......
...@@ -38,8 +38,7 @@ BuddyAllocator* GetCPUBuddyAllocator() { ...@@ -38,8 +38,7 @@ BuddyAllocator* GetCPUBuddyAllocator() {
} }
template <> template <>
void* Alloc<platform::CPUPlace>(platform::CPUPlace place, size_t size, void* Alloc<platform::CPUPlace>(platform::CPUPlace place, size_t size) {
bool is_pinned) {
VLOG(10) << "Allocate " << size << " bytes on " << platform::Place(place); VLOG(10) << "Allocate " << size << " bytes on " << platform::Place(place);
void* p = GetCPUBuddyAllocator()->Alloc(size); void* p = GetCPUBuddyAllocator()->Alloc(size);
VLOG(10) << " pointer=" << p; VLOG(10) << " pointer=" << p;
...@@ -47,8 +46,7 @@ void* Alloc<platform::CPUPlace>(platform::CPUPlace place, size_t size, ...@@ -47,8 +46,7 @@ void* Alloc<platform::CPUPlace>(platform::CPUPlace place, size_t size,
} }
template <> template <>
void Free<platform::CPUPlace>(platform::CPUPlace place, void* p, void Free<platform::CPUPlace>(platform::CPUPlace place, void* p) {
bool is_pinned) {
VLOG(10) << "Free pointer=" << p << " on " << platform::Place(place); VLOG(10) << "Free pointer=" << p << " on " << platform::Place(place);
GetCPUBuddyAllocator()->Free(p); GetCPUBuddyAllocator()->Free(p);
} }
...@@ -84,47 +82,15 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { ...@@ -84,47 +82,15 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
return as[gpu_id]; return as[gpu_id];
} }
BuddyAllocator* GetCUDAPinnedBuddyAllocator(int gpu_id) {
static BuddyAllocator** as = NULL;
if (as == NULL) {
int gpu_num = platform::GetCUDADeviceCount();
as = new BuddyAllocator*[gpu_num];
for (int gpu = 0; gpu < gpu_num; gpu++) {
as[gpu] = nullptr;
}
}
platform::SetDeviceId(gpu_id);
if (!as[gpu_id]) {
as[gpu_id] = new BuddyAllocator(new detail::CUDAPinnedAllocator,
platform::GpuMinChunkSize(),
platform::GpuMaxChunkSize());
VLOG(10) << "\n\nNOTE: each GPU device use "
<< FLAGS_fraction_of_gpu_memory_to_use * 100
<< "% of GPU memory.\n"
<< "You can set GFlags environment variable '"
<< "FLAGS_fraction_of_gpu_memory_to_use"
<< "' to change the fraction of GPU usage.\n\n";
}
return as[gpu_id];
}
template <> template <>
size_t Used<platform::CUDAPlace>(platform::CUDAPlace place) { size_t Used<platform::CUDAPlace>(platform::CUDAPlace place) {
return GetGPUBuddyAllocator(place.device)->Used(); return GetGPUBuddyAllocator(place.device)->Used();
} }
template <> template <>
void* Alloc<platform::CUDAPlace>(platform::CUDAPlace place, size_t size, void* Alloc<platform::CUDAPlace>(platform::CUDAPlace place, size_t size) {
bool is_pinned) {
void* ptr;
if (is_pinned) {
auto* buddy_allocator = GetCUDAPinnedBuddyAllocator(place.device);
ptr = buddy_allocator->Alloc(size);
} else {
auto* buddy_allocator = GetGPUBuddyAllocator(place.device); auto* buddy_allocator = GetGPUBuddyAllocator(place.device);
ptr = buddy_allocator->Alloc(size); auto* ptr = buddy_allocator->Alloc(size);
}
if (ptr == nullptr) { if (ptr == nullptr) {
int cur_dev = platform::GetCurrentDeviceId(); int cur_dev = platform::GetCurrentDeviceId();
platform::SetDeviceId(place.device); platform::SetDeviceId(place.device);
...@@ -142,15 +108,42 @@ void* Alloc<platform::CUDAPlace>(platform::CUDAPlace place, size_t size, ...@@ -142,15 +108,42 @@ void* Alloc<platform::CUDAPlace>(platform::CUDAPlace place, size_t size,
} }
template <> template <>
void Free<platform::CUDAPlace>(platform::CUDAPlace place, void* p, void Free<platform::CUDAPlace>(platform::CUDAPlace place, void* p) {
bool is_pinned) {
if (is_pinned) {
GetCUDAPinnedBuddyAllocator(place.device)->Free(p);
} else {
GetGPUBuddyAllocator(place.device)->Free(p); GetGPUBuddyAllocator(place.device)->Free(p);
}
BuddyAllocator* GetCUDAPinnedBuddyAllocator() {
static BuddyAllocator* ba = NULL;
if (ba == NULL) {
ba = new BuddyAllocator(new detail::CUDAPinnedAllocator,
platform::CUDAPinnedMinChunkSize(),
platform::CUDAPinnedMaxChunkSize());
} }
return ba;
} }
template <>
size_t Used<platform::CUDAPinnedPlace>(platform::CUDAPinnedPlace place) {
return GetCUDAPinnedBuddyAllocator()->Used();
}
template <>
void* Alloc<platform::CUDAPinnedPlace>(platform::CUDAPinnedPlace place,
size_t size) {
auto* buddy_allocator = GetCUDAPinnedBuddyAllocator();
void* ptr = buddy_allocator->Alloc(size);
if (ptr == nullptr) {
LOG(WARNING) << "cudaMallocHost Cannot allocate " << size
<< " bytes in CUDAPinnedPlace";
}
return ptr;
}
template <>
void Free<platform::CUDAPinnedPlace>(platform::CUDAPinnedPlace place, void* p) {
GetCUDAPinnedBuddyAllocator()->Free(p);
}
#endif #endif
size_t Usage::operator()(const platform::CPUPlace& cpu) const { size_t Usage::operator()(const platform::CPUPlace& cpu) const {
...@@ -165,6 +158,14 @@ size_t Usage::operator()(const platform::CUDAPlace& gpu) const { ...@@ -165,6 +158,14 @@ size_t Usage::operator()(const platform::CUDAPlace& gpu) const {
#endif #endif
} }
size_t Usage::operator()(const platform::CUDAPinnedPlace& cuda_pinned) const {
#ifdef PADDLE_WITH_CUDA
return Used(cuda_pinned);
#else
PADDLE_THROW("'CUDAPinnedPlace' is not supported in CPU only device.");
#endif
}
size_t memory_usage(const platform::Place& p) { size_t memory_usage(const platform::Place& p) {
return boost::apply_visitor(Usage(), p); return boost::apply_visitor(Usage(), p);
} }
......
...@@ -33,7 +33,7 @@ namespace memory { ...@@ -33,7 +33,7 @@ namespace memory {
* address is valid or not. * address is valid or not.
*/ */
template <typename Place> template <typename Place>
void* Alloc(Place place, size_t size, bool is_pinned = false); void* Alloc(Place place, size_t size);
/** /**
* \brief Free memory block in one place. * \brief Free memory block in one place.
...@@ -43,7 +43,7 @@ void* Alloc(Place place, size_t size, bool is_pinned = false); ...@@ -43,7 +43,7 @@ void* Alloc(Place place, size_t size, bool is_pinned = false);
* *
*/ */
template <typename Place> template <typename Place>
void Free(Place place, void* ptr, bool is_pinned = false); void Free(Place place, void* ptr);
/** /**
* \brief Total size of used memory in one place. * \brief Total size of used memory in one place.
...@@ -57,6 +57,7 @@ size_t Used(Place place); ...@@ -57,6 +57,7 @@ size_t Used(Place place);
struct Usage : public boost::static_visitor<size_t> { struct Usage : public boost::static_visitor<size_t> {
size_t operator()(const platform::CPUPlace& cpu) const; size_t operator()(const platform::CPUPlace& cpu) const;
size_t operator()(const platform::CUDAPlace& gpu) const; size_t operator()(const platform::CUDAPlace& gpu) const;
size_t operator()(const platform::CUDAPinnedPlace& cuda_pinned) const;
}; };
size_t memory_usage(const platform::Place& p); size_t memory_usage(const platform::Place& p);
...@@ -74,13 +75,11 @@ class PODDeleter { ...@@ -74,13 +75,11 @@ class PODDeleter {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
public: public:
explicit PODDeleter(Place place, bool is_pinned = false) explicit PODDeleter(Place place) : place_(place) {}
: place_(place), is_pinned_(is_pinned) {} void operator()(T* ptr) { Free(place_, static_cast<void*>(ptr)); }
void operator()(T* ptr) { Free(place_, static_cast<void*>(ptr), is_pinned_); }
private: private:
Place place_; Place place_;
bool is_pinned_;
}; };
/** /**
......
...@@ -141,4 +141,59 @@ TEST(BuddyAllocator, GPUMultAlloc) { ...@@ -141,4 +141,59 @@ TEST(BuddyAllocator, GPUMultAlloc) {
} }
} }
size_t align(size_t size, paddle::platform::CUDAPinnedPlace place) {
size += sizeof(paddle::memory::detail::Metadata);
size_t alignment = paddle::platform::CUDAPinnedMinChunkSize();
size_t remaining = size % alignment;
return remaining == 0 ? size : size + (alignment - remaining);
}
TEST(BuddyAllocator, CUDAPinnedAllocator) {
void *p = nullptr;
EXPECT_EQ(p, nullptr);
paddle::platform::CUDAPinnedPlace cpu;
p = paddle::memory::Alloc(cpu, 4096);
EXPECT_NE(p, nullptr);
paddle::platform::Place place = cpu;
EXPECT_EQ(paddle::memory::Used(cpu), paddle::memory::memory_usage(place));
paddle::memory::Free(cpu, p);
}
TEST(BuddyAllocator, CUDAPinnedMultAllocator) {
paddle::platform::CUDAPinnedPlace cpu;
std::unordered_map<void *, size_t> ps;
size_t total_size = paddle::memory::Used(cpu);
EXPECT_EQ(total_size, 0UL);
for (auto size :
{0, 128, 256, 1024, 4096, 16384, 65536, 262144, 1048576, 4194304}) {
ps[paddle::memory::Alloc(cpu, size)] = size;
// Buddy Allocator doesn't manage too large memory chunk
if (paddle::memory::Used(cpu) == total_size) continue;
size_t aligned_size = align(size, cpu);
total_size += aligned_size;
EXPECT_EQ(total_size, paddle::memory::Used(cpu));
}
for (auto p : ps) {
EXPECT_EQ(is_aligned(p.first), true);
paddle::memory::Free(cpu, p.first);
// Buddy Allocator doesn't manage too large memory chunk
if (paddle::memory::Used(cpu) == total_size) continue;
size_t aligned_size = align(p.second, cpu);
total_size -= aligned_size;
EXPECT_EQ(total_size, paddle::memory::Used(cpu));
}
}
#endif #endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <unordered_map>
#include "paddle/fluid/memory/detail/memory_block.h"
#include "paddle/fluid/memory/detail/meta_data.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h"
// This unit test is an example comparing the performance between using pinned
// memory and not. In general, using pinned memory will be faster.
template <typename T>
__global__ void Kernel(T* output, int dim) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < dim) {
output[tid] = output[tid] * output[tid] / 100;
}
}
template <typename Place>
float test_pinned_memory() {
Place cpu_place;
paddle::platform::CUDAPlace cuda_place;
const int data_size = 4096;
const int iteration = 10;
// create event start and end
cudaEvent_t start_e, stop_e, copying_e;
float elapsedTime = 0;
cudaEventCreate(&start_e);
cudaEventCreate(&stop_e);
cudaEventCreate(&copying_e);
// create computation stream, data copying stream
cudaStream_t computation_stream, copying_stream;
cudaStreamCreate(&computation_stream);
cudaStreamCreate(&copying_stream);
// create record event, pinned memory, gpu memory
std::vector<cudaEvent_t> record_event(iteration);
std::vector<float*> input_pinned_mem(iteration);
std::vector<float*> gpu_mem(iteration);
std::vector<float*> output_pinned_mem(iteration);
// initial data
for (int j = 0; j < iteration; ++j) {
cudaEventCreateWithFlags(&record_event[j], cudaEventDisableTiming);
cudaEventCreate(&(record_event[j]));
input_pinned_mem[j] = static_cast<float*>(
paddle::memory::Alloc(cpu_place, data_size * sizeof(float)));
output_pinned_mem[j] = static_cast<float*>(
paddle::memory::Alloc(cpu_place, data_size * sizeof(float)));
gpu_mem[j] = static_cast<float*>(
paddle::memory::Alloc(cuda_place, data_size * sizeof(float)));
for (int k = 0; k < data_size; ++k) {
input_pinned_mem[j][k] = k;
}
}
cudaEventRecord(start_e, computation_stream);
// computation
for (int m = 0; m < 30; ++m) {
for (int i = 0; i < iteration; ++i) {
// cpu -> GPU on computation stream.
// note: this operation is async for pinned memory.
paddle::memory::Copy(cuda_place, gpu_mem[i], cpu_place,
input_pinned_mem[i], data_size * sizeof(float),
computation_stream);
// call kernel on computation stream.
Kernel<<<4, 1024, 0, computation_stream>>>(gpu_mem[i], data_size);
// record event_computation on computation stream
cudaEventRecord(record_event[i], computation_stream);
// wait event_computation on copy stream.
// note: this operation is async.
cudaStreamWaitEvent(copying_stream, record_event[i], 0);
// copy data GPU->CPU, on copy stream.
// note: this operation is async for pinned memory.
paddle::memory::Copy(cpu_place, output_pinned_mem[i], cuda_place,
gpu_mem[i], data_size * sizeof(float),
copying_stream);
}
}
cudaEventRecord(copying_e, copying_stream);
cudaStreamWaitEvent(computation_stream, copying_e, 0);
cudaEventRecord(stop_e, computation_stream);
cudaEventSynchronize(start_e);
cudaEventSynchronize(stop_e);
cudaEventElapsedTime(&elapsedTime, start_e, stop_e);
// std::cout << cpu_place << " "
// << "time consume:" << elapsedTime / 30 << std::endl;
for (int l = 0; l < iteration; ++l) {
for (int k = 0; k < data_size; ++k) {
float temp = input_pinned_mem[l][k];
temp = temp * temp / 100;
EXPECT_FLOAT_EQ(temp, output_pinned_mem[l][k]);
}
}
// destroy resource
cudaEventDestroy(copying_e);
cudaEventDestroy(start_e);
cudaEventDestroy(stop_e);
for (int j = 0; j < 10; ++j) {
cudaEventDestroy((record_event[j]));
paddle::memory::Free(cpu_place, input_pinned_mem[j]);
paddle::memory::Free(cpu_place, output_pinned_mem[j]);
paddle::memory::Free(cuda_place, gpu_mem[j]);
}
return elapsedTime / 30;
}
TEST(CPUANDCUDAPinned, CPUAllocatorAndCUDAPinnedAllocator) {
// Generally speaking, operation on pinned_memory is faster than that on
// unpinned-memory, but if this unit test fails frequently, please close this
// test for the time being.
float time1 = test_pinned_memory<paddle::platform::CPUPlace>();
float time2 = test_pinned_memory<paddle::platform::CUDAPinnedPlace>();
EXPECT_GT(time1, time2);
}
...@@ -322,6 +322,14 @@ void set_constant_with_place<platform::CPUPlace>( ...@@ -322,6 +322,14 @@ void set_constant_with_place<platform::CPUPlace>(
TensorSetConstantCPU(tensor, value)); TensorSetConstantCPU(tensor, value));
} }
template <>
void set_constant_with_place<platform::CUDAPinnedPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor,
float value) {
framework::VisitDataType(framework::ToDataType(tensor->type()),
TensorSetConstantCPU(tensor, value));
}
struct TensorSetConstantWithPlace : public boost::static_visitor<void> { struct TensorSetConstantWithPlace : public boost::static_visitor<void> {
TensorSetConstantWithPlace(const platform::DeviceContext& context, TensorSetConstantWithPlace(const platform::DeviceContext& context,
framework::Tensor* tensor, float value) framework::Tensor* tensor, float value)
......
...@@ -27,6 +27,11 @@ DEFINE_double(fraction_of_cpu_memory_to_use, 1, ...@@ -27,6 +27,11 @@ DEFINE_double(fraction_of_cpu_memory_to_use, 1,
"Default use 100% of CPU memory for PaddlePaddle," "Default use 100% of CPU memory for PaddlePaddle,"
"reserve the rest for page tables, etc"); "reserve the rest for page tables, etc");
DEFINE_double(
fraction_of_cuda_pinned_memory_to_use, 0.5,
"Default use 50% of CPU memory as the pinned_memory for PaddlePaddle,"
"reserve the rest for page tables, etc");
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -62,5 +67,22 @@ size_t CpuMaxChunkSize() { ...@@ -62,5 +67,22 @@ size_t CpuMaxChunkSize() {
return CpuMaxAllocSize() / 32; return CpuMaxAllocSize() / 32;
} }
size_t CUDAPinnedMaxAllocSize() {
// For distributed systems, it requires configuring and limiting
// the fraction of memory to use.
return FLAGS_fraction_of_cuda_pinned_memory_to_use * CpuTotalPhysicalMemory();
}
size_t CUDAPinnedMinChunkSize() {
// Allow to allocate the minimum chunk size is 64 KB.
return 1 << 16;
}
size_t CUDAPinnedMaxChunkSize() {
// Allow to allocate the maximum chunk size is roughly 1/256 of CUDA_PINNED
// memory.
return CUDAPinnedMaxAllocSize() / 256;
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -22,11 +22,20 @@ namespace platform { ...@@ -22,11 +22,20 @@ namespace platform {
//! Get the maximum allocation size for a machine. //! Get the maximum allocation size for a machine.
size_t CpuMaxAllocSize(); size_t CpuMaxAllocSize();
//! Get the maximum allocation size for a machine.
size_t CUDAPinnedMaxAllocSize();
//! Get the minimum chunk size for buddy allocator. //! Get the minimum chunk size for buddy allocator.
size_t CpuMinChunkSize(); size_t CpuMinChunkSize();
//! Get the maximum chunk size for buddy allocator. //! Get the maximum chunk size for buddy allocator.
size_t CpuMaxChunkSize(); size_t CpuMaxChunkSize();
//! Get the minimum chunk size for buddy allocator.
size_t CUDAPinnedMinChunkSize();
//! Get the maximum chunk size for buddy allocator.
size_t CUDAPinnedMaxChunkSize();
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -53,6 +53,16 @@ DeviceContextPool::DeviceContextPool( ...@@ -53,6 +53,16 @@ DeviceContextPool::DeviceContextPool(
PADDLE_THROW( PADDLE_THROW(
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU " "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
"option"); "option");
#endif
} else if (platform::is_cuda_pinned_place(p)) {
#ifdef PADDLE_WITH_CUDA
device_contexts_.emplace(
p,
PtrType(new CUDAPinnedDeviceContext(boost::get<CUDAPinnedPlace>(p))));
#else
PADDLE_THROW(
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
"option");
#endif #endif
} }
} }
...@@ -186,6 +196,20 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } ...@@ -186,6 +196,20 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
cudaStream_t CUDADeviceContext::stream() const { return stream_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; }
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice());
}
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
: place_(place) {
eigen_device_.reset(new Eigen::DefaultDevice());
}
Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
return eigen_device_.get();
}
Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
......
...@@ -118,6 +118,25 @@ struct DefaultDeviceContextType<platform::CUDAPlace> { ...@@ -118,6 +118,25 @@ struct DefaultDeviceContextType<platform::CUDAPlace> {
using TYPE = CUDADeviceContext; using TYPE = CUDADeviceContext;
}; };
// Currently, CUDAPinnedDeviceContext is only used to data copying.
class CUDAPinnedDeviceContext : public DeviceContext {
public:
CUDAPinnedDeviceContext();
explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place);
Place GetPlace() const override;
Eigen::DefaultDevice* eigen_device() const;
private:
CUDAPinnedPlace place_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};
template <>
struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
using TYPE = CUDAPinnedDeviceContext;
};
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
......
...@@ -26,6 +26,7 @@ class PlacePrinter : public boost::static_visitor<> { ...@@ -26,6 +26,7 @@ class PlacePrinter : public boost::static_visitor<> {
void operator()(const CUDAPlace &p) { void operator()(const CUDAPlace &p) {
os_ << "CUDAPlace(" << p.device << ")"; os_ << "CUDAPlace(" << p.device << ")";
} }
void operator()(const CUDAPinnedPlace &p) { os_ << "CUDAPinnedPlace"; }
private: private:
std::ostream &os_; std::ostream &os_;
...@@ -40,12 +41,19 @@ const Place &get_place() { return the_default_place; } ...@@ -40,12 +41,19 @@ const Place &get_place() { return the_default_place; }
const CUDAPlace default_gpu() { return CUDAPlace(0); } const CUDAPlace default_gpu() { return CUDAPlace(0); }
const CPUPlace default_cpu() { return CPUPlace(); } const CPUPlace default_cpu() { return CPUPlace(); }
const CUDAPinnedPlace default_cuda_pinned() { return CUDAPinnedPlace(); }
bool is_gpu_place(const Place &p) { bool is_gpu_place(const Place &p) {
return boost::apply_visitor(IsCUDAPlace(), p); return boost::apply_visitor(IsCUDAPlace(), p);
} }
bool is_cpu_place(const Place &p) { return !is_gpu_place(p); } bool is_cpu_place(const Place &p) {
return boost::apply_visitor(IsCPUPlace(), p);
}
bool is_cuda_pinned_place(const Place &p) {
return boost::apply_visitor(IsCUDAPinnedPlace(), p);
}
bool places_are_same_class(const Place &p1, const Place &p2) { bool places_are_same_class(const Place &p1, const Place &p2) {
return p1.which() == p2.which(); return p1.which() == p2.which();
...@@ -53,7 +61,7 @@ bool places_are_same_class(const Place &p1, const Place &p2) { ...@@ -53,7 +61,7 @@ bool places_are_same_class(const Place &p1, const Place &p2) {
bool is_same_place(const Place &p1, const Place &p2) { bool is_same_place(const Place &p1, const Place &p2) {
if (places_are_same_class(p1, p2)) { if (places_are_same_class(p1, p2)) {
if (is_cpu_place(p1)) { if (is_cpu_place(p1) || is_cuda_pinned_place(p1)) {
return true; return true;
} else { } else {
return boost::get<CUDAPlace>(p1) == boost::get<CUDAPlace>(p2); return boost::get<CUDAPlace>(p1) == boost::get<CUDAPlace>(p2);
......
...@@ -45,12 +45,33 @@ struct CUDAPlace { ...@@ -45,12 +45,33 @@ struct CUDAPlace {
int device; int device;
}; };
struct CUDAPinnedPlace {
CUDAPinnedPlace() {}
// needed for variant equality comparison
inline bool operator==(const CUDAPinnedPlace &) const { return true; }
inline bool operator!=(const CUDAPinnedPlace &) const { return false; }
};
struct IsCUDAPlace : public boost::static_visitor<bool> { struct IsCUDAPlace : public boost::static_visitor<bool> {
bool operator()(const CPUPlace &) const { return false; } bool operator()(const CPUPlace &) const { return false; }
bool operator()(const CUDAPlace &gpu) const { return true; } bool operator()(const CUDAPlace &gpu) const { return true; }
bool operator()(const CUDAPinnedPlace &) const { return false; }
}; };
typedef boost::variant<CUDAPlace, CPUPlace> Place; struct IsCPUPlace : public boost::static_visitor<bool> {
bool operator()(const CPUPlace &cpu) const { return true; }
bool operator()(const CUDAPlace &) const { return false; }
bool operator()(const CUDAPinnedPlace &) const { return false; }
};
struct IsCUDAPinnedPlace : public boost::static_visitor<bool> {
bool operator()(const CPUPlace &) const { return false; }
bool operator()(const CUDAPlace &) const { return false; }
bool operator()(const CUDAPinnedPlace &cuda_pinned) const { return true; }
};
typedef boost::variant<CUDAPlace, CPUPlace, CUDAPinnedPlace> Place;
using PlaceList = std::vector<Place>; using PlaceList = std::vector<Place>;
...@@ -59,9 +80,11 @@ const Place &get_place(); ...@@ -59,9 +80,11 @@ const Place &get_place();
const CUDAPlace default_gpu(); const CUDAPlace default_gpu();
const CPUPlace default_cpu(); const CPUPlace default_cpu();
const CUDAPinnedPlace default_cuda_pinned();
bool is_gpu_place(const Place &); bool is_gpu_place(const Place &);
bool is_cpu_place(const Place &); bool is_cpu_place(const Place &);
bool is_cuda_pinned_place(const Place &);
bool places_are_same_class(const Place &, const Place &); bool places_are_same_class(const Place &, const Place &);
bool is_same_place(const Place &, const Place &); bool is_same_place(const Place &, const Place &);
...@@ -95,6 +118,16 @@ struct PlaceVisitorWrapper ...@@ -95,6 +118,16 @@ struct PlaceVisitorWrapper
#else #else
PADDLE_THROW("Paddle is not compiled with CUDA. Cannot visit cuda device"); PADDLE_THROW("Paddle is not compiled with CUDA. Cannot visit cuda device");
return typename Visitor::result_type(); return typename Visitor::result_type();
#endif
}
typename Visitor::result_type operator()(
const CUDAPinnedPlace &cuda_pinned) const {
#ifdef PADDLE_WITH_CUDA
return visitor_(cuda_pinned);
#else
PADDLE_THROW("Paddle is not compiled with CUDA. Cannot visit cuda_pinned");
return typename Visitor::result_type();
#endif #endif
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册