Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3e1280ea
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3e1280ea
编写于
7月 21, 2022
作者:
M
ming1753
提交者:
GitHub
7月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fc fp16 (#44505)
* fc support fp16 * add a ‘,’ on paddle_pass_builder.cc * fc support fp16 on non-cuda.
上级
185a900f
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
221 addition
and
27 deletion
+221
-27
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+5
-2
paddle/fluid/operators/fc_op.cu.cc
paddle/fluid/operators/fc_op.cu.cc
+1
-0
paddle/phi/kernels/funcs/fc_functor.cu
paddle/phi/kernels/funcs/fc_functor.cu
+215
-25
未找到文件。
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
3e1280ea
...
@@ -136,7 +136,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
...
@@ -136,7 +136,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
const
std
::
vector
<
std
::
string
>
kDlnneSubgraphPasses
({
const
std
::
vector
<
std
::
string
>
kDlnneSubgraphPasses
({
"is_test_pass"
,
//
"is_test_pass"
,
//
"delete_dropout_op_pass"
//
"delete_dropout_op_pass"
,
//
"simplify_with_basic_ops_pass"
,
//
"simplify_with_basic_ops_pass"
,
//
"conv_bn_fuse_pass"
,
//
"conv_bn_fuse_pass"
,
//
"depthwise_conv_bn_fuse_pass"
,
//
"depthwise_conv_bn_fuse_pass"
,
//
...
@@ -158,7 +158,10 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
...
@@ -158,7 +158,10 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"conv_eltwiseadd_bn_fuse_pass"
,
"conv_eltwiseadd_bn_fuse_pass"
,
"conv_elementwise_add_act_fuse_pass"
,
"conv_elementwise_add_act_fuse_pass"
,
"conv_elementwise_add2_act_fuse_pass"
,
"conv_elementwise_add2_act_fuse_pass"
,
"conv_elementwise_add_fuse_pass"
};
"conv_elementwise_add_fuse_pass"
,
"gpu_cpu_map_matmul_v2_to_mul_pass"
,
//
"gpu_cpu_map_matmul_v2_to_matmul_pass"
,
//
"fc_fuse_pass"
};
const
std
::
vector
<
std
::
string
>
kTrtLowerPrecisionPasses
{
const
std
::
vector
<
std
::
string
>
kTrtLowerPrecisionPasses
{
// "conv_bn_fuse_pass",
// "conv_bn_fuse_pass",
...
...
paddle/fluid/operators/fc_op.cu.cc
浏览文件 @
3e1280ea
...
@@ -17,5 +17,6 @@ limitations under the License. */
...
@@ -17,5 +17,6 @@ limitations under the License. */
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
fc
,
fc
,
ops
::
FCOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
phi
::
dtype
::
float16
>
,
ops
::
FCOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
FCOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
FCOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
ops
::
FCOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/phi/kernels/funcs/fc_functor.cu
浏览文件 @
3e1280ea
...
@@ -21,6 +21,8 @@ limitations under the License. */
...
@@ -21,6 +21,8 @@ limitations under the License. */
namespace
phi
{
namespace
phi
{
namespace
funcs
{
namespace
funcs
{
using
float16
=
phi
::
dtype
::
float16
;
template
<
typename
T
>
template
<
typename
T
>
struct
FcTypeTraits
;
struct
FcTypeTraits
;
...
@@ -75,6 +77,216 @@ __global__ void InplaceAddReluKernel(const int N, const T* bias, T* data) {
...
@@ -75,6 +77,216 @@ __global__ void InplaceAddReluKernel(const int N, const T* bias, T* data) {
}
}
}
}
template
<
typename
T
>
void
AddReluKernel
(
gpuStream_t
stream
,
const
int
M
,
const
int
N
,
T
*
Y
,
const
T
*
B
,
bool
relu
)
{
if
(
N
%
4
==
0
)
{
const
int
threads
=
256
;
const
int
num
=
M
*
N
/
4
;
const
int
blocks
=
(
num
+
threads
-
1
)
/
threads
;
typedef
typename
FcTypeTraits
<
T
>::
Type
trans_type
;
auto
*
bias_ptr_v4
=
reinterpret_cast
<
const
trans_type
*>
(
B
);
auto
*
data_ptr_v4
=
reinterpret_cast
<
trans_type
*>
(
Y
);
if
(
relu
)
{
bias_relu_v4
<
trans_type
,
true
><<<
blocks
,
threads
,
0
,
stream
>>>
(
num
,
bias_ptr_v4
,
data_ptr_v4
,
N
/
4
);
}
else
{
bias_relu_v4
<
trans_type
,
false
><<<
blocks
,
threads
,
0
,
stream
>>>
(
num
,
bias_ptr_v4
,
data_ptr_v4
,
N
/
4
);
}
}
else
{
const
int
threads
=
256
;
const
int
blocks
=
M
;
if
(
relu
)
{
InplaceAddReluKernel
<
T
,
true
,
threads
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
N
,
B
,
Y
);
}
else
{
InplaceAddReluKernel
<
T
,
false
,
threads
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
N
,
B
,
Y
);
}
}
}
#if defined(PADDLE_WITH_CUDA)
#include <cuda_fp16.h>
template
<
>
struct
FcTypeTraits
<
float16
>
{
typedef
half2
Type
;
};
template
<
bool
DoRelu
>
__global__
void
bias_relu_v2
(
const
int
num
,
const
half2
*
bias
,
half2
*
data
,
int
K
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
num
)
{
int
bias_idx
=
tid
%
K
;
const
half2
bias_ptr
=
bias
[
bias_idx
];
const
half2
in_ptr
=
data
[
tid
];
half2
packed_val
=
__hadd2
(
bias_ptr
,
in_ptr
);
if
(
DoRelu
)
{
#if __CUDA_ARCH__ >= 800
packed_val
=
__hmax2
(
__half2
(
0
,
0
),
packed_val
);
#else
packed_val
=
__hmul2
(
__hgt2
(
__half2
(
0
,
0
),
packed_val
),
packed_val
);
#endif
}
data
[
tid
]
=
packed_val
;
}
}
template
<
bool
DoRelu
,
int
BlockDim
>
__global__
void
InplaceAddReluKernel
(
const
int
N
,
const
half
*
bias
,
half
*
data
)
{
int
offset
=
blockIdx
.
x
*
N
;
for
(
int
i
=
threadIdx
.
x
;
i
<
N
;
i
+=
BlockDim
)
{
half
temp
;
#if defined(__HIPCC__) || __CUDA_ARCH__ >= 350
temp
=
__ldg
(
data
+
offset
+
i
)
+
__ldg
(
bias
+
i
);
#else
temp
=
data
[
offset
+
i
]
+
bias
[
i
];
#endif
if
(
DoRelu
)
{
#if __CUDA_ARCH__ >= 800
data
[
offset
+
i
]
=
__hmax
(
0
,
temp
);
#else
data
[
offset
+
i
]
=
__hmul
(
__hgt
(
temp
,
0
),
temp
);
#endif
}
else
{
data
[
offset
+
i
]
=
temp
;
}
}
}
template
<
>
void
AddReluKernel
(
cudaStream_t
stream
,
const
int
M
,
const
int
N
,
float16
*
Y
,
const
float16
*
B
,
bool
relu
)
{
if
(
N
%
2
==
0
)
{
const
int
threads
=
256
;
const
int
num
=
M
*
N
/
2
;
const
int
blocks
=
(
num
+
threads
-
1
)
/
threads
;
typedef
typename
FcTypeTraits
<
float16
>::
Type
trans_type
;
auto
*
bias_ptr_v2
=
reinterpret_cast
<
const
trans_type
*>
(
B
);
auto
*
data_ptr_v2
=
reinterpret_cast
<
trans_type
*>
(
Y
);
if
(
relu
)
{
bias_relu_v2
<
true
><<<
blocks
,
threads
,
0
,
stream
>>>
(
num
,
bias_ptr_v2
,
data_ptr_v2
,
N
/
2
);
}
else
{
bias_relu_v2
<
false
><<<
blocks
,
threads
,
0
,
stream
>>>
(
num
,
bias_ptr_v2
,
data_ptr_v2
,
N
/
2
);
}
}
else
{
const
int
threads
=
256
;
const
int
blocks
=
M
;
auto
*
halfB
=
reinterpret_cast
<
const
half
*>
(
B
);
auto
*
halfY
=
reinterpret_cast
<
half
*>
(
Y
);
if
(
relu
)
{
InplaceAddReluKernel
<
true
,
threads
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
N
,
halfB
,
halfY
);
}
else
{
InplaceAddReluKernel
<
false
,
threads
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
N
,
halfB
,
halfY
);
}
}
}
#else
struct
float16_4
{
float16
x
,
y
,
z
,
w
;
};
template
<
>
struct
FcTypeTraits
<
float16
>
{
typedef
float16_4
Type
;
};
template
<
bool
DoRelu
>
__global__
void
bias_relu_v4
(
const
int
num
,
const
float16_4
*
bias
,
float16_4
*
data
,
int
K
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
num
)
{
int
bias_idx
=
tid
%
K
;
const
float16_4
bias_ptr
=
bias
[
bias_idx
];
const
float16_4
in_ptr
=
data
[
tid
];
float16_4
packed_val
;
packed_val
.
x
=
in_ptr
.
x
+
bias_ptr
.
x
;
packed_val
.
y
=
in_ptr
.
y
+
bias_ptr
.
y
;
packed_val
.
z
=
in_ptr
.
z
+
bias_ptr
.
z
;
packed_val
.
w
=
in_ptr
.
w
+
bias_ptr
.
w
;
if
(
DoRelu
)
{
packed_val
.
x
=
fmaxf
(
0.
f
,
packed_val
.
x
);
packed_val
.
y
=
fmaxf
(
0.
f
,
packed_val
.
y
);
packed_val
.
z
=
fmaxf
(
0.
f
,
packed_val
.
z
);
packed_val
.
w
=
fmaxf
(
0.
f
,
packed_val
.
w
);
}
data
[
tid
]
=
packed_val
;
}
}
template
<
bool
DoRelu
,
int
BlockDim
>
__global__
void
InplaceAddReluKernel
(
const
int
N
,
const
float16
*
bias
,
float16
*
data
)
{
int
offset
=
blockIdx
.
x
*
N
;
for
(
int
i
=
threadIdx
.
x
;
i
<
N
;
i
+=
BlockDim
)
{
float16
temp
;
temp
=
data
[
offset
+
i
]
+
bias
[
i
];
if
(
DoRelu
)
{
data
[
offset
+
i
]
=
fmaxf
(
0.
f
,
temp
);
}
else
{
data
[
offset
+
i
]
=
temp
;
}
}
}
template
<
>
void
AddReluKernel
(
gpuStream_t
stream
,
const
int
M
,
const
int
N
,
float16
*
Y
,
const
float16
*
B
,
bool
relu
)
{
if
(
N
%
4
==
0
)
{
const
int
threads
=
256
;
const
int
num
=
M
*
N
/
4
;
const
int
blocks
=
(
num
+
threads
-
1
)
/
threads
;
typedef
typename
FcTypeTraits
<
float16
>::
Type
trans_type
;
auto
*
bias_ptr_v4
=
reinterpret_cast
<
const
trans_type
*>
(
B
);
auto
*
data_ptr_v4
=
reinterpret_cast
<
trans_type
*>
(
Y
);
if
(
relu
)
{
bias_relu_v4
<
trans_type
,
true
><<<
blocks
,
threads
,
0
,
stream
>>>
(
num
,
bias_ptr_v4
,
data_ptr_v4
,
N
/
4
);
}
else
{
bias_relu_v4
<
trans_type
,
false
><<<
blocks
,
threads
,
0
,
stream
>>>
(
num
,
bias_ptr_v4
,
data_ptr_v4
,
N
/
4
);
}
}
else
{
const
int
threads
=
256
;
const
int
blocks
=
M
;
if
(
relu
)
{
InplaceAddReluKernel
<
true
,
threads
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
N
,
B
,
Y
);
}
else
{
InplaceAddReluKernel
<
false
,
threads
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
N
,
B
,
Y
);
}
}
}
#endif
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
void
FCFunctor
<
DeviceContext
,
T
>::
operator
()(
const
DeviceContext
&
context
,
void
FCFunctor
<
DeviceContext
,
T
>::
operator
()(
const
DeviceContext
&
context
,
const
int
M
,
const
int
M
,
...
@@ -109,36 +321,14 @@ void FCFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
...
@@ -109,36 +321,14 @@ void FCFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
}
}
// M * N
// M * N
if
(
N
%
4
==
0
)
{
AddReluKernel
(
context
.
stream
(),
M
,
N
,
Y
,
B
,
relu
);
const
int
threads
=
256
;
const
int
num
=
M
*
N
/
4
;
const
int
blocks
=
(
num
+
threads
-
1
)
/
threads
;
typedef
typename
FcTypeTraits
<
T
>::
Type
trans_type
;
auto
*
bias_ptr_v4
=
reinterpret_cast
<
const
trans_type
*>
(
B
);
auto
*
data_ptr_v4
=
reinterpret_cast
<
trans_type
*>
(
Y
);
if
(
relu
)
{
bias_relu_v4
<
trans_type
,
true
><<<
blocks
,
threads
,
0
,
context
.
stream
()
>>>
(
num
,
bias_ptr_v4
,
data_ptr_v4
,
N
/
4
);
}
else
{
bias_relu_v4
<
trans_type
,
false
><<<
blocks
,
threads
,
0
,
context
.
stream
()
>>>
(
num
,
bias_ptr_v4
,
data_ptr_v4
,
N
/
4
);
}
}
else
{
const
int
threads
=
256
;
const
int
blocks
=
M
;
if
(
relu
)
{
InplaceAddReluKernel
<
T
,
true
,
threads
>
<<<
blocks
,
threads
,
0
,
context
.
stream
()
>>>
(
N
,
B
,
Y
);
}
else
{
InplaceAddReluKernel
<
T
,
false
,
threads
>
<<<
blocks
,
threads
,
0
,
context
.
stream
()
>>>
(
N
,
B
,
Y
);
}
}
}
}
template
class
FCFunctor
<
paddle
::
platform
::
CUDADeviceContext
,
float16
>;
template
class
FCFunctor
<
paddle
::
platform
::
CUDADeviceContext
,
float
>;
template
class
FCFunctor
<
paddle
::
platform
::
CUDADeviceContext
,
float
>;
template
class
FCFunctor
<
paddle
::
platform
::
CUDADeviceContext
,
double
>;
template
class
FCFunctor
<
paddle
::
platform
::
CUDADeviceContext
,
double
>;
template
class
FCFunctor
<
GPUContext
,
float16
>;
template
class
FCFunctor
<
GPUContext
,
float
>;
template
class
FCFunctor
<
GPUContext
,
float
>;
template
class
FCFunctor
<
GPUContext
,
double
>;
template
class
FCFunctor
<
GPUContext
,
double
>;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录