Protential memcpy and malloc error on CUDA
Created by: reyoung
According to the manual of CUDA
, cudaMallocHost
is the suggested API for allocating CPU memory to data exchange between GPU and CPU.
Allocates size bytes of host memory that is page-locked and accessible to the device. The driver tracks the virtual memory ranges allocated with this function and automatically accelerates calls to functions such as cudaMemcpy*(). Since the memory can be accessed directly by the device, it can be read or written with much higher bandwidth than pageable memory obtained with functions such as malloc(). Allocating excessive amounts of memory with cudaMallocHost() may degrade system performance, since it reduces the amount of memory available to the system for paging. As a result, this function is best used sparingly to allocate staging areas for data exchange between host and device.
We should use cudaMallcHost
for CPUAllocator, because
- We use BuddyAllocator to wrap device allocator in memory. Even
cudaMallocHost
is much slower thanmalloc
, we just invokecudaMallocHost
once to alloc memory buffer and use buddy allocator to manage the rest allocating. - All CPU memory of tensor
mlock
to physical memory. This is our current memory strategy. IntroducingcudaMallocHost
will not change this strategy. https://github.com/PaddlePaddle/Paddle/blob/9a8be9dacac4e1286482e3e0289985a19382beb9/paddle/fluid/memory/detail/system_allocator.cc#L59-L60
However, when we change the malloc
to cudaMallocHost
, we found several bugs in Fluid. The PR is #10136.
-
there are many places in our code, which invoke
cudaMemcpyAsync
, are not actually wait or synchronized.- When using
malloc
andcudaMemcpyAsync
, it will copy the CPU memory into a staging area of CUDA driver and then invokecudaMemcpyAsync
to that staging area. It is a synchronized method in CPU and it is slow. If we want toMemcpyAsync
, we should not usemalloc
for data exchanging.
- When using
-
In our experiment,
cudaMemcpyAsync
on a tiny data buffer and invokecudaSynchronizeStream()
immediately, thememcpy
may not be synchronized in multi-threads, multi-streams configuration. It is strange that it should be able to synchronize. However, since it is not the best practice to usemalloc
memory for data exchange and the result is correct when usingcudaMallocHost
, we will usecudaMallocHost
for CPUAllocator to avoid this bug.