提交 79373dab 编写于 作者: L liaogang

TEST: Add test for system allocator and deleter

上级 b22dd128
......@@ -18,107 +18,69 @@ limitations under the License. */
#include <sys/mman.h> // for mlock and munlock
#include <cstdlib> // for malloc and free
#ifndef PADDLE_ONLY_CPU
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#endif // PADDLE_ONLY_CPU
#include <gflags/gflags.h>
#include "paddle/platform/assert.h"
#include "paddle/platform/cuda.h"
DEFINE_bool(uses_pinned_memory, false,
"If set, allocate cpu/gpu pinned memory.");
namespace paddle {
namespace memory {
namespace detail {
class CPUDeleter {
public:
CPUDeleter(void* ptr, size_t size, bool locked)
: ptr_(ptr), size_(size), locked_(locked) {}
void* Ptr() { return ptr_; }
void operator()(void* ptr) {
PADDLE_ASSERT(ptr == ptr_);
if (ptr_ != nullptr && locked_) {
munlock(ptr_, size_);
}
std::free(ptr_);
}
private:
void* ptr_;
size_t size_;
bool locked_;
};
// CPUAllocator<lock_memory=true> calls mlock, which returns pinned
// and locked memory as staging areas for data exchange between host
// and device. Allocates too much would reduce the amount of memory
// available to the system for paging. So, by default, we should use
// CPUAllocator<staging=false>.
template <bool lock_memory>
// If uses_pinned_memory is true, CPUAllocator calls mlock, which
// returns pinned and locked memory as staging areas for data exchange
// between host and device. Allocates too much would reduce the amount
// of memory available to the system for paging. So, by default, we
// should set false to uses_pinned_memory.
class CPUAllocator {
public:
static CPUDeleter Alloc(size_t size) {
static void* Alloc(size_t size) {
void* p = std::malloc(size);
if (p != nullptr && lock_memory) {
if (p != nullptr && FLAGS_uses_pinned_memory) {
mlock(p, size);
}
return CPUDeleter(p, size, lock_memory);
return p;
}
};
#ifndef PADDLE_ONLY_CPU // The following code are for CUDA.
namespace {
inline void throw_on_error(cudaError_t e, const char* message) {
if (e) {
throw thrust::system_error(e, thrust::cuda_category(), message);
}
}
} // namespace
class GPUDeleter {
public:
GPUDeleter(void* ptr, size_t size, bool staging)
: ptr_(ptr), size_(size), staging_(staging) {}
void* Ptr() { return ptr_; }
void operator()(void* ptr) {
PADDLE_ASSERT(ptr == ptr_);
// Purposefully allow cudaErrorCudartUnloading, because
// that is returned if you ever call cudaFree after the
// driver has already shutdown. This happens only if the
// process is terminating, in which case we don't care if
// cudaFree succeeds.
cudaError_t err = staging_ ? cudaFreeHost(ptr) : cudaFree(ptr);
if (err != cudaErrorCudartUnloading) {
throw_on_error(err, "cudaFree{Host} failed");
static void Free(void* p, size_t size) {
if (p != nullptr && FLAGS_uses_pinned_memory) {
munlock(p, size);
}
std::free(p);
}
private:
void* ptr_;
size_t size_;
bool staging_;
};
#ifndef PADDLE_ONLY_CPU // The following code are for CUDA.
// GPUAllocator<staging=true> calls cudaHostMalloc, which returns
// pinned and locked memory as staging areas for data exchange
// between host and device. Allocates too much would reduce the
// amount of memory available to the system for paging. So, by
// default, we should use GPUAllocator<staging=false>.
template <bool staging>
class GPUAllocator {
public:
static GPUDeleter Alloc(size_t size) {
static void* Alloc(size_t size) {
void* p = 0;
cudaError_t result =
staging ? cudaMallocHost(&p, size) : cudaMalloc(&p, size);
cudaError_t result = FLAGS_uses_pinned_memory ? cudaMallocHost(&p, size)
: cudaMalloc(&p, size);
if (result != cudaSuccess) {
cudaGetLastError(); // clear error if there is any.
}
return GPUDeleter(result == cudaSuccess ? p : nullptr, size, staging);
return result == cudaSuccess ? p : nullptr;
}
static void Free(void* p, size_t size) {
// Purposefully allow cudaErrorCudartUnloading, because
// that is returned if you ever call cudaFree after the
// driver has already shutdown. This happens only if the
// process is terminating, in which case we don't care if
// cudaFree succeeds.
cudaError_t err = FLAGS_uses_pinned_memory ? cudaFreeHost(p) : cudaFree(p);
if (err != cudaErrorCudartUnloading) {
platform::throw_on_error(err, "cudaFree{Host} failed");
}
}
};
......
......@@ -17,34 +17,44 @@ limitations under the License. */
#include <memory>
#include <vector>
#include "glog/logging.h"
#include "gtest/gtest.h"
template <typename Allocator>
void TestAllocator() {
{
auto d = Allocator::Alloc(sizeof(int));
EXPECT_NE(d.Ptr(), nullptr);
std::unique_ptr<int> p(static_cast<int*>(d.Ptr()), d);
}
{
auto d = Allocator::Alloc(0);
EXPECT_EQ(d.Ptr(), nullptr);
std::unique_ptr<int> p(static_cast<int*>(d.Ptr()), d);
}
void TestAllocator(void* p) {
p = Allocator::Alloc(1024);
int* i = static_cast<int*>(p);
std::shared_ptr<int> ptr(i, [](int* p) { Allocator::Free(p, 1024); });
EXPECT_NE(p, nullptr);
}
TEST(CPUAllocator, NoLockMem) {
TestAllocator<paddle::memory::detail::CPUAllocator<false>>();
void* p = nullptr;
FLAGS_uses_pinned_memory = false;
TestAllocator<paddle::memory::detail::CPUAllocator>(p);
EXPECT_EQ(p, nullptr);
}
TEST(CPUAllocator, LockMem) {
TestAllocator<paddle::memory::detail::CPUAllocator<true>>();
void* p = nullptr;
FLAGS_uses_pinned_memory = true;
TestAllocator<paddle::memory::detail::CPUAllocator>(p);
EXPECT_EQ(p, nullptr);
}
#ifndef PADDLE_ONLY_CPU
TEST(GPUAllocator, NoStaging) {
TestAllocator<paddle::memory::detail::GPUAllocator<false>>();
void* p = nullptr;
FLAGS_uses_pinned_memory = false;
TestAllocator<paddle::memory::detail::GPUAllocator>(p);
EXPECT_EQ(p, nullptr);
}
TEST(GPUAllocator, Staging) {
TestAllocator<paddle::memory::detail::GPUAllocator<true>>();
void* p = nullptr;
FLAGS_uses_pinned_memory = true;
TestAllocator<paddle::memory::detail::GPUAllocator>(p);
EXPECT_EQ(p, nullptr);
}
#endif // PADDLE_ONLY_CPU
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册