提交 26cd0bb5 编写于 作者: L liaogang

ENH: count allocated fallback size for performance

上级 464886bf
......@@ -39,22 +39,22 @@ void* CPUAllocator::Alloc(size_t& index, size_t size) {
// pointer shall not be dereferenced -- so we make it nullptr.
if (size <= 0) return nullptr;
if (FLAGS_use_pinned_memory) {
index = 0; // unlock memory
void* p = malloc(size);
if (p != nullptr) {
mlock(p, size);
if (FLAGS_use_pinned_memory) {
index = 1;
mlock(p, size); // lock memory
}
}
void* p = malloc(size);
if (p != nullptr && FLAGS_use_pinned_memory) {
mlock(p, size);
}
return p;
}
void CPUAllocator::Free(void* p, size_t size, size_t index) {
if (p != nullptr && FLAGS_use_pinned_memory) {
if (p != nullptr && index == 1) {
munlock(p, size);
}
free(p);
......@@ -73,26 +73,34 @@ void* GPUAllocator::Alloc(size_t& index, size_t size) {
// Reserve memory for page tables, etc.
size_t reserving = capacity - paddle::platform::GpuMaxAllocSize();
size_t remaining = available > reserving ? available - reserving : 0;
size_t usable = available > reserving ? available - reserving : 0;
// If remaining size no less than expected size, using general
// cudaMalloc to allocate GPU memory.
void* p = 0;
if (size <= remaining) {
if (size <= usable) {
cudaError_t result = cudaMalloc(&p, size);
if (result == cudaSuccess) {
index = 0;
total_alloc_size_ += size;
gpu_alloc_size_ += size;
return p;
}
}
// If remaining size less than expected size or cudaMalloc failed,
// cudaMallocHost will be considered as a fallback allocator.
//
// NOTE: here, we use GpuMaxAllocSize() as the maximum memory size
// of host fallback allocation. Allocates too much would reduce
// the amount of memory available to the underlying system for paging.
usable = paddle::platform::GpuMaxAllocSize() - fallback_alloc_size_;
if (size > usable) return nullptr;
cudaError_t result = cudaMallocHost(&p, size);
if (result == cudaSuccess) {
index = 1;
total_alloc_size_ += size;
fallback_alloc_size_ += size;
return p;
}
......@@ -100,16 +108,26 @@ void* GPUAllocator::Alloc(size_t& index, size_t size) {
}
void GPUAllocator::Free(void* p, size_t size, size_t index) {
cudaError_t err;
if (index == 0) {
PADDLE_ASSERT(gpu_alloc_size_ >= size);
gpu_alloc_size_ -= size;
err = cudaFree(p);
} else {
PADDLE_ASSERT(fallback_alloc_size_ >= size);
fallback_alloc_size_ -= size;
err = cudaFreeHost(p);
}
// Purposefully allow cudaErrorCudartUnloading, because
// that is returned if you ever call cudaFree after the
// driver has already shutdown. This happens only if the
// process is terminating, in which case we don't care if
// cudaFree succeeds.
PADDLE_ASSERT(total_alloc_size_ >= size);
total_alloc_size_ -= size;
cudaError_t err = index == 1 ? cudaFreeHost(p) : cudaFree(p);
if (err != cudaErrorCudartUnloading) {
platform::throw_on_error(err, "cudaFree{Host} failed");
platform::throw_on_error(err,
"cudaFree{Host} failed in GPUAllocator::Free.");
}
}
......
......@@ -47,7 +47,8 @@ class GPUAllocator : public SystemAllocator {
virtual void Free(void* p, size_t size, size_t index);
private:
size_t total_alloc_size_ = 0;
size_t gpu_alloc_size_ = 0;
size_t fallback_alloc_size_ = 0;
};
#endif // PADDLE_ONLY_CPU
......
......@@ -1381,7 +1381,7 @@ def inputs(layers, *args):
if len(args) != 0:
layers.extend(args)
Inputs(* [l.name for l in layers])
Inputs(*[l.name for l in layers])
def outputs(layers, *args):
......@@ -1424,7 +1424,7 @@ def outputs(layers, *args):
assert len(layers) > 0
if HasInputsSet(): # input already set
Outputs(* [l.name for l in layers])
Outputs(*[l.name for l in layers])
return # just return outputs.
if len(layers) != 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册