提交 bc921927 编写于 作者: D Dun Liang

Fix Pr #15296

test=develop
上级 db8fcf6b
...@@ -257,7 +257,7 @@ void *Alloc<platform::CUDAPinnedPlace>(const platform::CUDAPinnedPlace &place, ...@@ -257,7 +257,7 @@ void *Alloc<platform::CUDAPinnedPlace>(const platform::CUDAPinnedPlace &place,
void *ptr = buddy_allocator->Alloc(size); void *ptr = buddy_allocator->Alloc(size);
if (ptr == nullptr) { if (ptr == nullptr) {
LOG(WARNING) << "cudaMallocHost Cannot allocate " << size LOG(WARNING) << "cudaHostAlloc Cannot allocate " << size
<< " bytes in CUDAPinnedPlace"; << " bytes in CUDAPinnedPlace";
} }
if (FLAGS_init_allocated_mem) { if (FLAGS_init_allocated_mem) {
......
...@@ -32,7 +32,7 @@ Allocation *CPUPinnedAllocator::AllocateImpl(size_t size, ...@@ -32,7 +32,7 @@ Allocation *CPUPinnedAllocator::AllocateImpl(size_t size,
// "CPUPinnedAllocator should be used for Cross-Device Communication"); // "CPUPinnedAllocator should be used for Cross-Device Communication");
void *ptr; void *ptr;
PADDLE_ENFORCE(cudaMallocHost(&ptr, size)); PADDLE_ENFORCE(cudaHostAlloc(&ptr, size, cudaHostAllocPortable));
return new CPUPinnedAllocation(ptr, size); return new CPUPinnedAllocation(ptr, size);
} }
} // namespace allocation } // namespace allocation
......
...@@ -19,7 +19,7 @@ namespace paddle { ...@@ -19,7 +19,7 @@ namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
// Allocator uses `cudaMallocHost` // Allocator uses `cudaHostAlloc`
class CPUPinnedAllocation : public Allocation { class CPUPinnedAllocation : public Allocation {
public: public:
CPUPinnedAllocation(void *ptr, size_t size) CPUPinnedAllocation(void *ptr, size_t size)
......
...@@ -173,14 +173,14 @@ void* CUDAPinnedAllocator::Alloc(size_t* index, size_t size) { ...@@ -173,14 +173,14 @@ void* CUDAPinnedAllocator::Alloc(size_t* index, size_t size) {
void* p; void* p;
// PINNED memory is visible to all CUDA contexts. // PINNED memory is visible to all CUDA contexts.
cudaError_t result = cudaMallocHost(&p, size); cudaError_t result = cudaHostAlloc(&p, size, cudaHostAllocPortable);
if (result == cudaSuccess) { if (result == cudaSuccess) {
*index = 1; // PINNED memory *index = 1; // PINNED memory
cuda_pinnd_alloc_size_ += size; cuda_pinnd_alloc_size_ += size;
return p; return p;
} else { } else {
LOG(WARNING) << "cudaMallocHost failed."; LOG(WARNING) << "cudaHostAlloc failed.";
return nullptr; return nullptr;
} }
......
...@@ -29,6 +29,7 @@ BufferedReader::~BufferedReader() { ...@@ -29,6 +29,7 @@ BufferedReader::~BufferedReader() {
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device); platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device);
PADDLE_ENFORCE(cudaStreamDestroy(stream)); PADDLE_ENFORCE(cudaStreamDestroy(stream));
for (auto &event : events) PADDLE_ENFORCE(cudaEventDestroy(event));
} }
#endif #endif
} }
...@@ -43,7 +44,14 @@ BufferedReader::BufferedReader( ...@@ -43,7 +44,14 @@ BufferedReader::BufferedReader(
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device); platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device);
PADDLE_ENFORCE(cudaStreamCreate(&stream)); compute_stream =
((platform::CUDADeviceContext *)(platform::DeviceContextPool::Instance()
.Get(place_)))
->stream();
events.resize(buffer_size);
for (auto &event : events)
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
PADDLE_ENFORCE(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
} }
#endif #endif
cpu_buffer_.resize(buffer_size); cpu_buffer_.resize(buffer_size);
...@@ -59,6 +67,12 @@ void BufferedReader::ReadTillBufferFullAsync() { ...@@ -59,6 +67,12 @@ void BufferedReader::ReadTillBufferFullAsync() {
} }
void BufferedReader::ReadAsync(size_t i) { void BufferedReader::ReadAsync(size_t i) {
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device);
PADDLE_ENFORCE(cudaEventRecord(events[i], compute_stream));
}
#endif
position_.emplace(thread_pool_.enqueue([this, i]() -> size_t { position_.emplace(thread_pool_.enqueue([this, i]() -> size_t {
TensorVec &cpu = cpu_buffer_[i]; TensorVec &cpu = cpu_buffer_[i];
reader_->ReadNext(&cpu); reader_->ReadNext(&cpu);
...@@ -71,6 +85,8 @@ void BufferedReader::ReadAsync(size_t i) { ...@@ -71,6 +85,8 @@ void BufferedReader::ReadAsync(size_t i) {
// NOTE(liangdun): using async copy instead of TensorCopySync // NOTE(liangdun): using async copy instead of TensorCopySync
// TensorCopySync would block other stream // TensorCopySync would block other stream
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device);
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, events[i], 0));
TensorVec &gpu = gpu_buffer_[i]; TensorVec &gpu = gpu_buffer_[i];
gpu.resize(cpu.size()); gpu.resize(cpu.size());
for (size_t i = 0; i < cpu.size(); ++i) { for (size_t i = 0; i < cpu.size(); ++i) {
......
...@@ -64,6 +64,8 @@ class BufferedReader : public framework::DecoratedReader { ...@@ -64,6 +64,8 @@ class BufferedReader : public framework::DecoratedReader {
size_t prev_pos_{-1UL}; size_t prev_pos_{-1UL};
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
cudaStream_t stream; cudaStream_t stream;
cudaStream_t compute_stream;
std::vector<cudaEvent_t> events;
#endif #endif
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册