提交 f79ca231 编写于 作者: F fengjiayi

fix bugs

上级 c501826f
...@@ -71,6 +71,9 @@ class RWLockGuard { ...@@ -71,6 +71,9 @@ class RWLockGuard {
WRLock(); WRLock();
break; break;
} }
case Status::kUnLock: {
break;
}
} }
} }
...@@ -78,6 +81,7 @@ class RWLockGuard { ...@@ -78,6 +81,7 @@ class RWLockGuard {
switch (status_) { switch (status_) {
case Status::kUnLock: { case Status::kUnLock: {
lock_->WRLock(); lock_->WRLock();
status_ = Status::kWRLock;
break; break;
} }
case Status::kWRLock: { case Status::kWRLock: {
...@@ -95,6 +99,7 @@ class RWLockGuard { ...@@ -95,6 +99,7 @@ class RWLockGuard {
switch (status_) { switch (status_) {
case Status::kUnLock: { case Status::kUnLock: {
lock_->RDLock(); lock_->RDLock();
status_ = Status::kRDLock;
break; break;
} }
case Status::kRDLock: { case Status::kRDLock: {
...@@ -111,6 +116,7 @@ class RWLockGuard { ...@@ -111,6 +116,7 @@ class RWLockGuard {
void UnLock() { void UnLock() {
if (status_ != Status::kUnLock) { if (status_ != Status::kUnLock) {
lock_->UNLock(); lock_->UNLock();
status_ = Status::kUnLock;
} }
} }
......
...@@ -230,7 +230,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -230,7 +230,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
// Because beta is zero, it is unnecessary to reset filter_grad. // Because beta is zero, it is unnecessary to reset filter_grad.
// Gradient with respect to the filter // Gradient with respect to the filter
for (int g = 0; g < groups; g++) { for (int g = 0; g < groups; g++) {
auto cudnn_func = [&](void* cudnn_func) { auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_output_desc, handle, &alpha, cudnn_output_desc,
output_grad_data + output_grad_offset * g, cudnn_input_desc, output_grad_data + output_grad_offset * g, cudnn_input_desc,
......
...@@ -176,7 +176,7 @@ class CudnnHolder { ...@@ -176,7 +176,7 @@ class CudnnHolder {
if (required_workspace_len <= workspace_len_) { if (required_workspace_len <= workspace_len_) {
return; return;
} }
void* new_workspace = paddle::memory::Alloc(place_, required_len); void* new_workspace = paddle::memory::Alloc(place_, required_workspace_len);
if (workspace_ != nullptr) { if (workspace_ != nullptr) {
// Maybe someone is using the current workspace // Maybe someone is using the current workspace
PADDLE_ENFORCE(cudaStreamSynchronize(*stream_)); PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
...@@ -184,7 +184,7 @@ class CudnnHolder { ...@@ -184,7 +184,7 @@ class CudnnHolder {
paddle::memory::Free(place_, workspace_); paddle::memory::Free(place_, workspace_);
} }
workspace_ = new_workspace; workspace_ = new_workspace;
workspace_len_ = required_len; workspace_len_ = required_workspace_len;
} }
cudnnHandle_t cudnn_handle_; cudnnHandle_t cudnn_handle_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册