提交 407ff0bd 编写于 作者: F fengjiayi

use CudnnHolder in conv_cudnn_op

上级 04bfd5c1
...@@ -159,9 +159,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -159,9 +159,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit, PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
"workspace_size to be allocated exceeds the limit"); "workspace_size to be allocated exceeds the limit");
// Allocate on GPU memory // Get cudnn workspace
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace()); cudnn_workspace = dev_ctx.cudnn_workspace(workspace_size_in_bytes);
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv forward --------------------- // ------------------- cudnn conv forward ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f; ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
...@@ -171,8 +170,6 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -171,8 +170,6 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes, cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes,
&beta, cudnn_output_desc, output_data + i * group_offset_out)); &beta, cudnn_output_desc, output_data + i * group_offset_out));
} }
// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
} }
}; };
...@@ -315,10 +312,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -315,10 +312,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
} }
// ------------------- cudnn conv workspace --------------------- // ------------------- cudnn conv workspace ---------------------
// Already on GPU void* cudnn_workspace = dev_ctx.cudnn_workspace(workspace_size_in_bytes);
void* cudnn_workspace = nullptr;
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv backward data --------------------- // ------------------- cudnn conv backward data ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f; ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
if (input_grad) { if (input_grad) {
...@@ -347,8 +341,6 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -347,8 +341,6 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
filter_grad_data + i * group_offset_filter)); filter_grad_data + i * group_offset_filter));
} }
} }
// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
} }
}; };
......
...@@ -162,6 +162,7 @@ class CudnnHolder { ...@@ -162,6 +162,7 @@ class CudnnHolder {
paddle::memory::Free(place_, workspace_); paddle::memory::Free(place_, workspace_);
} }
workspace_ = new_workspace; workspace_ = new_workspace;
workspace_len_ = required_len;
} }
return workspace_ return workspace_
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册