未验证 提交 825d4957 编写于 作者: A AshburnLee 提交者: GitHub

Correct typos (#32288)

上级 90133d24
...@@ -200,13 +200,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -200,13 +200,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_DEFAULT_MATH)); cudnn_conv_desc, CUDNN_DEFAULT_MATH));
#if CUDNN_VERSION >= 11000 #if CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000
if (!platform::allow_tf32_cudnn) { if (!platform::allow_tf32_cudnn) {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(cudnn_conv_desc, platform::dynload::cudnnSetConvolutionMathType(cudnn_conv_desc,
CUDNN_FMA_MATH)); CUDNN_FMA_MATH));
} }
#endif // CUDA_VERSION >= 11000 #endif // CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000
auto x_dims = framework::vectorize(transformed_input.dims()); auto x_dims = framework::vectorize(transformed_input.dims());
auto f_dims = framework::vectorize(filter->dims()); auto f_dims = framework::vectorize(filter->dims());
......
...@@ -153,13 +153,13 @@ class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> { ...@@ -153,13 +153,13 @@ class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(conv_desc[i], platform::dynload::cudnnSetConvolutionMathType(conv_desc[i],
CUDNN_DEFAULT_MATH)); CUDNN_DEFAULT_MATH));
#if CUDNN_VERSION >= 11000 #if CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000
if (!platform::allow_tf32_cudnn) { if (!platform::allow_tf32_cudnn) {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(conv_desc[i], platform::dynload::cudnnSetConvolutionMathType(conv_desc[i],
CUDNN_FMA_MATH)); CUDNN_FMA_MATH));
} }
#endif // CUDA_VERSION >= 11000 #endif // CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000
} }
in_dims[2][1] *= 2; in_dims[2][1] *= 2;
in_strides[2][0] = oc * h * w; in_strides[2][0] = oc * h * w;
......
...@@ -85,6 +85,7 @@ namespace platform { ...@@ -85,6 +85,7 @@ namespace platform {
void SetAllowTF32Cublas(bool active); void SetAllowTF32Cublas(bool active);
/*Get the global variable allow_tf32_cublas value*/ /*Get the global variable allow_tf32_cublas value*/
bool AllowTF32Cublas(); bool AllowTF32Cublas();
extern bool allow_tf32_cudnn;
/*Set the value of the global variable allow_tf32_cudnn*/ /*Set the value of the global variable allow_tf32_cudnn*/
void SetAllowTF32Cudnn(bool active); void SetAllowTF32Cudnn(bool active);
/*Get the global variable allow_tf32_cudnn value*/ /*Get the global variable allow_tf32_cudnn value*/
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册