diff --git a/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu b/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu index 51a5f50560eac7a8919053a91b69ada46c13ae00..b38cae829680b2e6f7638ec97e13db9dd045bebc 100644 --- a/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu @@ -487,13 +487,13 @@ __global__ void KeBicubicInterpBw(T* in, T in_img_idy = align_corners ? static_cast(ratio_h * out_img_idy) : static_cast(ratio_h * (out_img_idy + 0.5) - 0.5); - int input_y = floorf(in_img_idy); + int input_y = floorf(static_cast(in_img_idy)); using MT = typename phi::dtype::MPTypeTrait::Type; const T y_t = static_cast(static_cast(in_img_idy) - input_y); T in_img_idx = align_corners ? static_cast(ratio_w * out_img_idx) : static_cast(ratio_w * (out_img_idx + 0.5) - 0.5); - int input_x = floorf(in_img_idx); + int input_x = floorf(static_cast(in_img_idx)); const T x_t = static_cast(static_cast(in_img_idx) - input_x); T x_coeffs[4]; @@ -1577,7 +1577,8 @@ PD_REGISTER_KERNEL(nearest_interp_grad, phi::NearestInterpGradKernel, float, double, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } diff --git a/paddle/phi/kernels/gpu/interpolate_kernel.cu b/paddle/phi/kernels/gpu/interpolate_kernel.cu index 8135e73142fecbee48c501b187e7b76de060208e..07e113ef7aa8004284b35ecf14345c4dc0491261 100644 --- a/paddle/phi/kernels/gpu/interpolate_kernel.cu +++ b/paddle/phi/kernels/gpu/interpolate_kernel.cu @@ -355,14 +355,14 @@ __global__ void KeBicubicInterpFw(const T* in, T in_img_idy = align_corners ? static_cast(ratio_h * out_img_idy) : static_cast(ratio_h * (out_img_idy + 0.5) - 0.5); - int input_y = floorf(in_img_idy); + int input_y = floorf(static_cast(in_img_idy)); using MT = typename phi::dtype::MPTypeTrait::Type; const T y_t = static_cast(static_cast(in_img_idy) - input_y); T in_img_idx = align_corners ? static_cast(ratio_w * out_img_idx) : static_cast(ratio_w * (out_img_idx + 0.5) - 0.5); - int input_x = floorf(in_img_idx); + int input_x = floorf(static_cast(in_img_idx)); const T x_t = static_cast(static_cast(in_img_idx) - input_x); T coefficients[4]; @@ -1468,6 +1468,7 @@ PD_REGISTER_KERNEL(nearest_interp, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int, int64_t) { kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); diff --git a/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu b/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu index 0d5f266d3d172721cd663a4703d9f003322ad4f0..5d1a92a3119bc427ef4d13d516125b6715b3e6fc 100644 --- a/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu @@ -454,8 +454,14 @@ void ConvCudnnGradKernel(const Context& ctx, #ifdef PADDLE_WITH_HIP // HIP MIOPEN ONLY SUPPORT NCHW format auto compute_format = paddle::platform::DataLayout::kNCHW; +#else +#if CUDNN_VERSION_MIN(8, 1, 0) + const bool compute_in_nhwc = + (dtype == CUDNN_DATA_HALF || dtype == CUDNN_DATA_BFLOAT16) && + IsVoltaOrLater(ctx); #else const bool compute_in_nhwc = dtype == CUDNN_DATA_HALF && IsVoltaOrLater(ctx); +#endif auto compute_format = compute_in_nhwc && channel_last ? paddle::platform::DataLayout::kNHWC : paddle::platform::DataLayout::kNCHW; diff --git a/paddle/phi/kernels/gpudnn/conv_kernel.cu b/paddle/phi/kernels/gpudnn/conv_kernel.cu index 3e3b1fb198da93f7a12e0daf8ac0b584903d54b4..404405665316255ea48663ca7d7817e4e5b1bef7 100644 --- a/paddle/phi/kernels/gpudnn/conv_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_kernel.cu @@ -373,10 +373,18 @@ void ConvCudnnKernel(const Context& ctx, #ifdef PADDLE_WITH_HIP // HIP MIOPEN ONLY SUPPORT NCHW format auto compute_format = paddle::platform::DataLayout::kNCHW; +#else +#if CUDNN_VERSION_MIN(8, 1, 0) + // Tensor Core introduced from Volta GPUs supports more faster conv op + // with FP16 or BF16 in NHWC data format. + const bool compute_in_nhwc = + (dtype == CUDNN_DATA_HALF || dtype == CUDNN_DATA_BFLOAT16) && + IsVoltaOrLater(ctx); #else // Tensor Core introduced from Volta GPUs supports more faster conv op - // with FP16 in NHWC data format. + // with FP16 in NHWC data format. (BF16 require cudnn >= 8.1.0) const bool compute_in_nhwc = dtype == CUDNN_DATA_HALF && IsVoltaOrLater(ctx); +#endif // We will only do data format conversion from NHWC to NCHW. // cudnn will convert NCHW to NHWC automatically on Tensor Core. auto compute_format = compute_in_nhwc && channel_last