From 407ff0bdbc4a18583634b11c3d80f624459607e0 Mon Sep 17 00:00:00 2001
From: fengjiayi <fengjiayi@baidu.com>
Date: Thu, 30 Aug 2018 15:57:25 +0800
Subject: [PATCH] use CudnnHolder in conv_cudnn_op

---
 paddle/fluid/operators/conv_cudnn_op.cu.cc | 14 +++-----------
 paddle/fluid/platform/device_context.cc    |  1 +
 2 files changed, 4 insertions(+), 11 deletions(-)

diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc
index 22cbf680c06..92435d7c417 100644
--- a/paddle/fluid/operators/conv_cudnn_op.cu.cc
+++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc
@@ -159,9 +159,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
     PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
                       "workspace_size to be allocated exceeds the limit");
 
-    // Allocate on GPU memory
-    platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
-    cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
+    // Get cudnn workspace
+    cudnn_workspace = dev_ctx.cudnn_workspace(workspace_size_in_bytes);
     // ------------------- cudnn conv forward ---------------------
     ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
     for (int i = 0; i < groups; i++) {
@@ -171,8 +170,6 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
           cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes,
           &beta, cudnn_output_desc, output_data + i * group_offset_out));
     }
-    // Release the cudnn workspace
-    paddle::memory::Free(gpu, cudnn_workspace);
   }
 };
 
@@ -315,10 +312,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
       workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
     }
     // ------------------- cudnn conv workspace ---------------------
-    // Already on GPU
-    void* cudnn_workspace = nullptr;
-    platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
-    cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
+    void* cudnn_workspace = dev_ctx.cudnn_workspace(workspace_size_in_bytes);
     // ------------------- cudnn conv backward data ---------------------
     ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
     if (input_grad) {
@@ -347,8 +341,6 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
             filter_grad_data + i * group_offset_filter));
       }
     }
-    // Release the cudnn workspace
-    paddle::memory::Free(gpu, cudnn_workspace);
   }
 };
 
diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc
index 01fa9301d61..3f8da69fc2f 100644
--- a/paddle/fluid/platform/device_context.cc
+++ b/paddle/fluid/platform/device_context.cc
@@ -162,6 +162,7 @@ class CudnnHolder {
         paddle::memory::Free(place_, workspace_);
       }
       workspace_ = new_workspace;
+      workspace_len_ = required_len;
     }
     return workspace_
   }
-- 
GitLab