Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4acc87be
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4acc87be
编写于
4月 01, 2021
作者:
Z
Zhang Zheng
提交者:
GitHub
4月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize the perf of SameDimsAdd CUDA Kernel (#31872)
上级
980227f9
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
125 addition
and
55 deletion
+125
-55
paddle/fluid/operators/elementwise/elementwise_add_op.cu
paddle/fluid/operators/elementwise/elementwise_add_op.cu
+62
-26
paddle/fluid/operators/elementwise/elementwise_div_op.cu
paddle/fluid/operators/elementwise/elementwise_div_op.cu
+1
-1
paddle/fluid/operators/elementwise/elementwise_mul_op.cu
paddle/fluid/operators/elementwise/elementwise_mul_op.cu
+1
-1
paddle/fluid/operators/elementwise/elementwise_op_function.cu.h
.../fluid/operators/elementwise/elementwise_op_function.cu.h
+60
-26
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
+1
-1
未找到文件。
paddle/fluid/operators/elementwise/elementwise_add_op.cu
浏览文件 @
4acc87be
...
...
@@ -24,7 +24,10 @@ namespace paddle {
namespace
operators
{
template
<
typename
T
>
struct
SameDimsElemwiseAdd
<
platform
::
CUDADeviceContext
,
T
>
{
struct
SameDimsElemwiseAdd
<
platform
::
CUDADeviceContext
,
T
,
typename
std
::
enable_if
<!
std
::
is_same
<
T
,
platform
::
float16
>::
value
&&
!
std
::
is_same
<
T
,
float
>::
value
>::
type
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
...
...
@@ -36,16 +39,28 @@ struct SameDimsElemwiseAdd<platform::CUDADeviceContext, T> {
}
};
template
<
>
struct
SameDimsElemwiseAdd
<
platform
::
CUDADeviceContext
,
platform
::
float16
>
{
template
<
typename
T
>
struct
SameDimsElemwiseAdd
<
platform
::
CUDADeviceContext
,
T
,
typename
std
::
enable_if
<
std
::
is_same
<
T
,
platform
::
float16
>::
value
||
std
::
is_same
<
T
,
float
>::
value
>::
type
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
size
=
x
->
numel
();
dim3
grid_size
=
dim3
(((
size
+
1
)
/
2
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
int
vec_size
=
sizeof
(
float4
)
/
sizeof
(
T
);
dim3
grid_size
=
dim3
(((
size
+
vec_size
-
1
)
/
vec_size
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
PADDLE_CUDA_THREAD_SIZE
,
1
);
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD_SIZE
,
1
);
if
(
std
::
is_same
<
T
,
float
>::
value
)
{
SameDimsElemwiseAddCUDAKernel
<<<
grid_size
,
block_size
,
0
,
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>()
.
stream
()
>>>
(
x
->
data
<
float
>
(),
y
->
data
<
float
>
(),
z
->
data
<
float
>
(),
size
);
}
else
{
const
half
*
x2
=
reinterpret_cast
<
const
half
*>
(
x
->
data
<
platform
::
float16
>
());
const
half
*
y2
=
...
...
@@ -53,21 +68,39 @@ struct SameDimsElemwiseAdd<platform::CUDADeviceContext, platform::float16> {
half
*
z2
=
reinterpret_cast
<
half
*>
(
z
->
data
<
platform
::
float16
>
());
SameDimsElemwiseAddCUDAKernel
<<<
grid_size
,
block_size
,
0
,
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>().
stream
()
>>>
(
x2
,
y2
,
z2
,
size
);
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>()
.
stream
()
>>>
(
x2
,
y2
,
z2
,
size
);
}
}
};
template
<
typename
T
>
static
__global__
void
SimpleElemwiseAddGradCUDAKernel
(
const
T
*
dout
,
int64_t
size
,
T
*
dx
,
T
*
dy
)
{
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
static
__global__
void
SimpleElemwiseAddGradCUDAKernel
(
const
T
*
__restrict__
dout
,
int
size
,
int
vec_size
,
T
*
dx
,
T
*
dy
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
int
loop
=
size
/
vec_size
;
int
remainder
=
size
%
vec_size
;
const
float4
*
dout_vec
=
reinterpret_cast
<
const
float4
*>
(
dout
);
float4
*
dx_vec
=
reinterpret_cast
<
float4
*>
(
dx
);
float4
*
dy_vec
=
reinterpret_cast
<
float4
*>
(
dy
);
float4
tmp_loop
;
for
(
int
i
=
tid
;
i
<
loop
;
i
+=
stride
)
{
tmp_loop
=
dout_vec
[
i
];
dx_vec
[
i
]
=
tmp_loop
;
dy_vec
[
i
]
=
tmp_loop
;
}
while
(
col
<
size
)
{
dx
[
col
]
=
dout
[
col
];
dy
[
col
]
=
dout
[
col
];
col
+=
blockDim
.
x
*
gridDim
.
x
;
if
(
tid
==
loop
&&
remainder
!=
0
)
{
T
tmp_rem
;
while
(
remainder
)
{
int
idx
=
size
-
remainder
;
remainder
--
;
tmp_rem
=
dout
[
idx
];
dx
[
idx
]
=
tmp_rem
;
dy
[
idx
]
=
tmp_rem
;
}
}
}
...
...
@@ -79,14 +112,17 @@ elementwise_add_grad(const framework::ExecutionContext& ctx,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD_SIZE
,
1
);
auto
size
=
x
->
numel
();
int
vec_size
=
max
(
static_cast
<
int
>
(
sizeof
(
float4
)
/
sizeof
(
T
)),
1
);
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD_SIZE
,
1
);
dim3
grid_size
=
dim3
((
size
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
PADDLE_CUDA_THREAD_SIZE
,
1
);
dim3
(((
size
+
vec_size
-
1
)
/
vec_size
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
PADDLE_CUDA_THREAD_SIZE
,
1
);
SimpleElemwiseAddGradCUDAKernel
<
T
><<<
grid_size
,
block_size
,
0
,
ctx
.
template
device_context
<
plat
::
CUDADeviceContext
>().
stream
()
>>>
(
dout
->
data
<
T
>
(),
size
,
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dout
->
data
<
T
>
(),
size
,
vec_size
,
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
...
...
paddle/fluid/operators/elementwise/elementwise_div_op.cu
浏览文件 @
4acc87be
...
...
@@ -43,7 +43,7 @@ struct SameDimsElemwiseDiv<platform::CUDADeviceContext, platform::float16> {
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
size
=
x
->
numel
();
dim3
grid_size
=
dim3
(((
size
+
1
)
/
2
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
dim3
grid_size
=
dim3
(((
size
+
7
)
/
8
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
PADDLE_CUDA_THREAD_SIZE
,
1
);
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD_SIZE
,
1
);
...
...
paddle/fluid/operators/elementwise/elementwise_mul_op.cu
浏览文件 @
4acc87be
...
...
@@ -43,7 +43,7 @@ struct SameDimsElemwiseMul<platform::CUDADeviceContext, platform::float16> {
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
size
=
x
->
numel
();
dim3
grid_size
=
dim3
(((
size
+
1
)
/
2
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
dim3
grid_size
=
dim3
(((
size
+
7
)
/
8
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
PADDLE_CUDA_THREAD_SIZE
,
1
);
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD_SIZE
,
1
);
...
...
paddle/fluid/operators/elementwise/elementwise_op_function.cu.h
浏览文件 @
4acc87be
...
...
@@ -18,7 +18,11 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/hostdevice.h"
#ifdef __HIPCC__
#define PADDLE_CUDA_THREAD_SIZE 256
#else
#define PADDLE_CUDA_THREAD_SIZE 512
#endif
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
...
...
@@ -159,30 +163,60 @@ inline DEVICE half2 half2_div(const half2& a, const half2& b) {
}
#define DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Func, expr, FP16Function) \
template <typename T> \
__global__ void SameDimsElemwise##Func##CUDAKernel(const T* x, const T* y, \
T* z, int64_t size) { \
int col = blockIdx.x * blockDim.x + threadIdx.x; \
while (col < size) { \
z[col] = x[col] expr y[col]; \
col += blockDim.x * gridDim.x; \
inline __global__ void SameDimsElemwise##Func##CUDAKernel( \
const float* __restrict__ x, const float* __restrict__ y, float* z, \
int64_t size) { \
int tid = blockIdx.x * blockDim.x + threadIdx.x; \
int stride = gridDim.x * blockDim.x; \
int loop = size / 4; \
int remainder = size % 4; \
const float4* x_vec = reinterpret_cast<const float4*>(x); \
const float4* y_vec = reinterpret_cast<const float4*>(y); \
float4* z_vec = reinterpret_cast<float4*>(z); \
float4 x_f4, y_f4; \
for (int i = tid; i < loop; i += stride) { \
x_f4 = x_vec[i]; \
y_f4 = y_vec[i]; \
z_vec[i] = make_float4(x_f4.x expr y_f4.x, x_f4.y expr y_f4.y, \
x_f4.z expr y_f4.z, x_f4.w expr y_f4.w); \
} \
if (tid == loop && remainder != 0) { \
while (remainder) { \
int idx = size - remainder; \
remainder--; \
z[idx] = x[idx] expr y[idx]; \
} \
} \
} \
inline __global__ void SameDimsElemwise##Func##CUDAKernel( \
const half* __restrict__ x, const half* __restrict__ y, half* z, \
int64_t size) { \
int tid = blockIdx.x * blockDim.x + threadIdx.x; \
int stride = gridDim.x * blockDim.x; \
int loop = size / 8; \
int remainder = size % 8; \
const float4* x_vec = reinterpret_cast<const float4*>(x); \
const float4* y_vec = reinterpret_cast<const float4*>(y); \
float4* z_vec = reinterpret_cast<float4*>(z); \
float4 x_h8, y_h8, z_h8; \
for (int i = tid; i < loop; i += stride) { \
x_h8 = x_vec[i]; \
y_h8 = y_vec[i]; \
half2* x_h2 = reinterpret_cast<half2*>(&x_h8); \
half2* y_h2 = reinterpret_cast<half2*>(&y_h8); \
half2* z_h2 = reinterpret_cast<half2*>(&z_h8); \
z_h2[0] = FP16Function(x_h2[0], y_h2[0]); \
z_h2[1] = FP16Function(x_h2[1], y_h2[1]); \
z_h2[2] = FP16Function(x_h2[2], y_h2[2]); \
z_h2[3] = FP16Function(x_h2[3], y_h2[3]); \
z_vec[i] = z_h8; \
} \
template <> \
inline __global__ void SameDimsElemwise##Func##CUDAKernel<half>( \
const half* x, const half* y, half* z, int64_t size) { \
int start = threadIdx.x + blockDim.x * blockIdx.x; \
int stride = blockDim.x * gridDim.x; \
int n2 = size / 2; \
const half2* x2 = reinterpret_cast<const half2*>(x); \
const half2* y2 = reinterpret_cast<const half2*>(y); \
half2* z2 = reinterpret_cast<half2*>(z); \
for (int i = start; i < n2; i += stride) { \
z2[i] = FP16Function(x2[i], y2[i]); \
if (tid == loop && remainder != 0) { \
while (remainder) { \
int idx = size - remainder; \
remainder--; \
z[idx] = __float2half(__half2float(x[idx]) expr __half2float(y[idx])); \
} \
if (start == 0 && (size % 2)) { \
z[size - 1] = __float2half(__half2float(x[size - 1]) \
expr __half2float(y[size - 1])); \
} \
}
DEFINE_SIMPLE_CUDA_BINARY_KERNEL
(
Add
,
+
,
half2_add
)
...
...
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
浏览文件 @
4acc87be
...
...
@@ -43,7 +43,7 @@ struct SameDimsElemwiseSub<platform::CUDADeviceContext, platform::float16> {
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
size
=
x
->
numel
();
dim3
grid_size
=
dim3
(((
size
+
1
)
/
2
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
dim3
grid_size
=
dim3
(((
size
+
7
)
/
8
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
PADDLE_CUDA_THREAD_SIZE
,
1
);
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD_SIZE
,
1
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录