From c4eced9881bd6a5c4f47739f4a83c422620f11ac Mon Sep 17 00:00:00 2001 From: chengduozh Date: Fri, 11 Jan 2019 12:52:03 +0800 Subject: [PATCH] fix thread safe bug test=develop --- paddle/fluid/platform/device_context.cc | 28 ++++++++++++------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 09f3d3de5..8f80a2d78 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -92,26 +92,24 @@ platform::TemporaryAllocator& DeviceTemporaryAllocator::Get( const platform::Place& place, const cudaStream_t& stream) { PADDLE_ENFORCE(platform::is_gpu_place(place)); auto place_stream = std::make_pair(place, stream); - { - std::unique_lock lock(mtx_); - if (!device_allocator_.count(place_stream)) { - device_allocator_[place_stream].reset(new TemporaryAllocator(place)); - device_allocator_[place_stream]->SetCallback([stream]() { - PADDLE_ENFORCE(cudaStreamSynchronize(stream)); - PADDLE_ENFORCE(cudaGetLastError()); - }); - } + std::unique_lock lock(mtx_); + auto it = device_allocator_.find(place_stream); + if (it == device_allocator_.end()) { + auto tmp_allocator = new TemporaryAllocator(place); + tmp_allocator->SetCallback([stream]() { + PADDLE_ENFORCE(cudaStreamSynchronize(stream)); + PADDLE_ENFORCE(cudaGetLastError()); + }); + device_allocator_[place_stream].reset(tmp_allocator); + return *tmp_allocator; + } else { + return *it->second; } - return *device_allocator_.at(place_stream); } template <> platform::TemporaryAllocator& DeviceTemporaryAllocator::Get( const platform::CUDADeviceContext& dev_ctx) { - auto place_stream = std::make_pair(dev_ctx.GetPlace(), dev_ctx.stream()); - if (device_allocator_.count(place_stream)) { - return *device_allocator_.at(place_stream); - } return Get(dev_ctx.GetPlace(), dev_ctx.stream()); } #endif @@ -325,7 +323,7 @@ Place CUDADeviceContext::GetPlace() const { return place_; } void CUDADeviceContext::Wait() const { auto& allocator = DeviceTemporaryAllocator::Instance().Get(*this); - allocator.Release([=]() { + allocator.Release([this]() { PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); PADDLE_ENFORCE(cudaGetLastError()); }); -- GitLab