From 138a71b75049bca0f37b0575be13a964e0e7b9c7 Mon Sep 17 00:00:00 2001 From: AshburnLee <1578034415@qq.com> Date: Wed, 20 Jan 2021 18:07:57 +0800 Subject: [PATCH] Add tf32 switch for cuDNN (#29192) (#30574) This PR is cherry-picked from PR: #29192 Function: Added TF32 switch for cuDNN. Turned on as default, turned off when users set the switch as False --- paddle/fluid/operators/conv_cudnn_helper.h | 30 ++++++++++----- paddle/fluid/operators/conv_cudnn_op.cu | 21 ++++++---- .../operators/conv_transpose_cudnn_op.cu | 9 +++-- .../fluid/operators/fused/conv_fusion_op.cu | 7 ++++ .../fused/fusion_conv_inception_op.cu | 7 ++++ paddle/fluid/platform/cudnn_desc.h | 11 +++++- paddle/fluid/platform/device_context.cc | 10 +++++ paddle/fluid/platform/device_context.h | 4 ++ paddle/fluid/pybind/pybind.cc | 7 ++++ .../fluid/tests/unittests/test_tf32_cudnn.py | 38 +++++++++++++++++++ 10 files changed, 124 insertions(+), 20 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_tf32_cudnn.py diff --git a/paddle/fluid/operators/conv_cudnn_helper.h b/paddle/fluid/operators/conv_cudnn_helper.h index fe0150cca52..82c8aa50afc 100644 --- a/paddle/fluid/operators/conv_cudnn_helper.h +++ b/paddle/fluid/operators/conv_cudnn_helper.h @@ -210,16 +210,20 @@ struct SearchAlgorithm { #if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) auto& dev_ctx = ctx.template device_context(); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( + args.cdesc.desc(), CUDNN_DEFAULT_MATH)); + VLOG(5) << "NOT use cudnn_tensor_op_math"; if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(), CUDNN_TENSOR_OP_MATH)); VLOG(5) << "use cudnn_tensor_op_math"; - } else { + } else if (dtype == CUDNN_DATA_FLOAT && !args.cdesc.allow_tf32_) { +#if CUDA_VERSION >= 11000 PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(), - CUDNN_DEFAULT_MATH)); - VLOG(5) << "NOT use cudnn_tensor_op_math"; + CUDNN_FMA_MATH)); +#endif // CUDA_VERSION >= 11000 } #endif @@ -340,16 +344,20 @@ struct SearchAlgorithm { algo_t algo; #if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) auto& dev_ctx = ctx.template device_context(); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( + args.cdesc.desc(), CUDNN_DEFAULT_MATH)); + VLOG(5) << "NOT use cudnn_tensor_op_math"; if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(), CUDNN_TENSOR_OP_MATH)); VLOG(5) << "use cudnn_tensor_op_math"; - } else { + } else if (dtype == CUDNN_DATA_FLOAT && !args.cdesc.allow_tf32_) { +#if CUDA_VERSION >= 11000 PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(), - CUDNN_DEFAULT_MATH)); - VLOG(5) << "NOT use cudnn_tensor_op_math"; + CUDNN_FMA_MATH)); +#endif // CUDA_VERSION >= 11000 } #endif @@ -485,16 +493,20 @@ struct SearchAlgorithm { #if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) auto& dev_ctx = ctx.template device_context(); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( + args.cdesc.desc(), CUDNN_DEFAULT_MATH)); + VLOG(5) << "NOT use cudnn_tensor_op_math"; if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(), CUDNN_TENSOR_OP_MATH)); VLOG(5) << "use cudnn_tensor_op_math"; - } else { + } else if (dtype == CUDNN_DATA_FLOAT && !args.cdesc.allow_tf32_) { +#if CUDA_VERSION >= 11000 PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(), - CUDNN_DEFAULT_MATH)); - VLOG(5) << "NOT use cudnn_tensor_op_math"; + CUDNN_FMA_MATH)); +#endif // CUDA_VERSION >= 11000 } #endif diff --git a/paddle/fluid/operators/conv_cudnn_op.cu b/paddle/fluid/operators/conv_cudnn_op.cu index 5f469e6a0f5..5ef22b81869 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu +++ b/paddle/fluid/operators/conv_cudnn_op.cu @@ -240,7 +240,8 @@ class CUDNNConvOpKernel : public framework::OpKernel { auto layout_format = GetCudnnTensorFormat(layout); args.handle = handle; - args.cdesc.set(dtype, padding_common, strides, dilations); + args.cdesc.set(dtype, padding_common, strides, dilations, + platform::AllowTF32Cudnn()); #if CUDNN_VERSION_MIN(7, 0, 1) // cudnn 7 can support groups, no need to do it manually @@ -603,7 +604,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { args1.idesc.set(transformed_input_grad, layout_tensor); args1.wdesc.set(transformed_filter_channel, layout_tensor, iwo_groups); args1.odesc.set(transformed_output_grad_channel, layout_tensor); - args1.cdesc.set(dtype, padding_common, strides, dilations, c_groups); + args1.cdesc.set(dtype, padding_common, strides, dilations, + platform::AllowTF32Cudnn(), c_groups); using search1 = SearchAlgorithm; data_algo = @@ -620,7 +622,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { args2.wdesc.set(transformed_filter_grad_channel, layout_tensor, iwo_groups); args2.odesc.set(transformed_output_grad_channel, layout_tensor); - args2.cdesc.set(dtype, padding_common, strides, dilations, c_groups); + args2.cdesc.set(dtype, padding_common, strides, dilations, + platform::AllowTF32Cudnn(), c_groups); using search2 = SearchAlgorithm; filter_algo = @@ -980,7 +983,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { args1.idesc.set(transformed_ddX, iwo_group); args1.wdesc.set(*W, layout, iwo_group); args1.odesc.set(transformed_ddO_channel, iwo_group); - args1.cdesc.set(dtype, padding_common, strides, dilations, c_group); + args1.cdesc.set(dtype, padding_common, strides, dilations, + platform::AllowTF32Cudnn(), c_group); using search1 = SearchAlgorithm; fwd_algo1 = search1::Find(args1, exhaustive_search, false, ctx); @@ -995,7 +999,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { args2.wdesc.set(*ddW, layout, iwo_group); args2.odesc.set(transformed_ddO_channel, iwo_group); - args2.cdesc.set(dtype, padding_common, strides, dilations, c_group); + args2.cdesc.set(dtype, padding_common, strides, dilations, + platform::AllowTF32Cudnn(), c_group); using search2 = SearchAlgorithm; fwd_algo2 = search2::Find(args2, exhaustive_search, false, ctx); @@ -1012,7 +1017,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { args3.odesc.set(transformed_dO_channel, iwo_group); - args3.cdesc.set(dtype, padding_common, strides, dilations, c_group); + args3.cdesc.set(dtype, padding_common, strides, dilations, + platform::AllowTF32Cudnn(), c_group); using search3 = SearchAlgorithm; filter_algo = @@ -1028,7 +1034,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { args4.idesc.set(transformed_dX, iwo_group); args4.wdesc.set(*ddW, layout, iwo_group); args4.odesc.set(transformed_dO_channel, iwo_group); - args4.cdesc.set(dtype, padding_common, strides, dilations, c_group); + args4.cdesc.set(dtype, padding_common, strides, dilations, + platform::AllowTF32Cudnn(), c_group); using search4 = SearchAlgorithm; data_algo = diff --git a/paddle/fluid/operators/conv_transpose_cudnn_op.cu b/paddle/fluid/operators/conv_transpose_cudnn_op.cu index 94148109c73..a12629b7a49 100644 --- a/paddle/fluid/operators/conv_transpose_cudnn_op.cu +++ b/paddle/fluid/operators/conv_transpose_cudnn_op.cu @@ -232,7 +232,8 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { args.idesc.set(transformed_output, iwo_groups); args.wdesc.set(*filter, layout_tensor, iwo_groups); args.odesc.set(transformed_input, iwo_groups); - args.cdesc.set(dtype, padding_common, strides, dilations, c_groups); + args.cdesc.set(dtype, padding_common, strides, dilations, + platform::AllowTF32Cudnn(), c_groups); using search = SearchAlgorithm; algo = search::Find(args, false, deterministic, ctx); @@ -468,7 +469,8 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { args1.idesc.set(transformed_output_grad, iwo_groups); args1.wdesc.set(*filter, layout_tensor, iwo_groups); args1.odesc.set(input_transpose, iwo_groups); - args1.cdesc.set(dtype, padding_common, strides, dilations, c_groups); + args1.cdesc.set(dtype, padding_common, strides, dilations, + platform::AllowTF32Cudnn(), c_groups); using search1 = SearchAlgorithm; data_algo = search1::Find(args1, false, deterministic, ctx); workspace_size = @@ -481,7 +483,8 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { args2.idesc.set(transformed_output_grad, iwo_groups); args2.wdesc.set(*filter_grad, layout_tensor, iwo_groups); args2.odesc.set(input_transpose, iwo_groups); - args2.cdesc.set(dtype, padding_common, strides, dilations, c_groups); + args2.cdesc.set(dtype, padding_common, strides, dilations, + platform::AllowTF32Cudnn(), c_groups); using search2 = SearchAlgorithm; filter_algo = search2::Find(args2, false, deterministic, ctx); workspace_size = std::max(workspace_size, diff --git a/paddle/fluid/operators/fused/conv_fusion_op.cu b/paddle/fluid/operators/fused/conv_fusion_op.cu index 49fded886a0..33d408582ff 100644 --- a/paddle/fluid/operators/fused/conv_fusion_op.cu +++ b/paddle/fluid/operators/fused/conv_fusion_op.cu @@ -200,6 +200,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( cudnn_conv_desc, CUDNN_DEFAULT_MATH)); +#if CUDNN_VERSION >= 11000 + if (!platform::allow_tf32_cudnn) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnSetConvolutionMathType(cudnn_conv_desc, + CUDNN_FMA_MATH)); + } +#endif // CUDA_VERSION >= 11000 auto x_dims = framework::vectorize(transformed_input.dims()); auto f_dims = framework::vectorize(filter->dims()); diff --git a/paddle/fluid/operators/fused/fusion_conv_inception_op.cu b/paddle/fluid/operators/fused/fusion_conv_inception_op.cu index 3529ff1f94a..c448c529f56 100644 --- a/paddle/fluid/operators/fused/fusion_conv_inception_op.cu +++ b/paddle/fluid/operators/fused/fusion_conv_inception_op.cu @@ -153,6 +153,13 @@ class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnSetConvolutionMathType(conv_desc[i], CUDNN_DEFAULT_MATH)); +#if CUDNN_VERSION >= 11000 + if (!platform::allow_tf32_cudnn) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnSetConvolutionMathType(conv_desc[i], + CUDNN_FMA_MATH)); + } +#endif // CUDA_VERSION >= 11000 } in_dims[2][1] *= 2; in_strides[2][0] = oc * h * w; diff --git a/paddle/fluid/platform/cudnn_desc.h b/paddle/fluid/platform/cudnn_desc.h index 0e0218dcca3..05a431e731e 100644 --- a/paddle/fluid/platform/cudnn_desc.h +++ b/paddle/fluid/platform/cudnn_desc.h @@ -24,6 +24,7 @@ #include #include "paddle/fluid/platform/cudnn_helper.h" +#include "paddle/fluid/platform/device_context.h" namespace paddle { namespace framework { @@ -229,7 +230,8 @@ class ConvolutionDescriptor { void set(cudnnDataType_t dtype, const std::vector& pads, const std::vector& strides, const std::vector& dilations, - const int groups = 1) { + bool allow_tf32, const int groups = 1) { + allow_tf32_ = allow_tf32; cudnnDataType_t compute_type = (dtype == CUDNN_DATA_DOUBLE) ? CUDNN_DATA_DOUBLE : CUDNN_DATA_FLOAT; T* desc = desc_.get(); @@ -246,11 +248,18 @@ class ConvolutionDescriptor { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnSetConvolutionMathType(desc, CUDNN_TENSOR_OP_MATH)); + } else if (dtype == CUDNN_DATA_FLOAT && !allow_tf32) { +#if CUDA_VERSION >= 11000 + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnSetConvolutionMathType(desc, CUDNN_FMA_MATH)); +#endif // CUDA_VERSION >= 11000 } #endif #endif } + bool allow_tf32_; + private: std::unique_ptr desc_; }; diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index c7a0cdde8d9..a8d56c0717d 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -70,6 +70,16 @@ AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) { namespace paddle { namespace platform { +#ifdef PADDLE_WITH_CUDA +bool allow_tf32_cublas = true; +void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; } +bool AllowTF32Cublas() { return allow_tf32_cublas; } + +bool allow_tf32_cudnn = true; +void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; } +bool AllowTF32Cudnn() { return allow_tf32_cudnn; } +#endif // PADDLE_WITH_CUDA + DeviceContextPool* DeviceContextPool::pool = nullptr; platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) { diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 9f2e5acfc61..68f901e8af7 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -67,6 +67,10 @@ namespace platform { void SetAllowTF32Cublas(bool active); /*Get the global variable allow_tf32_cublas value*/ bool AllowTF32Cublas(); +/*Set the value of the global variable allow_tf32_cudnn*/ +void SetAllowTF32Cudnn(bool active); +/*Get the global variable allow_tf32_cudnn value*/ +bool AllowTF32Cudnn(); #endif // PADDLE_WITH_CUDA enum DeviceType { diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 8782f903428..0d365f2b3a5 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1987,6 +1987,13 @@ All parameter, weight, gradient are variables in Paddle. m.def("size_of_dtype", framework::SizeOfType); +#ifdef PADDLE_WITH_CUDA + m.def("set_cublas_switch", platform::SetAllowTF32Cublas); + m.def("get_cublas_switch", platform::AllowTF32Cublas); + m.def("set_cudnn_switch", platform::SetAllowTF32Cudnn); + m.def("get_cudnn_switch", platform::AllowTF32Cudnn); +#endif // PADDLE_WITH_CUDA + using VarQuantScale = std::unordered_map>; diff --git a/python/paddle/fluid/tests/unittests/test_tf32_cudnn.py b/python/paddle/fluid/tests/unittests/test_tf32_cudnn.py new file mode 100644 index 00000000000..48127c2a90b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_tf32_cudnn.py @@ -0,0 +1,38 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import six +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core + + +class TestTF32Switch(unittest.TestCase): + def test_on_off(self): + if core.is_compiled_with_cuda(): + self.assertTrue(core.get_cudnn_switch()) # default + core.set_cudnn_switch(0) + self.assertFalse(core.get_cudnn_switch()) # turn off + core.set_cudnn_switch(1) + self.assertTrue(core.get_cudnn_switch()) # turn on + + core.set_cudnn_switch(1) # restore the switch + else: + pass + + +if __name__ == '__main__': + unittest.main() -- GitLab