未验证 提交 ca2d6d3c 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #11224 from dzhwinter/fix/cudnn

fix cudnn version issue
...@@ -81,6 +81,27 @@ enum class PoolingMode { ...@@ -81,6 +81,27 @@ enum class PoolingMode {
kMaximumDeterministic, kMaximumDeterministic,
}; };
#if CUDNN_VERSION < 6000
#pragma message "CUDNN version under 6.0 is supported at best effort."
#pragma message "We strongly encourage you to move to 6.0 and above."
#pragma message "This message is intended to annoy you enough to update."
#pragma message \
"please see https://docs.nvidia.com/deeplearning/sdk/cudnn-release-notes/"
inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
switch (mode) {
case PoolingMode::kMaximumDeterministic:
return CUDNN_POOLING_MAX;
case PoolingMode::kAverage:
return CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
case PoolingMode::kMaximum:
return CUDNN_POOLING_MAX;
default:
PADDLE_THROW("Unexpected pooling mode.");
}
}
#else
inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) { inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
switch (mode) { switch (mode) {
case PoolingMode::kMaximumDeterministic: case PoolingMode::kMaximumDeterministic:
...@@ -93,6 +114,7 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) { ...@@ -93,6 +114,7 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
PADDLE_THROW("Unexpected pooling mode."); PADDLE_THROW("Unexpected pooling mode.");
} }
} }
#endif // CUDNN_VERSION < 6000
template <typename T> template <typename T>
class CudnnDataType; class CudnnDataType;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册