未验证 提交 924aac22 编写于 作者: A AshburnLee 提交者: GitHub

Add tf32 switch for cuDNN (#29192)

上级 8ce2482b
......@@ -210,16 +210,20 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math";
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_TENSOR_OP_MATH));
VLOG(5) << "use cudnn_tensor_op_math";
} else {
} else if (dtype == CUDNN_DATA_FLOAT && !args.cdesc.allow_tf32_) {
#if CUDA_VERSION >= 11000
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_DEFAULT_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math";
CUDNN_FMA_MATH));
#endif // CUDA_VERSION >= 11000
}
#endif
......@@ -340,16 +344,20 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
algo_t algo;
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math";
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_TENSOR_OP_MATH));
VLOG(5) << "use cudnn_tensor_op_math";
} else {
} else if (dtype == CUDNN_DATA_FLOAT && !args.cdesc.allow_tf32_) {
#if CUDA_VERSION >= 11000
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_DEFAULT_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math";
CUDNN_FMA_MATH));
#endif // CUDA_VERSION >= 11000
}
#endif
......@@ -485,16 +493,20 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math";
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_TENSOR_OP_MATH));
VLOG(5) << "use cudnn_tensor_op_math";
} else {
} else if (dtype == CUDNN_DATA_FLOAT && !args.cdesc.allow_tf32_) {
#if CUDA_VERSION >= 11000
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_DEFAULT_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math";
CUDNN_FMA_MATH));
#endif // CUDA_VERSION >= 11000
}
#endif
......
......@@ -240,7 +240,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
auto layout_format = GetCudnnTensorFormat(layout);
args.handle = handle;
args.cdesc.set(dtype, padding_common, strides, dilations);
args.cdesc.set(dtype, padding_common, strides, dilations,
platform::AllowTF32Cudnn());
#if CUDNN_VERSION_MIN(7, 0, 1)
// cudnn 7 can support groups, no need to do it manually
......@@ -603,7 +604,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
args1.idesc.set(transformed_input_grad, layout_tensor);
args1.wdesc.set(transformed_filter_channel, layout_tensor, iwo_groups);
args1.odesc.set(transformed_output_grad_channel, layout_tensor);
args1.cdesc.set(dtype, padding_common, strides, dilations, c_groups);
args1.cdesc.set(dtype, padding_common, strides, dilations,
platform::AllowTF32Cudnn(), c_groups);
using search1 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_algo =
......@@ -620,7 +622,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
args2.wdesc.set(transformed_filter_grad_channel, layout_tensor,
iwo_groups);
args2.odesc.set(transformed_output_grad_channel, layout_tensor);
args2.cdesc.set(dtype, padding_common, strides, dilations, c_groups);
args2.cdesc.set(dtype, padding_common, strides, dilations,
platform::AllowTF32Cudnn(), c_groups);
using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo =
......@@ -980,7 +983,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
args1.idesc.set(transformed_ddX, iwo_group);
args1.wdesc.set(*W, layout, iwo_group);
args1.odesc.set(transformed_ddO_channel, iwo_group);
args1.cdesc.set(dtype, padding_common, strides, dilations, c_group);
args1.cdesc.set(dtype, padding_common, strides, dilations,
platform::AllowTF32Cudnn(), c_group);
using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, ctx);
......@@ -995,7 +999,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
args2.wdesc.set(*ddW, layout, iwo_group);
args2.odesc.set(transformed_ddO_channel, iwo_group);
args2.cdesc.set(dtype, padding_common, strides, dilations, c_group);
args2.cdesc.set(dtype, padding_common, strides, dilations,
platform::AllowTF32Cudnn(), c_group);
using search2 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_algo2 = search2::Find<T>(args2, exhaustive_search, false, ctx);
......@@ -1012,7 +1017,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
args3.odesc.set(transformed_dO_channel, iwo_group);
args3.cdesc.set(dtype, padding_common, strides, dilations, c_group);
args3.cdesc.set(dtype, padding_common, strides, dilations,
platform::AllowTF32Cudnn(), c_group);
using search3 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo =
......@@ -1028,7 +1034,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
args4.idesc.set(transformed_dX, iwo_group);
args4.wdesc.set(*ddW, layout, iwo_group);
args4.odesc.set(transformed_dO_channel, iwo_group);
args4.cdesc.set(dtype, padding_common, strides, dilations, c_group);
args4.cdesc.set(dtype, padding_common, strides, dilations,
platform::AllowTF32Cudnn(), c_group);
using search4 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_algo =
......
......@@ -232,7 +232,8 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
args.idesc.set(transformed_output, iwo_groups);
args.wdesc.set(*filter, layout_tensor, iwo_groups);
args.odesc.set(transformed_input, iwo_groups);
args.cdesc.set(dtype, padding_common, strides, dilations, c_groups);
args.cdesc.set(dtype, padding_common, strides, dilations,
platform::AllowTF32Cudnn(), c_groups);
using search = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
algo = search::Find<T>(args, false, deterministic, ctx);
......@@ -468,7 +469,8 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
args1.idesc.set(transformed_output_grad, iwo_groups);
args1.wdesc.set(*filter, layout_tensor, iwo_groups);
args1.odesc.set(input_transpose, iwo_groups);
args1.cdesc.set(dtype, padding_common, strides, dilations, c_groups);
args1.cdesc.set(dtype, padding_common, strides, dilations,
platform::AllowTF32Cudnn(), c_groups);
using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
data_algo = search1::Find<T>(args1, false, deterministic, ctx);
workspace_size =
......@@ -481,7 +483,8 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
args2.idesc.set(transformed_output_grad, iwo_groups);
args2.wdesc.set(*filter_grad, layout_tensor, iwo_groups);
args2.odesc.set(input_transpose, iwo_groups);
args2.cdesc.set(dtype, padding_common, strides, dilations, c_groups);
args2.cdesc.set(dtype, padding_common, strides, dilations,
platform::AllowTF32Cudnn(), c_groups);
using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo = search2::Find<T>(args2, false, deterministic, ctx);
workspace_size = std::max(workspace_size,
......
......@@ -200,6 +200,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_DEFAULT_MATH));
#if CUDNN_VERSION >= 11000
if (!platform::allow_tf32_cudnn) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(cudnn_conv_desc,
CUDNN_FMA_MATH));
}
#endif // CUDA_VERSION >= 11000
auto x_dims = framework::vectorize(transformed_input.dims());
auto f_dims = framework::vectorize(filter->dims());
......
......@@ -153,6 +153,13 @@ class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(conv_desc[i],
CUDNN_DEFAULT_MATH));
#if CUDNN_VERSION >= 11000
if (!platform::allow_tf32_cudnn) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(conv_desc[i],
CUDNN_FMA_MATH));
}
#endif // CUDA_VERSION >= 11000
}
in_dims[2][1] *= 2;
in_strides[2][0] = oc * h * w;
......
......@@ -24,6 +24,7 @@
#include <vector>
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
......@@ -229,7 +230,8 @@ class ConvolutionDescriptor {
void set(cudnnDataType_t dtype, const std::vector<int>& pads,
const std::vector<int>& strides, const std::vector<int>& dilations,
const int groups = 1) {
bool allow_tf32, const int groups = 1) {
allow_tf32_ = allow_tf32;
cudnnDataType_t compute_type =
(dtype == CUDNN_DATA_DOUBLE) ? CUDNN_DATA_DOUBLE : CUDNN_DATA_FLOAT;
T* desc = desc_.get();
......@@ -246,11 +248,18 @@ class ConvolutionDescriptor {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(desc,
CUDNN_TENSOR_OP_MATH));
} else if (dtype == CUDNN_DATA_FLOAT && !allow_tf32) {
#if CUDA_VERSION >= 11000
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(desc, CUDNN_FMA_MATH));
#endif // CUDA_VERSION >= 11000
}
#endif
#endif
}
bool allow_tf32_;
private:
std::unique_ptr<T, Deleter> desc_;
};
......
......@@ -74,6 +74,10 @@ namespace platform {
bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; }
bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
#endif // PADDLE_WITH_CUDA
DeviceContextPool* DeviceContextPool::pool = nullptr;
......
......@@ -67,6 +67,10 @@ namespace platform {
void SetAllowTF32Cublas(bool active);
/*Get the global variable allow_tf32_cublas value*/
bool AllowTF32Cublas();
/*Set the value of the global variable allow_tf32_cudnn*/
void SetAllowTF32Cudnn(bool active);
/*Get the global variable allow_tf32_cudnn value*/
bool AllowTF32Cudnn();
#endif // PADDLE_WITH_CUDA
enum DeviceType {
......
......@@ -1988,6 +1988,8 @@ All parameter, weight, gradient are variables in Paddle.
#ifdef PADDLE_WITH_CUDA
m.def("set_cublas_switch", platform::SetAllowTF32Cublas);
m.def("get_cublas_switch", platform::AllowTF32Cublas);
m.def("set_cudnn_switch", platform::SetAllowTF32Cudnn);
m.def("get_cudnn_switch", platform::AllowTF32Cudnn);
#endif // PADDLE_WITH_CUDA
using VarQuantScale =
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import six
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
class TestTF32Switch(unittest.TestCase):
def test_on_off(self):
if core.is_compiled_with_cuda():
self.assertTrue(core.get_cudnn_switch()) # default
core.set_cudnn_switch(0)
self.assertFalse(core.get_cudnn_switch()) # turn off
core.set_cudnn_switch(1)
self.assertTrue(core.get_cudnn_switch()) # turn on
core.set_cudnn_switch(1) # restore the switch
else:
pass
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册