未验证 提交 00566ead 编写于 作者: Z Zhang Zheng 提交者: GitHub

Add exception throw for norm_conv when platform is not supported (#40166)

* Add throw for norm_conv when platform is not supported

* fix format
上级 73583f86
...@@ -405,8 +405,18 @@ TEST(CudnnNormConvFp16, K1S1) { ...@@ -405,8 +405,18 @@ TEST(CudnnNormConvFp16, K1S1) {
CudnnNormConvolutionTester<paddle::platform::float16> test( CudnnNormConvolutionTester<paddle::platform::float16> test(
batch_size, height, width, input_channels, output_channels, kernel_size, batch_size, height, width, input_channels, output_channels, kernel_size,
stride); stride);
test.CheckForward(1e-3, true); platform::CUDADeviceContext *ctx = static_cast<platform::CUDADeviceContext *>(
test.CheckBackward(1e-3, true); platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)));
if (ctx->GetComputeCapability() <= 70) {
ASSERT_THROW(test.CheckForward(1e-3, true),
paddle::platform::EnforceNotMet);
ASSERT_THROW(test.CheckBackward(1e-3, true),
paddle::platform::EnforceNotMet);
} else {
ASSERT_NO_THROW(test.CheckForward(1e-3, true));
ASSERT_NO_THROW(test.CheckBackward(1e-3, true));
}
} }
// test for fp16, kernel = 3, output_channels = input_channels // test for fp16, kernel = 3, output_channels = input_channels
...@@ -421,8 +431,18 @@ TEST(CudnnNormConvFp16, K3S1) { ...@@ -421,8 +431,18 @@ TEST(CudnnNormConvFp16, K3S1) {
CudnnNormConvolutionTester<paddle::platform::float16> test( CudnnNormConvolutionTester<paddle::platform::float16> test(
batch_size, height, width, input_channels, output_channels, kernel_size, batch_size, height, width, input_channels, output_channels, kernel_size,
stride); stride);
test.CheckForward(1e-3, true); platform::CUDADeviceContext *ctx = static_cast<platform::CUDADeviceContext *>(
test.CheckBackward(1e-3, true); platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)));
if (ctx->GetComputeCapability() <= 70) {
ASSERT_THROW(test.CheckForward(1e-3, true),
paddle::platform::EnforceNotMet);
ASSERT_THROW(test.CheckBackward(1e-3, true),
paddle::platform::EnforceNotMet);
} else {
ASSERT_NO_THROW(test.CheckForward(1e-3, true));
ASSERT_NO_THROW(test.CheckBackward(1e-3, true));
}
} }
// test for fp16, kernel = 1, output_channels = input_channels * 4 // test for fp16, kernel = 1, output_channels = input_channels * 4
...@@ -437,8 +457,18 @@ TEST(CudnnNormConvFp16, K1S1O4) { ...@@ -437,8 +457,18 @@ TEST(CudnnNormConvFp16, K1S1O4) {
CudnnNormConvolutionTester<paddle::platform::float16> test( CudnnNormConvolutionTester<paddle::platform::float16> test(
batch_size, height, width, input_channels, output_channels, kernel_size, batch_size, height, width, input_channels, output_channels, kernel_size,
stride); stride);
test.CheckForward(1e-3, true); platform::CUDADeviceContext *ctx = static_cast<platform::CUDADeviceContext *>(
test.CheckBackward(1e-3, true); platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)));
if (ctx->GetComputeCapability() <= 70) {
ASSERT_THROW(test.CheckForward(1e-3, true),
paddle::platform::EnforceNotMet);
ASSERT_THROW(test.CheckBackward(1e-3, true),
paddle::platform::EnforceNotMet);
} else {
ASSERT_NO_THROW(test.CheckForward(1e-3, true));
ASSERT_NO_THROW(test.CheckBackward(1e-3, true));
}
} }
// test for fp16, kernel = 1, stride = 2, output_channels = input_channels * 4 // test for fp16, kernel = 1, stride = 2, output_channels = input_channels * 4
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册