提交 5398e1a3 编写于 作者: F fengjiayi

fix bugs

上级 f79ca231
...@@ -157,19 +157,19 @@ class CudnnHolder { ...@@ -157,19 +157,19 @@ class CudnnHolder {
void RunFunc(const std::function<void(void*)>& cudnn_func, void RunFunc(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_len) { size_t required_workspace_len) {
framework::RWLockGuard lock_guard(&rw_lock_, std::lock_guard<std::mutex> lock(mtx_);
framework::RWLockGuard::Status::kRDLock);
if (required_workspace_len > workspace_len_) { if (required_workspace_len > workspace_len_) {
lock_guard.UnLock();
lock_guard.WRLock();
ReallocateWorkspace(required_workspace_len); ReallocateWorkspace(required_workspace_len);
lock_guard.UnLock();
lock_guard.RDLock();
} }
cudnn_func(workspace_); cudnn_func(workspace_);
} }
~CudnnHolder() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); } ~CudnnHolder() {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
if (workspace_ != nullptr) {
paddle::memory::Free(place_, workspace_);
}
}
private: private:
void ReallocateWorkspace(size_t required_workspace_len) { void ReallocateWorkspace(size_t required_workspace_len) {
...@@ -194,7 +194,7 @@ class CudnnHolder { ...@@ -194,7 +194,7 @@ class CudnnHolder {
const cudaStream_t* stream_; // not owned; const cudaStream_t* stream_; // not owned;
const CUDAPlace place_; const CUDAPlace place_;
framework::RWLock rw_lock_; std::mutex mtx_;
}; };
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册