From 00566eade8749566763af7e782224f3fed68bbdf Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Tue, 8 Mar 2022 16:47:20 +0800 Subject: [PATCH] Add exception throw for norm_conv when platform is not supported (#40166) * Add throw for norm_conv when platform is not supported * fix format --- .../operators/fused/cudnn_norm_conv_test.cc | 42 ++++++++++++++++--- 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/fused/cudnn_norm_conv_test.cc b/paddle/fluid/operators/fused/cudnn_norm_conv_test.cc index b3792a176fa..a80f590aa49 100644 --- a/paddle/fluid/operators/fused/cudnn_norm_conv_test.cc +++ b/paddle/fluid/operators/fused/cudnn_norm_conv_test.cc @@ -405,8 +405,18 @@ TEST(CudnnNormConvFp16, K1S1) { CudnnNormConvolutionTester test( batch_size, height, width, input_channels, output_channels, kernel_size, stride); - test.CheckForward(1e-3, true); - test.CheckBackward(1e-3, true); + platform::CUDADeviceContext *ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0))); + + if (ctx->GetComputeCapability() <= 70) { + ASSERT_THROW(test.CheckForward(1e-3, true), + paddle::platform::EnforceNotMet); + ASSERT_THROW(test.CheckBackward(1e-3, true), + paddle::platform::EnforceNotMet); + } else { + ASSERT_NO_THROW(test.CheckForward(1e-3, true)); + ASSERT_NO_THROW(test.CheckBackward(1e-3, true)); + } } // test for fp16, kernel = 3, output_channels = input_channels @@ -421,8 +431,18 @@ TEST(CudnnNormConvFp16, K3S1) { CudnnNormConvolutionTester test( batch_size, height, width, input_channels, output_channels, kernel_size, stride); - test.CheckForward(1e-3, true); - test.CheckBackward(1e-3, true); + platform::CUDADeviceContext *ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0))); + + if (ctx->GetComputeCapability() <= 70) { + ASSERT_THROW(test.CheckForward(1e-3, true), + paddle::platform::EnforceNotMet); + ASSERT_THROW(test.CheckBackward(1e-3, true), + paddle::platform::EnforceNotMet); + } else { + ASSERT_NO_THROW(test.CheckForward(1e-3, true)); + ASSERT_NO_THROW(test.CheckBackward(1e-3, true)); + } } // test for fp16, kernel = 1, output_channels = input_channels * 4 @@ -437,8 +457,18 @@ TEST(CudnnNormConvFp16, K1S1O4) { CudnnNormConvolutionTester test( batch_size, height, width, input_channels, output_channels, kernel_size, stride); - test.CheckForward(1e-3, true); - test.CheckBackward(1e-3, true); + platform::CUDADeviceContext *ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0))); + + if (ctx->GetComputeCapability() <= 70) { + ASSERT_THROW(test.CheckForward(1e-3, true), + paddle::platform::EnforceNotMet); + ASSERT_THROW(test.CheckBackward(1e-3, true), + paddle::platform::EnforceNotMet); + } else { + ASSERT_NO_THROW(test.CheckForward(1e-3, true)); + ASSERT_NO_THROW(test.CheckBackward(1e-3, true)); + } } // test for fp16, kernel = 1, stride = 2, output_channels = input_channels * 4 -- GitLab