Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
d8afe407
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d8afe407
编写于
4月 12, 2021
作者:
L
limingshu
提交者:
GitHub
4月 12, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimization of bilinear backward OP CUDA kernel. (#30950)
上级
af374ae6
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
218 addition
and
70 deletion
+218
-70
paddle/fluid/operators/interpolate_v2_op.cu
paddle/fluid/operators/interpolate_v2_op.cu
+218
-70
未找到文件。
paddle/fluid/operators/interpolate_v2_op.cu
浏览文件 @
d8afe407
...
@@ -12,6 +12,8 @@
...
@@ -12,6 +12,8 @@
#include <algorithm>
#include <algorithm>
#include <string>
#include <string>
#include "paddle/fluid/operators/interpolate_v2_op.h"
#include "paddle/fluid/operators/interpolate_v2_op.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
...
@@ -302,81 +304,214 @@ __global__ void KeBilinearInterpFw(
...
@@ -302,81 +304,214 @@ __global__ void KeBilinearInterpFw(
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
KeBilinearInterpBw
(
__forceinline__
__device__
void
PreCalculatorForInputIndex
(
T
*
in
,
const
size_t
in_img_h
,
const
size_t
in_img_w
,
const
size_t
input_h
,
int
*
in_img_idx
,
int
*
in_img_idy
,
int
*
w_id
,
int
*
h_id
,
T
*
w1lambda
,
const
size_t
input_w
,
const
T
*
out
,
const
size_t
out_img_h
,
T
*
h1lambda
,
T
*
w2lambda
,
T
*
h2lambda
,
T
src_w
,
T
src_h
,
const
int
in_img_w
,
const
size_t
out_img_w
,
const
size_t
output_h
,
const
size_t
output_w
,
const
int
in_img_h
)
{
const
size_t
num_channels
,
const
T
ratio_h
,
const
T
ratio_w
,
src_w
=
(
src_w
>
0
)
?
src_w
:
0.
f
;
const
bool
align_corners
,
const
int
align_mode
,
src_h
=
(
src_h
>
0
)
?
src_h
:
0.
f
;
const
DataLayout
data_layout
)
{
*
in_img_idx
=
static_cast
<
int
>
(
src_w
);
int
nthreads
=
output_h
*
output_w
;
*
in_img_idy
=
static_cast
<
int
>
(
src_h
);
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
*
w_id
=
(
*
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
*
h_id
=
(
*
in_img_idy
<
in_img_h
-
1
)
?
1
:
0
;
bool
align_flag
=
(
align_mode
==
0
&&
!
align_corners
);
*
w1lambda
=
src_w
-
*
in_img_idx
;
for
(;
tid
<
nthreads
;
tid
+=
stride
)
{
*
h1lambda
=
src_h
-
*
in_img_idy
;
int
out_id_h
=
tid
/
output_w
;
*
w2lambda
=
1.
f
-
*
w1lambda
;
int
out_id_w
=
tid
%
output_w
;
*
h2lambda
=
1.
f
-
*
h1lambda
;
int
in_img_size
=
input_w
/
num_channels
;
}
int
out_img_size
=
output_w
/
num_channels
;
int
channel_id
,
out_img_idy
,
out_img_idx
;
/* Calculate the minimum of partial elements in a block */
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
template
<
typename
T
>
channel_id
=
out_id_w
/
out_img_size
;
__inline__
__device__
T
PartialBlockMin
(
T
val
,
size_t
threads_num_in_block
,
out_img_idy
=
(
out_id_w
%
out_img_size
)
/
out_img_w
;
unsigned
mask
)
{
out_img_idx
=
tid
%
out_img_w
;
__shared__
T
shared
[
WARP_SIZE
];
}
else
{
__shared__
T
shared_last_val
;
out_img_idy
=
out_id_w
/
(
out_img_w
*
num_channels
);
__shared__
int
shared_last_idx
;
out_img_idx
=
out_id_w
%
(
out_img_w
*
num_channels
)
/
num_channels
;
int
lane
=
threadIdx
.
x
&
0x1f
;
channel_id
=
tid
%
num_channels
;
int
wid
=
threadIdx
.
x
>>
5
;
int
threshold
=
(
threads_num_in_block
&
(
-
WARP_SIZE
));
if
(
threadIdx
.
x
<
threshold
)
{
shared_last_idx
=
(
threshold
>>
5
)
-
1
;
val
=
math
::
warpReduceMin
(
val
,
mask
);
if
(
lane
==
0
)
{
shared
[
wid
]
=
val
;
}
}
}
else
{
shared_last_val
=
std
::
numeric_limits
<
T
>::
max
();
platform
::
CudaAtomicMin
(
&
shared_last_val
,
val
);
shared
[
wid
]
=
shared_last_val
;
shared_last_idx
=
wid
;
}
__syncthreads
();
int
in_img_idy
=
align_flag
?
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
if
(
threadIdx
.
x
<
threshold
)
{
:
ratio_h
*
out_img_idy
;
val
=
(
lane
<=
shared_last_idx
)
?
shared
[
lane
]
in_img_idy
=
(
in_img_idy
>
0
)
?
in_img_idy
:
0
;
:
std
::
numeric_limits
<
T
>::
max
();
int
h_id
=
(
in_img_idy
<
in_img_h
-
1
)
?
1
:
0
;
val
=
math
::
warpReduceMin
(
val
,
mask
);
T
src_h
=
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
;
shared_last_val
=
val
;
src_h
=
(
src_h
>
0
)
?
src_h
:
0
;
}
T
h1lambda
=
__syncthreads
();
align_flag
?
src_h
-
in_img_idy
:
ratio_h
*
out_img_idy
-
in_img_idy
;
if
(
threadIdx
.
x
>=
threshold
)
{
T
h2lambda
=
1.
f
-
h1lambda
;
val
=
shared_last_val
;
}
return
val
;
}
int
in_img_idx
=
align_flag
?
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
template
<
typename
T
>
:
ratio_w
*
out_img_idx
;
__global__
void
KeBilinearInterpBwShareMemory
(
in_img_idx
=
(
in_img_idx
>
0
)
?
in_img_idx
:
0
;
T
*
in
,
const
int
in_h
,
const
int
in_w
,
const
T
*
__restrict__
out
,
int
w_id
=
(
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
const
int
out_h
,
const
int
out_w
,
const
int
n
,
const
int
num_channels
,
T
src_w
=
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
;
float
ratio_h
,
float
ratio_w
,
const
T
align_type_value
,
bool
is_nchw
)
{
src_w
=
(
src_w
>
0
)
?
src_w
:
0
;
__shared__
T
s_data
[
2
][
1024
];
T
w1lambda
=
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
align_flag
?
src_w
-
in_img_idx
:
ratio_w
*
out_img_idx
-
in_img_idx
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
T
w2lambda
=
1.
f
-
w1lambda
;
int
in_chw
=
in_h
*
in_w
*
num_channels
;
int
out_chw
=
num_channels
*
out_h
*
out_w
;
int
nthreads
=
n
*
out_chw
;
T
*
in_pos
;
for
(;
tid
<
nthreads
;
tid
+=
stride
)
{
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
int
out_id_h
=
tid
/
out_chw
;
in_pos
=
&
in
[
out_id_h
*
input_w
+
channel_id
*
in_img_size
+
int
out_id_w
=
tid
%
out_chw
;
in_img_idy
*
in_img_w
+
in_img_idx
];
const
int
in_img_size
=
in_h
*
in_w
;
const
int
out_img_size
=
out_h
*
out_w
;
T
value
=
out
[
out_id_h
*
out_chw
+
out_id_w
];
int
channel_id
=
out_id_w
/
out_img_size
;
int
out_img_idy
=
(
out_id_w
%
out_img_size
)
/
out_w
;
int
out_img_idx
=
tid
%
out_w
;
int
in_img_idx
,
in_img_idy
,
w_id
,
h_id
;
T
w1lambda
,
h1lambda
,
w2lambda
,
h2lambda
;
T
src_w
=
ratio_w
*
(
out_img_idx
+
align_type_value
)
-
align_type_value
;
T
src_h
=
ratio_h
*
(
out_img_idy
+
align_type_value
)
-
align_type_value
;
PreCalculatorForInputIndex
(
&
in_img_idx
,
&
in_img_idy
,
&
w_id
,
&
h_id
,
&
w1lambda
,
&
h1lambda
,
&
w2lambda
,
&
h2lambda
,
src_w
,
src_h
,
in_w
,
in_h
);
// top_left_index is just input_index.
int
input_index
=
out_id_h
*
in_chw
+
channel_id
*
in_img_size
+
in_img_idy
*
in_w
+
in_img_idx
;
int
top_right_index
=
input_index
+
w_id
;
int
bot_left_index
=
input_index
+
h_id
*
in_w
;
int
bot_right_index
=
input_index
+
h_id
*
in_w
+
w_id
;
int
in_top_min_index
,
in_bot_min_index
;
s_data
[
0
][
threadIdx
.
x
]
=
0.
f
;
s_data
[
1
][
threadIdx
.
x
]
=
0.
f
;
int
remain
=
nthreads
-
(
tid
&
(
-
blockDim
.
x
));
int
in_top_max_index
=
math
::
blockReduceMax
(
top_right_index
,
FINAL_MASK
);
int
in_bot_max_index
=
math
::
blockReduceMax
(
bot_right_index
,
FINAL_MASK
);
if
(
remain
>
blockDim
.
x
)
{
in_top_min_index
=
math
::
blockReduceMin
(
input_index
,
FINAL_MASK
);
in_bot_min_index
=
math
::
blockReduceMin
(
bot_left_index
,
FINAL_MASK
);
}
else
{
}
else
{
in_
pos
=
&
in
[
out_id_h
*
input_w
+
in_img_idy
*
in_img_w
*
num_channels
+
in_
top_min_index
=
PartialBlockMin
(
input_index
,
remain
,
FINAL_MASK
);
in_img_idx
*
num_channels
+
channel_id
]
;
in_bot_min_index
=
PartialBlockMin
(
bot_left_index
,
remain
,
FINAL_MASK
)
;
}
}
int
upper_limit_share_idx
=
(
in_top_max_index
-
in_top_min_index
)
>
(
in_bot_max_index
-
in_bot_min_index
)
?
(
in_top_max_index
-
in_top_min_index
)
:
(
in_bot_max_index
-
in_bot_min_index
);
if
(
h_id
!=
0
)
{
platform
::
CudaAtomicAdd
(
&
s_data
[
0
][
input_index
-
in_top_min_index
],
h2lambda
*
w2lambda
*
value
);
platform
::
CudaAtomicAdd
(
&
s_data
[
0
][
top_right_index
-
in_top_min_index
],
h2lambda
*
w1lambda
*
value
);
platform
::
CudaAtomicAdd
(
&
s_data
[
1
][
bot_left_index
-
in_bot_min_index
],
h1lambda
*
w2lambda
*
value
);
platform
::
CudaAtomicAdd
(
&
s_data
[
1
][
bot_right_index
-
in_bot_min_index
],
h1lambda
*
w1lambda
*
value
);
}
else
{
platform
::
CudaAtomicAdd
(
&
s_data
[
0
][
top_right_index
-
in_top_min_index
],
(
h2lambda
+
h1lambda
)
*
w1lambda
*
value
);
platform
::
CudaAtomicAdd
(
&
s_data
[
1
][
bot_left_index
-
in_bot_min_index
],
(
h1lambda
+
h2lambda
)
*
w2lambda
*
value
);
}
__syncthreads
();
const
T
*
out_pos
=
&
out
[
out_id_h
*
output_w
+
out_id_w
];
if
(
threadIdx
.
x
<=
upper_limit_share_idx
)
{
platform
::
CudaAtomicAdd
(
&
in
[
in_top_min_index
+
threadIdx
.
x
],
s_data
[
0
][
threadIdx
.
x
]);
platform
::
CudaAtomicAdd
(
&
in
[
in_bot_min_index
+
threadIdx
.
x
],
s_data
[
1
][
threadIdx
.
x
]);
}
}
}
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
template
<
typename
T
>
platform
::
CudaAtomicAdd
(
&
in_pos
[
0
],
h2lambda
*
w2lambda
*
out_pos
[
0
]);
__global__
void
KeBilinearInterpBw
(
T
*
in
,
const
int
in_h
,
const
int
in_w
,
platform
::
CudaAtomicAdd
(
&
in_pos
[
w_id
],
h2lambda
*
w1lambda
*
out_pos
[
0
]);
const
T
*
__restrict__
out
,
const
int
out_h
,
platform
::
CudaAtomicAdd
(
&
in_pos
[
h_id
*
in_img_w
],
const
int
out_w
,
const
int
n
,
h1lambda
*
w2lambda
*
out_pos
[
0
]);
const
int
num_channels
,
float
ratio_h
,
platform
::
CudaAtomicAdd
(
&
in_pos
[
h_id
*
in_img_w
+
w_id
],
float
ratio_w
,
const
T
align_type_value
,
h1lambda
*
w1lambda
*
out_pos
[
0
]);
bool
is_nchw
)
{
}
else
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
platform
::
CudaAtomicAdd
(
&
in_pos
[
0
],
h2lambda
*
w2lambda
*
out_pos
[
0
]);
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
in_chw
=
in_h
*
in_w
*
num_channels
;
int
out_chw
=
num_channels
*
out_h
*
out_w
;
int
nthreads
=
n
*
out_chw
;
if
(
is_nchw
)
{
for
(;
tid
<
nthreads
;
tid
+=
stride
)
{
int
out_id_h
=
tid
/
out_chw
;
int
out_id_w
=
tid
%
out_chw
;
const
int
in_img_size
=
in_h
*
in_w
;
const
int
out_img_size
=
out_h
*
out_w
;
T
value
=
out
[
out_id_h
*
out_chw
+
out_id_w
];
int
channel_id
=
out_id_w
/
out_img_size
;
int
out_img_idy
=
(
out_id_w
%
out_img_size
)
/
out_w
;
int
out_img_idx
=
tid
%
out_w
;
int
in_img_idx
,
in_img_idy
,
w_id
,
h_id
;
T
w1lambda
,
h1lambda
,
w2lambda
,
h2lambda
;
T
src_w
=
ratio_w
*
(
out_img_idx
+
align_type_value
)
-
align_type_value
;
T
src_h
=
ratio_h
*
(
out_img_idy
+
align_type_value
)
-
align_type_value
;
PreCalculatorForInputIndex
(
&
in_img_idx
,
&
in_img_idy
,
&
w_id
,
&
h_id
,
&
w1lambda
,
&
h1lambda
,
&
w2lambda
,
&
h2lambda
,
src_w
,
src_h
,
in_w
,
in_h
);
T
*
in_pos
=
&
in
[
out_id_h
*
in_chw
+
channel_id
*
in_img_size
+
in_img_idy
*
in_w
+
in_img_idx
];
platform
::
CudaAtomicAdd
(
&
in_pos
[
0
],
h2lambda
*
w2lambda
*
value
);
platform
::
CudaAtomicAdd
(
&
in_pos
[
w_id
],
h2lambda
*
w1lambda
*
value
);
platform
::
CudaAtomicAdd
(
&
in_pos
[
h_id
*
in_w
],
h1lambda
*
w2lambda
*
value
);
platform
::
CudaAtomicAdd
(
&
in_pos
[
h_id
*
in_w
+
w_id
],
h1lambda
*
w1lambda
*
value
);
}
}
else
{
for
(;
tid
<
nthreads
;
tid
+=
stride
)
{
int
out_id_h
=
tid
/
out_chw
;
int
out_id_w
=
tid
%
out_chw
;
const
int
in_img_size
=
in_h
*
in_w
;
const
int
out_img_size
=
out_h
*
out_w
;
T
value
=
out
[
out_id_h
*
out_chw
+
out_id_w
];
int
out_img_idy
=
out_id_w
/
(
out_w
*
num_channels
);
int
out_img_idx
=
out_id_w
%
(
out_w
*
num_channels
)
/
num_channels
;
int
channel_id
=
tid
%
num_channels
;
int
in_img_idx
,
in_img_idy
,
w_id
,
h_id
;
T
w1lambda
,
h1lambda
,
w2lambda
,
h2lambda
;
T
src_w
=
ratio_w
*
(
out_img_idx
+
align_type_value
)
-
align_type_value
;
T
src_h
=
ratio_h
*
(
out_img_idy
+
align_type_value
)
-
align_type_value
;
PreCalculatorForInputIndex
(
&
in_img_idx
,
&
in_img_idy
,
&
w_id
,
&
h_id
,
&
w1lambda
,
&
h1lambda
,
&
w2lambda
,
&
h2lambda
,
src_w
,
src_h
,
in_w
,
in_h
);
T
*
in_pos
=
&
in
[
out_id_h
*
in_chw
+
in_img_idy
*
in_w
*
num_channels
+
in_img_idx
*
num_channels
+
channel_id
];
platform
::
CudaAtomicAdd
(
&
in_pos
[
0
],
h2lambda
*
w2lambda
*
value
);
platform
::
CudaAtomicAdd
(
&
in_pos
[
w_id
*
num_channels
],
platform
::
CudaAtomicAdd
(
&
in_pos
[
w_id
*
num_channels
],
h2lambda
*
w1lambda
*
out_pos
[
0
]
);
h2lambda
*
w1lambda
*
value
);
platform
::
CudaAtomicAdd
(
&
in_pos
[
h_id
*
in_
img_
w
*
num_channels
],
platform
::
CudaAtomicAdd
(
&
in_pos
[
h_id
*
in_w
*
num_channels
],
h1lambda
*
w2lambda
*
out_pos
[
0
]
);
h1lambda
*
w2lambda
*
value
);
platform
::
CudaAtomicAdd
(
platform
::
CudaAtomicAdd
(
&
in_pos
[
h_id
*
in_
img_
w
*
num_channels
+
w_id
*
num_channels
],
&
in_pos
[
h_id
*
in_w
*
num_channels
+
w_id
*
num_channels
],
h1lambda
*
w1lambda
*
out_pos
[
0
]
);
h1lambda
*
w1lambda
*
value
);
}
}
}
}
}
}
...
@@ -1373,7 +1508,6 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
...
@@ -1373,7 +1508,6 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
int
out_hw
=
out_h
*
out_w
;
int
out_hw
=
out_h
*
out_w
;
int
in_chw
=
c
*
in_hw
;
int
in_chw
=
c
*
in_hw
;
int
out_chw
=
c
*
out_hw
;
int
out_chw
=
c
*
out_hw
;
int
pixelNum
=
n
*
out_chw
;
int
pixelNum
=
n
*
out_chw
;
platform
::
GpuLaunchConfig
config
=
platform
::
GpuLaunchConfig
config
=
...
@@ -1386,11 +1520,25 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
...
@@ -1386,11 +1520,25 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
input_grad_data
,
in_h
,
in_w
,
n
,
in_chw
,
output_grad_data
,
out_h
,
out_w
,
input_grad_data
,
in_h
,
in_w
,
n
,
in_chw
,
output_grad_data
,
out_h
,
out_w
,
n
,
out_chw
,
c
,
ratio_h
,
ratio_w
,
align_corners
,
data_layout
);
n
,
out_chw
,
c
,
ratio_h
,
ratio_w
,
align_corners
,
data_layout
);
}
else
if
(
"bilinear"
==
interp_method
)
{
}
else
if
(
"bilinear"
==
interp_method
)
{
KeBilinearInterpBw
<
T
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
const
T
align_type_value
=
(
align_mode
==
0
&&
!
align_corners
)
?
0.5
f
:
0
;
ctx
.
cuda_device_context
().
stream
()
>>>
(
bool
is_nchw
=
(
data_layout
==
DataLayout
::
kNCHW
)
?
true
:
false
;
input_grad_data
,
in_h
,
in_w
,
n
,
in_chw
,
output_grad_data
,
out_h
,
out_w
,
bool
optimize_flag
=
false
;
n
,
out_chw
,
c
,
ratio_h
,
ratio_w
,
align_corners
,
align_mode
,
optimize_flag
=
(
in_h
<
(
out_h
>>
6
)
&&
in_w
<
(
out_w
>>
6
))
data_layout
);
?
true
:
((
in_h
==
1
&&
in_w
==
1
)
?
true
:
false
);
if
(
optimize_flag
&
is_nchw
)
{
KeBilinearInterpBwShareMemory
<
T
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_grad_data
,
in_h
,
in_w
,
output_grad_data
,
out_h
,
out_w
,
n
,
c
,
ratio_h
,
ratio_w
,
align_type_value
,
is_nchw
);
}
else
{
KeBilinearInterpBw
<
T
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_grad_data
,
in_h
,
in_w
,
output_grad_data
,
out_h
,
out_w
,
n
,
c
,
ratio_h
,
ratio_w
,
align_type_value
,
is_nchw
);
}
}
else
if
(
"bicubic"
==
interp_method
)
{
}
else
if
(
"bicubic"
==
interp_method
)
{
KeBicubicInterpBw
<
T
><<<
config
.
block_per_grid
,
512
,
0
,
KeBicubicInterpBw
<
T
><<<
config
.
block_per_grid
,
512
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
ctx
.
cuda_device_context
().
stream
()
>>>
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录