提交 407ff0bd 编写于 作者: F fengjiayi

use CudnnHolder in conv_cudnn_op

上级 04bfd5c1
......@@ -159,9 +159,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
"workspace_size to be allocated exceeds the limit");
// Allocate on GPU memory
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// Get cudnn workspace
cudnn_workspace = dev_ctx.cudnn_workspace(workspace_size_in_bytes);
// ------------------- cudnn conv forward ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
for (int i = 0; i < groups; i++) {
......@@ -171,8 +170,6 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes,
&beta, cudnn_output_desc, output_data + i * group_offset_out));
}
// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
}
};
......@@ -315,10 +312,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
}
// ------------------- cudnn conv workspace ---------------------
// Already on GPU
void* cudnn_workspace = nullptr;
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
void* cudnn_workspace = dev_ctx.cudnn_workspace(workspace_size_in_bytes);
// ------------------- cudnn conv backward data ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
if (input_grad) {
......@@ -347,8 +341,6 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
filter_grad_data + i * group_offset_filter));
}
}
// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
}
};
......
......@@ -162,6 +162,7 @@ class CudnnHolder {
paddle::memory::Free(place_, workspace_);
}
workspace_ = new_workspace;
workspace_len_ = required_len;
}
return workspace_
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册