Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7b2dc4e6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7b2dc4e6
编写于
12月 21, 2020
作者:
W
wangchaochaohu
提交者:
GitHub
12月 21, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimization for fp16 elementwise add (#29744)
上级
27bdbec7
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
38 addition
and
1 deletion
+38
-1
paddle/fluid/operators/elementwise/elementwise_add_op.h
paddle/fluid/operators/elementwise/elementwise_add_op.h
+38
-1
未找到文件。
paddle/fluid/operators/elementwise/elementwise_add_op.h
浏览文件 @
7b2dc4e6
...
...
@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#ifdef PADDLE_WITH_CUDA
#ifdef __NVCC__
#include "cub/cub.cuh"
...
...
@@ -176,6 +177,25 @@ __global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out,
}
}
template
<
int
SIZE
>
__global__
void
VecFP16MatrixColReduce
(
const
__half2
*
__restrict__
in
,
__half2
*
__restrict__
out
,
size_t
width
,
size_t
height
)
{
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
by
=
blockIdx
.
y
;
__half2
zero
=
__half2half2
(
static_cast
<
__half
>
(
0
));
const
int
cols
=
width
/
2
;
for
(;
idx
<
cols
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
__half2
sum
=
zero
;
for
(
int
row
=
0
;
row
<
SIZE
;
row
++
)
{
int
index
=
idx
+
(
row
+
by
*
SIZE
)
*
cols
;
sum
=
__hadd2
(
sum
,
in
[
index
]);
}
atomicAdd
(
&
(
out
[
idx
]),
sum
);
}
}
template
<
typename
T
>
__global__
void
MatrixReduceLongWidth
(
const
T
*
__restrict__
in
,
T
*
out
,
size_t
width
,
size_t
height
)
{
...
...
@@ -198,7 +218,7 @@ __global__ void VecMatrixReduceLongWidth(const T *__restrict__ in, T *out,
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
w
=
idx
*
VEC_SIZE
;
int
width_stride
=
blockDim
.
x
*
gridDim
.
x
*
VEC_SIZE
;
for
(;
w
<
width
;
w
+=
width
)
{
for
(;
w
<
width
;
w
+=
width
_stride
)
{
T
zero
=
static_cast
<
T
>
(
0
);
T
sum
[
VEC_SIZE
]
=
{
zero
};
T
tmp_vec
[
VEC_SIZE
]
=
{
zero
};
...
...
@@ -341,6 +361,23 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
int
max_blocks
=
std
::
max
(
max_physical_threads
/
(
block_x
*
block_y
),
1
);
int
theory_block
=
(
width
+
blocks
.
x
-
1
)
/
blocks
.
x
;
dim3
grids
(
std
::
min
(
theory_block
,
max_blocks
));
if
(
std
::
is_same
<
T
,
paddle
::
platform
::
float16
>::
value
&&
width
<
2048
&&
width
%
2
==
0
&&
height
%
64
==
0
)
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
functor
;
if
(
dout
->
dims
()
==
dx
->
dims
())
functor
(
dev_ctx
,
dy
,
static_cast
<
T
>
(
0
));
else
functor
(
dev_ctx
,
dx
,
static_cast
<
T
>
(
0
));
const
__half2
*
ptr1
=
reinterpret_cast
<
const
__half2
*>
(
dout_data
);
__half2
*
ptr2
=
reinterpret_cast
<
__half2
*>
(
out_data
);
const
int
threads
=
128
;
dim3
grid
(
1
,
(
height
+
64
-
1
)
/
64
);
VecFP16MatrixColReduce
<
64
><<<
grid
,
threads
,
0
,
stream
>>>
(
ptr1
,
ptr2
,
width
,
height
);
return
;
}
if
(
width
/
height
<
32
)
{
MatrixColReduce
<
T
,
block_x
,
block_y
><<<
grids
,
blocks
,
0
,
stream
>>>
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录