提交 55648b4f 编写于 作者: L liaogang

Merge remote-tracking branch 'wangkuiyi/memory_cpu_allocator' into cpu_mem

if(${WITH_GPU})
nv_test(system_allocator_test SRCS system_allocator_test.cc DEPS gflags glog)
nv_library(system_allocator SRCS system_allocator.cc DEPS gflags)
nv_test(system_allocator_test SRCS system_allocator_test.cc DEPS system_allocator gflags)
else(${WITH_GPU})
cc_test(system_allocator_test SRCS system_allocator_test.cc DEPS gflags glog)
cc_library(system_allocator SRCS system_allocator.cc DEPS gflags)
cc_test(system_allocator_test SRCS system_allocator_test.cc DEPS system_allocator gflags)
endif(${WITH_GPU})
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/memory/detail/system_allocator.h"
namespace paddle {
namespace memory {
namespace detail {
BuddyAllocator::BuddyAllocator(size_t pool_size, size_t max_pools,
SystemAllocator* system_allocator)
: pool_size_(pool_size),
max_pools_(max_pools),
system_allocator_(system_allocator) {
PADDLE_ASSERT(pool_size > 0);
PADDLE_ASSERT(max_pools > 0);
PADDLE_ASSERT(system_allocator != nullptr);
}
} // namespace detail
} // namespace memory
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
......@@ -20,34 +20,38 @@ namespace paddle {
namespace memory {
namespace detail {
template<typename Allocator>
class BuddyAllocator {
public:
// TODO(gangliao): This is a draft, add Buddy Allocator Algorithm soon
BuddyAllocator() {}
~BuddyAllocator() {}
BuddyAllocator(size_t pool_size, size_t max_pools,
SystemAllocator* system_allocator);
~BuddyAllocator();
public:
void* Alloc(size_t size) {
return Allocator::Alloc(size);
}
void Free(void*) {
// Because all info like size are stored in meta data,
// thus it's duplicate if add the parameter `size` in
// `Free(void*)` interface.
}
void* Alloc(size_t size);
void Free(void*);
size_t Used();
public:
BuddyAllocator(const BuddyAllocator&) = delete;
BuddyAllocator& operator=(const BuddyAllocator&) = delete;
private:
size_t min_alloc_size_;
size_t max_alloc_size_;
struct Block {
size_t size_;
Block* left_; // left buddy
Block* right_; // right buddy
};
// Initially, there is only one pool. If a Alloc founds not enough
// memory from that pool, and there has not been max_num_pools_,
// create a new pool by calling system_allocator_.Alloc(pool_size_).
std::vector<void*> pools_;
size_t pool_size_; // the size of each pool;
size_t max_num_pools_; // the size of all pools;
SystemAllocator* system_allocator_;
private:
std::mutex mutex_;
// Disable copy and assignment.
BuddyAllocator(const BuddyAllocator&) = delete;
BuddyAllocator& operator=(const BuddyAllocator&) = delete;
};
BuddyAllocator<CPUAllocator>* GetCPUBuddyAllocator() {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/memory/detail/system_allocator.h"
#include <stdlib.h> // for malloc and free
#include <sys/mman.h> // for mlock and munlock
#include "gflags/gflags.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/cuda.h"
// If use_pinned_memory is true, CPUAllocator calls mlock, which
// returns pinned and locked memory as staging areas for data exchange
// between host and device. Allocates too much would reduce the amount
// of memory available to the system for paging. So, by default, we
// should set false to use_pinned_memory.
DEFINE_bool(use_pinned_memory, false,
"If set, allocate cpu/gpu pinned memory.");
namespace paddle {
namespace memory {
namespace detail {
void* CPUAllocator::Alloc(size_t size) {
// According to http://www.cplusplus.com/reference/cstdlib/malloc/,
// malloc might not return nullptr if size is zero, but the returned
// pointer shall not be dereferenced -- so we make it nullptr.
if (size <= 0) return nullptr;
void* p = malloc(size);
if (p != nullptr && FLAGS_use_pinned_memory) {
mlock(p, size);
}
return p;
}
void CPUAllocator::Free(void* p, size_t size) {
if (p != nullptr && FLAGS_use_pinned_memory) {
munlock(p, size);
}
free(p);
}
#ifndef PADDLE_ONLY_CPU
void* GPUAllocator::Alloc(size_t size) {
// CUDA documentation doesn't explain if cudaMalloc returns nullptr
// if size is 0. We just make sure it does.
if (size <= 0) {
return nullptr;
}
void* p = 0;
cudaError_t result =
FLAGS_use_pinned_memory ? cudaMallocHost(&p, size) : cudaMalloc(&p, size);
if (result != cudaSuccess) {
cudaGetLastError(); // clear error if there is any.
}
return result == cudaSuccess ? p : nullptr;
}
void GPUAllocator::Free(void* p, size_t size) {
// Purposefully allow cudaErrorCudartUnloading, because
// that is returned if you ever call cudaFree after the
// driver has already shutdown. This happens only if the
// process is terminating, in which case we don't care if
// cudaFree succeeds.
cudaError_t err = FLAGS_use_pinned_memory ? cudaFreeHost(p) : cudaFree(p);
if (err != cudaErrorCudartUnloading) {
platform::throw_on_error(err, "cudaFree{Host} failed");
}
}
#endif // PADDLE_ONLY_CPU
} // namespace detail
} // namespace memory
} // namespace paddle
......@@ -15,75 +15,37 @@ limitations under the License. */
#pragma once
#include <stddef.h> // for size_t
#include <sys/mman.h> // for mlock and munlock
#include <cstdlib> // for malloc and free
#include <gflags/gflags.h>
#include "paddle/platform/assert.h"
#include "paddle/platform/cuda.h"
DEFINE_bool(uses_pinned_memory, false,
"If set, allocate cpu/gpu pinned memory.");
namespace paddle {
namespace memory {
namespace detail {
// If uses_pinned_memory is true, CPUAllocator calls mlock, which
// returns pinned and locked memory as staging areas for data exchange
// between host and device. Allocates too much would reduce the amount
// of memory available to the system for paging. So, by default, we
// should set false to uses_pinned_memory.
class CPUAllocator {
// SystemAllocator is the parent class of CPUAllocator and
// GPUAllocator. A BuddyAllocator object uses a SystemAllocator*
// pointing to the underlying system allocator. An alternative to
// this class hierarchy is to pass a system allocator class to
// BuddyAllocator as a template parameter. This approach makes
// BuddyAllocator a class template, and it's very complicated
// algorithm would make the buddy_allocator.h messy.
class SystemAllocator {
public:
static void* Alloc(size_t size) {
void* p = std::malloc(size);
if (p != nullptr && FLAGS_uses_pinned_memory) {
mlock(p, size);
}
return p;
}
static void Free(void* p, size_t size) {
if (p != nullptr && FLAGS_uses_pinned_memory) {
munlock(p, size);
}
std::free(p);
}
virtual ~SystemAllocator() {}
virtual void* Alloc(size_t size) = 0;
virtual void Free(void* p, size_t size) = 0;
};
#ifndef PADDLE_ONLY_CPU // The following code are for CUDA.
// GPUAllocator<staging=true> calls cudaHostMalloc, which returns
// pinned and locked memory as staging areas for data exchange
// between host and device. Allocates too much would reduce the
// amount of memory available to the system for paging. So, by
// default, we should use GPUAllocator<staging=false>.
class GPUAllocator {
class CPUAllocator : public SystemAllocator {
public:
static void* Alloc(size_t size) {
void* p = 0;
cudaError_t result = FLAGS_uses_pinned_memory ? cudaMallocHost(&p, size)
: cudaMalloc(&p, size);
if (result != cudaSuccess) {
cudaGetLastError(); // clear error if there is any.
}
return result == cudaSuccess ? p : nullptr;
}
static void Free(void* p, size_t size) {
// Purposefully allow cudaErrorCudartUnloading, because
// that is returned if you ever call cudaFree after the
// driver has already shutdown. This happens only if the
// process is terminating, in which case we don't care if
// cudaFree succeeds.
cudaError_t err = FLAGS_uses_pinned_memory ? cudaFreeHost(p) : cudaFree(p);
if (err != cudaErrorCudartUnloading) {
platform::throw_on_error(err, "cudaFree{Host} failed");
}
}
virtual void* Alloc(size_t size);
virtual void Free(void* p, size_t size);
};
#ifndef PADDLE_ONLY_CPU
class GPUAllocator : public SystemAllocator {
public:
virtual void* Alloc(size_t size);
virtual void Free(void* p, size_t size);
};
#endif // PADDLE_ONLY_CPU
} // namespace detail
......
......@@ -17,44 +17,55 @@ limitations under the License. */
#include <memory>
#include <vector>
#include "glog/logging.h"
#include "gflags/gflags.h"
#include "gtest/gtest.h"
template <typename Allocator>
void TestAllocator(void* p) {
p = Allocator::Alloc(1024);
int* i = static_cast<int*>(p);
std::shared_ptr<int> ptr(i, [](int* p) { Allocator::Free(p, 1024); });
DECLARE_bool(use_pinned_memory);
void TestAllocator(paddle::memory::detail::SystemAllocator& a, size_t size) {
bool freed = false;
{
void* p = a.Alloc(size);
if (size > 0) {
EXPECT_NE(p, nullptr);
} else {
EXPECT_EQ(p, nullptr);
}
int* i = static_cast<int*>(p);
std::shared_ptr<int> ptr(i, [&](void* p) {
freed = true;
a.Free(p, size);
});
}
EXPECT_TRUE(freed);
}
TEST(CPUAllocator, NoLockMem) {
void* p = nullptr;
FLAGS_uses_pinned_memory = false;
TestAllocator<paddle::memory::detail::CPUAllocator>(p);
EXPECT_EQ(p, nullptr);
FLAGS_use_pinned_memory = false;
paddle::memory::detail::CPUAllocator a;
TestAllocator(a, 2048);
TestAllocator(a, 0);
}
TEST(CPUAllocator, LockMem) {
void* p = nullptr;
FLAGS_uses_pinned_memory = true;
TestAllocator<paddle::memory::detail::CPUAllocator>(p);
EXPECT_EQ(p, nullptr);
FLAGS_use_pinned_memory = true;
paddle::memory::detail::CPUAllocator a;
TestAllocator(a, 2048);
TestAllocator(a, 0);
}
#ifndef PADDLE_ONLY_CPU
TEST(GPUAllocator, NoStaging) {
void* p = nullptr;
FLAGS_uses_pinned_memory = false;
TestAllocator<paddle::memory::detail::GPUAllocator>(p);
EXPECT_EQ(p, nullptr);
FLAGS_use_pinned_memory = false;
paddle::memory::detail::GPUAllocator a;
TestAllocator(a, 2048);
TestAllocator(a, 0);
}
TEST(GPUAllocator, Staging) {
void* p = nullptr;
FLAGS_uses_pinned_memory = true;
TestAllocator<paddle::memory::detail::GPUAllocator>(p);
EXPECT_EQ(p, nullptr);
FLAGS_use_pinned_memory = true;
paddle::memory::detail::GPUAllocator a;
TestAllocator(a, 2048);
TestAllocator(a, 0);
}
#endif // PADDLE_ONLY_CPU
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册