未验证 提交 924aac22 编写于 作者: A AshburnLee 提交者: GitHub

Add tf32 switch for cuDNN (#29192)

上级 8ce2482b
...@@ -210,16 +210,20 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> { ...@@ -210,16 +210,20 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) #if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math";
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) { if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(), platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_TENSOR_OP_MATH)); CUDNN_TENSOR_OP_MATH));
VLOG(5) << "use cudnn_tensor_op_math"; VLOG(5) << "use cudnn_tensor_op_math";
} else { } else if (dtype == CUDNN_DATA_FLOAT && !args.cdesc.allow_tf32_) {
#if CUDA_VERSION >= 11000
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(), platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_DEFAULT_MATH)); CUDNN_FMA_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math"; #endif // CUDA_VERSION >= 11000
} }
#endif #endif
...@@ -340,16 +344,20 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> { ...@@ -340,16 +344,20 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
algo_t algo; algo_t algo;
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) #if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math";
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) { if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(), platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_TENSOR_OP_MATH)); CUDNN_TENSOR_OP_MATH));
VLOG(5) << "use cudnn_tensor_op_math"; VLOG(5) << "use cudnn_tensor_op_math";
} else { } else if (dtype == CUDNN_DATA_FLOAT && !args.cdesc.allow_tf32_) {
#if CUDA_VERSION >= 11000
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(), platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_DEFAULT_MATH)); CUDNN_FMA_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math"; #endif // CUDA_VERSION >= 11000
} }
#endif #endif
...@@ -485,16 +493,20 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -485,16 +493,20 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) #if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math";
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) { if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(), platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_TENSOR_OP_MATH)); CUDNN_TENSOR_OP_MATH));
VLOG(5) << "use cudnn_tensor_op_math"; VLOG(5) << "use cudnn_tensor_op_math";
} else { } else if (dtype == CUDNN_DATA_FLOAT && !args.cdesc.allow_tf32_) {
#if CUDA_VERSION >= 11000
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(), platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
CUDNN_DEFAULT_MATH)); CUDNN_FMA_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math"; #endif // CUDA_VERSION >= 11000
} }
#endif #endif
......
...@@ -240,7 +240,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -240,7 +240,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
auto layout_format = GetCudnnTensorFormat(layout); auto layout_format = GetCudnnTensorFormat(layout);
args.handle = handle; args.handle = handle;
args.cdesc.set(dtype, padding_common, strides, dilations); args.cdesc.set(dtype, padding_common, strides, dilations,
platform::AllowTF32Cudnn());
#if CUDNN_VERSION_MIN(7, 0, 1) #if CUDNN_VERSION_MIN(7, 0, 1)
// cudnn 7 can support groups, no need to do it manually // cudnn 7 can support groups, no need to do it manually
...@@ -603,7 +604,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -603,7 +604,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
args1.idesc.set(transformed_input_grad, layout_tensor); args1.idesc.set(transformed_input_grad, layout_tensor);
args1.wdesc.set(transformed_filter_channel, layout_tensor, iwo_groups); args1.wdesc.set(transformed_filter_channel, layout_tensor, iwo_groups);
args1.odesc.set(transformed_output_grad_channel, layout_tensor); args1.odesc.set(transformed_output_grad_channel, layout_tensor);
args1.cdesc.set(dtype, padding_common, strides, dilations, c_groups); args1.cdesc.set(dtype, padding_common, strides, dilations,
platform::AllowTF32Cudnn(), c_groups);
using search1 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>; using search1 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_algo = data_algo =
...@@ -620,7 +622,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -620,7 +622,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
args2.wdesc.set(transformed_filter_grad_channel, layout_tensor, args2.wdesc.set(transformed_filter_grad_channel, layout_tensor,
iwo_groups); iwo_groups);
args2.odesc.set(transformed_output_grad_channel, layout_tensor); args2.odesc.set(transformed_output_grad_channel, layout_tensor);
args2.cdesc.set(dtype, padding_common, strides, dilations, c_groups); args2.cdesc.set(dtype, padding_common, strides, dilations,
platform::AllowTF32Cudnn(), c_groups);
using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>; using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo = filter_algo =
...@@ -980,7 +983,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -980,7 +983,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
args1.idesc.set(transformed_ddX, iwo_group); args1.idesc.set(transformed_ddX, iwo_group);
args1.wdesc.set(*W, layout, iwo_group); args1.wdesc.set(*W, layout, iwo_group);
args1.odesc.set(transformed_ddO_channel, iwo_group); args1.odesc.set(transformed_ddO_channel, iwo_group);
args1.cdesc.set(dtype, padding_common, strides, dilations, c_group); args1.cdesc.set(dtype, padding_common, strides, dilations,
platform::AllowTF32Cudnn(), c_group);
using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>; using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, ctx); fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, ctx);
...@@ -995,7 +999,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -995,7 +999,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
args2.wdesc.set(*ddW, layout, iwo_group); args2.wdesc.set(*ddW, layout, iwo_group);
args2.odesc.set(transformed_ddO_channel, iwo_group); args2.odesc.set(transformed_ddO_channel, iwo_group);
args2.cdesc.set(dtype, padding_common, strides, dilations, c_group); args2.cdesc.set(dtype, padding_common, strides, dilations,
platform::AllowTF32Cudnn(), c_group);
using search2 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>; using search2 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_algo2 = search2::Find<T>(args2, exhaustive_search, false, ctx); fwd_algo2 = search2::Find<T>(args2, exhaustive_search, false, ctx);
...@@ -1012,7 +1017,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -1012,7 +1017,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
args3.odesc.set(transformed_dO_channel, iwo_group); args3.odesc.set(transformed_dO_channel, iwo_group);
args3.cdesc.set(dtype, padding_common, strides, dilations, c_group); args3.cdesc.set(dtype, padding_common, strides, dilations,
platform::AllowTF32Cudnn(), c_group);
using search3 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>; using search3 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo = filter_algo =
...@@ -1028,7 +1034,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -1028,7 +1034,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
args4.idesc.set(transformed_dX, iwo_group); args4.idesc.set(transformed_dX, iwo_group);
args4.wdesc.set(*ddW, layout, iwo_group); args4.wdesc.set(*ddW, layout, iwo_group);
args4.odesc.set(transformed_dO_channel, iwo_group); args4.odesc.set(transformed_dO_channel, iwo_group);
args4.cdesc.set(dtype, padding_common, strides, dilations, c_group); args4.cdesc.set(dtype, padding_common, strides, dilations,
platform::AllowTF32Cudnn(), c_group);
using search4 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>; using search4 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_algo = data_algo =
......
...@@ -232,7 +232,8 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -232,7 +232,8 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
args.idesc.set(transformed_output, iwo_groups); args.idesc.set(transformed_output, iwo_groups);
args.wdesc.set(*filter, layout_tensor, iwo_groups); args.wdesc.set(*filter, layout_tensor, iwo_groups);
args.odesc.set(transformed_input, iwo_groups); args.odesc.set(transformed_input, iwo_groups);
args.cdesc.set(dtype, padding_common, strides, dilations, c_groups); args.cdesc.set(dtype, padding_common, strides, dilations,
platform::AllowTF32Cudnn(), c_groups);
using search = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>; using search = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
algo = search::Find<T>(args, false, deterministic, ctx); algo = search::Find<T>(args, false, deterministic, ctx);
...@@ -468,7 +469,8 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -468,7 +469,8 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
args1.idesc.set(transformed_output_grad, iwo_groups); args1.idesc.set(transformed_output_grad, iwo_groups);
args1.wdesc.set(*filter, layout_tensor, iwo_groups); args1.wdesc.set(*filter, layout_tensor, iwo_groups);
args1.odesc.set(input_transpose, iwo_groups); args1.odesc.set(input_transpose, iwo_groups);
args1.cdesc.set(dtype, padding_common, strides, dilations, c_groups); args1.cdesc.set(dtype, padding_common, strides, dilations,
platform::AllowTF32Cudnn(), c_groups);
using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>; using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
data_algo = search1::Find<T>(args1, false, deterministic, ctx); data_algo = search1::Find<T>(args1, false, deterministic, ctx);
workspace_size = workspace_size =
...@@ -481,7 +483,8 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -481,7 +483,8 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
args2.idesc.set(transformed_output_grad, iwo_groups); args2.idesc.set(transformed_output_grad, iwo_groups);
args2.wdesc.set(*filter_grad, layout_tensor, iwo_groups); args2.wdesc.set(*filter_grad, layout_tensor, iwo_groups);
args2.odesc.set(input_transpose, iwo_groups); args2.odesc.set(input_transpose, iwo_groups);
args2.cdesc.set(dtype, padding_common, strides, dilations, c_groups); args2.cdesc.set(dtype, padding_common, strides, dilations,
platform::AllowTF32Cudnn(), c_groups);
using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>; using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo = search2::Find<T>(args2, false, deterministic, ctx); filter_algo = search2::Find<T>(args2, false, deterministic, ctx);
workspace_size = std::max(workspace_size, workspace_size = std::max(workspace_size,
......
...@@ -200,6 +200,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -200,6 +200,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_DEFAULT_MATH)); cudnn_conv_desc, CUDNN_DEFAULT_MATH));
#if CUDNN_VERSION >= 11000
if (!platform::allow_tf32_cudnn) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(cudnn_conv_desc,
CUDNN_FMA_MATH));
}
#endif // CUDA_VERSION >= 11000
auto x_dims = framework::vectorize(transformed_input.dims()); auto x_dims = framework::vectorize(transformed_input.dims());
auto f_dims = framework::vectorize(filter->dims()); auto f_dims = framework::vectorize(filter->dims());
......
...@@ -153,6 +153,13 @@ class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> { ...@@ -153,6 +153,13 @@ class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(conv_desc[i], platform::dynload::cudnnSetConvolutionMathType(conv_desc[i],
CUDNN_DEFAULT_MATH)); CUDNN_DEFAULT_MATH));
#if CUDNN_VERSION >= 11000
if (!platform::allow_tf32_cudnn) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(conv_desc[i],
CUDNN_FMA_MATH));
}
#endif // CUDA_VERSION >= 11000
} }
in_dims[2][1] *= 2; in_dims[2][1] *= 2;
in_strides[2][0] = oc * h * w; in_strides[2][0] = oc * h * w;
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -229,7 +230,8 @@ class ConvolutionDescriptor { ...@@ -229,7 +230,8 @@ class ConvolutionDescriptor {
void set(cudnnDataType_t dtype, const std::vector<int>& pads, void set(cudnnDataType_t dtype, const std::vector<int>& pads,
const std::vector<int>& strides, const std::vector<int>& dilations, const std::vector<int>& strides, const std::vector<int>& dilations,
const int groups = 1) { bool allow_tf32, const int groups = 1) {
allow_tf32_ = allow_tf32;
cudnnDataType_t compute_type = cudnnDataType_t compute_type =
(dtype == CUDNN_DATA_DOUBLE) ? CUDNN_DATA_DOUBLE : CUDNN_DATA_FLOAT; (dtype == CUDNN_DATA_DOUBLE) ? CUDNN_DATA_DOUBLE : CUDNN_DATA_FLOAT;
T* desc = desc_.get(); T* desc = desc_.get();
...@@ -246,11 +248,18 @@ class ConvolutionDescriptor { ...@@ -246,11 +248,18 @@ class ConvolutionDescriptor {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(desc, platform::dynload::cudnnSetConvolutionMathType(desc,
CUDNN_TENSOR_OP_MATH)); CUDNN_TENSOR_OP_MATH));
} else if (dtype == CUDNN_DATA_FLOAT && !allow_tf32) {
#if CUDA_VERSION >= 11000
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(desc, CUDNN_FMA_MATH));
#endif // CUDA_VERSION >= 11000
} }
#endif #endif
#endif #endif
} }
bool allow_tf32_;
private: private:
std::unique_ptr<T, Deleter> desc_; std::unique_ptr<T, Deleter> desc_;
}; };
......
...@@ -74,6 +74,10 @@ namespace platform { ...@@ -74,6 +74,10 @@ namespace platform {
bool allow_tf32_cublas = true; bool allow_tf32_cublas = true;
void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; } void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; }
bool AllowTF32Cublas() { return allow_tf32_cublas; } bool AllowTF32Cublas() { return allow_tf32_cublas; }
bool allow_tf32_cudnn = true;
void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
#endif // PADDLE_WITH_CUDA #endif // PADDLE_WITH_CUDA
DeviceContextPool* DeviceContextPool::pool = nullptr; DeviceContextPool* DeviceContextPool::pool = nullptr;
......
...@@ -67,6 +67,10 @@ namespace platform { ...@@ -67,6 +67,10 @@ namespace platform {
void SetAllowTF32Cublas(bool active); void SetAllowTF32Cublas(bool active);
/*Get the global variable allow_tf32_cublas value*/ /*Get the global variable allow_tf32_cublas value*/
bool AllowTF32Cublas(); bool AllowTF32Cublas();
/*Set the value of the global variable allow_tf32_cudnn*/
void SetAllowTF32Cudnn(bool active);
/*Get the global variable allow_tf32_cudnn value*/
bool AllowTF32Cudnn();
#endif // PADDLE_WITH_CUDA #endif // PADDLE_WITH_CUDA
enum DeviceType { enum DeviceType {
......
...@@ -1988,6 +1988,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1988,6 +1988,8 @@ All parameter, weight, gradient are variables in Paddle.
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
m.def("set_cublas_switch", platform::SetAllowTF32Cublas); m.def("set_cublas_switch", platform::SetAllowTF32Cublas);
m.def("get_cublas_switch", platform::AllowTF32Cublas); m.def("get_cublas_switch", platform::AllowTF32Cublas);
m.def("set_cudnn_switch", platform::SetAllowTF32Cudnn);
m.def("get_cudnn_switch", platform::AllowTF32Cudnn);
#endif // PADDLE_WITH_CUDA #endif // PADDLE_WITH_CUDA
using VarQuantScale = using VarQuantScale =
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import six
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
class TestTF32Switch(unittest.TestCase):
def test_on_off(self):
if core.is_compiled_with_cuda():
self.assertTrue(core.get_cudnn_switch()) # default
core.set_cudnn_switch(0)
self.assertFalse(core.get_cudnn_switch()) # turn off
core.set_cudnn_switch(1)
self.assertTrue(core.get_cudnn_switch()) # turn on
core.set_cudnn_switch(1) # restore the switch
else:
pass
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册