未验证 提交 825d4957 编写于 作者: A AshburnLee 提交者: GitHub

Correct typos (#32288)

上级 90133d24
......@@ -200,13 +200,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_DEFAULT_MATH));
#if CUDNN_VERSION >= 11000
#if CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000
if (!platform::allow_tf32_cudnn) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(cudnn_conv_desc,
CUDNN_FMA_MATH));
}
#endif // CUDA_VERSION >= 11000
#endif // CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000
auto x_dims = framework::vectorize(transformed_input.dims());
auto f_dims = framework::vectorize(filter->dims());
......
......@@ -153,13 +153,13 @@ class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(conv_desc[i],
CUDNN_DEFAULT_MATH));
#if CUDNN_VERSION >= 11000
#if CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000
if (!platform::allow_tf32_cudnn) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(conv_desc[i],
CUDNN_FMA_MATH));
}
#endif // CUDA_VERSION >= 11000
#endif // CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000
}
in_dims[2][1] *= 2;
in_strides[2][0] = oc * h * w;
......
......@@ -85,6 +85,7 @@ namespace platform {
void SetAllowTF32Cublas(bool active);
/*Get the global variable allow_tf32_cublas value*/
bool AllowTF32Cublas();
extern bool allow_tf32_cudnn;
/*Set the value of the global variable allow_tf32_cudnn*/
void SetAllowTF32Cudnn(bool active);
/*Get the global variable allow_tf32_cudnn value*/
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册