Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a1174973
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a1174973
编写于
2月 11, 2022
作者:
L
Lijunhui
提交者:
GitHub
2月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize bilinear interpolation foward (#39243)
* bilinear_fw init * optimize code * pre-compute linear_interp input index
上级
c86765ed
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
118 addition
and
90 deletion
+118
-90
paddle/fluid/operators/interpolate_v2_op.cu
paddle/fluid/operators/interpolate_v2_op.cu
+118
-90
未找到文件。
paddle/fluid/operators/interpolate_v2_op.cu
浏览文件 @
a1174973
...
...
@@ -59,6 +59,17 @@ inline platform::GpuLaunchConfig GetGpuLaunchConfig3D(
return
config
;
}
template
<
typename
T
>
__forceinline__
__device__
void
PreCalculatorForLinearInterpInputIndex
(
int
*
in_img_idx
,
int
*
w_id
,
T
*
w1lambda
,
T
*
w2lambda
,
T
src_w
,
const
int
in_img_w
)
{
src_w
=
(
src_w
>
0
)
?
src_w
:
0.
f
;
*
in_img_idx
=
static_cast
<
int
>
(
src_w
);
*
w_id
=
(
*
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
*
w1lambda
=
src_w
-
*
in_img_idx
;
*
w2lambda
=
1.
f
-
*
w1lambda
;
}
struct
FastDivModForInterpolate
{
public:
FastDivMod
channels_div
;
...
...
@@ -417,96 +428,93 @@ __global__ void KeLinearInterpBw(T* in, const size_t in_img_w,
}
template
<
typename
T
>
__global__
void
KeBilinearInterpFw
(
const
T
*
in
,
const
size_t
in_img_h
,
const
size_t
in_img_w
,
const
size_t
input_h
,
const
size_t
input_w
,
T
*
out
,
const
size_t
out_img_h
,
const
size_t
out_img_w
,
const
size_t
output_h
,
const
size_t
output_w
,
const
size_t
num_channels
,
const
float
ratio_h
,
const
float
ratio_w
,
const
bool
align_corners
,
const
int
align_mode
,
const
DataLayout
data_layout
)
{
int
nthreads
=
output_h
*
output_w
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
bool
align_flag
=
(
align_mode
==
0
&&
!
align_corners
);
for
(;
tid
<
nthreads
;
tid
+=
stride
)
{
int
out_id_h
=
tid
/
output_w
;
int
out_id_w
=
tid
%
output_w
;
int
in_img_size
=
input_w
/
num_channels
;
int
out_img_size
=
output_w
/
num_channels
;
__global__
void
KeBilinearInterpNCHWFw
(
const
T
*
in
,
const
size_t
in_img_h
,
const
size_t
in_img_w
,
T
*
out
,
const
size_t
out_img_h
,
const
size_t
out_img_w
,
const
size_t
nc
,
const
float
ratio_h
,
const
float
ratio_w
,
const
T
align_type_value
)
{
int
out_img_idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
out_img_idy
=
threadIdx
.
y
+
blockIdx
.
y
*
blockDim
.
y
;
int
nc_id
=
threadIdx
.
z
+
blockIdx
.
z
*
blockDim
.
z
;
int
nc_stride
=
blockDim
.
z
*
gridDim
.
z
;
int
channel_id
,
out_img_idy
,
out_img_idx
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
channel_id
=
out_id_w
/
out_img_size
;
out_img_idy
=
(
out_id_w
%
out_img_size
)
/
out_img_w
;
out_img_idx
=
tid
%
out_img_w
;
}
else
{
out_img_idy
=
out_id_w
/
(
out_img_w
*
num_channels
);
out_img_idx
=
out_id_w
%
(
out_img_w
*
num_channels
)
/
num_channels
;
channel_id
=
tid
%
num_channels
;
}
int
in_img_idx
,
in_img_idy
,
h_id
,
w_id
;
T
h1lambda
,
w1lambda
,
h2lambda
,
w2lambda
;
T
src_w
=
ratio_w
*
(
out_img_idx
+
align_type_value
)
-
align_type_value
;
T
src_h
=
ratio_h
*
(
out_img_idy
+
align_type_value
)
-
align_type_value
;
int
in_img_idy
=
align_flag
?
static_cast
<
int
>
(
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
)
:
static_cast
<
int
>
(
ratio_h
*
out_img_idy
);
in_img_idy
=
(
in_img_idy
>
0
)
?
in_img_idy
:
0
;
int
h_id
=
(
in_img_idy
<
in_img_h
-
1
)
?
1
:
0
;
T
src_h
=
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
;
src_h
=
(
src_h
>
0
)
?
src_h
:
0
;
T
h1lambda
=
align_flag
?
src_h
-
in_img_idy
:
ratio_h
*
out_img_idy
-
in_img_idy
;
T
h2lambda
=
1.
f
-
h1lambda
;
PreCalculatorForLinearInterpInputIndex
(
&
in_img_idx
,
&
w_id
,
&
w1lambda
,
&
w2lambda
,
src_w
,
in_img_w
);
PreCalculatorForLinearInterpInputIndex
(
&
in_img_idy
,
&
h_id
,
&
h1lambda
,
&
h2lambda
,
src_h
,
in_img_h
);
int
in_img_idx
=
align_flag
?
static_cast
<
int
>
(
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
)
:
static_cast
<
int
>
(
ratio_w
*
out_img_idx
);
in_img_idx
=
(
in_img_idx
>
0
)
?
in_img_idx
:
0
;
int
w_id
=
(
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
T
src_w
=
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
;
src_w
=
(
src_w
>
0
)
?
src_w
:
0
;
T
w1lambda
=
align_flag
?
src_w
-
in_img_idx
:
ratio_w
*
out_img_idx
-
in_img_idx
;
T
w2lambda
=
1.
f
-
w1lambda
;
int
in_index
=
(
nc_id
*
in_img_h
+
in_img_idy
)
*
in_img_w
+
in_img_idx
;
int
in_index_stride
=
nc_stride
*
in_img_h
*
in_img_w
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
const
T
*
in_pos
=
&
in
[
out_id_h
*
input_w
+
channel_id
*
in_img_size
+
in_img_idy
*
in_img_w
+
in_img_idx
];
int
out_index
=
(
nc_id
*
out_img_h
+
out_img_idy
)
*
out_img_w
+
out_img_idx
;
int
out_index_stride
=
nc_stride
*
out_img_h
*
out_img_w
;
// bilinear interpolation
out
[
out_id_h
*
output_w
+
out_id_w
]
=
// prevent from multiple threads writing
if
(
out_img_idx
<
out_img_w
&&
out_img_idy
<
out_img_h
)
{
while
(
nc_id
<
nc
)
{
const
T
*
in_pos
=
&
in
[
in_index
];
out
[
out_index
]
=
h2lambda
*
(
w2lambda
*
in_pos
[
0
]
+
w1lambda
*
in_pos
[
w_id
])
+
h1lambda
*
(
w2lambda
*
in_pos
[
h_id
*
in_img_w
]
+
w1lambda
*
in_pos
[
h_id
*
in_img_w
+
w_id
]);
}
else
{
const
T
*
in_pos
=
&
in
[
out_id_h
*
input_w
+
in_img_idy
*
in_img_w
*
num_channels
+
in_img_idx
*
num_channels
+
channel_id
];
// bilinear interpolation
out
[
out_id_h
*
output_w
+
out_id_w
]
=
h2lambda
*
(
w2lambda
*
in_pos
[
0
]
+
w1lambda
*
in_pos
[
w_id
*
num_channels
])
+
h1lambda
*
(
w2lambda
*
in_pos
[
h_id
*
in_img_w
*
num_channels
]
+
w1lambda
*
in_pos
[
h_id
*
in_img_w
*
num_channels
+
w_id
*
num_channels
]);
in_index
+=
in_index_stride
;
out_index
+=
out_index_stride
;
nc_id
+=
nc_stride
;
}
}
}
template
<
typename
T
>
__forceinline__
__device__
void
PreCalculatorForInputIndex
(
int
*
in_img_idx
,
int
*
in_img_idy
,
int
*
w_id
,
int
*
h_id
,
T
*
w1lambda
,
T
*
h1lambda
,
T
*
w2lambda
,
T
*
h2lambda
,
T
src_w
,
T
src_h
,
const
int
in_img_w
,
const
int
in_img_h
)
{
src_w
=
(
src_w
>
0
)
?
src_w
:
0.
f
;
src_h
=
(
src_h
>
0
)
?
src_h
:
0.
f
;
*
in_img_idx
=
static_cast
<
int
>
(
src_w
);
*
in_img_idy
=
static_cast
<
int
>
(
src_h
);
*
w_id
=
(
*
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
*
h_id
=
(
*
in_img_idy
<
in_img_h
-
1
)
?
1
:
0
;
*
w1lambda
=
src_w
-
*
in_img_idx
;
*
h1lambda
=
src_h
-
*
in_img_idy
;
*
w2lambda
=
1.
f
-
*
w1lambda
;
*
h2lambda
=
1.
f
-
*
h1lambda
;
__global__
void
KeBilinearInterpFw
(
const
T
*
in
,
const
size_t
in_img_h
,
const
size_t
in_img_w
,
const
size_t
input_h
,
const
size_t
input_w
,
T
*
out
,
const
size_t
out_img_h
,
const
size_t
out_img_w
,
const
size_t
output_h
,
const
size_t
output_w
,
const
size_t
num_channels
,
const
float
ratio_h
,
const
float
ratio_w
,
const
T
align_type_value
,
FastDivModForInterpolate
divmods
)
{
int
nthreads
=
output_h
*
output_w
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(;
tid
<
nthreads
;
tid
+=
stride
)
{
auto
out_id_divmod
=
divmods
.
output_w_div
.
Divmod
(
tid
);
int
out_id_h
=
out_id_divmod
.
val
[
0
];
int
out_id_w
=
out_id_divmod
.
val
[
1
];
int
channel_id
=
divmods
.
channels_div
.
Divmod
(
tid
).
val
[
1
];
auto
outimg_id_divmod
=
divmods
.
output_wc_div
.
Divmod
(
out_id_w
);
int
out_img_idy
=
outimg_id_divmod
.
val
[
0
];
int
out_img_idx
=
divmods
.
channels_div
.
Divmod
(
outimg_id_divmod
.
val
[
1
]).
val
[
0
];
int
in_img_idx
,
in_img_idy
,
h_id
,
w_id
;
T
h1lambda
,
w1lambda
,
h2lambda
,
w2lambda
;
T
src_w
=
ratio_w
*
(
out_img_idx
+
align_type_value
)
-
align_type_value
;
T
src_h
=
ratio_h
*
(
out_img_idy
+
align_type_value
)
-
align_type_value
;
PreCalculatorForLinearInterpInputIndex
(
&
in_img_idx
,
&
w_id
,
&
w1lambda
,
&
w2lambda
,
src_w
,
in_img_w
);
PreCalculatorForLinearInterpInputIndex
(
&
in_img_idy
,
&
h_id
,
&
h1lambda
,
&
h2lambda
,
src_h
,
in_img_h
);
// bilinear interpolation
const
T
*
in_pos
=
&
in
[
out_id_h
*
input_w
+
in_img_idy
*
in_img_w
*
num_channels
+
in_img_idx
*
num_channels
+
channel_id
];
out
[
tid
]
=
h2lambda
*
(
w2lambda
*
in_pos
[
0
]
+
w1lambda
*
in_pos
[
w_id
*
num_channels
])
+
h1lambda
*
(
w2lambda
*
in_pos
[
h_id
*
in_img_w
*
num_channels
]
+
w1lambda
*
in_pos
[
h_id
*
in_img_w
*
num_channels
+
w_id
*
num_channels
]);
}
}
/* Calculate the minimum of partial elements in a block */
...
...
@@ -574,9 +582,11 @@ __global__ void KeBilinearInterpBwShareMemory(
T
w1lambda
,
h1lambda
,
w2lambda
,
h2lambda
;
T
src_w
=
ratio_w
*
(
out_img_idx
+
align_type_value
)
-
align_type_value
;
T
src_h
=
ratio_h
*
(
out_img_idy
+
align_type_value
)
-
align_type_value
;
PreCalculatorForInputIndex
(
&
in_img_idx
,
&
in_img_idy
,
&
w_id
,
&
h_id
,
&
w1lambda
,
&
h1lambda
,
&
w2lambda
,
&
h2lambda
,
src_w
,
src_h
,
in_w
,
in_h
);
PreCalculatorForLinearInterpInputIndex
(
&
in_img_idx
,
&
w_id
,
&
w1lambda
,
&
w2lambda
,
src_w
,
in_w
);
PreCalculatorForLinearInterpInputIndex
(
&
in_img_idy
,
&
h_id
,
&
h1lambda
,
&
h2lambda
,
src_h
,
in_h
);
// top_left_index is just input_index.
int
input_index
=
out_id_h
*
in_chw
+
channel_id
*
in_img_size
+
...
...
@@ -661,9 +671,11 @@ __global__ void KeBilinearInterpBw(T* in, const int in_h, const int in_w,
T
src_w
=
ratio_w
*
(
out_img_idx
+
align_type_value
)
-
align_type_value
;
T
src_h
=
ratio_h
*
(
out_img_idy
+
align_type_value
)
-
align_type_value
;
PreCalculatorForInputIndex
(
&
in_img_idx
,
&
in_img_idy
,
&
w_id
,
&
h_id
,
&
w1lambda
,
&
h1lambda
,
&
w2lambda
,
&
h2lambda
,
src_w
,
src_h
,
in_w
,
in_h
);
PreCalculatorForLinearInterpInputIndex
(
&
in_img_idx
,
&
w_id
,
&
w1lambda
,
&
w2lambda
,
src_w
,
in_w
);
PreCalculatorForLinearInterpInputIndex
(
&
in_img_idy
,
&
h_id
,
&
h1lambda
,
&
h2lambda
,
src_h
,
in_h
);
T
*
in_pos
=
&
in
[
out_id_h
*
in_chw
+
channel_id
*
in_img_size
+
in_img_idy
*
in_w
+
in_img_idx
];
...
...
@@ -690,9 +702,11 @@ __global__ void KeBilinearInterpBw(T* in, const int in_h, const int in_w,
T
w1lambda
,
h1lambda
,
w2lambda
,
h2lambda
;
T
src_w
=
ratio_w
*
(
out_img_idx
+
align_type_value
)
-
align_type_value
;
T
src_h
=
ratio_h
*
(
out_img_idy
+
align_type_value
)
-
align_type_value
;
PreCalculatorForInputIndex
(
&
in_img_idx
,
&
in_img_idy
,
&
w_id
,
&
h_id
,
&
w1lambda
,
&
h1lambda
,
&
w2lambda
,
&
h2lambda
,
src_w
,
src_h
,
in_w
,
in_h
);
PreCalculatorForLinearInterpInputIndex
(
&
in_img_idx
,
&
w_id
,
&
w1lambda
,
&
w2lambda
,
src_w
,
in_w
);
PreCalculatorForLinearInterpInputIndex
(
&
in_img_idy
,
&
h_id
,
&
h1lambda
,
&
h2lambda
,
src_h
,
in_h
);
T
*
in_pos
=
&
in
[
out_id_h
*
in_chw
+
in_img_idy
*
in_w
*
num_channels
+
in_img_idx
*
num_channels
+
channel_id
];
...
...
@@ -1398,11 +1412,25 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
thread_num
=
512
;
}
#endif
KeBilinearInterpFw
<
T
><<<
config
.
block_per_grid
,
thread_num
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_data
,
in_h
,
in_w
,
n
,
in_chw
,
output_data
,
out_h
,
out_w
,
n
,
out_chw
,
c
,
ratio_h
,
ratio_w
,
align_corners
,
align_mode
,
data_layout
);
const
T
align_type_value
=
(
align_mode
==
0
&&
!
align_corners
)
?
0.5
f
:
0
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
// get launch 3D config
int
nc
=
n
*
c
;
platform
::
GpuLaunchConfig
config_3d
=
GetGpuLaunchConfig3D
(
ctx
.
cuda_device_context
(),
nc
,
out_h
,
out_w
);
KeBilinearInterpNCHWFw
<
T
><<<
config_3d
.
block_per_grid
,
config_3d
.
thread_per_block
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_data
,
in_h
,
in_w
,
output_data
,
out_h
,
out_w
,
nc
,
ratio_h
,
ratio_w
,
align_type_value
);
}
else
{
int64_t
cw
=
c
*
out_w
;
auto
interp_divmods
=
FastDivModForInterpolate
(
c
,
out_chw
,
cw
);
KeBilinearInterpFw
<
T
><<<
config
.
block_per_grid
,
thread_num
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_data
,
in_h
,
in_w
,
n
,
in_chw
,
output_data
,
out_h
,
out_w
,
n
,
out_chw
,
c
,
ratio_h
,
ratio_w
,
align_type_value
,
interp_divmods
);
}
}
else
if
(
"bicubic"
==
interp_method
)
{
#ifdef __HIPCC__
constexpr
int
thread_per_block
=
256
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录