提交 c501826f 编写于 作者: F fengjiayi

use framework::RWLock

上级 1f36a4c2
......@@ -56,5 +56,70 @@ struct RWLock {
};
#endif
class RWLockGuard {
public:
enum Status { kUnLock, kWRLock, kRDLock };
RWLockGuard(RWLock* rw_lock, Status init_status)
: lock_(rw_lock), status_(Status::kUnLock) {
switch (init_status) {
case Status::kRDLock: {
RDLock();
break;
}
case Status::kWRLock: {
WRLock();
break;
}
}
}
void WRLock() {
switch (status_) {
case Status::kUnLock: {
lock_->WRLock();
break;
}
case Status::kWRLock: {
break;
}
case Status::kRDLock: {
PADDLE_THROW(
"Please unlock read lock first before invoking write lock.");
break;
}
}
}
void RDLock() {
switch (status_) {
case Status::kUnLock: {
lock_->RDLock();
break;
}
case Status::kRDLock: {
break;
}
case Status::kWRLock: {
PADDLE_THROW(
"Please unlock write lock first before invoking read lock.");
break;
}
}
}
void UnLock() {
if (status_ != Status::kUnLock) {
lock_->UNLock();
}
}
~RWLockGuard() { UnLock(); }
private:
RWLock* lock_;
Status status_;
};
} // namespace framework
} // namespace paddle
......@@ -15,12 +15,11 @@ limitations under the License. */
#include <unordered_set>
#include <vector>
#include "paddle/fluid/memory/memory.h"
#ifdef PADDLE_WITH_CUDA
#include <boost\thread\thread.hpp>
#include "paddle/fluid/framework/rw_lock.h"
#endif
#include "paddle/fluid/memory/memory.h"
namespace paddle {
namespace platform {
......@@ -158,9 +157,14 @@ class CudnnHolder {
void RunFunc(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_len) {
boost::upgrade_lock<boost::shared_mutex> shared_lock(mtx_);
framework::RWLockGuard lock_guard(&rw_lock_,
framework::RWLockGuard::Status::kRDLock);
if (required_workspace_len > workspace_len_) {
ReallocateWorkspace(required_workspace_len, &shared_lock);
lock_guard.UnLock();
lock_guard.WRLock();
ReallocateWorkspace(required_workspace_len);
lock_guard.UnLock();
lock_guard.RDLock();
}
cudnn_func(workspace_);
}
......@@ -168,9 +172,7 @@ class CudnnHolder {
~CudnnHolder() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); }
private:
void ReallocateWorkspace(size_t required_workspace_len,
boost::upgrade_lock<boost::shared_mutex>* lock) {
boost::upgrade_to_unique_lock<boost::shared_mutex> unique_lock(*lock);
void ReallocateWorkspace(size_t required_workspace_len) {
if (required_workspace_len <= workspace_len_) {
return;
}
......@@ -192,7 +194,7 @@ class CudnnHolder {
const cudaStream_t* stream_; // not owned;
const CUDAPlace place_;
boost::shared_mutex mtx_;
framework::RWLock rw_lock_;
};
CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册