From 2f2bf4e8872356c772baf665dc933114cbaced6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Mon, 6 Mar 2023 17:02:29 +0800 Subject: [PATCH] [AMP OP&Test] add bf16 fp16 type support for interpolate (#51153) * add bf16 fp16 type support for interpolate * add bf16 fp16 support for interpolate in phi on cpu --- .../kernels/cpu/interpolate_grad_kernel.cc | 66 +++--- paddle/phi/kernels/cpu/interpolate_kernel.cc | 194 +++++++++++------- .../kernels/gpu/interpolate_grad_kernel.cu | 68 +++--- paddle/phi/kernels/gpu/interpolate_kernel.cu | 61 +++--- .../phi/kernels/onednn/interpolate_kernel.cc | 4 +- 5 files changed, 232 insertions(+), 161 deletions(-) diff --git a/paddle/phi/kernels/cpu/interpolate_grad_kernel.cc b/paddle/phi/kernels/cpu/interpolate_grad_kernel.cc index 7ad69fbe159..c1e7b0129e5 100644 --- a/paddle/phi/kernels/cpu/interpolate_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/interpolate_grad_kernel.cc @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/interpolate_grad_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/layout.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/interpolate_function.h" @@ -36,6 +37,7 @@ static void LinearInterpolationGrad(const DenseTensor& output_grad, auto input_grad_t = EigenTensor::From(*input_grad); auto output_grad_t = EigenTensor::From(output_grad); bool align_flag = (align_mode == 0 && !align_corners); + using MT = typename phi::dtype::MPTypeTrait::Type; for (int l = 0; l < out_w; l++) { int x_w = align_flag ? static_cast(ratio_w * (l + 0.5) - 0.5) : static_cast(ratio_w * l); @@ -51,11 +53,11 @@ static void LinearInterpolationGrad(const DenseTensor& output_grad, for (int j = 0; j < c; j++) { // loop for channels // linear interpolation grad if (data_layout == DataLayout::kNCHW) { - const T grad = output_grad_t(i, j, l); + const MT grad = static_cast(output_grad_t(i, j, l)); input_grad_t(i, j, x_w) += static_cast(grad * d_e); input_grad_t(i, j, x_e) += static_cast(grad * d_w); } else { - const T grad = output_grad_t(i, l, j); + const MT grad = static_cast(output_grad_t(i, l, j)); input_grad_t(i, x_w, j) += static_cast(grad * d_e); input_grad_t(i, x_e, j) += static_cast(grad * d_w); } @@ -81,6 +83,9 @@ static void BilinearInterpolationGrad(const DenseTensor& output_grad, auto input_grad_t = EigenTensor::From(*input_grad); auto output_grad_t = EigenTensor::From(output_grad); bool align_flag = (align_mode == 0 && !align_corners); + + using MT = typename phi::dtype::MPTypeTrait::Type; + for (int k = 0; k < out_h; k++) { // loop for images int y_n = align_flag ? static_cast(ratio_h * (k + 0.5) - 0.5) : static_cast(ratio_h * k); @@ -105,13 +110,14 @@ static void BilinearInterpolationGrad(const DenseTensor& output_grad, for (int j = 0; j < c; j++) { // loop for channels // bilinear interpolation grad if (data_layout == DataLayout::kNCHW) { - const T grad = output_grad_t(i, j, k, l); + // const T grad = output_grad_t(i, j, k, l); + const MT grad = static_cast(output_grad_t(i, j, k, l)); input_grad_t(i, j, y_n, x_w) += static_cast(grad * d_s * d_e); input_grad_t(i, j, y_s, x_w) += static_cast(grad * d_n * d_e); input_grad_t(i, j, y_n, x_e) += static_cast(grad * d_s * d_w); input_grad_t(i, j, y_s, x_e) += static_cast(grad * d_n * d_w); } else { - const T grad = output_grad_t(i, k, l, j); + const MT grad = static_cast(output_grad_t(i, k, l, j)); input_grad_t(i, y_n, x_w, j) += static_cast(grad * d_s * d_e); input_grad_t(i, y_s, x_w, j) += static_cast(grad * d_n * d_e); input_grad_t(i, y_n, x_e, j) += static_cast(grad * d_s * d_w); @@ -173,24 +179,23 @@ static void BicubicInterpolationGrad(const DenseTensor& output_grad, const DataLayout data_layout) { auto input_grad_t = EigenTensor::From(*input_grad); auto output_grad_t = EigenTensor::From(output_grad); + using MT = typename phi::dtype::MPTypeTrait::Type; for (int k = 0; k < out_h; k++) { // loop for images - T y_n = align_corners ? static_cast(ratio_h * k) - : static_cast(ratio_h * (k + 0.5) - 0.5); + MT y_n = align_corners ? ratio_h * k : ratio_h * (k + 0.5) - 0.5; int input_y = floorf(y_n); - T y_t = y_n - input_y; + MT y_t = y_n - input_y; for (int l = 0; l < out_w; l++) { - T x_n = align_corners ? static_cast(ratio_w * l) - : static_cast(ratio_w * (l + 0.5) - 0.5); + MT x_n = align_corners ? ratio_w * l : ratio_w * (l + 0.5) - 0.5; int input_x = floorf(x_n); - T x_t = x_n - input_x; + MT x_t = x_n - input_x; - T x_coeffs[4]; - T y_coeffs[4]; + MT x_coeffs[4]; + MT y_coeffs[4]; - funcs::get_cubic_upsample_coefficients(x_coeffs, x_t); - funcs::get_cubic_upsample_coefficients(y_coeffs, y_t); + funcs::get_cubic_upsample_coefficients(x_coeffs, x_t); + funcs::get_cubic_upsample_coefficients(y_coeffs, y_t); for (int i = 0; i < n; i++) { // loop for batches for (int j = 0; j < c; j++) { // loop for channels @@ -202,13 +207,13 @@ static void BicubicInterpolationGrad(const DenseTensor& output_grad, int access_y = std::max(std::min(input_y - 1 + jj, in_h - 1), static_cast(0)); if (data_layout == DataLayout::kNCHW) { - T grad = output_grad_t(i, j, k, l); + MT grad = static_cast(output_grad_t(i, j, k, l)); input_grad_t(i, j, access_y, access_x) += - grad * y_coeffs[jj] * x_coeffs[ii]; + static_cast(grad * y_coeffs[jj] * x_coeffs[ii]); } else { - T grad = output_grad_t(i, k, l, j); + MT grad = static_cast(output_grad_t(i, k, l, j)); input_grad_t(i, access_y, access_x, j) += - grad * y_coeffs[jj] * x_coeffs[ii]; + static_cast(grad * y_coeffs[jj] * x_coeffs[ii]); } } } @@ -238,6 +243,7 @@ static void TrilinearInterpolationGrad(const DenseTensor& output_grad, auto input_grad_t = EigenTensor::From(*input_grad); auto output_grad_t = EigenTensor::From(output_grad); bool align_flag = (align_mode == 0 && !align_corners); + using MT = typename phi::dtype::MPTypeTrait::Type; for (int j = 0; j < out_d; j++) { // loop for D int t_f = align_flag ? static_cast(ratio_d * (j + 0.5) - 0.5) : static_cast(ratio_d * j); @@ -272,7 +278,7 @@ static void TrilinearInterpolationGrad(const DenseTensor& output_grad, for (int i = 0; i < c; i++) { // loop for channels // trilinear interpolation grad if (data_layout == DataLayout::kNCHW) { - const T grad = output_grad_t(b, i, j, k, l); + const MT grad = static_cast(output_grad_t(b, i, j, k, l)); input_grad_t(b, i, t_f, y_n, x_w) += static_cast(grad * d_b * d_s * d_e); input_grad_t(b, i, t_f, y_n, x_e) += @@ -290,7 +296,7 @@ static void TrilinearInterpolationGrad(const DenseTensor& output_grad, input_grad_t(b, i, t_b, y_s, x_e) += static_cast(grad * d_f * d_n * d_w); } else { - const T grad = output_grad_t(b, j, k, l, i); + const MT grad = static_cast(output_grad_t(b, j, k, l, i)); input_grad_t(b, t_f, y_n, x_w, i) += static_cast(grad * d_b * d_s * d_e); input_grad_t(b, t_f, y_n, x_e, i) += @@ -1038,7 +1044,9 @@ PD_REGISTER_KERNEL(bilinear_interp_grad, ALL_LAYOUT, phi::BilinearInterpGradKernel, float, - double) { + double, + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } @@ -1047,7 +1055,9 @@ PD_REGISTER_KERNEL(nearest_interp_grad, ALL_LAYOUT, phi::NearestInterpGradKernel, float, - double) { + double, + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } @@ -1056,7 +1066,9 @@ PD_REGISTER_KERNEL(trilinear_interp_grad, ALL_LAYOUT, phi::TrilinearInterpGradKernel, float, - double) { + double, + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } @@ -1065,7 +1077,9 @@ PD_REGISTER_KERNEL(linear_interp_grad, ALL_LAYOUT, phi::LinearInterpGradKernel, float, - double) { + double, + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } @@ -1074,7 +1088,9 @@ PD_REGISTER_KERNEL(bicubic_interp_grad, ALL_LAYOUT, phi::BicubicInterpGradKernel, float, - double) { + double, + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } diff --git a/paddle/phi/kernels/cpu/interpolate_kernel.cc b/paddle/phi/kernels/cpu/interpolate_kernel.cc index 1cdde3a7b1e..35f4ae31cfe 100644 --- a/paddle/phi/kernels/cpu/interpolate_kernel.cc +++ b/paddle/phi/kernels/cpu/interpolate_kernel.cc @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/interpolate_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/layout.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/interpolate_function.h" @@ -43,6 +44,7 @@ static void LinearInterpolation(const DenseTensor& input, auto input_t = EigenTensor::From(input); auto output_t = EigenTensor::From(*output); bool align_flag = (align_mode == 0 && !align_corners); + using MT = typename phi::dtype::MPTypeTrait::Type; std::vector vx_w, vx_e; std::vector vd_w, vd_e; @@ -80,12 +82,14 @@ static void LinearInterpolation(const DenseTensor& input, // linear interpolation T out_t; if (data_layout == DataLayout::kNCHW) { - out_t = input_t(i, j, vx_w[l]) * vd_e[l] + - input_t(i, j, vx_e[l]) * vd_w[l]; + out_t = + static_cast(static_cast(input_t(i, j, vx_w[l])) * vd_e[l] + + static_cast(input_t(i, j, vx_e[l])) * vd_w[l]); output_t(i, j, l) = out_t; } else { - out_t = input_t(i, vx_w[l], j) * vd_e[l] + - input_t(i, vx_e[l], j) * vd_w[l]; + out_t = + static_cast(static_cast(input_t(i, vx_w[l], j)) * vd_e[l] + + static_cast(input_t(i, vx_e[l], j)) * vd_w[l]); output_t(i, l, j) = out_t; } } @@ -110,6 +114,7 @@ static void BilinearInterpolation(const DenseTensor& input, auto input_t = EigenTensor::From(input); auto output_t = EigenTensor::From(*output); bool align_flag = (align_mode == 0 && !align_corners); + using MT = typename phi::dtype::MPTypeTrait::Type; std::vector vy_n, vy_s; std::vector vd_n, vd_s; @@ -174,17 +179,27 @@ static void BilinearInterpolation(const DenseTensor& input, // bilinear interpolation T out_t; if (data_layout == DataLayout::kNCHW) { - out_t = input_t(i, j, vy_n[k], vx_w[l]) * vd_s[k] * vd_e[l] + - input_t(i, j, vy_s[k], vx_w[l]) * vd_n[k] * vd_e[l] + - input_t(i, j, vy_n[k], vx_e[l]) * vd_s[k] * vd_w[l] + - input_t(i, j, vy_s[k], vx_e[l]) * vd_n[k] * vd_w[l]; + out_t = static_cast( + static_cast(input_t(i, j, vy_n[k], vx_w[l])) * vd_s[k] * + vd_e[l] + + static_cast(input_t(i, j, vy_s[k], vx_w[l])) * vd_n[k] * + vd_e[l] + + static_cast(input_t(i, j, vy_n[k], vx_e[l])) * vd_s[k] * + vd_w[l] + + static_cast(input_t(i, j, vy_s[k], vx_e[l])) * vd_n[k] * + vd_w[l]); output_t(i, j, k, l) = out_t; } else { - out_t = input_t(i, vy_n[k], vx_w[l], j) * vd_s[k] * vd_e[l] + - input_t(i, vy_s[k], vx_w[l], j) * vd_n[k] * vd_e[l] + - input_t(i, vy_n[k], vx_e[l], j) * vd_s[k] * vd_w[l] + - input_t(i, vy_s[k], vx_e[l], j) * vd_n[k] * vd_w[l]; + out_t = static_cast( + static_cast(input_t(i, vy_n[k], vx_w[l], j)) * vd_s[k] * + vd_e[l] + + static_cast(input_t(i, vy_s[k], vx_w[l], j)) * vd_n[k] * + vd_e[l] + + static_cast(input_t(i, vy_n[k], vx_e[l], j)) * vd_s[k] * + vd_w[l] + + static_cast(input_t(i, vy_s[k], vx_e[l], j)) * vd_n[k] * + vd_w[l]); output_t(i, k, l, j) = out_t; } } @@ -206,6 +221,7 @@ static void NearestNeighborInterpolate(const DenseTensor& input, const DataLayout& data_layout) { auto input_t = EigenTensor::From(input); auto output_t = EigenTensor::From(*output); + for (int k = 0; k < out_h; k++) { // loop for images int in_k = (align_corners) ? static_cast(ratio_h * k + 0.5) : static_cast(ratio_h * k); @@ -242,22 +258,23 @@ static void BicubicInterpolation(const DenseTensor& input, const DataLayout data_layout) { auto input_t = EigenTensor::From(input); auto output_t = EigenTensor::From(*output); + using MT = typename phi::dtype::MPTypeTrait::Type; for (int k = 0; k < out_h; k++) { // loop for images - T y_n = align_corners ? static_cast(ratio_h * k) - : static_cast(ratio_h * (k + 0.5) - 0.5); + MT y_n = align_corners ? static_cast(ratio_h * k) + : static_cast(ratio_h * (k + 0.5) - 0.5); int input_y = floorf(y_n); - const T y_t = y_n - input_y; + const MT y_t = y_n - input_y; for (int l = 0; l < out_w; l++) { - T x_n = align_corners ? static_cast(ratio_w * l) - : static_cast(ratio_w * (l + 0.5) - 0.5); + MT x_n = align_corners ? static_cast(ratio_w * l) + : static_cast(ratio_w * (l + 0.5) - 0.5); int input_x = floorf(x_n); - const T x_t = x_n - input_x; + const MT x_t = x_n - input_x; for (int i = 0; i < n; i++) { // loop for batches for (int j = 0; j < c; j++) { // loop for channels - T coefficients[4]; + MT coefficients[4]; // interp 4 times in x direction for (int ii = 0; ii < 4; ii++) { int access_y = std::max(std::min(input_y - 1 + ii, in_h - 1), @@ -271,35 +288,37 @@ static void BicubicInterpolation(const DenseTensor& input, int access_x_3 = std::max(std::min(input_x + 2, in_w - 1), static_cast(0)); if (data_layout == DataLayout::kNCHW) { - coefficients[ii] = - cubic_interp(input_t(i, j, access_y, access_x_0), - input_t(i, j, access_y, access_x_1), - input_t(i, j, access_y, access_x_2), - input_t(i, j, access_y, access_x_3), - x_t); + coefficients[ii] = cubic_interp( + static_cast(input_t(i, j, access_y, access_x_0)), + static_cast(input_t(i, j, access_y, access_x_1)), + static_cast(input_t(i, j, access_y, access_x_2)), + static_cast(input_t(i, j, access_y, access_x_3)), + x_t); } else { - coefficients[ii] = - cubic_interp(input_t(i, access_y, access_x_0, j), - input_t(i, access_y, access_x_1, j), - input_t(i, access_y, access_x_2, j), - input_t(i, access_y, access_x_3, j), - x_t); + coefficients[ii] = cubic_interp( + static_cast(input_t(i, access_y, access_x_0, j)), + static_cast(input_t(i, access_y, access_x_1, j)), + static_cast(input_t(i, access_y, access_x_2, j)), + static_cast(input_t(i, access_y, access_x_3, j)), + x_t); } } // interp y direction if (data_layout == DataLayout::kNCHW) { - output_t(i, j, k, l) = cubic_interp(coefficients[0], - coefficients[1], - coefficients[2], - coefficients[3], - y_t); + output_t(i, j, k, l) = + static_cast(cubic_interp(coefficients[0], + coefficients[1], + coefficients[2], + coefficients[3], + y_t)); } else { - output_t(i, k, l, j) = cubic_interp(coefficients[0], - coefficients[1], - coefficients[2], - coefficients[3], - y_t); + output_t(i, k, l, j) = + static_cast(cubic_interp(coefficients[0], + coefficients[1], + coefficients[2], + coefficients[3], + y_t)); } } } @@ -327,6 +346,7 @@ static void TrilinearInterpolation(const DenseTensor& input, auto input_t = EigenTensor::From(input); auto output_t = EigenTensor::From(*output); bool align_flag = (align_mode == 0 && !align_corners); + using MT = typename phi::dtype::MPTypeTrait::Type; std::vector vt_f, vt_b; std::vector vd_f, vd_b; @@ -417,40 +437,42 @@ static void TrilinearInterpolation(const DenseTensor& input, for (int l = 0; l < out_w; l++) { // trilinear interpolation if (data_layout == DataLayout::kNCHW) { - T out_t = input_t(b, i, vt_f[j], vy_n[k], vx_w[l]) * vd_b[j] * - vd_s[k] * vd_e[l] + - input_t(b, i, vt_f[j], vy_n[k], vx_e[l]) * vd_b[j] * - vd_s[k] * vd_w[l] + - input_t(b, i, vt_f[j], vy_s[k], vx_w[l]) * vd_b[j] * - vd_n[k] * vd_e[l] + - input_t(b, i, vt_f[j], vy_s[k], vx_e[l]) * vd_b[j] * - vd_n[k] * vd_w[l] + - input_t(b, i, vt_b[j], vy_n[k], vx_w[l]) * vd_f[j] * - vd_s[k] * vd_e[l] + - input_t(b, i, vt_b[j], vy_n[k], vx_e[l]) * vd_f[j] * - vd_s[k] * vd_w[l] + - input_t(b, i, vt_b[j], vy_s[k], vx_w[l]) * vd_f[j] * - vd_n[k] * vd_e[l] + - input_t(b, i, vt_b[j], vy_s[k], vx_e[l]) * vd_f[j] * - vd_n[k] * vd_w[l]; + T out_t = static_cast( + static_cast(input_t(b, i, vt_f[j], vy_n[k], vx_w[l])) * + vd_b[j] * vd_s[k] * vd_e[l] + + static_cast(input_t(b, i, vt_f[j], vy_n[k], vx_e[l])) * + vd_b[j] * vd_s[k] * vd_w[l] + + static_cast(input_t(b, i, vt_f[j], vy_s[k], vx_w[l])) * + vd_b[j] * vd_n[k] * vd_e[l] + + static_cast(input_t(b, i, vt_f[j], vy_s[k], vx_e[l])) * + vd_b[j] * vd_n[k] * vd_w[l] + + static_cast(input_t(b, i, vt_b[j], vy_n[k], vx_w[l])) * + vd_f[j] * vd_s[k] * vd_e[l] + + static_cast(input_t(b, i, vt_b[j], vy_n[k], vx_e[l])) * + vd_f[j] * vd_s[k] * vd_w[l] + + static_cast(input_t(b, i, vt_b[j], vy_s[k], vx_w[l])) * + vd_f[j] * vd_n[k] * vd_e[l] + + static_cast(input_t(b, i, vt_b[j], vy_s[k], vx_e[l])) * + vd_f[j] * vd_n[k] * vd_w[l]); output_t(b, i, j, k, l) = out_t; } else { - T out_t = input_t(b, vt_f[j], vy_n[k], vx_w[l], i) * vd_b[j] * - vd_s[k] * vd_e[l] + - input_t(b, vt_f[j], vy_n[k], vx_e[l], i) * vd_b[j] * - vd_s[k] * vd_w[l] + - input_t(b, vt_f[j], vy_s[k], vx_w[l], i) * vd_b[j] * - vd_n[k] * vd_e[l] + - input_t(b, vt_f[j], vy_s[k], vx_e[l], i) * vd_b[j] * - vd_n[k] * vd_w[l] + - input_t(b, vt_b[j], vy_n[k], vx_w[l], i) * vd_f[j] * - vd_s[k] * vd_e[l] + - input_t(b, vt_b[j], vy_n[k], vx_e[l], i) * vd_f[j] * - vd_s[k] * vd_w[l] + - input_t(b, vt_b[j], vy_s[k], vx_w[l], i) * vd_f[j] * - vd_n[k] * vd_e[l] + - input_t(b, vt_b[j], vy_s[k], vx_e[l], i) * vd_f[j] * - vd_n[k] * vd_w[l]; + T out_t = static_cast( + static_cast(input_t(b, vt_f[j], vy_n[k], vx_w[l], i)) * + vd_b[j] * vd_s[k] * vd_e[l] + + static_cast(input_t(b, vt_f[j], vy_n[k], vx_e[l], i)) * + vd_b[j] * vd_s[k] * vd_w[l] + + static_cast(input_t(b, vt_f[j], vy_s[k], vx_w[l], i)) * + vd_b[j] * vd_n[k] * vd_e[l] + + static_cast(input_t(b, vt_f[j], vy_s[k], vx_e[l], i)) * + vd_b[j] * vd_n[k] * vd_w[l] + + static_cast(input_t(b, vt_b[j], vy_n[k], vx_w[l], i)) * + vd_f[j] * vd_s[k] * vd_e[l] + + static_cast(input_t(b, vt_b[j], vy_n[k], vx_e[l], i)) * + vd_f[j] * vd_s[k] * vd_w[l] + + static_cast(input_t(b, vt_b[j], vy_s[k], vx_w[l], i)) * + vd_f[j] * vd_n[k] * vd_e[l] + + static_cast(input_t(b, vt_b[j], vy_s[k], vx_e[l], i)) * + vd_f[j] * vd_n[k] * vd_w[l]); output_t(b, j, k, l, i) = out_t; } } @@ -1190,7 +1212,9 @@ PD_REGISTER_KERNEL(bilinear_interp, phi::BilinearInterpKernel, float, double, - uint8_t) { + uint8_t, + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } @@ -1202,7 +1226,9 @@ PD_REGISTER_KERNEL(nearest_interp, double, int, int64_t, - uint8_t) { + uint8_t, + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } @@ -1212,7 +1238,9 @@ PD_REGISTER_KERNEL(trilinear_interp, phi::TrilinearInterpKernel, float, double, - uint8_t) { + uint8_t, + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } @@ -1222,12 +1250,20 @@ PD_REGISTER_KERNEL(linear_interp, phi::LinearInterpKernel, float, double, - uint8_t) { + uint8_t, + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } -PD_REGISTER_KERNEL( - bicubic_interp, CPU, ALL_LAYOUT, phi::BicubicInterpKernel, float, double) { +PD_REGISTER_KERNEL(bicubic_interp, + CPU, + ALL_LAYOUT, + phi::BicubicInterpKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } diff --git a/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu b/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu index 9b335786b26..115f22ab2a5 100644 --- a/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu @@ -60,6 +60,7 @@ __global__ void KeLinearInterpBw(T* in, int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; bool align_flag = (align_mode == 0 && !align_corners); + using MT = typename phi::dtype::MPTypeTrait::Type; for (; tid < nthreads; tid += stride) { int out_id_h = tid / output_w; int out_id_w = tid % output_w; @@ -79,13 +80,12 @@ __global__ void KeLinearInterpBw(T* in, : ratio_w * out_img_idx; in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; // w int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; // w_id - using MT = typename phi::dtype::MPTypeTrait::Type; - T src_w = static_cast(ratio_w * (out_img_idx + 0.5) - 0.5); - src_w = (src_w > static_cast(0)) ? src_w : static_cast(0); - T w1lambda = align_flag - ? static_cast(static_cast(src_w) - in_img_idx) - : static_cast(ratio_w * out_img_idx - in_img_idx); - T w2lambda = static_cast(1.0) - w1lambda; + + MT src_w = ratio_w * (out_img_idx + 0.5) - 0.5; + src_w = (src_w > 0) ? src_w : 0; + MT w1lambda = + align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx; + MT w2lambda = 1.0 - w1lambda; T* in_pos; if (data_layout == DataLayout::kNCHW) { @@ -96,11 +96,17 @@ __global__ void KeLinearInterpBw(T* in, const T* out_pos = &out[out_id_w]; if (data_layout == DataLayout::kNCHW) { - phi::CudaAtomicAdd(&in_pos[0], w2lambda * out_pos[0]); - phi::CudaAtomicAdd(&in_pos[w_id], w1lambda * out_pos[0]); + phi::CudaAtomicAdd( + &in_pos[0], static_cast(w2lambda * static_cast(out_pos[0]))); + phi::CudaAtomicAdd( + &in_pos[w_id], + static_cast(w1lambda * static_cast(out_pos[0]))); } else { - phi::CudaAtomicAdd(&in_pos[0], w2lambda * out_pos[0]); - phi::CudaAtomicAdd(&in_pos[w_id * num_channels], w1lambda * out_pos[0]); + phi::CudaAtomicAdd( + &in_pos[0], static_cast(w2lambda * static_cast(out_pos[0]))); + phi::CudaAtomicAdd( + &in_pos[w_id * num_channels], + static_cast(w1lambda * static_cast(out_pos[0]))); } } } @@ -469,6 +475,7 @@ __global__ void KeBicubicInterpBw(T* in, int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; + using MT = typename phi::dtype::MPTypeTrait::Type; for (; tid < nthreads; tid += stride) { int out_id_h = tid / output_w; @@ -487,23 +494,21 @@ __global__ void KeBicubicInterpBw(T* in, channel_id = tid % num_channels; } - T in_img_idy = align_corners - ? static_cast(ratio_h * out_img_idy) - : static_cast(ratio_h * (out_img_idy + 0.5) - 0.5); + MT in_img_idy = align_corners ? ratio_h * out_img_idy + : ratio_h * (out_img_idy + 0.5) - 0.5; int input_y = floorf(static_cast(in_img_idy)); - using MT = typename phi::dtype::MPTypeTrait::Type; - const T y_t = static_cast(static_cast(in_img_idy) - input_y); - T in_img_idx = align_corners - ? static_cast(ratio_w * out_img_idx) - : static_cast(ratio_w * (out_img_idx + 0.5) - 0.5); + + const MT y_t = in_img_idy - input_y; + MT in_img_idx = align_corners ? ratio_w * out_img_idx + : ratio_w * (out_img_idx + 0.5) - 0.5; int input_x = floorf(static_cast(in_img_idx)); - const T x_t = static_cast(static_cast(in_img_idx) - input_x); + const MT x_t = in_img_idx - input_x; - T x_coeffs[4]; - T y_coeffs[4]; + MT x_coeffs[4]; + MT y_coeffs[4]; - funcs::get_cubic_upsample_coefficients(x_coeffs, x_t); - funcs::get_cubic_upsample_coefficients(y_coeffs, y_t); + funcs::get_cubic_upsample_coefficients(x_coeffs, x_t); + funcs::get_cubic_upsample_coefficients(y_coeffs, y_t); const T* out_pos = &out[out_id_h * output_w + out_id_w]; T* in_pos; @@ -524,7 +529,8 @@ __global__ void KeBicubicInterpBw(T* in, access_x * num_channels + channel_id]; } phi::CudaAtomicAdd(&in_pos[0], - (out_pos[0] * y_coeffs[j] * x_coeffs[i])); + static_cast(static_cast(out_pos[0]) * + y_coeffs[j] * x_coeffs[i])); } } } @@ -1568,7 +1574,8 @@ PD_REGISTER_KERNEL(bilinear_interp_grad, phi::BilinearInterpGradKernel, float, double, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } @@ -1589,7 +1596,8 @@ PD_REGISTER_KERNEL(trilinear_interp_grad, phi::TrilinearInterpGradKernel, float, double, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } @@ -1599,7 +1607,8 @@ PD_REGISTER_KERNEL(linear_interp_grad, phi::LinearInterpGradKernel, float, double, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } @@ -1609,7 +1618,8 @@ PD_REGISTER_KERNEL(bicubic_interp_grad, phi::BicubicInterpGradKernel, float, double, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); } diff --git a/paddle/phi/kernels/gpu/interpolate_kernel.cu b/paddle/phi/kernels/gpu/interpolate_kernel.cu index 9aa5d55201c..39274446cb9 100644 --- a/paddle/phi/kernels/gpu/interpolate_kernel.cu +++ b/paddle/phi/kernels/gpu/interpolate_kernel.cu @@ -82,26 +82,27 @@ __global__ void KeLinearInterpFw(const T* in, in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; // w int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; // w_id using MT = typename phi::dtype::MPTypeTrait::Type; - T src_w = static_cast(ratio_w * (out_img_idx + 0.5) - 0.5); - src_w = (src_w > static_cast(0)) ? src_w : static_cast(0); - T w1lambda = align_flag - ? static_cast(static_cast(src_w) - in_img_idx) - : static_cast(ratio_w * out_img_idx - in_img_idx); - T w2lambda = static_cast(1.0) - w1lambda; + MT src_w = ratio_w * (out_img_idx + 0.5) - 0.5; + src_w = (src_w > 0) ? src_w : 0; + MT w1lambda = align_flag ? (src_w - in_img_idx) + : (ratio_w * out_img_idx - in_img_idx); + MT w2lambda = 1.0 - w1lambda; if (data_layout == DataLayout::kNCHW) { const T* in_pos = &in[out_id_h * out_id_w + channel_id * in_img_size + in_img_idx]; // linear interpolation out[out_id_h * output_w + out_id_w] = - w2lambda * in_pos[0] + w1lambda * in_pos[w_id]; + static_cast(w2lambda * static_cast(in_pos[0]) + + w1lambda * static_cast(in_pos[w_id])); } else { const T* in_pos = &in[out_id_h * input_w + in_img_idx * num_channels + channel_id]; // linear interpolation - out[out_id_h * output_w + out_id_w] = - w2lambda * in_pos[0] + w1lambda * in_pos[w_id * num_channels]; + out[out_id_h * output_w + out_id_w] = static_cast( + w2lambda * static_cast(in_pos[0]) + + w1lambda * static_cast(in_pos[w_id * num_channels])); } } } @@ -308,15 +309,18 @@ __global__ void KeBilinearInterpNCHWFw(const T* in, template __device__ __forceinline__ static T Kecubic_interp( const T x0, const T x1, const T x2, const T x3, T t) { - T coeffs[4]; - T a = static_cast(-0.75); - T x_1 = t; - T x_2 = static_cast(1.0) - t; - coeffs[0] = funcs::CubicConvolution2(x_1 + static_cast(1.0), a); - coeffs[1] = funcs::CubicConvolution1(x_1, a); - coeffs[2] = funcs::CubicConvolution1(x_2, a); - coeffs[3] = funcs::CubicConvolution2(x_2 + static_cast(1.0), a); - return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; + using MT = typename phi::dtype::MPTypeTrait::Type; + MT coeffs[4]; + MT a = static_cast(-0.75); + MT x_1 = static_cast(t); + MT x_2 = static_cast(1.0) - static_cast(t); + coeffs[0] = funcs::CubicConvolution2(x_1 + static_cast(1.0), a); + coeffs[1] = funcs::CubicConvolution1(x_1, a); + coeffs[2] = funcs::CubicConvolution1(x_2, a); + coeffs[3] = funcs::CubicConvolution2(x_2 + static_cast(1.0), a); + return static_cast( + static_cast(x0) * coeffs[0] + static_cast(x1) * coeffs[1] + + static_cast(x2) * coeffs[2] + static_cast(x3) * coeffs[3]); } template @@ -338,6 +342,7 @@ __global__ void KeBicubicInterpFw(const T* in, int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; + using MT = typename phi::dtype::MPTypeTrait::Type; for (; tid < nthreads; tid += stride) { int out_id_h = tid / output_w; @@ -357,18 +362,16 @@ __global__ void KeBicubicInterpFw(const T* in, channel_id = tid % num_channels; } - T in_img_idy = align_corners - ? static_cast(ratio_h * out_img_idy) - : static_cast(ratio_h * (out_img_idy + 0.5) - 0.5); + MT in_img_idy = align_corners ? ratio_h * out_img_idy + : ratio_h * (out_img_idy + 0.5) - 0.5; int input_y = floorf(static_cast(in_img_idy)); - using MT = typename phi::dtype::MPTypeTrait::Type; - const T y_t = static_cast(static_cast(in_img_idy) - input_y); - T in_img_idx = align_corners - ? static_cast(ratio_w * out_img_idx) - : static_cast(ratio_w * (out_img_idx + 0.5) - 0.5); + const T y_t = static_cast(in_img_idy - input_y); + + MT in_img_idx = align_corners ? ratio_w * out_img_idx + : ratio_w * (out_img_idx + 0.5) - 0.5; int input_x = floorf(static_cast(in_img_idx)); - const T x_t = static_cast(static_cast(in_img_idx) - input_x); + const T x_t = static_cast(in_img_idx - input_x); T coefficients[4]; const T* in_pos_0; @@ -1460,6 +1463,7 @@ PD_REGISTER_KERNEL(bilinear_interp, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int) { kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); @@ -1486,6 +1490,7 @@ PD_REGISTER_KERNEL(trilinear_interp, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int) { kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); @@ -1498,6 +1503,7 @@ PD_REGISTER_KERNEL(linear_interp, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int) { kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); @@ -1510,6 +1516,7 @@ PD_REGISTER_KERNEL(bicubic_interp, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int) { kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); diff --git a/paddle/phi/kernels/onednn/interpolate_kernel.cc b/paddle/phi/kernels/onednn/interpolate_kernel.cc index 7f6ded1958f..3f2fa5bf382 100644 --- a/paddle/phi/kernels/onednn/interpolate_kernel.cc +++ b/paddle/phi/kernels/onednn/interpolate_kernel.cc @@ -232,7 +232,8 @@ PD_REGISTER_KERNEL(bilinear_interp, ONEDNN, phi::BilinearInterpKernel, float, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::float16) {} PD_REGISTER_KERNEL(nearest_interp, OneDNN, @@ -240,5 +241,6 @@ PD_REGISTER_KERNEL(nearest_interp, phi::NearestInterpKernel, float, phi::dtype::bfloat16, + phi::dtype::float16, int8_t, uint8_t) {} -- GitLab