Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e77d1cac
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e77d1cac
编写于
12月 16, 2022
作者:
MarDino
提交者:
GitHub
12月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize bias_add reluv2 in half2 (#49048)
* optimize bias_add reluv2 in half2 * Add annotation * refine code format
上级
a5ce60b8
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
88 addition
and
35 deletion
+88
-35
paddle/phi/kernels/funcs/fc_functor.cu
paddle/phi/kernels/funcs/fc_functor.cu
+88
-35
未找到文件。
paddle/phi/kernels/funcs/fc_functor.cu
浏览文件 @
e77d1cac
...
@@ -15,6 +15,7 @@ limitations under the License. */
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include <algorithm>
#include <algorithm>
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/fc_functor.h"
#include "paddle/phi/kernels/funcs/fc_functor.h"
...
@@ -127,37 +128,54 @@ void AddReluKernel(
...
@@ -127,37 +128,54 @@ void AddReluKernel(
}
}
#if defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_CUDA)
template
<
bool
DoRelu
>
template
<
bool
DoRelu
,
int
Half2VecSize
>
__global__
void
bias_relu_v2
(
const
int
num
,
__global__
void
bias_relu_v4_half2
(
const
int
num
,
const
half2
*
bias
,
const
half2
*
bias
,
half2
*
data
,
half2
*
data
,
int
K
)
{
int
K
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
using
LoadT
=
phi
::
AlignedVector
<
half2
,
Half2VecSize
>
;
LoadT
data_vec
;
LoadT
bias_vec
;
const
int32_t
global_thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int32_t
grid_stride
=
gridDim
.
x
*
blockDim
.
x
;
if
(
tid
<
num
)
{
for
(
int32_t
linear_idx
=
global_thread_idx
*
Half2VecSize
;
linear_idx
<
num
;
int
bias_idx
=
tid
%
K
;
linear_idx
+=
grid_stride
*
Half2VecSize
)
{
const
half2
bias_ptr
=
bias
[
bias_idx
];
phi
::
Load
<
half2
,
Half2VecSize
>
(
&
data
[
linear_idx
],
&
data_vec
);
const
half2
in_ptr
=
data
[
tid
];
const
int
bias_idx
=
linear_idx
%
K
;
half2
packed_val
;
phi
::
Load
<
half2
,
Half2VecSize
>
(
&
bias
[
bias_idx
],
&
bias_vec
);
#pragma unroll
for
(
int
unroll_idx
=
0
;
unroll_idx
<
Half2VecSize
;
unroll_idx
++
)
{
// Do biasAdd
#if __CUDA_ARCH__ >= 530
#if __CUDA_ARCH__ >= 530
packed_val
=
__hadd2
(
bias_ptr
,
in_ptr
);
data_vec
[
unroll_idx
]
=
__hadd2
(
data_vec
[
unroll_idx
],
bias_vec
[
unroll_idx
]);
#else
#else
packed_val
.
x
=
__hadd
(
bias_ptr
.
x
,
in_ptr
.
x
);
data_vec
[
unroll_idx
].
x
=
packed_val
.
y
=
__hadd
(
bias_ptr
.
y
,
in_ptr
.
y
);
__hadd
(
data_vec
[
unroll_idx
].
x
,
bias_vec
[
unroll_idx
].
x
);
data_vec
[
unroll_idx
].
y
=
__hadd
(
data_vec
[
unroll_idx
].
y
,
bias_vec
[
unroll_idx
].
y
);
#endif
#endif
if
(
DoRelu
)
{
// Do relu
if
(
DoRelu
)
{
#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ >= 800
packed_val
=
__hmax2
(
__half2
(
0
,
0
),
packed_val
);
data_vec
[
unroll_idx
]
=
__hmax2
(
__half2
(
0
,
0
),
data_vec
[
unroll_idx
]
);
#elif __CUDA_ARCH__ >= 530
#elif __CUDA_ARCH__ >= 530
packed_val
=
__hmul2
(
__hgt2
(
__half2
(
0
,
0
),
packed_val
),
packed_val
);
data_vec
[
unroll_idx
]
=
__hmul2
(
__hgt2
(
__half2
(
0
,
0
),
data_vec
[
unroll_idx
]),
data_vec
[
unroll_idx
]);
#else
#else
packed_val
.
x
=
static_cast
<
int
>
(
static_cast
<
float
>
(
packed_val
.
x
)
>
0
)
*
data_vec
[
unroll_idx
].
x
=
static_cast
<
float
>
(
packed_val
.
x
);
static_cast
<
int
>
(
static_cast
<
float
>
(
data_vec
[
unroll_idx
].
x
)
>
0
)
*
packed_val
.
y
=
static_cast
<
int
>
(
static_cast
<
float
>
(
packed_val
.
y
)
>
0
)
*
static_cast
<
float
>
(
data_vec
[
unroll_idx
].
x
);
static_cast
<
float
>
(
packed_val
.
y
);
data_vec
[
unroll_idx
].
y
=
static_cast
<
int
>
(
static_cast
<
float
>
(
data_vec
[
unroll_idx
].
y
)
>
0
)
*
static_cast
<
float
>
(
data_vec
[
unroll_idx
].
y
);
#endif
#endif
}
}
}
data
[
tid
]
=
packed_val
;
phi
::
Store
<
half2
,
Half2VecSize
>
(
data_vec
,
&
data
[
linear_idx
])
;
}
}
}
}
...
@@ -188,6 +206,53 @@ __global__ void InplaceAddReluKernel(const int N,
...
@@ -188,6 +206,53 @@ __global__ void InplaceAddReluKernel(const int N,
}
}
}
}
/**
* brief: Launch BiasAddReluKernel with relu or not.
**/
template
<
int
Half2VecSize
>
void
LaunchBiasAddReluHalf2Kernel
(
cudaStream_t
stream
,
const
int32_t
rows
,
const
int32_t
cols
,
float16
*
Y
,
const
float16
*
B
,
bool
relu
)
{
const
int
threads
=
256
;
const
int
vec_num
=
rows
*
cols
/
(
Half2VecSize
*
2
);
const
int
half2_num
=
rows
*
cols
/
2
;
const
int
blocks
=
(
vec_num
+
threads
-
1
)
/
threads
;
// Here reinterpret_cast to half2 type.
typedef
typename
FcTypeTraits
<
float16
>::
Type
trans_type
;
auto
*
bias_half2_ptr
=
reinterpret_cast
<
const
trans_type
*>
(
B
);
auto
*
data_half2_ptr
=
reinterpret_cast
<
trans_type
*>
(
Y
);
if
(
relu
)
{
bias_relu_v4_half2
<
true
,
Half2VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
half2_num
,
bias_half2_ptr
,
data_half2_ptr
,
cols
/
2
);
}
else
{
bias_relu_v4_half2
<
false
,
Half2VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
half2_num
,
bias_half2_ptr
,
data_half2_ptr
,
cols
/
2
);
}
}
/**
* brief: Dispatch BiasAddReluKernel half2 type with 8 / 4 / 2 vecsize.
**/
void
DispatchBiasAddReluKernelHalf2VecSize
(
cudaStream_t
stream
,
const
int32_t
rows
,
const
int32_t
cols
,
float16
*
Y
,
const
float16
*
B
,
bool
relu
)
{
// Half Max Vecsize is 128 / 16 = 8, since we use half2 type, here
// Half2VecSize need divide 2.
if
(
cols
%
8
==
0
)
{
LaunchBiasAddReluHalf2Kernel
<
4
>
(
stream
,
rows
,
cols
,
Y
,
B
,
relu
);
}
else
if
(
cols
%
4
==
0
)
{
LaunchBiasAddReluHalf2Kernel
<
2
>
(
stream
,
rows
,
cols
,
Y
,
B
,
relu
);
}
else
{
LaunchBiasAddReluHalf2Kernel
<
1
>
(
stream
,
rows
,
cols
,
Y
,
B
,
relu
);
}
}
template
<
>
template
<
>
void
AddReluKernel
(
cudaStream_t
stream
,
void
AddReluKernel
(
cudaStream_t
stream
,
const
int
M
,
const
int
M
,
...
@@ -196,19 +261,7 @@ void AddReluKernel(cudaStream_t stream,
...
@@ -196,19 +261,7 @@ void AddReluKernel(cudaStream_t stream,
const
float16
*
B
,
const
float16
*
B
,
bool
relu
)
{
bool
relu
)
{
if
(
N
%
2
==
0
)
{
if
(
N
%
2
==
0
)
{
const
int
threads
=
256
;
DispatchBiasAddReluKernelHalf2VecSize
(
stream
,
M
,
N
,
Y
,
B
,
relu
);
const
int
num
=
M
*
N
/
2
;
const
int
blocks
=
(
num
+
threads
-
1
)
/
threads
;
typedef
typename
FcTypeTraits
<
float16
>::
Type
trans_type
;
auto
*
bias_ptr_v2
=
reinterpret_cast
<
const
trans_type
*>
(
B
);
auto
*
data_ptr_v2
=
reinterpret_cast
<
trans_type
*>
(
Y
);
if
(
relu
)
{
bias_relu_v2
<
true
><<<
blocks
,
threads
,
0
,
stream
>>>
(
num
,
bias_ptr_v2
,
data_ptr_v2
,
N
/
2
);
}
else
{
bias_relu_v2
<
false
><<<
blocks
,
threads
,
0
,
stream
>>>
(
num
,
bias_ptr_v2
,
data_ptr_v2
,
N
/
2
);
}
}
else
{
}
else
{
const
int
threads
=
256
;
const
int
threads
=
256
;
const
int
blocks
=
M
;
const
int
blocks
=
M
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录