Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
db8c52da
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
db8c52da
编写于
11月 07, 2018
作者:
Q
qingqing01
提交者:
GitHub
11月 07, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Revert " Exhaustive search for cuDNN conv. (#14043)"
This reverts commit
ce7d9b07
.
上级
ce7d9b07
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
74 addition
and
381 deletion
+74
-381
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+0
-1
paddle/fluid/inference/api/analysis_predictor.h
paddle/fluid/inference/api/analysis_predictor.h
+0
-2
paddle/fluid/inference/api/helper.h
paddle/fluid/inference/api/helper.h
+1
-2
paddle/fluid/inference/io.cc
paddle/fluid/inference/io.cc
+1
-2
paddle/fluid/operators/add_position_encoding_op.h
paddle/fluid/operators/add_position_encoding_op.h
+3
-4
paddle/fluid/operators/conv_cudnn_op.cu.cc
paddle/fluid/operators/conv_cudnn_op.cu.cc
+19
-185
paddle/fluid/operators/conv_cudnn_op_cache.h
paddle/fluid/operators/conv_cudnn_op_cache.h
+0
-90
paddle/fluid/operators/conv_op.cc
paddle/fluid/operators/conv_op.cc
+1
-10
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+1
-4
paddle/fluid/platform/dynload/cudnn.h
paddle/fluid/platform/dynload/cudnn.h
+45
-48
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+1
-2
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+1
-16
python/paddle/fluid/tests/unittests/test_conv2d_op.py
python/paddle/fluid/tests/unittests/test_conv2d_op.py
+1
-9
python/paddle/fluid/tests/unittests/test_conv3d_op.py
python/paddle/fluid/tests/unittests/test_conv3d_op.py
+0
-6
未找到文件。
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
db8c52da
...
...
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <array>
#include <string>
#include <vector>
...
...
paddle/fluid/inference/api/analysis_predictor.h
浏览文件 @
db8c52da
...
...
@@ -13,8 +13,6 @@
// limitations under the License.
#pragma once
#include <algorithm>
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/naive_executor.h"
...
...
paddle/fluid/inference/api/helper.h
浏览文件 @
db8c52da
...
...
@@ -16,14 +16,13 @@
#include <glog/logging.h>
#include <sys/time.h>
#include <algorithm>
#include <chrono> // NOLINT
#include <numeric>
#include <sstream>
#include <string>
#include <vector>
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/string/printf.h"
#include "paddle_inference_api.h"
namespace
paddle
{
namespace
inference
{
...
...
paddle/fluid/inference/io.cc
浏览文件 @
db8c52da
...
...
@@ -59,8 +59,7 @@ void ReadBinaryFile(const std::string& filename, std::string* contents) {
bool
IsPersistable
(
const
framework
::
VarDesc
*
var
)
{
if
(
var
->
Persistable
()
&&
var
->
GetType
()
!=
framework
::
proto
::
VarType
::
FEED_MINIBATCH
&&
var
->
GetType
()
!=
framework
::
proto
::
VarType
::
FETCH_LIST
&&
var
->
GetType
()
!=
framework
::
proto
::
VarType
::
RAW
)
{
var
->
GetType
()
!=
framework
::
proto
::
VarType
::
FETCH_LIST
)
{
return
true
;
}
return
false
;
...
...
paddle/fluid/operators/add_position_encoding_op.h
浏览文件 @
db8c52da
...
...
@@ -66,10 +66,9 @@ class AddPositionEncodingKernel : public framework::OpKernel<T> {
x_lod
.
empty
()
?
max_seq_len
:
x_lod
[
0
][
i
+
1
]
-
x_lod
[
0
][
i
];
for
(
int
j
=
0
;
j
<
max_length
;
++
j
)
{
for
(
int
k
=
0
;
k
<
half_size
;
++
k
)
{
const
double
val
=
(
half_size
>
1
)
?
j
/
pow
(
10000.0
,
static_cast
<
double
>
(
k
)
/
(
half_size
-
1
))
:
j
/
10000.0
;
const
double
val
=
(
half_size
>
1
)
?
j
/
pow
(
10000.0
,
double
(
k
)
/
(
half_size
-
1
))
:
j
/
10000.0
;
dst_ptr
[
k
]
=
src_ptr
[
k
]
*
alpha
+
sin
(
val
)
*
beta
;
dst_ptr
[
half_size
+
k
]
=
src_ptr
[
half_size
+
k
]
*
alpha
+
cos
(
val
)
*
beta
;
...
...
paddle/fluid/operators/conv_cudnn_op.cu.cc
浏览文件 @
db8c52da
...
...
@@ -15,22 +15,15 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_bool
(
cudnn_deterministic
,
false
,
"Whether allow using an autotuning algorithm for convolution "
"operator. The autotuning algorithm may be non-deterministic. If "
"true, the algorithm is deterministic."
);
DEFINE_uint64
(
conv_workspace_size_limit
,
4096
,
"cuDNN convolution workspace limit in MB unit."
);
DEFINE_bool
(
cudnn_exhaustive_search
,
false
,
"Whether enable exhaustive search for cuDNN convolution or "
"not, defalut is False."
);
namespace
paddle
{
namespace
operators
{
...
...
@@ -43,25 +36,13 @@ using DataLayout = platform::DataLayout;
template
<
typename
T
>
using
ScalingParamType
=
typename
platform
::
CudnnDataType
<
T
>::
ScalingParamType
;
static
constexpr
char
kCUDNNFwdAlgoCache
[]
=
"kCUDNNFwdAlgoCache"
;
static
constexpr
char
kCUDNNBwdDataAlgoCache
[]
=
"kCUDNNBwdDataAlgoCache"
;
static
constexpr
char
kCUDNNBwdFilterAlgoCache
[]
=
"kCUDNNBwdFilterAlgoCache"
;
static
constexpr
size_t
kCONV_CUDNN_WORKSPACE_LIMIT_BYTES
=
static_cast
<
size_t
>
(
1024
)
*
1024
*
1024
;
static
constexpr
size_t
kNUM_CUDNN_FWD_ALGS
=
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT
;
static
constexpr
size_t
kNUM_CUDNN_BWD_FILTER_ALGS
=
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT
;
static
constexpr
size_t
kNUM_CUDNN_BWD_DATA_ALGS
=
CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT
;
template
<
typename
T
>
class
CUDNNConvOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use CUDAPlace."
);
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
...
...
@@ -74,8 +55,6 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
int
groups
=
ctx
.
Attr
<
int
>
(
"groups"
);
int64_t
user_workspace_size
=
static_cast
<
size_t
>
(
ctx
.
Attr
<
int
>
(
"workspace_size_MB"
));
bool
exhaustive_search
=
FLAGS_cudnn_exhaustive_search
||
ctx
.
Attr
<
bool
>
(
"exhaustive_search"
);
const
T
*
input_data
=
input
->
data
<
T
>
();
const
T
*
filter_data
=
filter
->
data
<
T
>
();
...
...
@@ -141,18 +120,19 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv workspace ---------------------
size_t
workspace_size_in_bytes
;
// final workspace to allocate.
size_t
workspace_size_limit
=
kCONV_CUDNN_WORKSPACE_LIMIT_BYTES
;
if
(
FLAGS_conv_workspace_size_limit
>
0
||
user_workspace_size
>
0
)
{
int64_t
max_user_size
=
std
::
max
(
static_cast
<
int64_t
>
(
FLAGS_conv_workspace_size_limit
),
user_workspace_size
);
workspace_size_limit
=
max_user_size
*
1024
*
1024
;
if
(
user_workspace_size
>
0
)
{
workspace_size_limit
=
user_workspace_size
*
1024
*
1024
;
}
// ------------------- cudnn conv algorithm ---------------------
cudnnConvolutionFwdAlgo_t
algo
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
bool
half_float
=
false
;
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionForwardAlgorithm
(
handle
,
cudnn_input_desc
,
cudnn_filter_desc
,
cudnn_conv_desc
,
cudnn_output_desc
,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
,
workspace_size_limit
,
&
algo
));
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
// Tensor core is supported since the volta GPU and
// is only enabled when input and filter data are float16
...
...
@@ -163,65 +143,12 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnn_conv_desc
,
CUDNN_TENSOR_OP_MATH
));
// Currently tensor core is only enabled using this algo
algo
=
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
;
half_float
=
true
;
}
else
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetConvolutionMathType
(
cudnn_conv_desc
,
CUDNN_DEFAULT_MATH
));
}
#endif
auto
x_dims
=
framework
::
vectorize
(
input
->
dims
());
auto
f_dims
=
framework
::
vectorize
(
filter
->
dims
());
if
((
!
exhaustive_search
)
&&
(
!
half_float
))
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionForwardAlgorithm
(
handle
,
cudnn_input_desc
,
cudnn_filter_desc
,
cudnn_conv_desc
,
cudnn_output_desc
,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
,
workspace_size_limit
,
&
algo
));
VLOG
(
3
)
<<
"cuDNN forward algo "
<<
algo
;
}
else
if
(
exhaustive_search
&&
(
!
half_float
))
{
AlgorithmsCache
<
cudnnConvolutionFwdAlgo_t
>*
algo_cache
=
nullptr
;
if
(
ctx
.
scope
().
FindVar
(
kCUDNNFwdAlgoCache
))
{
algo_cache
=
ctx
.
scope
()
.
FindVar
(
kCUDNNFwdAlgoCache
)
->
GetMutable
<
AlgorithmsCache
<
cudnnConvolutionFwdAlgo_t
>>
();
}
else
{
algo_cache
=
const_cast
<
framework
::
Scope
&>
(
ctx
.
scope
())
.
Var
(
kCUDNNFwdAlgoCache
)
->
GetMutable
<
AlgorithmsCache
<
cudnnConvolutionFwdAlgo_t
>>
();
}
algo
=
algo_cache
->
GetAlgorithm
(
x_dims
,
f_dims
,
strides
,
paddings
,
dilations
,
0
,
[
&
]()
{
int
returned_algo_count
;
std
::
array
<
cudnnConvolutionFwdAlgoPerf_t
,
kNUM_CUDNN_FWD_ALGS
>
fwd_perf_stat
;
auto
cudnn_find_func
=
[
&
](
void
*
cudnn_workspace
)
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnFindConvolutionForwardAlgorithmEx
(
handle
,
cudnn_input_desc
,
input_data
,
cudnn_filter_desc
,
filter_data
,
cudnn_conv_desc
,
cudnn_output_desc
,
output_data
,
kNUM_CUDNN_FWD_ALGS
,
&
returned_algo_count
,
fwd_perf_stat
.
data
(),
cudnn_workspace
,
workspace_size_limit
));
};
dev_ctx
.
RunCudnnFuncWithWorkspace
(
cudnn_find_func
,
workspace_size_limit
);
VLOG
(
3
)
<<
"Perf result: (algo: stat, time, memory)"
;
for
(
int
i
=
0
;
i
<
returned_algo_count
;
++
i
)
{
const
auto
&
stat
=
fwd_perf_stat
[
i
];
VLOG
(
3
)
<<
stat
.
algo
<<
": "
<<
stat
.
status
<<
" "
<<
stat
.
time
<<
" "
<<
stat
.
memory
;
}
return
fwd_perf_stat
[
0
].
algo
;
});
VLOG
(
3
)
<<
"choose algo "
<<
algo
;
}
else
{
PADDLE_ENFORCE
(
half_float
,
"cuDNN exhaustive search doesn't support half float."
);
}
// get workspace size able to allocate
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionForwardWorkspaceSize
(
handle
,
cudnn_input_desc
,
cudnn_filter_desc
,
cudnn_conv_desc
,
...
...
@@ -251,7 +178,6 @@ template <typename T>
class
CUDNNConvGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use CUDAPlace."
);
auto
input
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
...
...
@@ -270,13 +196,6 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
int
groups
=
ctx
.
Attr
<
int
>
(
"groups"
);
int64_t
user_workspace_size
=
static_cast
<
size_t
>
(
ctx
.
Attr
<
int
>
(
"workspace_size_MB"
));
bool
exhaustive_search
=
FLAGS_cudnn_exhaustive_search
||
ctx
.
Attr
<
bool
>
(
"exhaustive_search"
);
if
(
exhaustive_search
&&
FLAGS_cudnn_deterministic
)
{
PADDLE_THROW
(
"Cann't set exhaustive_search True and "
"FLAGS_cudnn_deterministic True at same time."
);
}
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor
input_desc
;
...
...
@@ -344,65 +263,14 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
cudnnConvolutionBwdFilterAlgo_t
filter_algo
;
size_t
workspace_size_in_bytes
=
0
,
tmp_size
=
0
;
size_t
workspace_size_limit
=
kCONV_CUDNN_WORKSPACE_LIMIT_BYTES
;
if
(
FLAGS_conv_workspace_size_limit
>
0
||
user_workspace_size
>
0
)
{
int64_t
max_user_size
=
std
::
max
(
static_cast
<
int64_t
>
(
FLAGS_conv_workspace_size_limit
),
user_workspace_size
);
workspace_size_limit
=
max_user_size
*
1024
*
1024
;
if
(
user_workspace_size
>
0
)
{
workspace_size_limit
=
user_workspace_size
*
1024
*
1024
;
}
auto
x_dims
=
framework
::
vectorize
(
input
->
dims
());
auto
f_dims
=
framework
::
vectorize
(
filter
->
dims
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
if
(
input_grad
)
{
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
exhaustive_search
)
{
AlgorithmsCache
<
cudnnConvolutionBwdDataAlgo_t
>*
data_algo_cache
;
if
(
ctx
.
scope
().
FindVar
(
kCUDNNBwdDataAlgoCache
))
{
data_algo_cache
=
ctx
.
scope
()
.
FindVar
(
kCUDNNBwdDataAlgoCache
)
->
GetMutable
<
AlgorithmsCache
<
cudnnConvolutionBwdDataAlgo_t
>>
();
}
else
{
data_algo_cache
=
const_cast
<
framework
::
Scope
&>
(
ctx
.
scope
())
.
Var
(
kCUDNNBwdDataAlgoCache
)
->
GetMutable
<
AlgorithmsCache
<
cudnnConvolutionBwdDataAlgo_t
>>
();
}
data_algo
=
data_algo_cache
->
GetAlgorithm
(
x_dims
,
f_dims
,
strides
,
paddings
,
dilations
,
0
,
[
&
]()
{
int
returned_algo_count
;
std
::
array
<
cudnnConvolutionBwdDataAlgoPerf_t
,
kNUM_CUDNN_BWD_DATA_ALGS
>
data_perf_stat
;
auto
cudnn_find_func
=
[
&
](
void
*
cudnn_workspace
)
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnFindConvolutionBackwardDataAlgorithmEx
(
handle
,
cudnn_filter_desc
,
filter_data
,
cudnn_output_grad_desc
,
output_grad_data
,
cudnn_conv_desc
,
cudnn_input_desc
,
input_grad_data
,
kNUM_CUDNN_BWD_DATA_ALGS
,
&
returned_algo_count
,
data_perf_stat
.
data
(),
cudnn_workspace
,
workspace_size_limit
));
};
dev_ctx
.
RunCudnnFuncWithWorkspace
(
cudnn_find_func
,
workspace_size_limit
);
VLOG
(
3
)
<<
"Perf result: (algo: stat, time, memory)"
;
for
(
int
i
=
0
;
i
<
returned_algo_count
;
++
i
)
{
const
auto
&
stat
=
data_perf_stat
[
i
];
VLOG
(
3
)
<<
stat
.
algo
<<
": "
<<
stat
.
status
<<
" "
<<
stat
.
time
<<
" "
<<
stat
.
memory
;
}
return
data_perf_stat
[
0
].
algo
;
});
VLOG
(
3
)
<<
"cuDNN backward data algo "
<<
data_algo
;
}
else
if
(
FLAGS_cudnn_deterministic
)
{
data_algo
=
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1
;
}
else
{
if
(
!
FLAGS_cudnn_deterministic
)
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionBackwardDataAlgorithm
(
handle
,
cudnn_filter_desc
,
...
...
@@ -415,7 +283,10 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
cudnn_input_desc
,
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
,
workspace_size_limit
,
&
data_algo
));
}
else
{
data_algo
=
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1
;
}
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionBackwardDataWorkspaceSize
(
handle
,
cudnn_filter_desc
,
cudnn_output_grad_desc
,
...
...
@@ -424,54 +295,17 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
}
if
(
filter_grad
)
{
T
*
filter_grad_data
=
filter_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
exhaustive_search
)
{
AlgorithmsCache
<
cudnnConvolutionBwdFilterAlgo_t
>*
f_algo_cache
;
if
(
ctx
.
scope
().
FindVar
(
kCUDNNBwdFilterAlgoCache
))
{
f_algo_cache
=
ctx
.
scope
()
.
FindVar
(
kCUDNNBwdFilterAlgoCache
)
->
GetMutable
<
AlgorithmsCache
<
cudnnConvolutionBwdFilterAlgo_t
>>
();
}
else
{
f_algo_cache
=
const_cast
<
framework
::
Scope
&>
(
ctx
.
scope
())
.
Var
(
kCUDNNBwdFilterAlgoCache
)
->
GetMutable
<
AlgorithmsCache
<
cudnnConvolutionBwdFilterAlgo_t
>>
();
}
filter_algo
=
f_algo_cache
->
GetAlgorithm
(
x_dims
,
f_dims
,
strides
,
paddings
,
dilations
,
0
,
[
&
]()
{
int
returned_algo_count
;
std
::
array
<
cudnnConvolutionBwdFilterAlgoPerf_t
,
kNUM_CUDNN_BWD_FILTER_ALGS
>
filter_perf_stat
;
auto
cudnn_find_f_func
=
[
&
](
void
*
cudnn_workspace
)
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnFindConvolutionBackwardFilterAlgorithmEx
(
handle
,
cudnn_input_desc
,
input_data
,
cudnn_output_grad_desc
,
output_grad_data
,
cudnn_conv_desc
,
cudnn_filter_desc
,
filter_grad_data
,
kNUM_CUDNN_BWD_FILTER_ALGS
,
&
returned_algo_count
,
filter_perf_stat
.
data
(),
cudnn_workspace
,
workspace_size_limit
));
};
dev_ctx
.
RunCudnnFuncWithWorkspace
(
cudnn_find_f_func
,
workspace_size_limit
);
return
filter_perf_stat
[
0
].
algo
;
});
VLOG
(
3
)
<<
"cuDNN backward filter algo "
<<
filter_algo
;
}
else
if
(
FLAGS_cudnn_deterministic
)
{
filter_algo
=
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1
;
}
else
{
if
(
!
FLAGS_cudnn_deterministic
)
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionBackwardFilterAlgorithm
(
handle
,
cudnn_input_desc
,
cudnn_output_grad_desc
,
cudnn_conv_desc
,
cudnn_filter_desc
,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
,
workspace_size_limit
,
&
filter_algo
));
}
else
{
filter_algo
=
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1
;
}
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionBackwardFilterWorkspaceSize
(
handle
,
cudnn_input_desc
,
cudnn_output_grad_desc
,
cudnn_conv_desc
,
...
...
paddle/fluid/operators/conv_cudnn_op_cache.h
已删除
100644 → 0
浏览文件 @
ce7d9b07
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include <unordered_map>
#include <vector>
namespace
paddle
{
namespace
operators
{
template
<
typename
TAlgorithm
>
class
AlgorithmsCache
{
public:
// Caches the best algorithm for a given
// combination of tensor dimensions & compute data type.
TAlgorithm
GetAlgorithm
(
const
std
::
vector
<
int64_t
>&
dims1
,
const
std
::
vector
<
int64_t
>&
dims2
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
int
algorithmFlags
,
// can set for different data type
std
::
function
<
TAlgorithm
()
>
gen_func
);
private:
std
::
unordered_map
<
int64_t
,
TAlgorithm
>
hash_
;
std
::
mutex
mutex_
;
};
template
<
typename
TAlgorithm
>
TAlgorithm
AlgorithmsCache
<
TAlgorithm
>::
GetAlgorithm
(
const
std
::
vector
<
int64_t
>&
dims1
,
const
std
::
vector
<
int64_t
>&
dims2
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
int
algorithmFlags
,
std
::
function
<
TAlgorithm
()
>
gen_func
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
int64_t
seed
=
0
;
// Hash all of the inputs, use to try and look up a previously
// discovered algorithm, or fall back to generating a new one.
std
::
hash
<
int64_t
>
hashFn
;
// do hash like boost
// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x
for
(
const
auto
num
:
dims1
)
{
seed
^=
hashFn
(
num
)
+
0x9e3779b9
+
(
seed
<<
6
)
+
(
seed
>>
2
);
}
for
(
const
auto
num
:
dims2
)
{
seed
^=
hashFn
(
num
)
+
0x9e3779b9
+
(
seed
<<
6
)
+
(
seed
>>
2
)
+
1
;
}
for
(
const
auto
num
:
strides
)
{
seed
^=
hashFn
(
static_cast
<
int64_t
>
(
num
))
+
0x9e3779b9
+
(
seed
<<
6
)
+
(
seed
>>
2
)
+
2
;
}
for
(
const
auto
num
:
paddings
)
{
seed
^=
hashFn
(
static_cast
<
int64_t
>
(
num
))
+
0x9e3779b9
+
(
seed
<<
6
)
+
(
seed
>>
2
)
+
3
;
}
for
(
const
auto
num
:
dilations
)
{
seed
^=
hashFn
(
static_cast
<
int64_t
>
(
num
))
+
0x9e3779b9
+
(
seed
<<
6
)
+
(
seed
>>
2
)
+
4
;
}
seed
^=
hashFn
(
static_cast
<
int64_t
>
(
algorithmFlags
))
+
0x9e3779b9
+
(
seed
<<
6
)
+
(
seed
>>
2
)
+
5
;
if
(
seed
==
0
)
return
gen_func
();
if
(
hash_
.
find
(
seed
)
==
hash_
.
end
())
{
TAlgorithm
value
=
gen_func
();
hash_
[
seed
]
=
value
;
}
return
hash_
[
seed
];
}
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/conv_op.cc
浏览文件 @
db8c52da
...
...
@@ -189,11 +189,6 @@ void Conv2DOpMaker::Make() {
"workspace size can increase performance but also requires "
"better hardware. This size should be chosen carefully."
)
.
SetDefault
(
4096
);
AddAttr
<
bool
>
(
"exhaustive_search"
,
"(bool, default false) cuDNN has many algorithm to calculation "
"convolution, whether enable exhaustive search "
,
"for cuDNN convolution or not, defalut is False."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
Convolution Operator.
...
...
@@ -288,11 +283,7 @@ void Conv3DOpMaker::Make() {
"workspace size can increase performance but also requires "
"better hardware. This size should be chosen carefully."
)
.
SetDefault
(
4096
);
AddAttr
<
bool
>
(
"exhaustive_search"
,
"(bool, default false) cuDNN has many algorithm to calculation "
"convolution, whether enable exhaustive search "
,
"for cuDNN convolution or not, defalut is False."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
Convolution3D Operator.
...
...
paddle/fluid/platform/device_context.cc
浏览文件 @
db8c52da
...
...
@@ -204,10 +204,7 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
<<
"."
<<
(
driver_version_
%
100
)
/
10
<<
", Runtime Version: "
<<
runtime_version_
/
1000
<<
"."
<<
(
runtime_version_
%
100
)
/
10
;
size_t
cudnn_dso_ver
=
dynload
::
cudnnGetVersion
();
LOG
(
INFO
)
<<
"device: "
<<
place_
.
device
<<
", cuDNN Version: "
<<
cudnn_dso_ver
/
1000
<<
"."
<<
(
cudnn_dso_ver
%
100
)
/
10
<<
"."
;
callback_manager_
.
reset
(
new
StreamCallbackManager
(
stream_
));
}
...
...
paddle/fluid/platform/dynload/cudnn.h
浏览文件 @
db8c52da
...
...
@@ -65,54 +65,51 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
* include all needed cudnn functions in HPPL
* different cudnn version has different interfaces
**/
#define CUDNN_DNN_ROUTINE_EACH(__macro) \
__macro(cudnnSetTensor4dDescriptor); \
__macro(cudnnSetTensor4dDescriptorEx); \
__macro(cudnnSetTensorNdDescriptor); \
__macro(cudnnGetTensorNdDescriptor); \
__macro(cudnnGetConvolutionNdForwardOutputDim); \
__macro(cudnnGetConvolutionForwardAlgorithm); \
__macro(cudnnCreateTensorDescriptor); \
__macro(cudnnDestroyTensorDescriptor); \
__macro(cudnnCreateFilterDescriptor); \
__macro(cudnnSetFilter4dDescriptor); \
__macro(cudnnSetFilterNdDescriptor); \
__macro(cudnnGetFilterNdDescriptor); \
__macro(cudnnSetPooling2dDescriptor); \
__macro(cudnnSetPoolingNdDescriptor); \
__macro(cudnnGetPoolingNdDescriptor); \
__macro(cudnnDestroyFilterDescriptor); \
__macro(cudnnCreateConvolutionDescriptor); \
__macro(cudnnCreatePoolingDescriptor); \
__macro(cudnnDestroyPoolingDescriptor); \
__macro(cudnnSetConvolution2dDescriptor); \
__macro(cudnnDestroyConvolutionDescriptor); \
__macro(cudnnSetConvolutionNdDescriptor); \
__macro(cudnnGetConvolutionNdDescriptor); \
__macro(cudnnDeriveBNTensorDescriptor); \
__macro(cudnnCreateSpatialTransformerDescriptor); \
__macro(cudnnSetSpatialTransformerNdDescriptor); \
__macro(cudnnDestroySpatialTransformerDescriptor); \
__macro(cudnnSpatialTfGridGeneratorForward); \
__macro(cudnnSpatialTfGridGeneratorBackward); \
__macro(cudnnSpatialTfSamplerForward); \
__macro(cudnnSpatialTfSamplerBackward); \
__macro(cudnnCreate); \
__macro(cudnnDestroy); \
__macro(cudnnSetStream); \
__macro(cudnnActivationForward); \
__macro(cudnnConvolutionForward); \
__macro(cudnnConvolutionBackwardBias); \
__macro(cudnnGetConvolutionForwardWorkspaceSize); \
__macro(cudnnTransformTensor); \
__macro(cudnnPoolingForward); \
__macro(cudnnPoolingBackward); \
__macro(cudnnSoftmaxBackward); \
__macro(cudnnSoftmaxForward); \
__macro(cudnnGetVersion); \
__macro(cudnnFindConvolutionForwardAlgorithmEx); \
__macro(cudnnFindConvolutionBackwardFilterAlgorithmEx); \
__macro(cudnnFindConvolutionBackwardDataAlgorithmEx); \
#define CUDNN_DNN_ROUTINE_EACH(__macro) \
__macro(cudnnSetTensor4dDescriptor); \
__macro(cudnnSetTensor4dDescriptorEx); \
__macro(cudnnSetTensorNdDescriptor); \
__macro(cudnnGetTensorNdDescriptor); \
__macro(cudnnGetConvolutionNdForwardOutputDim); \
__macro(cudnnGetConvolutionForwardAlgorithm); \
__macro(cudnnCreateTensorDescriptor); \
__macro(cudnnDestroyTensorDescriptor); \
__macro(cudnnCreateFilterDescriptor); \
__macro(cudnnSetFilter4dDescriptor); \
__macro(cudnnSetFilterNdDescriptor); \
__macro(cudnnGetFilterNdDescriptor); \
__macro(cudnnSetPooling2dDescriptor); \
__macro(cudnnSetPoolingNdDescriptor); \
__macro(cudnnGetPoolingNdDescriptor); \
__macro(cudnnDestroyFilterDescriptor); \
__macro(cudnnCreateConvolutionDescriptor); \
__macro(cudnnCreatePoolingDescriptor); \
__macro(cudnnDestroyPoolingDescriptor); \
__macro(cudnnSetConvolution2dDescriptor); \
__macro(cudnnDestroyConvolutionDescriptor); \
__macro(cudnnSetConvolutionNdDescriptor); \
__macro(cudnnGetConvolutionNdDescriptor); \
__macro(cudnnDeriveBNTensorDescriptor); \
__macro(cudnnCreateSpatialTransformerDescriptor); \
__macro(cudnnSetSpatialTransformerNdDescriptor); \
__macro(cudnnDestroySpatialTransformerDescriptor); \
__macro(cudnnSpatialTfGridGeneratorForward); \
__macro(cudnnSpatialTfGridGeneratorBackward); \
__macro(cudnnSpatialTfSamplerForward); \
__macro(cudnnSpatialTfSamplerBackward); \
__macro(cudnnCreate); \
__macro(cudnnDestroy); \
__macro(cudnnSetStream); \
__macro(cudnnActivationForward); \
__macro(cudnnConvolutionForward); \
__macro(cudnnConvolutionBackwardBias); \
__macro(cudnnGetConvolutionForwardWorkspaceSize); \
__macro(cudnnTransformTensor); \
__macro(cudnnPoolingForward); \
__macro(cudnnPoolingBackward); \
__macro(cudnnSoftmaxBackward); \
__macro(cudnnSoftmaxForward); \
__macro(cudnnGetVersion); \
__macro(cudnnGetErrorString);
CUDNN_DNN_ROUTINE_EACH
(
DECLARE_DYNAMIC_LOAD_CUDNN_WRAP
)
...
...
python/paddle/fluid/__init__.py
浏览文件 @
db8c52da
...
...
@@ -127,8 +127,7 @@ def __bootstrap__():
if
core
.
is_compiled_with_cuda
():
read_env_flags
+=
[
'fraction_of_gpu_memory_to_use'
,
'cudnn_deterministic'
,
'conv_workspace_size_limit'
,
'cudnn_exhaustive_search'
'fraction_of_gpu_memory_to_use'
,
'cudnn_deterministic'
]
core
.
init_gflags
([
sys
.
argv
[
0
]]
+
[
"--tryfromenv="
+
","
.
join
(
read_env_flags
)])
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
db8c52da
...
...
@@ -27,7 +27,6 @@ from .tensor import concat
from
.
import
utils
from
..
import
unique_name
from
functools
import
reduce
from
..
import
core
__all__
=
[
'fc'
,
...
...
@@ -1665,20 +1664,6 @@ def conv2d(input,
pre_bias
=
helper
.
create_variable_for_type_inference
(
dtype
)
if
use_cudnn
:
helper
.
create_variable
(
name
=
"kCUDNNFwdAlgoCache"
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
helper
.
create_variable
(
name
=
"kCUDNNBwdDataAlgoCache"
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
helper
.
create_variable
(
name
=
"kCUDNNBwdFilterAlgoCache"
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
helper
.
append_op
(
type
=
l_type
,
inputs
=
{
...
...
@@ -1692,7 +1677,7 @@ def conv2d(input,
'dilations'
:
dilation
,
'groups'
:
groups
,
'use_cudnn'
:
use_cudnn
,
'use_mkldnn'
:
False
,
'use_mkldnn'
:
False
})
pre_act
=
helper
.
append_bias_op
(
pre_bias
,
dim_start
=
1
,
dim_end
=
2
)
...
...
python/paddle/fluid/tests/unittests/test_conv2d_op.py
浏览文件 @
db8c52da
...
...
@@ -67,7 +67,6 @@ class TestConv2dOp(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"conv2d"
self
.
use_cudnn
=
False
self
.
exhaustive_search
=
False
self
.
use_cuda
=
False
self
.
use_mkldnn
=
False
self
.
data_format
=
"AnyLayout"
...
...
@@ -99,8 +98,7 @@ class TestConv2dOp(OpTest):
'dilations'
:
self
.
dilations
,
'use_cudnn'
:
self
.
use_cudnn
,
'use_mkldnn'
:
self
.
use_mkldnn
,
'data_format'
:
self
.
data_format
,
'exhaustive_search'
:
self
.
exhaustive_search
'data_format'
:
self
.
data_format
}
self
.
outputs
=
{
'Output'
:
output
}
...
...
@@ -394,12 +392,6 @@ class TestDepthwiseConvWithDilation2(TestConv2dOp):
self
.
op_type
=
"depthwise_conv2d"
class
TestCUDNNExhaustiveSearch
(
TestCUDNN
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
self
.
exhaustive_search
=
True
# Please Don't remove the following code.
# Currently, CI use cudnn V5.0 which not support dilation conv.
# class TestCUDNNWithDilation(TestWithDilation):
...
...
python/paddle/fluid/tests/unittests/test_conv3d_op.py
浏览文件 @
db8c52da
...
...
@@ -335,12 +335,6 @@ class TestFP16WithInput1x1Filter1x1CUDNN(TestWithInput1x1Filter1x1):
self
.
check_output_with_place
(
place
,
atol
=
2e-2
)
class
TestCUDNNExhaustiveSearch
(
TestCUDNN
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
self
.
exhaustive_search
=
True
# FIXME(typhoonzero): find a way to determine if
# using cudnn > 6 in python
# class TestWithDilationCUDNN(TestWithDilation):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录