提交 c4eced98 编写于 作者: C chengduozh

fix thread safe bug

test=develop
上级 358e657f
......@@ -92,26 +92,24 @@ platform::TemporaryAllocator& DeviceTemporaryAllocator::Get(
const platform::Place& place, const cudaStream_t& stream) {
PADDLE_ENFORCE(platform::is_gpu_place(place));
auto place_stream = std::make_pair(place, stream);
{
std::unique_lock<std::mutex> lock(mtx_);
if (!device_allocator_.count(place_stream)) {
device_allocator_[place_stream].reset(new TemporaryAllocator(place));
device_allocator_[place_stream]->SetCallback([stream]() {
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
PADDLE_ENFORCE(cudaGetLastError());
});
}
std::unique_lock<std::mutex> lock(mtx_);
auto it = device_allocator_.find(place_stream);
if (it == device_allocator_.end()) {
auto tmp_allocator = new TemporaryAllocator(place);
tmp_allocator->SetCallback([stream]() {
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
PADDLE_ENFORCE(cudaGetLastError());
});
device_allocator_[place_stream].reset(tmp_allocator);
return *tmp_allocator;
} else {
return *it->second;
}
return *device_allocator_.at(place_stream);
}
template <>
platform::TemporaryAllocator& DeviceTemporaryAllocator::Get(
const platform::CUDADeviceContext& dev_ctx) {
auto place_stream = std::make_pair(dev_ctx.GetPlace(), dev_ctx.stream());
if (device_allocator_.count(place_stream)) {
return *device_allocator_.at(place_stream);
}
return Get(dev_ctx.GetPlace(), dev_ctx.stream());
}
#endif
......@@ -325,7 +323,7 @@ Place CUDADeviceContext::GetPlace() const { return place_; }
void CUDADeviceContext::Wait() const {
auto& allocator =
DeviceTemporaryAllocator::Instance().Get<CUDADeviceContext>(*this);
allocator.Release([=]() {
allocator.Release([this]() {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaGetLastError());
});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册