Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b12c27eb
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b12c27eb
编写于
9月 02, 2022
作者:
Y
Yuanle Liu
提交者:
GitHub
9月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
interpolate (forward grad) op support fp16 on gpu (#45061)
上级
cbf26bb1
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
403 addition
and
125 deletion
+403
-125
paddle/phi/kernels/funcs/interpolate_function.h
paddle/phi/kernels/funcs/interpolate_function.h
+8
-6
paddle/phi/kernels/gpu/interpolate_grad_kernel.cu
paddle/phi/kernels/gpu/interpolate_grad_kernel.cu
+69
-51
paddle/phi/kernels/gpu/interpolate_kernel.cu
paddle/phi/kernels/gpu/interpolate_kernel.cu
+64
-46
python/paddle/fluid/tests/unittests/test_bicubic_interp_v2_op.py
...paddle/fluid/tests/unittests/test_bicubic_interp_v2_op.py
+38
-0
python/paddle/fluid/tests/unittests/test_bilinear_interp_v2_op.py
...addle/fluid/tests/unittests/test_bilinear_interp_v2_op.py
+40
-0
python/paddle/fluid/tests/unittests/test_linear_interp_v2_op.py
.../paddle/fluid/tests/unittests/test_linear_interp_v2_op.py
+45
-13
python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py
...paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py
+77
-1
python/paddle/fluid/tests/unittests/test_trilinear_interp_v2_op.py
...ddle/fluid/tests/unittests/test_trilinear_interp_v2_op.py
+62
-8
未找到文件。
paddle/phi/kernels/funcs/interpolate_function.h
浏览文件 @
b12c27eb
...
@@ -28,26 +28,28 @@ namespace funcs {
...
@@ -28,26 +28,28 @@ namespace funcs {
template
<
typename
T
>
template
<
typename
T
>
HOSTDEVICE
inline
T
CubicConvolution1
(
T
x
,
T
A
)
{
HOSTDEVICE
inline
T
CubicConvolution1
(
T
x
,
T
A
)
{
return
((
A
+
2
)
*
x
-
(
A
+
3
))
*
x
*
x
+
1
;
return
((
A
+
static_cast
<
T
>
(
2
))
*
x
-
(
A
+
static_cast
<
T
>
(
3
)))
*
x
*
x
+
static_cast
<
T
>
(
1
);
}
}
template
<
typename
T
>
template
<
typename
T
>
HOSTDEVICE
inline
T
CubicConvolution2
(
T
x
,
T
A
)
{
HOSTDEVICE
inline
T
CubicConvolution2
(
T
x
,
T
A
)
{
return
((
A
*
x
-
5
*
A
)
*
x
+
8
*
A
)
*
x
-
4
*
A
;
return
((
A
*
x
-
static_cast
<
T
>
(
5
)
*
A
)
*
x
+
static_cast
<
T
>
(
8
)
*
A
)
*
x
-
static_cast
<
T
>
(
4
)
*
A
;
}
}
template
<
typename
T
>
template
<
typename
T
>
HOSTDEVICE
inline
void
get_cubic_upsample_coefficients
(
T
coeffs
[
4
],
T
t
)
{
HOSTDEVICE
inline
void
get_cubic_upsample_coefficients
(
T
coeffs
[
4
],
T
t
)
{
T
A
=
-
0.75
;
T
A
=
static_cast
<
T
>
(
-
0.75
)
;
T
x1
=
t
;
T
x1
=
t
;
coeffs
[
0
]
=
CubicConvolution2
<
T
>
(
x1
+
1.0
,
A
);
coeffs
[
0
]
=
CubicConvolution2
<
T
>
(
x1
+
static_cast
<
T
>
(
1.0
)
,
A
);
coeffs
[
1
]
=
CubicConvolution1
<
T
>
(
x1
,
A
);
coeffs
[
1
]
=
CubicConvolution1
<
T
>
(
x1
,
A
);
// opposite coefficients
// opposite coefficients
T
x2
=
1.0
-
t
;
T
x2
=
static_cast
<
T
>
(
1.0
)
-
t
;
coeffs
[
2
]
=
CubicConvolution1
<
T
>
(
x2
,
A
);
coeffs
[
2
]
=
CubicConvolution1
<
T
>
(
x2
,
A
);
coeffs
[
3
]
=
CubicConvolution2
<
T
>
(
x2
+
1.0
,
A
);
coeffs
[
3
]
=
CubicConvolution2
<
T
>
(
x2
+
static_cast
<
T
>
(
1.0
)
,
A
);
}
}
inline
void
ExtractNCDWH
(
const
DDim
&
dims
,
inline
void
ExtractNCDWH
(
const
DDim
&
dims
,
...
...
paddle/phi/kernels/gpu/interpolate_grad_kernel.cu
浏览文件 @
b12c27eb
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/interpolate_function.h"
#include "paddle/phi/kernels/funcs/interpolate_function.h"
...
@@ -34,11 +35,12 @@ __forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex(
...
@@ -34,11 +35,12 @@ __forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex(
T
*
lambda2
,
T
*
lambda2
,
T
src_x
,
T
src_x
,
const
int
in_img_x
)
{
const
int
in_img_x
)
{
src_x
=
(
src_x
>
0
)
?
src_x
:
0.
f
;
src_x
=
(
src_x
>
static_cast
<
T
>
(
0
))
?
src_x
:
static_cast
<
T
>
(
0
)
;
*
in_img_idx
=
static_cast
<
int
>
(
src_x
);
*
in_img_idx
=
static_cast
<
int
>
(
src_x
);
*
x_id
=
(
*
in_img_idx
<
in_img_x
-
1
)
?
1
:
0
;
*
x_id
=
(
*
in_img_idx
<
in_img_x
-
1
)
?
1
:
0
;
*
lambda1
=
src_x
-
*
in_img_idx
;
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
*
lambda2
=
1.
f
-
*
lambda1
;
*
lambda1
=
static_cast
<
T
>
(
static_cast
<
MT
>
(
src_x
)
-
*
in_img_idx
);
*
lambda2
=
static_cast
<
T
>
(
1.0
)
-
*
lambda1
;
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -50,7 +52,7 @@ __global__ void KeLinearInterpBw(T* in,
...
@@ -50,7 +52,7 @@ __global__ void KeLinearInterpBw(T* in,
const
size_t
output_h
,
const
size_t
output_h
,
const
size_t
output_w
,
const
size_t
output_w
,
const
size_t
num_channels
,
const
size_t
num_channels
,
const
T
ratio_w
,
const
float
ratio_w
,
const
bool
align_corners
,
const
bool
align_corners
,
const
int
align_mode
,
const
int
align_mode
,
const
DataLayout
data_layout
)
{
const
DataLayout
data_layout
)
{
...
@@ -77,12 +79,13 @@ __global__ void KeLinearInterpBw(T* in,
...
@@ -77,12 +79,13 @@ __global__ void KeLinearInterpBw(T* in,
:
ratio_w
*
out_img_idx
;
:
ratio_w
*
out_img_idx
;
in_img_idx
=
(
in_img_idx
>
0
)
?
in_img_idx
:
0
;
// w
in_img_idx
=
(
in_img_idx
>
0
)
?
in_img_idx
:
0
;
// w
int
w_id
=
(
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
// w_id
int
w_id
=
(
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
// w_id
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
T
src_w
=
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
;
T
src_w
=
static_cast
<
T
>
(
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
);
src_w
=
(
src_w
>
0
)
?
src_w
:
0
;
src_w
=
(
src_w
>
static_cast
<
T
>
(
0
))
?
src_w
:
static_cast
<
T
>
(
0
);
T
w1lambda
=
T
w1lambda
=
align_flag
align_flag
?
src_w
-
in_img_idx
:
ratio_w
*
out_img_idx
-
in_img_idx
;
?
static_cast
<
T
>
(
static_cast
<
MT
>
(
src_w
)
-
in_img_idx
)
T
w2lambda
=
1.
f
-
w1lambda
;
:
static_cast
<
T
>
(
ratio_w
*
out_img_idx
-
in_img_idx
);
T
w2lambda
=
static_cast
<
T
>
(
1.0
)
-
w1lambda
;
T
*
in_pos
;
T
*
in_pos
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
...
@@ -245,7 +248,7 @@ __global__ void KeBilinearInterpBwShareMemory(T* in,
...
@@ -245,7 +248,7 @@ __global__ void KeBilinearInterpBwShareMemory(T* in,
const
int
num_channels
,
const
int
num_channels
,
float
ratio_h
,
float
ratio_h
,
float
ratio_w
,
float
ratio_w
,
const
T
align_type_value
,
const
float
align_type_value
,
bool
is_nchw
)
{
bool
is_nchw
)
{
__shared__
T
s_data
[
2
][
1024
];
__shared__
T
s_data
[
2
][
1024
];
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
...
@@ -267,8 +270,10 @@ __global__ void KeBilinearInterpBwShareMemory(T* in,
...
@@ -267,8 +270,10 @@ __global__ void KeBilinearInterpBwShareMemory(T* in,
int
in_img_idx
,
in_img_idy
,
w_id
,
h_id
;
int
in_img_idx
,
in_img_idy
,
w_id
,
h_id
;
T
w1lambda
,
h1lambda
,
w2lambda
,
h2lambda
;
T
w1lambda
,
h1lambda
,
w2lambda
,
h2lambda
;
T
src_w
=
ratio_w
*
(
out_img_idx
+
align_type_value
)
-
align_type_value
;
T
src_w
=
static_cast
<
T
>
(
ratio_w
*
(
out_img_idx
+
align_type_value
)
-
T
src_h
=
ratio_h
*
(
out_img_idy
+
align_type_value
)
-
align_type_value
;
align_type_value
);
T
src_h
=
static_cast
<
T
>
(
ratio_h
*
(
out_img_idy
+
align_type_value
)
-
align_type_value
);
PreCalculatorForLinearInterpInputIndex
(
PreCalculatorForLinearInterpInputIndex
(
&
in_img_idx
,
&
w_id
,
&
w1lambda
,
&
w2lambda
,
src_w
,
in_w
);
&
in_img_idx
,
&
w_id
,
&
w1lambda
,
&
w2lambda
,
src_w
,
in_w
);
...
@@ -283,8 +288,8 @@ __global__ void KeBilinearInterpBwShareMemory(T* in,
...
@@ -283,8 +288,8 @@ __global__ void KeBilinearInterpBwShareMemory(T* in,
int
bot_right_index
=
input_index
+
h_id
*
in_w
+
w_id
;
int
bot_right_index
=
input_index
+
h_id
*
in_w
+
w_id
;
int
in_top_min_index
,
in_bot_min_index
;
int
in_top_min_index
,
in_bot_min_index
;
s_data
[
0
][
threadIdx
.
x
]
=
0.
f
;
s_data
[
0
][
threadIdx
.
x
]
=
static_cast
<
T
>
(
0
)
;
s_data
[
1
][
threadIdx
.
x
]
=
0.
f
;
s_data
[
1
][
threadIdx
.
x
]
=
static_cast
<
T
>
(
0
)
;
int
remain
=
nthreads
-
(
tid
&
(
-
blockDim
.
x
));
int
remain
=
nthreads
-
(
tid
&
(
-
blockDim
.
x
));
int
in_top_max_index
=
int
in_top_max_index
=
phi
::
funcs
::
blockReduceMax
(
top_right_index
,
FINAL_MASK
);
phi
::
funcs
::
blockReduceMax
(
top_right_index
,
FINAL_MASK
);
...
@@ -353,7 +358,7 @@ __global__ void KeBilinearInterpNCHWBw(T* in,
...
@@ -353,7 +358,7 @@ __global__ void KeBilinearInterpNCHWBw(T* in,
float
ratio_h
,
float
ratio_h
,
float
ratio_w
,
float
ratio_w
,
const
T
*
__restrict__
out
,
const
T
*
__restrict__
out
,
const
T
align_type_value
)
{
const
float
align_type_value
)
{
int
index
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
int
index
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
num_out
=
n
*
num_channels
*
out_h
*
out_w
;
int
num_out
=
n
*
num_channels
*
out_h
*
out_w
;
...
@@ -368,13 +373,15 @@ __global__ void KeBilinearInterpNCHWBw(T* in,
...
@@ -368,13 +373,15 @@ __global__ void KeBilinearInterpNCHWBw(T* in,
int
h1
,
y_id
;
int
h1
,
y_id
;
T
h1lambda
,
h0lambda
;
T
h1lambda
,
h0lambda
;
T
src_y
=
ratio_h
*
(
h2
+
align_type_value
)
-
align_type_value
;
T
src_y
=
static_cast
<
T
>
(
ratio_h
*
(
h2
+
align_type_value
)
-
align_type_value
);
PreCalculatorForLinearInterpInputIndex
(
PreCalculatorForLinearInterpInputIndex
(
&
h1
,
&
y_id
,
&
h1lambda
,
&
h0lambda
,
src_y
,
in_h
);
&
h1
,
&
y_id
,
&
h1lambda
,
&
h0lambda
,
src_y
,
in_h
);
int
w1
,
x_id
;
int
w1
,
x_id
;
T
w1lambda
,
w0lambda
;
T
w1lambda
,
w0lambda
;
T
src_x
=
ratio_w
*
(
w2
+
align_type_value
)
-
align_type_value
;
T
src_x
=
static_cast
<
T
>
(
ratio_w
*
(
w2
+
align_type_value
)
-
align_type_value
);
PreCalculatorForLinearInterpInputIndex
(
PreCalculatorForLinearInterpInputIndex
(
&
w1
,
&
x_id
,
&
w1lambda
,
&
w0lambda
,
src_x
,
in_w
);
&
w1
,
&
x_id
,
&
w1lambda
,
&
w0lambda
,
src_x
,
in_w
);
...
@@ -406,7 +413,7 @@ __global__ void KeBilinearInterpBw(T* in,
...
@@ -406,7 +413,7 @@ __global__ void KeBilinearInterpBw(T* in,
const
int
num_channels
,
const
int
num_channels
,
float
ratio_h
,
float
ratio_h
,
float
ratio_w
,
float
ratio_w
,
const
T
align_type_value
,
const
float
align_type_value
,
funcs
::
FastDivModForInterpolate
divmods
)
{
funcs
::
FastDivModForInterpolate
divmods
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
...
@@ -426,8 +433,10 @@ __global__ void KeBilinearInterpBw(T* in,
...
@@ -426,8 +433,10 @@ __global__ void KeBilinearInterpBw(T* in,
int
in_img_idx
,
in_img_idy
,
w_id
,
h_id
;
int
in_img_idx
,
in_img_idy
,
w_id
,
h_id
;
T
w1lambda
,
h1lambda
,
w2lambda
,
h2lambda
;
T
w1lambda
,
h1lambda
,
w2lambda
,
h2lambda
;
T
src_w
=
ratio_w
*
(
out_img_idx
+
align_type_value
)
-
align_type_value
;
T
src_w
=
static_cast
<
T
>
(
ratio_w
*
(
out_img_idx
+
align_type_value
)
-
T
src_h
=
ratio_h
*
(
out_img_idy
+
align_type_value
)
-
align_type_value
;
align_type_value
);
T
src_h
=
static_cast
<
T
>
(
ratio_h
*
(
out_img_idy
+
align_type_value
)
-
align_type_value
);
PreCalculatorForLinearInterpInputIndex
(
PreCalculatorForLinearInterpInputIndex
(
&
in_img_idx
,
&
w_id
,
&
w1lambda
,
&
w2lambda
,
src_w
,
in_w
);
&
in_img_idx
,
&
w_id
,
&
w1lambda
,
&
w2lambda
,
src_w
,
in_w
);
...
@@ -489,14 +498,13 @@ __global__ void KeBicubicInterpBw(T* in,
...
@@ -489,14 +498,13 @@ __global__ void KeBicubicInterpBw(T* in,
?
static_cast
<
T
>
(
ratio_h
*
out_img_idy
)
?
static_cast
<
T
>
(
ratio_h
*
out_img_idy
)
:
static_cast
<
T
>
(
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
);
:
static_cast
<
T
>
(
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
);
int
input_y
=
floorf
(
in_img_idy
);
int
input_y
=
floorf
(
in_img_idy
);
const
T
y_t
=
in_img_idy
-
input_y
;
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
const
T
y_t
=
static_cast
<
T
>
(
static_cast
<
MT
>
(
in_img_idy
)
-
input_y
);
T
in_img_idx
=
align_corners
T
in_img_idx
=
align_corners
?
static_cast
<
T
>
(
ratio_w
*
out_img_idx
)
?
static_cast
<
T
>
(
ratio_w
*
out_img_idx
)
:
static_cast
<
T
>
(
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
);
:
static_cast
<
T
>
(
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
);
int
input_x
=
floorf
(
in_img_idx
);
int
input_x
=
floorf
(
in_img_idx
);
const
T
x_t
=
static_cast
<
T
>
(
static_cast
<
MT
>
(
in_img_idx
)
-
input_x
);
const
T
x_t
=
in_img_idx
-
input_x
;
T
x_coeffs
[
4
];
T
x_coeffs
[
4
];
T
y_coeffs
[
4
];
T
y_coeffs
[
4
];
...
@@ -543,9 +551,9 @@ __global__ void KeTrilinearInterpBw(T* in,
...
@@ -543,9 +551,9 @@ __global__ void KeTrilinearInterpBw(T* in,
const
size_t
output_h
,
const
size_t
output_h
,
const
size_t
output_w
,
const
size_t
output_w
,
const
size_t
num_channels
,
const
size_t
num_channels
,
const
T
ratio_d
,
const
float
ratio_d
,
const
T
ratio_h
,
const
float
ratio_h
,
const
T
ratio_w
,
const
float
ratio_w
,
const
bool
align_corners
,
const
bool
align_corners
,
const
int
align_mode
,
const
int
align_mode
,
const
DataLayout
data_layout
)
{
const
DataLayout
data_layout
)
{
...
@@ -578,33 +586,37 @@ __global__ void KeTrilinearInterpBw(T* in,
...
@@ -578,33 +586,37 @@ __global__ void KeTrilinearInterpBw(T* in,
:
static_cast
<
int
>
(
ratio_d
*
out_img_idt
);
:
static_cast
<
int
>
(
ratio_d
*
out_img_idt
);
in_img_idt
=
(
in_img_idt
>
0
)
?
in_img_idt
:
0
;
in_img_idt
=
(
in_img_idt
>
0
)
?
in_img_idt
:
0
;
int
d_id
=
(
in_img_idt
<
in_img_d
-
1
)
?
1
:
0
;
int
d_id
=
(
in_img_idt
<
in_img_d
-
1
)
?
1
:
0
;
T
src_d
=
ratio_d
*
(
out_img_idt
+
0.5
)
-
0.5
;
T
src_d
=
static_cast
<
T
>
(
ratio_d
*
(
out_img_idt
+
0.5
)
-
0.5
);
src_d
=
(
src_d
>
0
)
?
src_d
:
0
;
src_d
=
(
src_d
>
static_cast
<
T
>
(
0
))
?
src_d
:
static_cast
<
T
>
(
0
);
T
d1lambda
=
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
align_flag
?
src_d
-
in_img_idt
:
ratio_d
*
out_img_idt
-
in_img_idt
;
T
d1lambda
=
align_flag
T
d2lambda
=
1.
f
-
d1lambda
;
?
static_cast
<
T
>
(
static_cast
<
MT
>
(
src_d
)
-
in_img_idt
)
:
static_cast
<
T
>
(
ratio_d
*
out_img_idt
-
in_img_idt
);
T
d2lambda
=
static_cast
<
T
>
(
1.0
)
-
d1lambda
;
int
in_img_idy
=
align_flag
int
in_img_idy
=
align_flag
?
static_cast
<
int
>
(
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
)
?
static_cast
<
int
>
(
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
)
:
static_cast
<
int
>
(
ratio_h
*
out_img_idy
);
:
static_cast
<
int
>
(
ratio_h
*
out_img_idy
);
in_img_idy
=
(
in_img_idy
>
0
)
?
in_img_idy
:
0
;
in_img_idy
=
(
in_img_idy
>
0
)
?
in_img_idy
:
0
;
int
h_id
=
(
in_img_idy
<
in_img_h
-
1
)
?
1
:
0
;
int
h_id
=
(
in_img_idy
<
in_img_h
-
1
)
?
1
:
0
;
T
src_h
=
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
;
T
src_h
=
static_cast
<
T
>
(
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
);
src_h
=
(
src_h
>
0
)
?
src_h
:
0
;
src_h
=
(
src_h
>
static_cast
<
T
>
(
0
))
?
src_h
:
static_cast
<
T
>
(
0
);
T
h1lambda
=
T
h1lambda
=
align_flag
align_flag
?
src_h
-
in_img_idy
:
ratio_h
*
out_img_idy
-
in_img_idy
;
?
static_cast
<
T
>
(
static_cast
<
MT
>
(
src_h
)
-
in_img_idy
)
T
h2lambda
=
1.
f
-
h1lambda
;
:
static_cast
<
T
>
(
ratio_h
*
out_img_idy
-
in_img_idy
);
T
h2lambda
=
static_cast
<
T
>
(
1.0
)
-
h1lambda
;
int
in_img_idx
=
align_flag
int
in_img_idx
=
align_flag
?
static_cast
<
int
>
(
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
)
?
static_cast
<
int
>
(
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
)
:
static_cast
<
int
>
(
ratio_w
*
out_img_idx
);
:
static_cast
<
int
>
(
ratio_w
*
out_img_idx
);
in_img_idx
=
(
in_img_idx
>
0
)
?
in_img_idx
:
0
;
in_img_idx
=
(
in_img_idx
>
0
)
?
in_img_idx
:
0
;
int
w_id
=
(
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
int
w_id
=
(
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
T
src_w
=
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
;
T
src_w
=
static_cast
<
T
>
(
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
);
src_w
=
(
src_w
>
0
)
?
src_w
:
0
;
src_w
=
(
src_w
>
static_cast
<
T
>
(
0
))
?
src_w
:
static_cast
<
T
>
(
0
);
T
w1lambda
=
T
w1lambda
=
align_flag
align_flag
?
src_w
-
in_img_idx
:
ratio_w
*
out_img_idx
-
in_img_idx
;
?
static_cast
<
T
>
(
static_cast
<
MT
>
(
src_w
)
-
in_img_idx
)
T
w2lambda
=
1.
f
-
w1lambda
;
:
static_cast
<
T
>
(
ratio_w
*
out_img_idx
-
in_img_idx
);
T
w2lambda
=
static_cast
<
T
>
(
1.0
)
-
w1lambda
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
int
in_pos1_idx
=
out_id_h
*
input_w
+
channel_id
*
in_img_size
+
int
in_pos1_idx
=
out_id_h
*
input_w
+
channel_id
*
in_img_size
+
...
@@ -1031,7 +1043,8 @@ static void Interpolate2DCUDABwd(
...
@@ -1031,7 +1043,8 @@ static void Interpolate2DCUDABwd(
interp_divmods
);
interp_divmods
);
}
}
}
else
if
(
"bilinear"
==
interp_method
)
{
}
else
if
(
"bilinear"
==
interp_method
)
{
const
T
align_type_value
=
(
align_mode
==
0
&&
!
align_corners
)
?
0.5
f
:
0
;
const
float
align_type_value
=
(
align_mode
==
0
&&
!
align_corners
)
?
0.5
f
:
0.
f
;
bool
is_nchw
=
(
data_layout
==
DataLayout
::
kNCHW
)
?
true
:
false
;
bool
is_nchw
=
(
data_layout
==
DataLayout
::
kNCHW
)
?
true
:
false
;
bool
optimize_flag
=
false
;
bool
optimize_flag
=
false
;
#ifndef __HIPCC__
#ifndef __HIPCC__
...
@@ -1148,7 +1161,7 @@ static void Interpolate3DCUDABwd(
...
@@ -1148,7 +1161,7 @@ static void Interpolate3DCUDABwd(
if
(
scale_tensor
)
{
if
(
scale_tensor
)
{
auto
scale_data
=
auto
scale_data
=
funcs
::
get_new_data_from_tensor
<
float
>
(
scale_tensor
.
get_ptr
());
funcs
::
get_new_data_from_tensor
<
float
>
(
scale_tensor
.
get_ptr
());
if
(
scale_data
.
size
()
>
1
)
{
if
(
scale_data
.
size
()
>
2
)
{
scale_d
=
scale_data
[
0
];
scale_d
=
scale_data
[
0
];
scale_h
=
scale_data
[
1
];
scale_h
=
scale_data
[
1
];
scale_w
=
scale_data
[
2
];
scale_w
=
scale_data
[
2
];
...
@@ -1179,7 +1192,7 @@ static void Interpolate3DCUDABwd(
...
@@ -1179,7 +1192,7 @@ static void Interpolate3DCUDABwd(
"should be greater than 0, but received value is %d."
,
"should be greater than 0, but received value is %d."
,
scale_d
));
scale_d
));
}
else
{
}
else
{
if
(
scale
.
size
()
>
1
)
{
if
(
scale
.
size
()
>
2
)
{
scale_d
=
scale
[
0
];
scale_d
=
scale
[
0
];
scale_h
=
scale
[
1
];
scale_h
=
scale
[
1
];
scale_w
=
scale
[
2
];
scale_w
=
scale
[
2
];
...
@@ -1574,7 +1587,8 @@ PD_REGISTER_KERNEL(bilinear_interp_grad,
...
@@ -1574,7 +1587,8 @@ PD_REGISTER_KERNEL(bilinear_interp_grad,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
BilinearInterpGradKernel
,
phi
::
BilinearInterpGradKernel
,
float
,
float
,
double
)
{
double
,
phi
::
dtype
::
float16
)
{
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
3
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
3
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
}
}
...
@@ -1583,7 +1597,8 @@ PD_REGISTER_KERNEL(nearest_interp_grad,
...
@@ -1583,7 +1597,8 @@ PD_REGISTER_KERNEL(nearest_interp_grad,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
NearestInterpGradKernel
,
phi
::
NearestInterpGradKernel
,
float
,
float
,
double
)
{
double
,
phi
::
dtype
::
float16
)
{
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
3
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
3
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
}
}
...
@@ -1592,7 +1607,8 @@ PD_REGISTER_KERNEL(trilinear_interp_grad,
...
@@ -1592,7 +1607,8 @@ PD_REGISTER_KERNEL(trilinear_interp_grad,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
TrilinearInterpGradKernel
,
phi
::
TrilinearInterpGradKernel
,
float
,
float
,
double
)
{
double
,
phi
::
dtype
::
float16
)
{
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
3
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
3
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
}
}
...
@@ -1601,7 +1617,8 @@ PD_REGISTER_KERNEL(linear_interp_grad,
...
@@ -1601,7 +1617,8 @@ PD_REGISTER_KERNEL(linear_interp_grad,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
LinearInterpGradKernel
,
phi
::
LinearInterpGradKernel
,
float
,
float
,
double
)
{
double
,
phi
::
dtype
::
float16
)
{
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
3
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
3
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
}
}
...
@@ -1610,7 +1627,8 @@ PD_REGISTER_KERNEL(bicubic_interp_grad,
...
@@ -1610,7 +1627,8 @@ PD_REGISTER_KERNEL(bicubic_interp_grad,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
BicubicInterpGradKernel
,
phi
::
BicubicInterpGradKernel
,
float
,
float
,
double
)
{
double
,
phi
::
dtype
::
float16
)
{
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
3
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
3
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
}
}
paddle/phi/kernels/gpu/interpolate_kernel.cu
浏览文件 @
b12c27eb
...
@@ -19,6 +19,8 @@
...
@@ -19,6 +19,8 @@
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/interpolate_function.h"
#include "paddle/phi/kernels/funcs/interpolate_function.h"
...
@@ -34,11 +36,12 @@ __forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex(
...
@@ -34,11 +36,12 @@ __forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex(
T
*
lambda2
,
T
*
lambda2
,
T
src_x
,
T
src_x
,
const
int
in_img_x
)
{
const
int
in_img_x
)
{
src_x
=
(
src_x
>
0
)
?
src_x
:
0.
f
;
src_x
=
(
src_x
>
static_cast
<
T
>
(
0
))
?
src_x
:
static_cast
<
T
>
(
0
)
;
*
in_img_idx
=
static_cast
<
int
>
(
src_x
);
*
in_img_idx
=
static_cast
<
int
>
(
src_x
);
*
x_id
=
(
*
in_img_idx
<
in_img_x
-
1
)
?
1
:
0
;
*
x_id
=
(
*
in_img_idx
<
in_img_x
-
1
)
?
1
:
0
;
*
lambda1
=
src_x
-
*
in_img_idx
;
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
*
lambda2
=
1.
f
-
*
lambda1
;
*
lambda1
=
static_cast
<
T
>
(
static_cast
<
MT
>
(
src_x
)
-
*
in_img_idx
);
*
lambda2
=
static_cast
<
T
>
(
1.0
)
-
*
lambda1
;
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -78,12 +81,13 @@ __global__ void KeLinearInterpFw(const T* in,
...
@@ -78,12 +81,13 @@ __global__ void KeLinearInterpFw(const T* in,
:
static_cast
<
int
>
(
ratio_w
*
out_img_idx
);
:
static_cast
<
int
>
(
ratio_w
*
out_img_idx
);
in_img_idx
=
(
in_img_idx
>
0
)
?
in_img_idx
:
0
;
// w
in_img_idx
=
(
in_img_idx
>
0
)
?
in_img_idx
:
0
;
// w
int
w_id
=
(
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
// w_id
int
w_id
=
(
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
// w_id
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
T
src_w
=
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
;
T
src_w
=
static_cast
<
T
>
(
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
);
src_w
=
(
src_w
>
0
)
?
src_w
:
0
;
src_w
=
(
src_w
>
static_cast
<
T
>
(
0
))
?
src_w
:
static_cast
<
T
>
(
0
);
T
w1lambda
=
T
w1lambda
=
align_flag
align_flag
?
src_w
-
in_img_idx
:
ratio_w
*
out_img_idx
-
in_img_idx
;
?
static_cast
<
T
>
(
static_cast
<
MT
>
(
src_w
)
-
in_img_idx
)
T
w2lambda
=
1.
f
-
w1lambda
;
:
static_cast
<
T
>
(
ratio_w
*
out_img_idx
-
in_img_idx
);
T
w2lambda
=
static_cast
<
T
>
(
1.0
)
-
w1lambda
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
const
T
*
in_pos
=
const
T
*
in_pos
=
...
@@ -203,7 +207,7 @@ __global__ void KeBilinearInterpFw(const T* in,
...
@@ -203,7 +207,7 @@ __global__ void KeBilinearInterpFw(const T* in,
const
size_t
num_channels
,
const
size_t
num_channels
,
const
float
ratio_h
,
const
float
ratio_h
,
const
float
ratio_w
,
const
float
ratio_w
,
const
T
align_type_value
,
const
float
align_type_value
,
funcs
::
FastDivModForInterpolate
divmods
)
{
funcs
::
FastDivModForInterpolate
divmods
)
{
int
nthreads
=
output_h
*
output_w
;
int
nthreads
=
output_h
*
output_w
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
...
@@ -222,8 +226,10 @@ __global__ void KeBilinearInterpFw(const T* in,
...
@@ -222,8 +226,10 @@ __global__ void KeBilinearInterpFw(const T* in,
int
in_img_idx
,
in_img_idy
,
h_id
,
w_id
;
int
in_img_idx
,
in_img_idy
,
h_id
,
w_id
;
T
h1lambda
,
w1lambda
,
h2lambda
,
w2lambda
;
T
h1lambda
,
w1lambda
,
h2lambda
,
w2lambda
;
T
src_w
=
ratio_w
*
(
out_img_idx
+
align_type_value
)
-
align_type_value
;
T
src_w
=
static_cast
<
T
>
(
ratio_w
*
(
out_img_idx
+
align_type_value
)
-
T
src_h
=
ratio_h
*
(
out_img_idy
+
align_type_value
)
-
align_type_value
;
align_type_value
);
T
src_h
=
static_cast
<
T
>
(
ratio_h
*
(
out_img_idy
+
align_type_value
)
-
align_type_value
);
PreCalculatorForLinearInterpInputIndex
(
PreCalculatorForLinearInterpInputIndex
(
&
in_img_idx
,
&
w_id
,
&
w1lambda
,
&
w2lambda
,
src_w
,
in_img_w
);
&
in_img_idx
,
&
w_id
,
&
w1lambda
,
&
w2lambda
,
src_w
,
in_img_w
);
...
@@ -254,7 +260,7 @@ __global__ void KeBilinearInterpNCHWFw(const T* in,
...
@@ -254,7 +260,7 @@ __global__ void KeBilinearInterpNCHWFw(const T* in,
const
size_t
nc
,
const
size_t
nc
,
const
float
ratio_h
,
const
float
ratio_h
,
const
float
ratio_w
,
const
float
ratio_w
,
const
T
align_type_value
)
{
const
float
align_type_value
)
{
int
out_img_idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
out_img_idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
out_img_idy
=
threadIdx
.
y
+
blockIdx
.
y
*
blockDim
.
y
;
int
out_img_idy
=
threadIdx
.
y
+
blockIdx
.
y
*
blockDim
.
y
;
int
nc_id
=
threadIdx
.
z
+
blockIdx
.
z
*
blockDim
.
z
;
int
nc_id
=
threadIdx
.
z
+
blockIdx
.
z
*
blockDim
.
z
;
...
@@ -262,8 +268,10 @@ __global__ void KeBilinearInterpNCHWFw(const T* in,
...
@@ -262,8 +268,10 @@ __global__ void KeBilinearInterpNCHWFw(const T* in,
int
in_img_idx
,
in_img_idy
,
h_id
,
w_id
;
int
in_img_idx
,
in_img_idy
,
h_id
,
w_id
;
T
h1lambda
,
w1lambda
,
h2lambda
,
w2lambda
;
T
h1lambda
,
w1lambda
,
h2lambda
,
w2lambda
;
T
src_w
=
ratio_w
*
(
out_img_idx
+
align_type_value
)
-
align_type_value
;
T
src_w
=
static_cast
<
T
>
(
ratio_w
*
(
out_img_idx
+
align_type_value
)
-
T
src_h
=
ratio_h
*
(
out_img_idy
+
align_type_value
)
-
align_type_value
;
align_type_value
);
T
src_h
=
static_cast
<
T
>
(
ratio_h
*
(
out_img_idy
+
align_type_value
)
-
align_type_value
);
PreCalculatorForLinearInterpInputIndex
(
PreCalculatorForLinearInterpInputIndex
(
&
in_img_idx
,
&
w_id
,
&
w1lambda
,
&
w2lambda
,
src_w
,
in_img_w
);
&
in_img_idx
,
&
w_id
,
&
w1lambda
,
&
w2lambda
,
src_w
,
in_img_w
);
...
@@ -296,13 +304,13 @@ template <typename T>
...
@@ -296,13 +304,13 @@ template <typename T>
__device__
__forceinline__
static
T
Kecubic_interp
(
__device__
__forceinline__
static
T
Kecubic_interp
(
const
T
x0
,
const
T
x1
,
const
T
x2
,
const
T
x3
,
T
t
)
{
const
T
x0
,
const
T
x1
,
const
T
x2
,
const
T
x3
,
T
t
)
{
T
coeffs
[
4
];
T
coeffs
[
4
];
T
a
=
-
0.75
;
T
a
=
static_cast
<
T
>
(
-
0.75
)
;
T
x_1
=
t
;
T
x_1
=
t
;
T
x_2
=
1.0
-
t
;
T
x_2
=
static_cast
<
T
>
(
1.0
)
-
t
;
coeffs
[
0
]
=
funcs
::
CubicConvolution2
<
T
>
(
x_1
+
1.0
,
a
);
coeffs
[
0
]
=
funcs
::
CubicConvolution2
<
T
>
(
x_1
+
static_cast
<
T
>
(
1.0
)
,
a
);
coeffs
[
1
]
=
funcs
::
CubicConvolution1
<
T
>
(
x_1
,
a
);
coeffs
[
1
]
=
funcs
::
CubicConvolution1
<
T
>
(
x_1
,
a
);
coeffs
[
2
]
=
funcs
::
CubicConvolution1
<
T
>
(
x_2
,
a
);
coeffs
[
2
]
=
funcs
::
CubicConvolution1
<
T
>
(
x_2
,
a
);
coeffs
[
3
]
=
funcs
::
CubicConvolution2
<
T
>
(
x_2
+
1.0
,
a
);
coeffs
[
3
]
=
funcs
::
CubicConvolution2
<
T
>
(
x_2
+
static_cast
<
T
>
(
1.0
)
,
a
);
return
x0
*
coeffs
[
0
]
+
x1
*
coeffs
[
1
]
+
x2
*
coeffs
[
2
]
+
x3
*
coeffs
[
3
];
return
x0
*
coeffs
[
0
]
+
x1
*
coeffs
[
1
]
+
x2
*
coeffs
[
2
]
+
x3
*
coeffs
[
3
];
}
}
...
@@ -348,13 +356,14 @@ __global__ void KeBicubicInterpFw(const T* in,
...
@@ -348,13 +356,14 @@ __global__ void KeBicubicInterpFw(const T* in,
?
static_cast
<
T
>
(
ratio_h
*
out_img_idy
)
?
static_cast
<
T
>
(
ratio_h
*
out_img_idy
)
:
static_cast
<
T
>
(
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
);
:
static_cast
<
T
>
(
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
);
int
input_y
=
floorf
(
in_img_idy
);
int
input_y
=
floorf
(
in_img_idy
);
const
T
y_t
=
in_img_idy
-
input_y
;
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
const
T
y_t
=
static_cast
<
T
>
(
static_cast
<
MT
>
(
in_img_idy
)
-
input_y
);
T
in_img_idx
=
align_corners
T
in_img_idx
=
align_corners
?
static_cast
<
T
>
(
ratio_w
*
out_img_idx
)
?
static_cast
<
T
>
(
ratio_w
*
out_img_idx
)
:
static_cast
<
T
>
(
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
);
:
static_cast
<
T
>
(
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
);
int
input_x
=
floorf
(
in_img_idx
);
int
input_x
=
floorf
(
in_img_idx
);
const
T
x_t
=
in_img_idx
-
input_x
;
const
T
x_t
=
static_cast
<
T
>
(
static_cast
<
MT
>
(
in_img_idx
)
-
input_x
)
;
T
coefficients
[
4
];
T
coefficients
[
4
];
const
T
*
in_pos_0
;
const
T
*
in_pos_0
;
...
@@ -419,16 +428,15 @@ __global__ void KeBicubicInterpFw(const T* in,
...
@@ -419,16 +428,15 @@ __global__ void KeBicubicInterpFw(const T* in,
&
in
[
out_id_h
*
input_w
+
access_y
*
in_img_w
*
num_channels
+
&
in
[
out_id_h
*
input_w
+
access_y
*
in_img_w
*
num_channels
+
access_x_3
*
num_channels
+
channel_id
];
access_x_3
*
num_channels
+
channel_id
];
coefficients
[
k
]
=
Kecubic_interp
(
coefficients
[
k
]
=
Kecubic_interp
<
T
>
(
in_pos_0
[
0
],
in_pos_1
[
0
],
in_pos_2
[
0
],
in_pos_3
[
0
],
x_t
);
in_pos_0
[
0
],
in_pos_1
[
0
],
in_pos_2
[
0
],
in_pos_3
[
0
],
x_t
);
}
}
out
[
out_id_h
*
output_w
+
out_id_w
]
=
out
[
out_id_h
*
output_w
+
out_id_w
]
=
Kecubic_interp
<
T
>
(
coefficients
[
0
],
static_cast
<
T
>
(
Kecubic_interp
(
coefficients
[
0
],
coefficients
[
1
],
coefficients
[
1
],
coefficients
[
2
],
coefficients
[
2
],
coefficients
[
3
],
coefficients
[
3
],
y_t
);
y_t
));
}
}
}
}
}
}
...
@@ -482,33 +490,37 @@ __global__ void KeTrilinearInterpFw(const T* in,
...
@@ -482,33 +490,37 @@ __global__ void KeTrilinearInterpFw(const T* in,
:
static_cast
<
int
>
(
ratio_d
*
out_img_idt
);
:
static_cast
<
int
>
(
ratio_d
*
out_img_idt
);
in_img_idt
=
(
in_img_idt
>
0
)
?
in_img_idt
:
0
;
in_img_idt
=
(
in_img_idt
>
0
)
?
in_img_idt
:
0
;
int
d_id
=
(
in_img_idt
<
in_img_d
-
1
)
?
1
:
0
;
int
d_id
=
(
in_img_idt
<
in_img_d
-
1
)
?
1
:
0
;
T
src_d
=
ratio_d
*
(
out_img_idt
+
0.5
)
-
0.5
;
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
src_d
=
(
src_d
>
0
)
?
src_d
:
0
;
T
src_d
=
static_cast
<
T
>
(
ratio_d
*
(
out_img_idt
+
0.5
)
-
0.5
);
T
d1lambda
=
src_d
=
(
src_d
>
static_cast
<
T
>
(
0
))
?
src_d
:
static_cast
<
T
>
(
0
);
align_flag
?
src_d
-
in_img_idt
:
ratio_d
*
out_img_idt
-
in_img_idt
;
T
d1lambda
=
align_flag
T
d2lambda
=
1.
f
-
d1lambda
;
?
static_cast
<
T
>
(
static_cast
<
MT
>
(
src_d
)
-
in_img_idt
)
:
static_cast
<
T
>
(
ratio_d
*
out_img_idt
-
in_img_idt
);
T
d2lambda
=
static_cast
<
T
>
(
1.0
)
-
d1lambda
;
int
in_img_idy
=
align_flag
int
in_img_idy
=
align_flag
?
static_cast
<
int
>
(
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
)
?
static_cast
<
int
>
(
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
)
:
static_cast
<
int
>
(
ratio_h
*
out_img_idy
);
:
static_cast
<
int
>
(
ratio_h
*
out_img_idy
);
in_img_idy
=
(
in_img_idy
>
0
)
?
in_img_idy
:
0
;
in_img_idy
=
(
in_img_idy
>
0
)
?
in_img_idy
:
0
;
int
h_id
=
(
in_img_idy
<
in_img_h
-
1
)
?
1
:
0
;
int
h_id
=
(
in_img_idy
<
in_img_h
-
1
)
?
1
:
0
;
T
src_h
=
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
;
T
src_h
=
static_cast
<
T
>
(
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
);
src_h
=
(
src_h
>
0
)
?
src_h
:
0
;
src_h
=
(
src_h
>
static_cast
<
T
>
(
0
))
?
src_h
:
static_cast
<
T
>
(
0
);
T
h1lambda
=
T
h1lambda
=
align_flag
align_flag
?
src_h
-
in_img_idy
:
ratio_h
*
out_img_idy
-
in_img_idy
;
?
static_cast
<
T
>
(
static_cast
<
MT
>
(
src_h
)
-
in_img_idy
)
T
h2lambda
=
1.
f
-
h1lambda
;
:
static_cast
<
T
>
(
ratio_h
*
out_img_idy
-
in_img_idy
);
T
h2lambda
=
static_cast
<
T
>
(
1.0
)
-
h1lambda
;
int
in_img_idx
=
align_flag
int
in_img_idx
=
align_flag
?
static_cast
<
int
>
(
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
)
?
static_cast
<
int
>
(
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
)
:
static_cast
<
int
>
(
ratio_w
*
out_img_idx
);
:
static_cast
<
int
>
(
ratio_w
*
out_img_idx
);
in_img_idx
=
(
in_img_idx
>
0
)
?
in_img_idx
:
0
;
in_img_idx
=
(
in_img_idx
>
0
)
?
in_img_idx
:
0
;
int
w_id
=
(
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
int
w_id
=
(
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
T
src_w
=
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
;
T
src_w
=
static_cast
<
T
>
(
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
);
src_w
=
(
src_w
>
0
)
?
src_w
:
0
;
src_w
=
(
src_w
>
static_cast
<
T
>
(
0
))
?
src_w
:
static_cast
<
T
>
(
0
);
T
w1lambda
=
T
w1lambda
=
align_flag
align_flag
?
src_w
-
in_img_idx
:
ratio_w
*
out_img_idx
-
in_img_idx
;
?
static_cast
<
T
>
(
static_cast
<
MT
>
(
src_w
)
-
in_img_idx
)
T
w2lambda
=
1.
f
-
w1lambda
;
:
static_cast
<
T
>
(
ratio_w
*
out_img_idx
-
in_img_idx
);
T
w2lambda
=
static_cast
<
T
>
(
1.0
)
-
w1lambda
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
int
in_pos1_idx
=
out_id_h
*
input_w
+
channel_id
*
in_img_size
+
int
in_pos1_idx
=
out_id_h
*
input_w
+
channel_id
*
in_img_size
+
...
@@ -926,7 +938,8 @@ static void Interpolate2DCUDAFwd(
...
@@ -926,7 +938,8 @@ static void Interpolate2DCUDAFwd(
thread_num
=
512
;
thread_num
=
512
;
}
}
#endif
#endif
const
T
align_type_value
=
(
align_mode
==
0
&&
!
align_corners
)
?
0.5
f
:
0
;
const
float
align_type_value
=
(
align_mode
==
0
&&
!
align_corners
)
?
0.5
f
:
0.
f
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
// get launch 3D config
// get launch 3D config
int
nc
=
n
*
c
;
int
nc
=
n
*
c
;
...
@@ -1028,7 +1041,7 @@ static void Interpolate3DCUDAFwd(
...
@@ -1028,7 +1041,7 @@ static void Interpolate3DCUDAFwd(
if
(
scale_tensor
)
{
if
(
scale_tensor
)
{
auto
scale_data
=
auto
scale_data
=
funcs
::
get_new_data_from_tensor
<
float
>
(
scale_tensor
.
get_ptr
());
funcs
::
get_new_data_from_tensor
<
float
>
(
scale_tensor
.
get_ptr
());
if
(
scale_data
.
size
()
>
1
)
{
if
(
scale_data
.
size
()
>
2
)
{
scale_d
=
scale_data
[
0
];
scale_d
=
scale_data
[
0
];
scale_h
=
scale_data
[
1
];
scale_h
=
scale_data
[
1
];
scale_w
=
scale_data
[
2
];
scale_w
=
scale_data
[
2
];
...
@@ -1060,7 +1073,7 @@ static void Interpolate3DCUDAFwd(
...
@@ -1060,7 +1073,7 @@ static void Interpolate3DCUDAFwd(
"should be greater than 0, but received value is %d."
,
"should be greater than 0, but received value is %d."
,
scale_d
));
scale_d
));
}
else
{
}
else
{
if
(
scale
.
size
()
>
1
)
{
if
(
scale
.
size
()
>
2
)
{
scale_d
=
scale
[
0
];
scale_d
=
scale
[
0
];
scale_h
=
scale
[
1
];
scale_h
=
scale
[
1
];
scale_w
=
scale
[
2
];
scale_w
=
scale
[
2
];
...
@@ -1446,6 +1459,7 @@ PD_REGISTER_KERNEL(bilinear_interp,
...
@@ -1446,6 +1459,7 @@ PD_REGISTER_KERNEL(bilinear_interp,
phi
::
BilinearInterpKernel
,
phi
::
BilinearInterpKernel
,
float
,
float
,
double
,
double
,
phi
::
dtype
::
float16
,
int
)
{
int
)
{
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
3
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
3
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
...
@@ -1456,6 +1470,7 @@ PD_REGISTER_KERNEL(nearest_interp,
...
@@ -1456,6 +1470,7 @@ PD_REGISTER_KERNEL(nearest_interp,
phi
::
NearestInterpKernel
,
phi
::
NearestInterpKernel
,
float
,
float
,
double
,
double
,
phi
::
dtype
::
float16
,
int
,
int
,
int64_t
)
{
int64_t
)
{
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
...
@@ -1467,6 +1482,7 @@ PD_REGISTER_KERNEL(trilinear_interp,
...
@@ -1467,6 +1482,7 @@ PD_REGISTER_KERNEL(trilinear_interp,
phi
::
TrilinearInterpKernel
,
phi
::
TrilinearInterpKernel
,
float
,
float
,
double
,
double
,
phi
::
dtype
::
float16
,
int
)
{
int
)
{
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
3
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
3
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
...
@@ -1477,6 +1493,7 @@ PD_REGISTER_KERNEL(linear_interp,
...
@@ -1477,6 +1493,7 @@ PD_REGISTER_KERNEL(linear_interp,
phi
::
LinearInterpKernel
,
phi
::
LinearInterpKernel
,
float
,
float
,
double
,
double
,
phi
::
dtype
::
float16
,
int
)
{
int
)
{
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
3
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
3
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
...
@@ -1487,6 +1504,7 @@ PD_REGISTER_KERNEL(bicubic_interp,
...
@@ -1487,6 +1504,7 @@ PD_REGISTER_KERNEL(bicubic_interp,
phi
::
BicubicInterpKernel
,
phi
::
BicubicInterpKernel
,
float
,
float
,
double
,
double
,
phi
::
dtype
::
float16
,
int
)
{
int
)
{
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
2
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
3
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
3
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
...
...
python/paddle/fluid/tests/unittests/test_bicubic_interp_v2_op.py
浏览文件 @
b12c27eb
...
@@ -622,6 +622,44 @@ class TestBicubicOpError(unittest.TestCase):
...
@@ -622,6 +622,44 @@ class TestBicubicOpError(unittest.TestCase):
self
.
test_imperative_errors
()
self
.
test_imperative_errors
()
@
unittest
.
skipIf
(
not
fluid
.
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestBicubicInterpOpForFloat16
(
unittest
.
TestCase
):
def
init_test_case
(
self
):
self
.
interp_method
=
'bicubic'
self
.
input_shape
=
[
2
,
3
,
5
,
5
]
self
.
out_size
=
np
.
array
([
3
,
3
]).
astype
(
"int32"
)
self
.
align_corners
=
True
self
.
data_layout
=
'NCHW'
def
check_main
(
self
,
x_np
,
dtype
):
paddle
.
disable_static
()
x_np
=
x_np
.
astype
(
dtype
)
x
=
paddle
.
to_tensor
(
x_np
)
x
.
stop_gradient
=
False
y
=
interpolate
(
x
,
size
=
self
.
out_size
.
tolist
(),
mode
=
self
.
interp_method
,
align_corners
=
self
.
align_corners
,
data_format
=
self
.
data_layout
)
x_g
=
paddle
.
grad
(
y
,
x
)
y_np
=
y
[
0
].
numpy
().
astype
(
'float32'
)
x_g_np
=
x_g
[
0
].
numpy
().
astype
(
'float32'
)
paddle
.
enable_static
()
return
y_np
,
x_g_np
def
test_main
(
self
):
self
.
init_test_case
()
x_np
=
np
.
random
.
random
(
self
.
input_shape
).
astype
(
"float16"
)
y_np_1
,
x_g_np_1
=
self
.
check_main
(
x_np
,
'float16'
)
y_np_2
,
x_g_np_2
=
self
.
check_main
(
x_np
,
'float32'
)
np
.
testing
.
assert_allclose
(
y_np_1
,
y_np_2
)
np
.
testing
.
assert_allclose
(
x_g_np_1
,
x_g_np_2
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
paddle
.
enable_static
()
paddle
.
enable_static
()
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_bilinear_interp_v2_op.py
浏览文件 @
b12c27eb
...
@@ -766,5 +766,45 @@ class TestBilinearInterpOpAPI_dy4(unittest.TestCase):
...
@@ -766,5 +766,45 @@ class TestBilinearInterpOpAPI_dy4(unittest.TestCase):
np
.
testing
.
assert_allclose
(
out
.
numpy
(),
expect_res
,
rtol
=
1e-05
)
np
.
testing
.
assert_allclose
(
out
.
numpy
(),
expect_res
,
rtol
=
1e-05
)
@
unittest
.
skipIf
(
not
fluid
.
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestBilinearInterpOpForFloat16
(
unittest
.
TestCase
):
def
init_test_case
(
self
):
self
.
interp_method
=
'bilinear'
self
.
input_shape
=
[
2
,
3
,
5
,
5
]
self
.
out_size
=
np
.
array
([
3
,
3
]).
astype
(
"int32"
)
self
.
align_corners
=
True
self
.
align_mode
=
1
self
.
data_layout
=
'NCHW'
def
check_main
(
self
,
x_np
,
dtype
):
paddle
.
disable_static
()
x_np
=
x_np
.
astype
(
dtype
)
x
=
paddle
.
to_tensor
(
x_np
)
x
.
stop_gradient
=
False
y
=
interpolate
(
x
,
size
=
self
.
out_size
.
tolist
(),
mode
=
self
.
interp_method
,
align_mode
=
self
.
align_mode
,
align_corners
=
self
.
align_corners
,
data_format
=
self
.
data_layout
)
x_g
=
paddle
.
grad
(
y
,
x
)
y_np
=
y
[
0
].
numpy
().
astype
(
'float32'
)
x_g_np
=
x_g
[
0
].
numpy
().
astype
(
'float32'
)
paddle
.
enable_static
()
return
y_np
,
x_g_np
def
test_main
(
self
):
self
.
init_test_case
()
x_np
=
np
.
random
.
random
(
self
.
input_shape
).
astype
(
"float16"
)
y_np_1
,
x_g_np_1
=
self
.
check_main
(
x_np
,
'float16'
)
y_np_2
,
x_g_np_2
=
self
.
check_main
(
x_np
,
'float32'
)
np
.
testing
.
assert_allclose
(
y_np_1
,
y_np_2
)
np
.
testing
.
assert_allclose
(
x_g_np_1
,
x_g_np_2
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_linear_interp_v2_op.py
浏览文件 @
b12c27eb
...
@@ -376,9 +376,7 @@ class TestLinearInterpOpAPI2_0(unittest.TestCase):
...
@@ -376,9 +376,7 @@ class TestLinearInterpOpAPI2_0(unittest.TestCase):
# dygraph
# dygraph
x_data
=
np
.
random
.
random
((
1
,
3
,
128
)).
astype
(
"float32"
)
x_data
=
np
.
random
.
random
((
1
,
3
,
128
)).
astype
(
"float32"
)
us_1
=
paddle
.
nn
.
Upsample
(
size
=
[
us_1
=
paddle
.
nn
.
Upsample
(
size
=
[
64
],
64
,
],
mode
=
'linear'
,
mode
=
'linear'
,
align_mode
=
1
,
align_mode
=
1
,
align_corners
=
False
,
align_corners
=
False
,
...
@@ -493,28 +491,21 @@ class TestLinearInterpOpError(unittest.TestCase):
...
@@ -493,28 +491,21 @@ class TestLinearInterpOpError(unittest.TestCase):
def
input_shape_error
():
def
input_shape_error
():
x1
=
fluid
.
data
(
name
=
"x1"
,
shape
=
[
1
],
dtype
=
"float32"
)
x1
=
fluid
.
data
(
name
=
"x1"
,
shape
=
[
1
],
dtype
=
"float32"
)
out1
=
paddle
.
nn
.
Upsample
(
size
=
[
out1
=
paddle
.
nn
.
Upsample
(
size
=
[
256
],
256
,
],
data_format
=
'NCW'
,
data_format
=
'NCW'
,
mode
=
'linear'
)
mode
=
'linear'
)
out1_res
=
out1
(
x1
)
out1_res
=
out1
(
x1
)
def
data_format_error
():
def
data_format_error
():
x2
=
fluid
.
data
(
name
=
"x2"
,
shape
=
[
1
,
3
,
128
],
dtype
=
"float32"
)
x2
=
fluid
.
data
(
name
=
"x2"
,
shape
=
[
1
,
3
,
128
],
dtype
=
"float32"
)
out2
=
paddle
.
nn
.
Upsample
(
size
=
[
out2
=
paddle
.
nn
.
Upsample
(
size
=
[
256
],
256
,
],
data_format
=
'NHWCD'
,
data_format
=
'NHWCD'
,
mode
=
'linear'
)
mode
=
'linear'
)
out2_res
=
out2
(
x2
)
out2_res
=
out2
(
x2
)
def
out_shape_error
():
def
out_shape_error
():
x3
=
fluid
.
data
(
name
=
"x3"
,
shape
=
[
1
,
3
,
128
],
dtype
=
"float32"
)
x3
=
fluid
.
data
(
name
=
"x3"
,
shape
=
[
1
,
3
,
128
],
dtype
=
"float32"
)
out3
=
paddle
.
nn
.
Upsample
(
size
=
[
out3
=
paddle
.
nn
.
Upsample
(
size
=
[
256
,
256
],
256
,
256
,
],
data_format
=
'NHWC'
,
data_format
=
'NHWC'
,
mode
=
'linear'
)
mode
=
'linear'
)
out3_res
=
out3
(
x3
)
out3_res
=
out3
(
x3
)
...
@@ -524,5 +515,46 @@ class TestLinearInterpOpError(unittest.TestCase):
...
@@ -524,5 +515,46 @@ class TestLinearInterpOpError(unittest.TestCase):
self
.
assertRaises
(
ValueError
,
out_shape_error
)
self
.
assertRaises
(
ValueError
,
out_shape_error
)
@
unittest
.
skipIf
(
not
fluid
.
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestLinearInterpOpForFloat16
(
unittest
.
TestCase
):
def
init_test_case
(
self
):
self
.
interp_method
=
'linear'
self
.
input_shape
=
[
1
,
3
,
64
]
self
.
scale
=
2
self
.
align_corners
=
False
self
.
align_mode
=
1
self
.
data_layout
=
'NCW'
def
check_main
(
self
,
x_np
,
dtype
):
paddle
.
disable_static
()
x_np
=
x_np
.
astype
(
dtype
)
x
=
paddle
.
to_tensor
(
x_np
)
x
.
stop_gradient
=
False
y
=
interpolate
(
x
,
scale_factor
=
self
.
scale
,
mode
=
self
.
interp_method
,
align_mode
=
self
.
align_mode
,
align_corners
=
self
.
align_corners
,
data_format
=
self
.
data_layout
)
x_g
=
paddle
.
grad
(
y
,
x
)
y_np
=
y
[
0
].
numpy
().
astype
(
'float32'
)
x_g_np
=
x_g
[
0
].
numpy
().
astype
(
'float32'
)
paddle
.
enable_static
()
return
y_np
,
x_g_np
def
test_main
(
self
):
self
.
init_test_case
()
x_np
=
np
.
random
.
random
(
self
.
input_shape
).
astype
(
"float16"
)
y_np_1
,
x_g_np_1
=
self
.
check_main
(
x_np
,
'float16'
)
y_np_2
,
x_g_np_2
=
self
.
check_main
(
x_np
,
'float32'
)
# forward
np
.
testing
.
assert_allclose
(
y_np_1
,
y_np_2
,
rtol
=
1e-03
)
# backward
np
.
testing
.
assert_allclose
(
x_g_np_1
,
x_g_np_2
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py
浏览文件 @
b12c27eb
#
Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -802,5 +802,81 @@ class TestNearestInterpException(unittest.TestCase):
...
@@ -802,5 +802,81 @@ class TestNearestInterpException(unittest.TestCase):
self
.
assertRaises
(
ValueError
,
mode_error
)
self
.
assertRaises
(
ValueError
,
mode_error
)
@
unittest
.
skipIf
(
not
fluid
.
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestNearestInterp3DOpForFloat16
(
unittest
.
TestCase
):
def
init_test_case
(
self
):
self
.
interp_method
=
'nearest'
self
.
input_shape
=
[
2
,
2
,
6
,
6
,
6
]
self
.
scale
=
[
2
,
2
,
2
]
self
.
align_corners
=
False
self
.
data_layout
=
'NCDHW'
def
check_main
(
self
,
x_np
,
dtype
):
paddle
.
disable_static
()
x_np
=
x_np
.
astype
(
dtype
)
x
=
paddle
.
to_tensor
(
x_np
)
x
.
stop_gradient
=
False
y
=
interpolate
(
x
,
scale_factor
=
self
.
scale
,
mode
=
self
.
interp_method
,
align_corners
=
self
.
align_corners
,
data_format
=
self
.
data_layout
)
x_g
=
paddle
.
grad
(
y
,
x
)
y_np
=
y
[
0
].
numpy
().
astype
(
'float32'
)
x_g_np
=
x_g
[
0
].
numpy
().
astype
(
'float32'
)
paddle
.
enable_static
()
return
y_np
,
x_g_np
def
test_main
(
self
):
self
.
init_test_case
()
x_np
=
np
.
random
.
random
(
self
.
input_shape
).
astype
(
"float16"
)
y_np_1
,
x_g_np_1
=
self
.
check_main
(
x_np
,
'float16'
)
y_np_2
,
x_g_np_2
=
self
.
check_main
(
x_np
,
'float32'
)
# forward
np
.
testing
.
assert_allclose
(
y_np_1
,
y_np_2
,
rtol
=
1e-03
)
# backward
np
.
testing
.
assert_allclose
(
x_g_np_1
,
x_g_np_2
)
@
unittest
.
skipIf
(
not
fluid
.
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestNearestInterpOpForFloat16
(
unittest
.
TestCase
):
def
init_test_case
(
self
):
self
.
interp_method
=
'nearest'
self
.
input_shape
=
[
2
,
2
,
6
,
6
]
self
.
scale
=
[
2
,
2
]
self
.
align_corners
=
False
def
check_main
(
self
,
x_np
,
dtype
):
paddle
.
disable_static
()
x_np
=
x_np
.
astype
(
dtype
)
x
=
paddle
.
to_tensor
(
x_np
)
x
.
stop_gradient
=
False
y
=
interpolate
(
x
,
scale_factor
=
self
.
scale
,
mode
=
self
.
interp_method
,
align_corners
=
self
.
align_corners
)
x_g
=
paddle
.
grad
(
y
,
x
)
y_np
=
y
[
0
].
numpy
().
astype
(
'float32'
)
x_g_np
=
x_g
[
0
].
numpy
().
astype
(
'float32'
)
paddle
.
enable_static
()
return
y_np
,
x_g_np
def
test_main
(
self
):
self
.
init_test_case
()
x_np
=
np
.
random
.
random
(
self
.
input_shape
).
astype
(
"float16"
)
y_np_1
,
x_g_np_1
=
self
.
check_main
(
x_np
,
'float16'
)
y_np_2
,
x_g_np_2
=
self
.
check_main
(
x_np
,
'float32'
)
# forward
np
.
testing
.
assert_allclose
(
y_np_1
,
y_np_2
)
# backward
np
.
testing
.
assert_allclose
(
x_g_np_1
,
x_g_np_2
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_trilinear_interp_v2_op.py
浏览文件 @
b12c27eb
...
@@ -154,15 +154,15 @@ def trilinear_interp_np(input,
...
@@ -154,15 +154,15 @@ def trilinear_interp_np(input,
out
[:,
:,
i
,
j
,
k
]
=
\
out
[:,
:,
i
,
j
,
k
]
=
\
d2lambda
*
\
d2lambda
*
\
(
h2lambda
*
(
w2lambda
*
input
[:,
:,
d
,
h
,
w
]
+
\
(
h2lambda
*
(
w2lambda
*
input
[:,
:,
d
,
h
,
w
]
+
w1lambda
*
input
[:,
:,
d
,
h
,
w
+
wid
])
+
\
w1lambda
*
input
[:,
:,
d
,
h
,
w
+
wid
])
+
h1lambda
*
(
w2lambda
*
input
[:,
:,
d
,
h
+
hid
,
w
]
+
\
h1lambda
*
(
w2lambda
*
input
[:,
:,
d
,
h
+
hid
,
w
]
+
w1lambda
*
input
[:,
:,
d
,
h
+
hid
,
w
+
wid
]))
+
\
w1lambda
*
input
[:,
:,
d
,
h
+
hid
,
w
+
wid
]))
+
\
d1lambda
*
\
d1lambda
*
\
(
h2lambda
*
(
w2lambda
*
input
[:,
:,
d
+
did
,
h
,
w
]
+
\
(
h2lambda
*
(
w2lambda
*
input
[:,
:,
d
+
did
,
h
,
w
]
+
w1lambda
*
input
[:,
:,
d
+
did
,
h
,
w
+
wid
])
+
\
w1lambda
*
input
[:,
:,
d
+
did
,
h
,
w
+
wid
])
+
h1lambda
*
(
w2lambda
*
input
[:,
:,
d
+
did
,
h
+
hid
,
w
]
+
\
h1lambda
*
(
w2lambda
*
input
[:,
:,
d
+
did
,
h
+
hid
,
w
]
+
w1lambda
*
input
[:,
:,
d
+
did
,
h
+
hid
,
w
+
wid
]))
w1lambda
*
input
[:,
:,
d
+
did
,
h
+
hid
,
w
+
wid
]))
if
data_layout
==
"NDHWC"
:
if
data_layout
==
"NDHWC"
:
out
=
np
.
transpose
(
out
,
(
0
,
2
,
3
,
4
,
1
))
# NCDHW => NDHWC
out
=
np
.
transpose
(
out
,
(
0
,
2
,
3
,
4
,
1
))
# NCDHW => NDHWC
...
@@ -809,5 +809,59 @@ class TestTrilinearInterpOpException(unittest.TestCase):
...
@@ -809,5 +809,59 @@ class TestTrilinearInterpOpException(unittest.TestCase):
self
.
assertRaises
(
ValueError
,
attr_data_format
)
self
.
assertRaises
(
ValueError
,
attr_data_format
)
@
unittest
.
skipIf
(
not
fluid
.
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestTrilinearInterpOpForFloat16
(
unittest
.
TestCase
):
def
init_test_case
(
self
):
self
.
interp_method
=
'trilinear'
self
.
input_shape
=
[
2
,
3
,
4
,
4
,
4
]
self
.
out_size
=
np
.
array
([
3
,
3
,
3
]).
astype
(
"int32"
)
self
.
align_corners
=
True
self
.
align_mode
=
1
self
.
data_layout
=
'NCDHW'
def
check_main
(
self
,
x_np
,
dtype
):
paddle
.
disable_static
()
x_np
=
x_np
.
astype
(
dtype
)
x
=
paddle
.
to_tensor
(
x_np
)
x
.
stop_gradient
=
False
y
=
interpolate
(
x
,
size
=
self
.
out_size
.
tolist
(),
mode
=
self
.
interp_method
,
align_corners
=
self
.
align_corners
,
align_mode
=
self
.
align_mode
,
data_format
=
self
.
data_layout
)
x_g
=
paddle
.
grad
(
y
,
x
)
y_np
=
y
[
0
].
numpy
().
astype
(
'float32'
)
x_g_np
=
x_g
[
0
].
numpy
().
astype
(
'float32'
)
paddle
.
enable_static
()
return
y_np
,
x_g_np
def
test_main
(
self
):
self
.
init_test_case
()
x_np
=
np
.
random
.
random
(
self
.
input_shape
).
astype
(
"float16"
)
y_np_1
,
x_g_np_1
=
self
.
check_main
(
x_np
,
'float16'
)
y_np_2
,
x_g_np_2
=
self
.
check_main
(
x_np
,
'float32'
)
# forward
np
.
testing
.
assert_allclose
(
y_np_1
,
y_np_2
,
rtol
=
1e-03
)
# backward
np
.
testing
.
assert_allclose
(
x_g_np_1
,
x_g_np_2
,
rtol
=
1e-05
)
@
unittest
.
skipIf
(
not
fluid
.
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestTrilinearInterpDatalayoutForFloat16
(
TestTrilinearInterpOpForFloat16
):
def
init_test_case
(
self
):
self
.
interp_method
=
'trilinear'
self
.
input_shape
=
[
2
,
4
,
4
,
4
,
3
]
self
.
out_size
=
np
.
array
([
3
,
3
,
3
]).
astype
(
"int32"
)
self
.
align_corners
=
True
self
.
align_mode
=
1
self
.
data_layout
=
"NDHWC"
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录