未验证 提交 fb3d5f07 编写于 作者: Z Zhang Zheng 提交者: GitHub

Fix test_cudnn_norm_conv and test_cudnn_bn_add_relu in CUDA11.2 (#42405)

* Fix test_cudnn_norm_conv and test_cudnn_bn_add_relu in CUDA11.2

* no throw in V100 for some cases
上级 a3d56a9c
......@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h"
#include "paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
DECLARE_bool(cudnn_batchnorm_spatial_persistent);
......@@ -33,6 +34,7 @@ namespace op = paddle::operators;
using Tensor = paddle::framework::Tensor;
USE_OP_ITSELF(batch_norm);
PD_DECLARE_KERNEL(batch_norm, GPU, ALL_LAYOUT);
USE_CUDA_ONLY_OP(fused_bn_add_activation);
USE_CUDA_ONLY_OP(fused_bn_add_activation_grad);
......
......@@ -164,6 +164,7 @@ void ComputeConv2DBackward(const platform::CUDADeviceContext &ctx,
attrs.insert({"groups", groups});
attrs.insert({"exhaustive_search", exhaustive_search});
attrs.insert({"use_addto", use_addto});
attrs.insert({"workspace_size_MB", 512});
auto op = framework::OpRegistry::CreateOp(
"conv2d_grad", {{"Input", {"Input"}},
......@@ -408,7 +409,7 @@ TEST(CudnnNormConvFp16, K1S1) {
platform::CUDADeviceContext *ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)));
if (ctx->GetComputeCapability() <= 70) {
if (ctx->GetComputeCapability() < 70) {
ASSERT_THROW(test.CheckForward(1e-3, true),
paddle::platform::EnforceNotMet);
ASSERT_THROW(test.CheckBackward(1e-3, true),
......@@ -434,7 +435,7 @@ TEST(CudnnNormConvFp16, K3S1) {
platform::CUDADeviceContext *ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)));
if (ctx->GetComputeCapability() <= 70) {
if (ctx->GetComputeCapability() < 70) {
ASSERT_THROW(test.CheckForward(1e-3, true),
paddle::platform::EnforceNotMet);
ASSERT_THROW(test.CheckBackward(1e-3, true),
......@@ -460,7 +461,7 @@ TEST(CudnnNormConvFp16, K1S1O4) {
platform::CUDADeviceContext *ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)));
if (ctx->GetComputeCapability() <= 70) {
if (ctx->GetComputeCapability() < 70) {
ASSERT_THROW(test.CheckForward(1e-3, true),
paddle::platform::EnforceNotMet);
ASSERT_THROW(test.CheckBackward(1e-3, true),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册