提交 2d956b82 编写于 作者: Z zchen0211

deconv cudnn

上级 7e34b8e3
...@@ -29,7 +29,7 @@ using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; ...@@ -29,7 +29,7 @@ using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using DataLayout = platform::DataLayout; using DataLayout = platform::DataLayout;
using CUDADeviceContext = platform::CUDADeviceContext; using CUDADeviceContext = platform::CUDADeviceContext;
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024; static constexpr size_t kConvCudnnWorkspaceLimitBytes = 1024 * 1024 * 1024;
template <typename T> template <typename T>
class CudnnConvTransposeOpKernel : public framework::OpKernel<T> { class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
...@@ -71,7 +71,7 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -71,7 +71,7 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv workspace --------------------- // ------------------- cudnn conv workspace ---------------------
void* cudnn_workspace = nullptr; void* cudnn_workspace = nullptr;
size_t workspace_size_in_bytes; // final workspace to allocate. size_t workspace_size_in_bytes; // final workspace to allocate.
size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; size_t workspace_size_limit = kConvCudnnWorkspaceLimitBytes;
if (user_workspace_size > 0) { if (user_workspace_size > 0) {
workspace_size_limit = user_workspace_size * 1024 * 1024; workspace_size_limit = user_workspace_size * 1024 * 1024;
} }
...@@ -125,6 +125,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -125,6 +125,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
// cudnn v5 does not support dilations
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int user_workspace_size = ctx.Attr<int>("workspace_size_MB"); int user_workspace_size = ctx.Attr<int>("workspace_size_MB");
...@@ -153,7 +154,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -153,7 +154,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
cudnnConvolutionBwdFilterAlgo_t filter_algo; cudnnConvolutionBwdFilterAlgo_t filter_algo;
size_t bwd_filter_ws_size, fwd_ws_size; size_t bwd_filter_ws_size, fwd_ws_size;
size_t workspace_size_in_bytes = 0; size_t workspace_size_in_bytes = 0;
size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; size_t workspace_size_limit = kConvCudnnWorkspaceLimitBytes;
if (user_workspace_size > 0) { if (user_workspace_size > 0) {
workspace_size_limit = user_workspace_size * 1024 * 1024; workspace_size_limit = user_workspace_size * 1024 * 1024;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册