Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
fcfaa104
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fcfaa104
编写于
7月 22, 2022
作者:
M
ming1753
提交者:
GitHub
7月 22, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
(modified) fc support fp16 (#44540)
上级
3b0aa75e
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
18 addition
and
43 deletion
+18
-43
paddle/phi/kernels/funcs/fc_functor.cu
paddle/phi/kernels/funcs/fc_functor.cu
+18
-43
未找到文件。
paddle/phi/kernels/funcs/fc_functor.cu
浏览文件 @
fcfaa104
...
...
@@ -36,6 +36,24 @@ struct FcTypeTraits<double> {
typedef
double4
Type
;
};
#if defined(PADDLE_WITH_CUDA)
#include <cuda_fp16.h>
template
<
>
struct
FcTypeTraits
<
float16
>
{
typedef
half2
Type
;
};
#else
struct
float16_4
{
float16
x
,
y
,
z
,
w
;
};
template
<
>
struct
FcTypeTraits
<
float16
>
{
typedef
float16_4
Type
;
};
#endif
template
<
typename
T
,
bool
DoRelu
>
__global__
void
bias_relu_v4
(
const
int
num
,
const
T
*
bias
,
T
*
data
,
int
K
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
...
...
@@ -109,14 +127,6 @@ void AddReluKernel(
}
#if defined(PADDLE_WITH_CUDA)
#include <cuda_fp16.h>
template
<
>
struct
FcTypeTraits
<
float16
>
{
typedef
half2
Type
;
};
template
<
bool
DoRelu
>
__global__
void
bias_relu_v2
(
const
int
num
,
const
half2
*
bias
,
...
...
@@ -200,46 +210,11 @@ void AddReluKernel(cudaStream_t stream,
}
#else
struct
float16_4
{
float16
x
,
y
,
z
,
w
;
};
template
<
>
struct
FcTypeTraits
<
float16
>
{
typedef
float16_4
Type
;
};
template
<
bool
DoRelu
>
__global__
void
bias_relu_v4
(
const
int
num
,
const
float16_4
*
bias
,
float16_4
*
data
,
int
K
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
num
)
{
int
bias_idx
=
tid
%
K
;
const
float16_4
bias_ptr
=
bias
[
bias_idx
];
const
float16_4
in_ptr
=
data
[
tid
];
float16_4
packed_val
;
packed_val
.
x
=
in_ptr
.
x
+
bias_ptr
.
x
;
packed_val
.
y
=
in_ptr
.
y
+
bias_ptr
.
y
;
packed_val
.
z
=
in_ptr
.
z
+
bias_ptr
.
z
;
packed_val
.
w
=
in_ptr
.
w
+
bias_ptr
.
w
;
if
(
DoRelu
)
{
packed_val
.
x
=
fmaxf
(
0.
f
,
packed_val
.
x
);
packed_val
.
y
=
fmaxf
(
0.
f
,
packed_val
.
y
);
packed_val
.
z
=
fmaxf
(
0.
f
,
packed_val
.
z
);
packed_val
.
w
=
fmaxf
(
0.
f
,
packed_val
.
w
);
}
data
[
tid
]
=
packed_val
;
}
}
template
<
bool
DoRelu
,
int
BlockDim
>
__global__
void
InplaceAddReluKernel
(
const
int
N
,
const
float16
*
bias
,
float16
*
data
)
{
int
offset
=
blockIdx
.
x
*
N
;
for
(
int
i
=
threadIdx
.
x
;
i
<
N
;
i
+=
BlockDim
)
{
float16
temp
;
temp
=
data
[
offset
+
i
]
+
bias
[
i
];
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录