提交 f149d183 编写于 作者: Y Yi Wang

Add system_allocator

上级 f7530e89
cc_test(system_allocator_test SRCS system_allocator_test.cc)
if(${WITH_GPU})
nv_test(system_allocator_test SRCS system_allocator_test.cc)
else(${WITH_GPU})
cc_test(system_allocator_test SRCS system_allocator_test.cc)
endif(${WITH_GPU})
......@@ -23,14 +23,31 @@ limitations under the License. */
#include <thrust/system_error.h>
#endif // PADDLE_ONLY_CPU
#include "paddle/platform/assert.h"
namespace paddle {
namespace memory {
namespace detail {
class SystemAllocator {
class CPUDeleter {
public:
virtual void* Alloc(size_t size) = 0;
virtual void* Free(void* p) = 0;
CPUDeleter(void* ptr, size_t size, bool locked)
: ptr_(ptr), size_(size), locked_(locked) {}
void* Ptr() { return ptr_; }
void operator()(void* ptr) {
PADDLE_ASSERT(ptr == ptr_);
if (ptr_ != nullptr && locked_) {
munlock(ptr_, size_);
}
std::free(ptr_);
}
private:
void* ptr_;
size_t size_;
bool locked_;
};
// CPUAllocator<lock_memory=true> calls mlock, which returns pinned
......@@ -39,21 +56,14 @@ class SystemAllocator {
// available to the system for paging. So, by default, we should use
// CPUAllocator<staging=false>.
template <bool lock_memory>
class CPUAllocator : public SystemAllocator {
class CPUAllocator {
public:
virtual void* Alloc(size_t size) {
static CPUDeleter Alloc(size_t size) {
void* p = std::malloc(size);
if (p != nullptr && lock_memory) {
mlock(p, size);
}
return p;
}
virtual void Free(void* p, size_t size) {
if (p != nullptr && lock_memory) {
munlock(p, size);
}
std::free(p);
return CPUDeleter(p, size, lock_memory);
}
};
......@@ -67,6 +77,32 @@ inline void throw_on_error(cudaError_t e, const char* message) {
}
} // namespace
class GPUDeleter {
public:
GPUDeleter(void* ptr, size_t size, bool staging)
: ptr_(ptr), size_(size), staging_(staging) {}
void* Ptr() { return ptr_; }
void operator()(void* ptr) {
PADDLE_ASSERT(ptr == ptr_);
// Purposefully allow cudaErrorCudartUnloading, because
// that is returned if you ever call cudaFree after the
// driver has already shutdown. This happens only if the
// process is terminating, in which case we don't care if
// cudaFree succeeds.
cudaError_t err = staging_ ? cudaFreeHost(ptr) : cudaFree(ptr);
if (err != cudaErrorCudartUnloading) {
throw_on_error(err, "cudaFree{Host} failed");
}
}
private:
void* ptr_;
size_t size_;
bool staging_;
};
// GPUAllocator<staging=true> calls cudaHostMalloc, which returns
// pinned and locked memory as staging areas for data exchange
// between host and device. Allocates too much would reduce the
......@@ -75,28 +111,14 @@ inline void throw_on_error(cudaError_t e, const char* message) {
template <bool staging>
class GPUAllocator {
public:
void* Alloc(size_t size) {
static GPUDeleter Alloc(size_t size) {
void* p = 0;
cudaError_t result =
staging ? cudaMallocHost(&p, size) : cudaMalloc(&p, size);
if (result == cudaSuccess) {
return p;
}
// clear last error
cudaGetLastError();
return nullptr;
}
void Free(void* p, size_t size) {
// Purposefully allow cudaErrorCudartUnloading, because
// that is returned if you ever call cudaFree after the
// driver has already shutdown. This happens only if the
// process is terminating, in which case we don't care if
// cudaFree succeeds.
auto err = staging ? cudaFreeHost(p) : cudaFree(p);
if (err != cudaErrorCudartUnloading) {
throw_on_error(err, "cudaFree failed");
if (result != cudaSuccess) {
cudaGetLastError(); // clear error if there is any.
}
return GPUDeleter(result == cudaSuccess ? p : nullptr, size, staging);
}
};
......
......@@ -13,36 +13,38 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/memory/detail/system_allocator.h"
#include <memory>
#include <vector>
#include "gtest/gtest.h"
TEST(CPUAllocator, NoLockMem) {
paddle::memory::detail::CPUAllocator<false> a;
void* p = a.Alloc(4096);
EXPECT_NE(p, nullptr);
a.Free(p, 4096);
template <typename Allocator>
void TestAllocator() {
{
auto d = Allocator::Alloc(sizeof(int));
EXPECT_NE(d.Ptr(), nullptr);
std::unique_ptr<int> p(static_cast<int*>(d.Ptr()), d);
}
{
auto d = Allocator::Alloc(0);
EXPECT_EQ(d.Ptr(), nullptr);
std::unique_ptr<int> p(static_cast<int*>(d.Ptr()), d);
}
}
TEST(CPUAllocator, NoLockMem) {
TestAllocator<paddle::memory::detail::CPUAllocator<false>>();
}
TEST(CPUAllocator, LockMem) {
paddle::memory::detail::CPUAllocator<true> a;
void* p = a.Alloc(4096);
EXPECT_NE(p, nullptr);
a.Free(p, 4096);
TestAllocator<paddle::memory::detail::CPUAllocator<true>>();
}
#ifndef PADDLE_ONLY_CPU
TEST(GPUAllocator, NonStaging) {
paddle::memory::detail::GPUAllocator<false> a;
void* p = a.Alloc(4096);
EXPECT_NE(p, nullptr);
a.Free(p, 4096);
TEST(GPUAllocator, NoStaging) {
TestAllocator<paddle::memory::detail::GPUAllocator<false>>();
}
TEST(GPUAllocator, Staging) {
paddle::memory::detail::GPUAllocator<true> a;
void* p = a.Alloc(4096);
EXPECT_NE(p, nullptr);
a.Free(p, 4096);
TestAllocator<paddle::memory::detail::GPUAllocator<true>>();
}
#endif // PADDLE_ONLY_CPU
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册