未验证 提交 b12c27eb 编写于 作者: Y Yuanle Liu 提交者: GitHub

interpolate (forward grad) op support fp16 on gpu (#45061)

上级 cbf26bb1
...@@ -28,26 +28,28 @@ namespace funcs { ...@@ -28,26 +28,28 @@ namespace funcs {
template <typename T> template <typename T>
HOSTDEVICE inline T CubicConvolution1(T x, T A) { HOSTDEVICE inline T CubicConvolution1(T x, T A) {
return ((A + 2) * x - (A + 3)) * x * x + 1; return ((A + static_cast<T>(2)) * x - (A + static_cast<T>(3))) * x * x +
static_cast<T>(1);
} }
template <typename T> template <typename T>
HOSTDEVICE inline T CubicConvolution2(T x, T A) { HOSTDEVICE inline T CubicConvolution2(T x, T A) {
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; return ((A * x - static_cast<T>(5) * A) * x + static_cast<T>(8) * A) * x -
static_cast<T>(4) * A;
} }
template <typename T> template <typename T>
HOSTDEVICE inline void get_cubic_upsample_coefficients(T coeffs[4], T t) { HOSTDEVICE inline void get_cubic_upsample_coefficients(T coeffs[4], T t) {
T A = -0.75; T A = static_cast<T>(-0.75);
T x1 = t; T x1 = t;
coeffs[0] = CubicConvolution2<T>(x1 + 1.0, A); coeffs[0] = CubicConvolution2<T>(x1 + static_cast<T>(1.0), A);
coeffs[1] = CubicConvolution1<T>(x1, A); coeffs[1] = CubicConvolution1<T>(x1, A);
// opposite coefficients // opposite coefficients
T x2 = 1.0 - t; T x2 = static_cast<T>(1.0) - t;
coeffs[2] = CubicConvolution1<T>(x2, A); coeffs[2] = CubicConvolution1<T>(x2, A);
coeffs[3] = CubicConvolution2<T>(x2 + 1.0, A); coeffs[3] = CubicConvolution2<T>(x2 + static_cast<T>(1.0), A);
} }
inline void ExtractNCDWH(const DDim& dims, inline void ExtractNCDWH(const DDim& dims,
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/platform/fast_divmod.h" #include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/layout.h" #include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/interpolate_function.h" #include "paddle/phi/kernels/funcs/interpolate_function.h"
...@@ -34,11 +35,12 @@ __forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex( ...@@ -34,11 +35,12 @@ __forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex(
T* lambda2, T* lambda2,
T src_x, T src_x,
const int in_img_x) { const int in_img_x) {
src_x = (src_x > 0) ? src_x : 0.f; src_x = (src_x > static_cast<T>(0)) ? src_x : static_cast<T>(0);
*in_img_idx = static_cast<int>(src_x); *in_img_idx = static_cast<int>(src_x);
*x_id = (*in_img_idx < in_img_x - 1) ? 1 : 0; *x_id = (*in_img_idx < in_img_x - 1) ? 1 : 0;
*lambda1 = src_x - *in_img_idx; using MT = typename phi::dtype::MPTypeTrait<T>::Type;
*lambda2 = 1.f - *lambda1; *lambda1 = static_cast<T>(static_cast<MT>(src_x) - *in_img_idx);
*lambda2 = static_cast<T>(1.0) - *lambda1;
} }
template <typename T> template <typename T>
...@@ -50,7 +52,7 @@ __global__ void KeLinearInterpBw(T* in, ...@@ -50,7 +52,7 @@ __global__ void KeLinearInterpBw(T* in,
const size_t output_h, const size_t output_h,
const size_t output_w, const size_t output_w,
const size_t num_channels, const size_t num_channels,
const T ratio_w, const float ratio_w,
const bool align_corners, const bool align_corners,
const int align_mode, const int align_mode,
const DataLayout data_layout) { const DataLayout data_layout) {
...@@ -77,12 +79,13 @@ __global__ void KeLinearInterpBw(T* in, ...@@ -77,12 +79,13 @@ __global__ void KeLinearInterpBw(T* in,
: ratio_w * out_img_idx; : ratio_w * out_img_idx;
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; // w in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; // w
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; // w_id int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; // w_id
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5; T src_w = static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5);
src_w = (src_w > 0) ? src_w : 0; src_w = (src_w > static_cast<T>(0)) ? src_w : static_cast<T>(0);
T w1lambda = T w1lambda = align_flag
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx; ? static_cast<T>(static_cast<MT>(src_w) - in_img_idx)
T w2lambda = 1.f - w1lambda; : static_cast<T>(ratio_w * out_img_idx - in_img_idx);
T w2lambda = static_cast<T>(1.0) - w1lambda;
T* in_pos; T* in_pos;
if (data_layout == DataLayout::kNCHW) { if (data_layout == DataLayout::kNCHW) {
...@@ -245,7 +248,7 @@ __global__ void KeBilinearInterpBwShareMemory(T* in, ...@@ -245,7 +248,7 @@ __global__ void KeBilinearInterpBwShareMemory(T* in,
const int num_channels, const int num_channels,
float ratio_h, float ratio_h,
float ratio_w, float ratio_w,
const T align_type_value, const float align_type_value,
bool is_nchw) { bool is_nchw) {
__shared__ T s_data[2][1024]; __shared__ T s_data[2][1024];
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -267,8 +270,10 @@ __global__ void KeBilinearInterpBwShareMemory(T* in, ...@@ -267,8 +270,10 @@ __global__ void KeBilinearInterpBwShareMemory(T* in,
int in_img_idx, in_img_idy, w_id, h_id; int in_img_idx, in_img_idy, w_id, h_id;
T w1lambda, h1lambda, w2lambda, h2lambda; T w1lambda, h1lambda, w2lambda, h2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value; T src_w = static_cast<T>(ratio_w * (out_img_idx + align_type_value) -
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value; align_type_value);
T src_h = static_cast<T>(ratio_h * (out_img_idy + align_type_value) -
align_type_value);
PreCalculatorForLinearInterpInputIndex( PreCalculatorForLinearInterpInputIndex(
&in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_w); &in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_w);
...@@ -283,8 +288,8 @@ __global__ void KeBilinearInterpBwShareMemory(T* in, ...@@ -283,8 +288,8 @@ __global__ void KeBilinearInterpBwShareMemory(T* in,
int bot_right_index = input_index + h_id * in_w + w_id; int bot_right_index = input_index + h_id * in_w + w_id;
int in_top_min_index, in_bot_min_index; int in_top_min_index, in_bot_min_index;
s_data[0][threadIdx.x] = 0.f; s_data[0][threadIdx.x] = static_cast<T>(0);
s_data[1][threadIdx.x] = 0.f; s_data[1][threadIdx.x] = static_cast<T>(0);
int remain = nthreads - (tid & (-blockDim.x)); int remain = nthreads - (tid & (-blockDim.x));
int in_top_max_index = int in_top_max_index =
phi::funcs::blockReduceMax(top_right_index, FINAL_MASK); phi::funcs::blockReduceMax(top_right_index, FINAL_MASK);
...@@ -353,7 +358,7 @@ __global__ void KeBilinearInterpNCHWBw(T* in, ...@@ -353,7 +358,7 @@ __global__ void KeBilinearInterpNCHWBw(T* in,
float ratio_h, float ratio_h,
float ratio_w, float ratio_w,
const T* __restrict__ out, const T* __restrict__ out,
const T align_type_value) { const float align_type_value) {
int index = threadIdx.x + blockDim.x * blockIdx.x; int index = threadIdx.x + blockDim.x * blockIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
int num_out = n * num_channels * out_h * out_w; int num_out = n * num_channels * out_h * out_w;
...@@ -368,13 +373,15 @@ __global__ void KeBilinearInterpNCHWBw(T* in, ...@@ -368,13 +373,15 @@ __global__ void KeBilinearInterpNCHWBw(T* in,
int h1, y_id; int h1, y_id;
T h1lambda, h0lambda; T h1lambda, h0lambda;
T src_y = ratio_h * (h2 + align_type_value) - align_type_value; T src_y =
static_cast<T>(ratio_h * (h2 + align_type_value) - align_type_value);
PreCalculatorForLinearInterpInputIndex( PreCalculatorForLinearInterpInputIndex(
&h1, &y_id, &h1lambda, &h0lambda, src_y, in_h); &h1, &y_id, &h1lambda, &h0lambda, src_y, in_h);
int w1, x_id; int w1, x_id;
T w1lambda, w0lambda; T w1lambda, w0lambda;
T src_x = ratio_w * (w2 + align_type_value) - align_type_value; T src_x =
static_cast<T>(ratio_w * (w2 + align_type_value) - align_type_value);
PreCalculatorForLinearInterpInputIndex( PreCalculatorForLinearInterpInputIndex(
&w1, &x_id, &w1lambda, &w0lambda, src_x, in_w); &w1, &x_id, &w1lambda, &w0lambda, src_x, in_w);
...@@ -406,7 +413,7 @@ __global__ void KeBilinearInterpBw(T* in, ...@@ -406,7 +413,7 @@ __global__ void KeBilinearInterpBw(T* in,
const int num_channels, const int num_channels,
float ratio_h, float ratio_h,
float ratio_w, float ratio_w,
const T align_type_value, const float align_type_value,
funcs::FastDivModForInterpolate divmods) { funcs::FastDivModForInterpolate divmods) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
...@@ -426,8 +433,10 @@ __global__ void KeBilinearInterpBw(T* in, ...@@ -426,8 +433,10 @@ __global__ void KeBilinearInterpBw(T* in,
int in_img_idx, in_img_idy, w_id, h_id; int in_img_idx, in_img_idy, w_id, h_id;
T w1lambda, h1lambda, w2lambda, h2lambda; T w1lambda, h1lambda, w2lambda, h2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value; T src_w = static_cast<T>(ratio_w * (out_img_idx + align_type_value) -
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value; align_type_value);
T src_h = static_cast<T>(ratio_h * (out_img_idy + align_type_value) -
align_type_value);
PreCalculatorForLinearInterpInputIndex( PreCalculatorForLinearInterpInputIndex(
&in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_w); &in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_w);
...@@ -489,14 +498,13 @@ __global__ void KeBicubicInterpBw(T* in, ...@@ -489,14 +498,13 @@ __global__ void KeBicubicInterpBw(T* in,
? static_cast<T>(ratio_h * out_img_idy) ? static_cast<T>(ratio_h * out_img_idy)
: static_cast<T>(ratio_h * (out_img_idy + 0.5) - 0.5); : static_cast<T>(ratio_h * (out_img_idy + 0.5) - 0.5);
int input_y = floorf(in_img_idy); int input_y = floorf(in_img_idy);
const T y_t = in_img_idy - input_y; using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const T y_t = static_cast<T>(static_cast<MT>(in_img_idy) - input_y);
T in_img_idx = align_corners T in_img_idx = align_corners
? static_cast<T>(ratio_w * out_img_idx) ? static_cast<T>(ratio_w * out_img_idx)
: static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5); : static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5);
int input_x = floorf(in_img_idx); int input_x = floorf(in_img_idx);
const T x_t = static_cast<T>(static_cast<MT>(in_img_idx) - input_x);
const T x_t = in_img_idx - input_x;
T x_coeffs[4]; T x_coeffs[4];
T y_coeffs[4]; T y_coeffs[4];
...@@ -543,9 +551,9 @@ __global__ void KeTrilinearInterpBw(T* in, ...@@ -543,9 +551,9 @@ __global__ void KeTrilinearInterpBw(T* in,
const size_t output_h, const size_t output_h,
const size_t output_w, const size_t output_w,
const size_t num_channels, const size_t num_channels,
const T ratio_d, const float ratio_d,
const T ratio_h, const float ratio_h,
const T ratio_w, const float ratio_w,
const bool align_corners, const bool align_corners,
const int align_mode, const int align_mode,
const DataLayout data_layout) { const DataLayout data_layout) {
...@@ -578,33 +586,37 @@ __global__ void KeTrilinearInterpBw(T* in, ...@@ -578,33 +586,37 @@ __global__ void KeTrilinearInterpBw(T* in,
: static_cast<int>(ratio_d * out_img_idt); : static_cast<int>(ratio_d * out_img_idt);
in_img_idt = (in_img_idt > 0) ? in_img_idt : 0; in_img_idt = (in_img_idt > 0) ? in_img_idt : 0;
int d_id = (in_img_idt < in_img_d - 1) ? 1 : 0; int d_id = (in_img_idt < in_img_d - 1) ? 1 : 0;
T src_d = ratio_d * (out_img_idt + 0.5) - 0.5; T src_d = static_cast<T>(ratio_d * (out_img_idt + 0.5) - 0.5);
src_d = (src_d > 0) ? src_d : 0; src_d = (src_d > static_cast<T>(0)) ? src_d : static_cast<T>(0);
T d1lambda = using MT = typename phi::dtype::MPTypeTrait<T>::Type;
align_flag ? src_d - in_img_idt : ratio_d * out_img_idt - in_img_idt; T d1lambda = align_flag
T d2lambda = 1.f - d1lambda; ? static_cast<T>(static_cast<MT>(src_d) - in_img_idt)
: static_cast<T>(ratio_d * out_img_idt - in_img_idt);
T d2lambda = static_cast<T>(1.0) - d1lambda;
int in_img_idy = align_flag int in_img_idy = align_flag
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5) ? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
: static_cast<int>(ratio_h * out_img_idy); : static_cast<int>(ratio_h * out_img_idy);
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0; in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T src_h = ratio_h * (out_img_idy + 0.5) - 0.5; T src_h = static_cast<T>(ratio_h * (out_img_idy + 0.5) - 0.5);
src_h = (src_h > 0) ? src_h : 0; src_h = (src_h > static_cast<T>(0)) ? src_h : static_cast<T>(0);
T h1lambda = T h1lambda = align_flag
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy; ? static_cast<T>(static_cast<MT>(src_h) - in_img_idy)
T h2lambda = 1.f - h1lambda; : static_cast<T>(ratio_h * out_img_idy - in_img_idy);
T h2lambda = static_cast<T>(1.0) - h1lambda;
int in_img_idx = align_flag int in_img_idx = align_flag
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5) ? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
: static_cast<int>(ratio_w * out_img_idx); : static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5; T src_w = static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5);
src_w = (src_w > 0) ? src_w : 0; src_w = (src_w > static_cast<T>(0)) ? src_w : static_cast<T>(0);
T w1lambda = T w1lambda = align_flag
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx; ? static_cast<T>(static_cast<MT>(src_w) - in_img_idx)
T w2lambda = 1.f - w1lambda; : static_cast<T>(ratio_w * out_img_idx - in_img_idx);
T w2lambda = static_cast<T>(1.0) - w1lambda;
if (data_layout == DataLayout::kNCHW) { if (data_layout == DataLayout::kNCHW) {
int in_pos1_idx = out_id_h * input_w + channel_id * in_img_size + int in_pos1_idx = out_id_h * input_w + channel_id * in_img_size +
...@@ -1031,7 +1043,8 @@ static void Interpolate2DCUDABwd( ...@@ -1031,7 +1043,8 @@ static void Interpolate2DCUDABwd(
interp_divmods); interp_divmods);
} }
} else if ("bilinear" == interp_method) { } else if ("bilinear" == interp_method) {
const T align_type_value = (align_mode == 0 && !align_corners) ? 0.5f : 0; const float align_type_value =
(align_mode == 0 && !align_corners) ? 0.5f : 0.f;
bool is_nchw = (data_layout == DataLayout::kNCHW) ? true : false; bool is_nchw = (data_layout == DataLayout::kNCHW) ? true : false;
bool optimize_flag = false; bool optimize_flag = false;
#ifndef __HIPCC__ #ifndef __HIPCC__
...@@ -1148,7 +1161,7 @@ static void Interpolate3DCUDABwd( ...@@ -1148,7 +1161,7 @@ static void Interpolate3DCUDABwd(
if (scale_tensor) { if (scale_tensor) {
auto scale_data = auto scale_data =
funcs::get_new_data_from_tensor<float>(scale_tensor.get_ptr()); funcs::get_new_data_from_tensor<float>(scale_tensor.get_ptr());
if (scale_data.size() > 1) { if (scale_data.size() > 2) {
scale_d = scale_data[0]; scale_d = scale_data[0];
scale_h = scale_data[1]; scale_h = scale_data[1];
scale_w = scale_data[2]; scale_w = scale_data[2];
...@@ -1179,7 +1192,7 @@ static void Interpolate3DCUDABwd( ...@@ -1179,7 +1192,7 @@ static void Interpolate3DCUDABwd(
"should be greater than 0, but received value is %d.", "should be greater than 0, but received value is %d.",
scale_d)); scale_d));
} else { } else {
if (scale.size() > 1) { if (scale.size() > 2) {
scale_d = scale[0]; scale_d = scale[0];
scale_h = scale[1]; scale_h = scale[1];
scale_w = scale[2]; scale_w = scale[2];
...@@ -1574,7 +1587,8 @@ PD_REGISTER_KERNEL(bilinear_interp_grad, ...@@ -1574,7 +1587,8 @@ PD_REGISTER_KERNEL(bilinear_interp_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::BilinearInterpGradKernel, phi::BilinearInterpGradKernel,
float, float,
double) { double,
phi::dtype::float16) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
} }
...@@ -1583,7 +1597,8 @@ PD_REGISTER_KERNEL(nearest_interp_grad, ...@@ -1583,7 +1597,8 @@ PD_REGISTER_KERNEL(nearest_interp_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::NearestInterpGradKernel, phi::NearestInterpGradKernel,
float, float,
double) { double,
phi::dtype::float16) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
} }
...@@ -1592,7 +1607,8 @@ PD_REGISTER_KERNEL(trilinear_interp_grad, ...@@ -1592,7 +1607,8 @@ PD_REGISTER_KERNEL(trilinear_interp_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::TrilinearInterpGradKernel, phi::TrilinearInterpGradKernel,
float, float,
double) { double,
phi::dtype::float16) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
} }
...@@ -1601,7 +1617,8 @@ PD_REGISTER_KERNEL(linear_interp_grad, ...@@ -1601,7 +1617,8 @@ PD_REGISTER_KERNEL(linear_interp_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::LinearInterpGradKernel, phi::LinearInterpGradKernel,
float, float,
double) { double,
phi::dtype::float16) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
} }
...@@ -1610,7 +1627,8 @@ PD_REGISTER_KERNEL(bicubic_interp_grad, ...@@ -1610,7 +1627,8 @@ PD_REGISTER_KERNEL(bicubic_interp_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::BicubicInterpGradKernel, phi::BicubicInterpGradKernel,
float, float,
double) { double,
phi::dtype::float16) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
} }
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
#include "paddle/fluid/platform/fast_divmod.h" #include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/layout.h" #include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/interpolate_function.h" #include "paddle/phi/kernels/funcs/interpolate_function.h"
...@@ -34,11 +36,12 @@ __forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex( ...@@ -34,11 +36,12 @@ __forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex(
T* lambda2, T* lambda2,
T src_x, T src_x,
const int in_img_x) { const int in_img_x) {
src_x = (src_x > 0) ? src_x : 0.f; src_x = (src_x > static_cast<T>(0)) ? src_x : static_cast<T>(0);
*in_img_idx = static_cast<int>(src_x); *in_img_idx = static_cast<int>(src_x);
*x_id = (*in_img_idx < in_img_x - 1) ? 1 : 0; *x_id = (*in_img_idx < in_img_x - 1) ? 1 : 0;
*lambda1 = src_x - *in_img_idx; using MT = typename phi::dtype::MPTypeTrait<T>::Type;
*lambda2 = 1.f - *lambda1; *lambda1 = static_cast<T>(static_cast<MT>(src_x) - *in_img_idx);
*lambda2 = static_cast<T>(1.0) - *lambda1;
} }
template <typename T> template <typename T>
...@@ -78,12 +81,13 @@ __global__ void KeLinearInterpFw(const T* in, ...@@ -78,12 +81,13 @@ __global__ void KeLinearInterpFw(const T* in,
: static_cast<int>(ratio_w * out_img_idx); : static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; // w in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; // w
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; // w_id int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; // w_id
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5; T src_w = static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5);
src_w = (src_w > 0) ? src_w : 0; src_w = (src_w > static_cast<T>(0)) ? src_w : static_cast<T>(0);
T w1lambda = T w1lambda = align_flag
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx; ? static_cast<T>(static_cast<MT>(src_w) - in_img_idx)
T w2lambda = 1.f - w1lambda; : static_cast<T>(ratio_w * out_img_idx - in_img_idx);
T w2lambda = static_cast<T>(1.0) - w1lambda;
if (data_layout == DataLayout::kNCHW) { if (data_layout == DataLayout::kNCHW) {
const T* in_pos = const T* in_pos =
...@@ -203,7 +207,7 @@ __global__ void KeBilinearInterpFw(const T* in, ...@@ -203,7 +207,7 @@ __global__ void KeBilinearInterpFw(const T* in,
const size_t num_channels, const size_t num_channels,
const float ratio_h, const float ratio_h,
const float ratio_w, const float ratio_w,
const T align_type_value, const float align_type_value,
funcs::FastDivModForInterpolate divmods) { funcs::FastDivModForInterpolate divmods) {
int nthreads = output_h * output_w; int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -222,8 +226,10 @@ __global__ void KeBilinearInterpFw(const T* in, ...@@ -222,8 +226,10 @@ __global__ void KeBilinearInterpFw(const T* in,
int in_img_idx, in_img_idy, h_id, w_id; int in_img_idx, in_img_idy, h_id, w_id;
T h1lambda, w1lambda, h2lambda, w2lambda; T h1lambda, w1lambda, h2lambda, w2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value; T src_w = static_cast<T>(ratio_w * (out_img_idx + align_type_value) -
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value; align_type_value);
T src_h = static_cast<T>(ratio_h * (out_img_idy + align_type_value) -
align_type_value);
PreCalculatorForLinearInterpInputIndex( PreCalculatorForLinearInterpInputIndex(
&in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_img_w); &in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_img_w);
...@@ -254,7 +260,7 @@ __global__ void KeBilinearInterpNCHWFw(const T* in, ...@@ -254,7 +260,7 @@ __global__ void KeBilinearInterpNCHWFw(const T* in,
const size_t nc, const size_t nc,
const float ratio_h, const float ratio_h,
const float ratio_w, const float ratio_w,
const T align_type_value) { const float align_type_value) {
int out_img_idx = threadIdx.x + blockIdx.x * blockDim.x; int out_img_idx = threadIdx.x + blockIdx.x * blockDim.x;
int out_img_idy = threadIdx.y + blockIdx.y * blockDim.y; int out_img_idy = threadIdx.y + blockIdx.y * blockDim.y;
int nc_id = threadIdx.z + blockIdx.z * blockDim.z; int nc_id = threadIdx.z + blockIdx.z * blockDim.z;
...@@ -262,8 +268,10 @@ __global__ void KeBilinearInterpNCHWFw(const T* in, ...@@ -262,8 +268,10 @@ __global__ void KeBilinearInterpNCHWFw(const T* in,
int in_img_idx, in_img_idy, h_id, w_id; int in_img_idx, in_img_idy, h_id, w_id;
T h1lambda, w1lambda, h2lambda, w2lambda; T h1lambda, w1lambda, h2lambda, w2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value; T src_w = static_cast<T>(ratio_w * (out_img_idx + align_type_value) -
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value; align_type_value);
T src_h = static_cast<T>(ratio_h * (out_img_idy + align_type_value) -
align_type_value);
PreCalculatorForLinearInterpInputIndex( PreCalculatorForLinearInterpInputIndex(
&in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_img_w); &in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_img_w);
...@@ -296,13 +304,13 @@ template <typename T> ...@@ -296,13 +304,13 @@ template <typename T>
__device__ __forceinline__ static T Kecubic_interp( __device__ __forceinline__ static T Kecubic_interp(
const T x0, const T x1, const T x2, const T x3, T t) { const T x0, const T x1, const T x2, const T x3, T t) {
T coeffs[4]; T coeffs[4];
T a = -0.75; T a = static_cast<T>(-0.75);
T x_1 = t; T x_1 = t;
T x_2 = 1.0 - t; T x_2 = static_cast<T>(1.0) - t;
coeffs[0] = funcs::CubicConvolution2<T>(x_1 + 1.0, a); coeffs[0] = funcs::CubicConvolution2<T>(x_1 + static_cast<T>(1.0), a);
coeffs[1] = funcs::CubicConvolution1<T>(x_1, a); coeffs[1] = funcs::CubicConvolution1<T>(x_1, a);
coeffs[2] = funcs::CubicConvolution1<T>(x_2, a); coeffs[2] = funcs::CubicConvolution1<T>(x_2, a);
coeffs[3] = funcs::CubicConvolution2<T>(x_2 + 1.0, a); coeffs[3] = funcs::CubicConvolution2<T>(x_2 + static_cast<T>(1.0), a);
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
} }
...@@ -348,13 +356,14 @@ __global__ void KeBicubicInterpFw(const T* in, ...@@ -348,13 +356,14 @@ __global__ void KeBicubicInterpFw(const T* in,
? static_cast<T>(ratio_h * out_img_idy) ? static_cast<T>(ratio_h * out_img_idy)
: static_cast<T>(ratio_h * (out_img_idy + 0.5) - 0.5); : static_cast<T>(ratio_h * (out_img_idy + 0.5) - 0.5);
int input_y = floorf(in_img_idy); int input_y = floorf(in_img_idy);
const T y_t = in_img_idy - input_y; using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const T y_t = static_cast<T>(static_cast<MT>(in_img_idy) - input_y);
T in_img_idx = align_corners T in_img_idx = align_corners
? static_cast<T>(ratio_w * out_img_idx) ? static_cast<T>(ratio_w * out_img_idx)
: static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5); : static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5);
int input_x = floorf(in_img_idx); int input_x = floorf(in_img_idx);
const T x_t = in_img_idx - input_x; const T x_t = static_cast<T>(static_cast<MT>(in_img_idx) - input_x);
T coefficients[4]; T coefficients[4];
const T* in_pos_0; const T* in_pos_0;
...@@ -419,16 +428,15 @@ __global__ void KeBicubicInterpFw(const T* in, ...@@ -419,16 +428,15 @@ __global__ void KeBicubicInterpFw(const T* in,
&in[out_id_h * input_w + access_y * in_img_w * num_channels + &in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_3 * num_channels + channel_id]; access_x_3 * num_channels + channel_id];
coefficients[k] = Kecubic_interp( coefficients[k] = Kecubic_interp<T>(
in_pos_0[0], in_pos_1[0], in_pos_2[0], in_pos_3[0], x_t); in_pos_0[0], in_pos_1[0], in_pos_2[0], in_pos_3[0], x_t);
} }
out[out_id_h * output_w + out_id_w] = out[out_id_h * output_w + out_id_w] = Kecubic_interp<T>(coefficients[0],
static_cast<T>(Kecubic_interp(coefficients[0], coefficients[1],
coefficients[1], coefficients[2],
coefficients[2], coefficients[3],
coefficients[3], y_t);
y_t));
} }
} }
} }
...@@ -482,33 +490,37 @@ __global__ void KeTrilinearInterpFw(const T* in, ...@@ -482,33 +490,37 @@ __global__ void KeTrilinearInterpFw(const T* in,
: static_cast<int>(ratio_d * out_img_idt); : static_cast<int>(ratio_d * out_img_idt);
in_img_idt = (in_img_idt > 0) ? in_img_idt : 0; in_img_idt = (in_img_idt > 0) ? in_img_idt : 0;
int d_id = (in_img_idt < in_img_d - 1) ? 1 : 0; int d_id = (in_img_idt < in_img_d - 1) ? 1 : 0;
T src_d = ratio_d * (out_img_idt + 0.5) - 0.5; using MT = typename phi::dtype::MPTypeTrait<T>::Type;
src_d = (src_d > 0) ? src_d : 0; T src_d = static_cast<T>(ratio_d * (out_img_idt + 0.5) - 0.5);
T d1lambda = src_d = (src_d > static_cast<T>(0)) ? src_d : static_cast<T>(0);
align_flag ? src_d - in_img_idt : ratio_d * out_img_idt - in_img_idt; T d1lambda = align_flag
T d2lambda = 1.f - d1lambda; ? static_cast<T>(static_cast<MT>(src_d) - in_img_idt)
: static_cast<T>(ratio_d * out_img_idt - in_img_idt);
T d2lambda = static_cast<T>(1.0) - d1lambda;
int in_img_idy = align_flag int in_img_idy = align_flag
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5) ? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
: static_cast<int>(ratio_h * out_img_idy); : static_cast<int>(ratio_h * out_img_idy);
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0; in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T src_h = ratio_h * (out_img_idy + 0.5) - 0.5; T src_h = static_cast<T>(ratio_h * (out_img_idy + 0.5) - 0.5);
src_h = (src_h > 0) ? src_h : 0; src_h = (src_h > static_cast<T>(0)) ? src_h : static_cast<T>(0);
T h1lambda = T h1lambda = align_flag
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy; ? static_cast<T>(static_cast<MT>(src_h) - in_img_idy)
T h2lambda = 1.f - h1lambda; : static_cast<T>(ratio_h * out_img_idy - in_img_idy);
T h2lambda = static_cast<T>(1.0) - h1lambda;
int in_img_idx = align_flag int in_img_idx = align_flag
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5) ? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
: static_cast<int>(ratio_w * out_img_idx); : static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5; T src_w = static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5);
src_w = (src_w > 0) ? src_w : 0; src_w = (src_w > static_cast<T>(0)) ? src_w : static_cast<T>(0);
T w1lambda = T w1lambda = align_flag
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx; ? static_cast<T>(static_cast<MT>(src_w) - in_img_idx)
T w2lambda = 1.f - w1lambda; : static_cast<T>(ratio_w * out_img_idx - in_img_idx);
T w2lambda = static_cast<T>(1.0) - w1lambda;
if (data_layout == DataLayout::kNCHW) { if (data_layout == DataLayout::kNCHW) {
int in_pos1_idx = out_id_h * input_w + channel_id * in_img_size + int in_pos1_idx = out_id_h * input_w + channel_id * in_img_size +
...@@ -926,7 +938,8 @@ static void Interpolate2DCUDAFwd( ...@@ -926,7 +938,8 @@ static void Interpolate2DCUDAFwd(
thread_num = 512; thread_num = 512;
} }
#endif #endif
const T align_type_value = (align_mode == 0 && !align_corners) ? 0.5f : 0; const float align_type_value =
(align_mode == 0 && !align_corners) ? 0.5f : 0.f;
if (data_layout == DataLayout::kNCHW) { if (data_layout == DataLayout::kNCHW) {
// get launch 3D config // get launch 3D config
int nc = n * c; int nc = n * c;
...@@ -1028,7 +1041,7 @@ static void Interpolate3DCUDAFwd( ...@@ -1028,7 +1041,7 @@ static void Interpolate3DCUDAFwd(
if (scale_tensor) { if (scale_tensor) {
auto scale_data = auto scale_data =
funcs::get_new_data_from_tensor<float>(scale_tensor.get_ptr()); funcs::get_new_data_from_tensor<float>(scale_tensor.get_ptr());
if (scale_data.size() > 1) { if (scale_data.size() > 2) {
scale_d = scale_data[0]; scale_d = scale_data[0];
scale_h = scale_data[1]; scale_h = scale_data[1];
scale_w = scale_data[2]; scale_w = scale_data[2];
...@@ -1060,7 +1073,7 @@ static void Interpolate3DCUDAFwd( ...@@ -1060,7 +1073,7 @@ static void Interpolate3DCUDAFwd(
"should be greater than 0, but received value is %d.", "should be greater than 0, but received value is %d.",
scale_d)); scale_d));
} else { } else {
if (scale.size() > 1) { if (scale.size() > 2) {
scale_d = scale[0]; scale_d = scale[0];
scale_h = scale[1]; scale_h = scale[1];
scale_w = scale[2]; scale_w = scale[2];
...@@ -1446,6 +1459,7 @@ PD_REGISTER_KERNEL(bilinear_interp, ...@@ -1446,6 +1459,7 @@ PD_REGISTER_KERNEL(bilinear_interp,
phi::BilinearInterpKernel, phi::BilinearInterpKernel,
float, float,
double, double,
phi::dtype::float16,
int) { int) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
...@@ -1456,6 +1470,7 @@ PD_REGISTER_KERNEL(nearest_interp, ...@@ -1456,6 +1470,7 @@ PD_REGISTER_KERNEL(nearest_interp,
phi::NearestInterpKernel, phi::NearestInterpKernel,
float, float,
double, double,
phi::dtype::float16,
int, int,
int64_t) { int64_t) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
...@@ -1467,6 +1482,7 @@ PD_REGISTER_KERNEL(trilinear_interp, ...@@ -1467,6 +1482,7 @@ PD_REGISTER_KERNEL(trilinear_interp,
phi::TrilinearInterpKernel, phi::TrilinearInterpKernel,
float, float,
double, double,
phi::dtype::float16,
int) { int) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
...@@ -1477,6 +1493,7 @@ PD_REGISTER_KERNEL(linear_interp, ...@@ -1477,6 +1493,7 @@ PD_REGISTER_KERNEL(linear_interp,
phi::LinearInterpKernel, phi::LinearInterpKernel,
float, float,
double, double,
phi::dtype::float16,
int) { int) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
...@@ -1487,6 +1504,7 @@ PD_REGISTER_KERNEL(bicubic_interp, ...@@ -1487,6 +1504,7 @@ PD_REGISTER_KERNEL(bicubic_interp,
phi::BicubicInterpKernel, phi::BicubicInterpKernel,
float, float,
double, double,
phi::dtype::float16,
int) { int) {
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
......
...@@ -622,6 +622,44 @@ class TestBicubicOpError(unittest.TestCase): ...@@ -622,6 +622,44 @@ class TestBicubicOpError(unittest.TestCase):
self.test_imperative_errors() self.test_imperative_errors()
@unittest.skipIf(not fluid.core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestBicubicInterpOpForFloat16(unittest.TestCase):
def init_test_case(self):
self.interp_method = 'bicubic'
self.input_shape = [2, 3, 5, 5]
self.out_size = np.array([3, 3]).astype("int32")
self.align_corners = True
self.data_layout = 'NCHW'
def check_main(self, x_np, dtype):
paddle.disable_static()
x_np = x_np.astype(dtype)
x = paddle.to_tensor(x_np)
x.stop_gradient = False
y = interpolate(x,
size=self.out_size.tolist(),
mode=self.interp_method,
align_corners=self.align_corners,
data_format=self.data_layout)
x_g = paddle.grad(y, x)
y_np = y[0].numpy().astype('float32')
x_g_np = x_g[0].numpy().astype('float32')
paddle.enable_static()
return y_np, x_g_np
def test_main(self):
self.init_test_case()
x_np = np.random.random(self.input_shape).astype("float16")
y_np_1, x_g_np_1 = self.check_main(x_np, 'float16')
y_np_2, x_g_np_2 = self.check_main(x_np, 'float32')
np.testing.assert_allclose(y_np_1, y_np_2)
np.testing.assert_allclose(x_g_np_1, x_g_np_2)
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
...@@ -766,5 +766,45 @@ class TestBilinearInterpOpAPI_dy4(unittest.TestCase): ...@@ -766,5 +766,45 @@ class TestBilinearInterpOpAPI_dy4(unittest.TestCase):
np.testing.assert_allclose(out.numpy(), expect_res, rtol=1e-05) np.testing.assert_allclose(out.numpy(), expect_res, rtol=1e-05)
@unittest.skipIf(not fluid.core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestBilinearInterpOpForFloat16(unittest.TestCase):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 5, 5]
self.out_size = np.array([3, 3]).astype("int32")
self.align_corners = True
self.align_mode = 1
self.data_layout = 'NCHW'
def check_main(self, x_np, dtype):
paddle.disable_static()
x_np = x_np.astype(dtype)
x = paddle.to_tensor(x_np)
x.stop_gradient = False
y = interpolate(x,
size=self.out_size.tolist(),
mode=self.interp_method,
align_mode=self.align_mode,
align_corners=self.align_corners,
data_format=self.data_layout)
x_g = paddle.grad(y, x)
y_np = y[0].numpy().astype('float32')
x_g_np = x_g[0].numpy().astype('float32')
paddle.enable_static()
return y_np, x_g_np
def test_main(self):
self.init_test_case()
x_np = np.random.random(self.input_shape).astype("float16")
y_np_1, x_g_np_1 = self.check_main(x_np, 'float16')
y_np_2, x_g_np_2 = self.check_main(x_np, 'float32')
np.testing.assert_allclose(y_np_1, y_np_2)
np.testing.assert_allclose(x_g_np_1, x_g_np_2)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -376,9 +376,7 @@ class TestLinearInterpOpAPI2_0(unittest.TestCase): ...@@ -376,9 +376,7 @@ class TestLinearInterpOpAPI2_0(unittest.TestCase):
# dygraph # dygraph
x_data = np.random.random((1, 3, 128)).astype("float32") x_data = np.random.random((1, 3, 128)).astype("float32")
us_1 = paddle.nn.Upsample(size=[ us_1 = paddle.nn.Upsample(size=[64],
64,
],
mode='linear', mode='linear',
align_mode=1, align_mode=1,
align_corners=False, align_corners=False,
...@@ -493,28 +491,21 @@ class TestLinearInterpOpError(unittest.TestCase): ...@@ -493,28 +491,21 @@ class TestLinearInterpOpError(unittest.TestCase):
def input_shape_error(): def input_shape_error():
x1 = fluid.data(name="x1", shape=[1], dtype="float32") x1 = fluid.data(name="x1", shape=[1], dtype="float32")
out1 = paddle.nn.Upsample(size=[ out1 = paddle.nn.Upsample(size=[256],
256,
],
data_format='NCW', data_format='NCW',
mode='linear') mode='linear')
out1_res = out1(x1) out1_res = out1(x1)
def data_format_error(): def data_format_error():
x2 = fluid.data(name="x2", shape=[1, 3, 128], dtype="float32") x2 = fluid.data(name="x2", shape=[1, 3, 128], dtype="float32")
out2 = paddle.nn.Upsample(size=[ out2 = paddle.nn.Upsample(size=[256],
256,
],
data_format='NHWCD', data_format='NHWCD',
mode='linear') mode='linear')
out2_res = out2(x2) out2_res = out2(x2)
def out_shape_error(): def out_shape_error():
x3 = fluid.data(name="x3", shape=[1, 3, 128], dtype="float32") x3 = fluid.data(name="x3", shape=[1, 3, 128], dtype="float32")
out3 = paddle.nn.Upsample(size=[ out3 = paddle.nn.Upsample(size=[256, 256],
256,
256,
],
data_format='NHWC', data_format='NHWC',
mode='linear') mode='linear')
out3_res = out3(x3) out3_res = out3(x3)
...@@ -524,5 +515,46 @@ class TestLinearInterpOpError(unittest.TestCase): ...@@ -524,5 +515,46 @@ class TestLinearInterpOpError(unittest.TestCase):
self.assertRaises(ValueError, out_shape_error) self.assertRaises(ValueError, out_shape_error)
@unittest.skipIf(not fluid.core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestLinearInterpOpForFloat16(unittest.TestCase):
def init_test_case(self):
self.interp_method = 'linear'
self.input_shape = [1, 3, 64]
self.scale = 2
self.align_corners = False
self.align_mode = 1
self.data_layout = 'NCW'
def check_main(self, x_np, dtype):
paddle.disable_static()
x_np = x_np.astype(dtype)
x = paddle.to_tensor(x_np)
x.stop_gradient = False
y = interpolate(x,
scale_factor=self.scale,
mode=self.interp_method,
align_mode=self.align_mode,
align_corners=self.align_corners,
data_format=self.data_layout)
x_g = paddle.grad(y, x)
y_np = y[0].numpy().astype('float32')
x_g_np = x_g[0].numpy().astype('float32')
paddle.enable_static()
return y_np, x_g_np
def test_main(self):
self.init_test_case()
x_np = np.random.random(self.input_shape).astype("float16")
y_np_1, x_g_np_1 = self.check_main(x_np, 'float16')
y_np_2, x_g_np_2 = self.check_main(x_np, 'float32')
# forward
np.testing.assert_allclose(y_np_1, y_np_2, rtol=1e-03)
# backward
np.testing.assert_allclose(x_g_np_1, x_g_np_2)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -802,5 +802,81 @@ class TestNearestInterpException(unittest.TestCase): ...@@ -802,5 +802,81 @@ class TestNearestInterpException(unittest.TestCase):
self.assertRaises(ValueError, mode_error) self.assertRaises(ValueError, mode_error)
@unittest.skipIf(not fluid.core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestNearestInterp3DOpForFloat16(unittest.TestCase):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 2, 6, 6, 6]
self.scale = [2, 2, 2]
self.align_corners = False
self.data_layout = 'NCDHW'
def check_main(self, x_np, dtype):
paddle.disable_static()
x_np = x_np.astype(dtype)
x = paddle.to_tensor(x_np)
x.stop_gradient = False
y = interpolate(x,
scale_factor=self.scale,
mode=self.interp_method,
align_corners=self.align_corners,
data_format=self.data_layout)
x_g = paddle.grad(y, x)
y_np = y[0].numpy().astype('float32')
x_g_np = x_g[0].numpy().astype('float32')
paddle.enable_static()
return y_np, x_g_np
def test_main(self):
self.init_test_case()
x_np = np.random.random(self.input_shape).astype("float16")
y_np_1, x_g_np_1 = self.check_main(x_np, 'float16')
y_np_2, x_g_np_2 = self.check_main(x_np, 'float32')
# forward
np.testing.assert_allclose(y_np_1, y_np_2, rtol=1e-03)
# backward
np.testing.assert_allclose(x_g_np_1, x_g_np_2)
@unittest.skipIf(not fluid.core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestNearestInterpOpForFloat16(unittest.TestCase):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 2, 6, 6]
self.scale = [2, 2]
self.align_corners = False
def check_main(self, x_np, dtype):
paddle.disable_static()
x_np = x_np.astype(dtype)
x = paddle.to_tensor(x_np)
x.stop_gradient = False
y = interpolate(x,
scale_factor=self.scale,
mode=self.interp_method,
align_corners=self.align_corners)
x_g = paddle.grad(y, x)
y_np = y[0].numpy().astype('float32')
x_g_np = x_g[0].numpy().astype('float32')
paddle.enable_static()
return y_np, x_g_np
def test_main(self):
self.init_test_case()
x_np = np.random.random(self.input_shape).astype("float16")
y_np_1, x_g_np_1 = self.check_main(x_np, 'float16')
y_np_2, x_g_np_2 = self.check_main(x_np, 'float32')
# forward
np.testing.assert_allclose(y_np_1, y_np_2)
# backward
np.testing.assert_allclose(x_g_np_1, x_g_np_2)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -154,15 +154,15 @@ def trilinear_interp_np(input, ...@@ -154,15 +154,15 @@ def trilinear_interp_np(input,
out[:, :, i, j, k] = \ out[:, :, i, j, k] = \
d2lambda * \ d2lambda * \
(h2lambda * (w2lambda * input[:, :, d, h, w] + \ (h2lambda * (w2lambda * input[:, :, d, h, w] +
w1lambda * input[:, :, d, h, w+wid]) + \ w1lambda * input[:, :, d, h, w+wid]) +
h1lambda * (w2lambda * input[:, :, d, h+hid, w] + \ h1lambda * (w2lambda * input[:, :, d, h+hid, w] +
w1lambda * input[:, :, d, h+hid, w+wid])) + \ w1lambda * input[:, :, d, h+hid, w+wid])) + \
d1lambda * \ d1lambda * \
(h2lambda * (w2lambda * input[:, :, d+did, h, w] + \ (h2lambda * (w2lambda * input[:, :, d+did, h, w] +
w1lambda * input[:, :, d+did, h, w+wid]) + \ w1lambda * input[:, :, d+did, h, w+wid]) +
h1lambda * (w2lambda * input[:, :, d+did, h+hid, w] + \ h1lambda * (w2lambda * input[:, :, d+did, h+hid, w] +
w1lambda * input[:, :, d+did, h+hid, w+wid])) w1lambda * input[:, :, d+did, h+hid, w+wid]))
if data_layout == "NDHWC": if data_layout == "NDHWC":
out = np.transpose(out, (0, 2, 3, 4, 1)) # NCDHW => NDHWC out = np.transpose(out, (0, 2, 3, 4, 1)) # NCDHW => NDHWC
...@@ -809,5 +809,59 @@ class TestTrilinearInterpOpException(unittest.TestCase): ...@@ -809,5 +809,59 @@ class TestTrilinearInterpOpException(unittest.TestCase):
self.assertRaises(ValueError, attr_data_format) self.assertRaises(ValueError, attr_data_format)
@unittest.skipIf(not fluid.core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestTrilinearInterpOpForFloat16(unittest.TestCase):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 4, 4, 4]
self.out_size = np.array([3, 3, 3]).astype("int32")
self.align_corners = True
self.align_mode = 1
self.data_layout = 'NCDHW'
def check_main(self, x_np, dtype):
paddle.disable_static()
x_np = x_np.astype(dtype)
x = paddle.to_tensor(x_np)
x.stop_gradient = False
y = interpolate(x,
size=self.out_size.tolist(),
mode=self.interp_method,
align_corners=self.align_corners,
align_mode=self.align_mode,
data_format=self.data_layout)
x_g = paddle.grad(y, x)
y_np = y[0].numpy().astype('float32')
x_g_np = x_g[0].numpy().astype('float32')
paddle.enable_static()
return y_np, x_g_np
def test_main(self):
self.init_test_case()
x_np = np.random.random(self.input_shape).astype("float16")
y_np_1, x_g_np_1 = self.check_main(x_np, 'float16')
y_np_2, x_g_np_2 = self.check_main(x_np, 'float32')
# forward
np.testing.assert_allclose(y_np_1, y_np_2, rtol=1e-03)
# backward
np.testing.assert_allclose(x_g_np_1, x_g_np_2, rtol=1e-05)
@unittest.skipIf(not fluid.core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestTrilinearInterpDatalayoutForFloat16(TestTrilinearInterpOpForFloat16):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 4, 4, 4, 3]
self.out_size = np.array([3, 3, 3]).astype("int32")
self.align_corners = True
self.align_mode = 1
self.data_layout = "NDHWC"
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册