提交 c4eced98 编写于 作者: C chengduozh

fix thread safe bug

test=develop
上级 358e657f
...@@ -92,26 +92,24 @@ platform::TemporaryAllocator& DeviceTemporaryAllocator::Get( ...@@ -92,26 +92,24 @@ platform::TemporaryAllocator& DeviceTemporaryAllocator::Get(
const platform::Place& place, const cudaStream_t& stream) { const platform::Place& place, const cudaStream_t& stream) {
PADDLE_ENFORCE(platform::is_gpu_place(place)); PADDLE_ENFORCE(platform::is_gpu_place(place));
auto place_stream = std::make_pair(place, stream); auto place_stream = std::make_pair(place, stream);
{ std::unique_lock<std::mutex> lock(mtx_);
std::unique_lock<std::mutex> lock(mtx_); auto it = device_allocator_.find(place_stream);
if (!device_allocator_.count(place_stream)) { if (it == device_allocator_.end()) {
device_allocator_[place_stream].reset(new TemporaryAllocator(place)); auto tmp_allocator = new TemporaryAllocator(place);
device_allocator_[place_stream]->SetCallback([stream]() { tmp_allocator->SetCallback([stream]() {
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream));
PADDLE_ENFORCE(cudaGetLastError()); PADDLE_ENFORCE(cudaGetLastError());
}); });
} device_allocator_[place_stream].reset(tmp_allocator);
return *tmp_allocator;
} else {
return *it->second;
} }
return *device_allocator_.at(place_stream);
} }
template <> template <>
platform::TemporaryAllocator& DeviceTemporaryAllocator::Get( platform::TemporaryAllocator& DeviceTemporaryAllocator::Get(
const platform::CUDADeviceContext& dev_ctx) { const platform::CUDADeviceContext& dev_ctx) {
auto place_stream = std::make_pair(dev_ctx.GetPlace(), dev_ctx.stream());
if (device_allocator_.count(place_stream)) {
return *device_allocator_.at(place_stream);
}
return Get(dev_ctx.GetPlace(), dev_ctx.stream()); return Get(dev_ctx.GetPlace(), dev_ctx.stream());
} }
#endif #endif
...@@ -325,7 +323,7 @@ Place CUDADeviceContext::GetPlace() const { return place_; } ...@@ -325,7 +323,7 @@ Place CUDADeviceContext::GetPlace() const { return place_; }
void CUDADeviceContext::Wait() const { void CUDADeviceContext::Wait() const {
auto& allocator = auto& allocator =
DeviceTemporaryAllocator::Instance().Get<CUDADeviceContext>(*this); DeviceTemporaryAllocator::Instance().Get<CUDADeviceContext>(*this);
allocator.Release([=]() { allocator.Release([this]() {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaGetLastError()); PADDLE_ENFORCE(cudaGetLastError());
}); });
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册