Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
eab44e1f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
eab44e1f
编写于
12月 16, 2020
作者:
W
wangchaochaohu
提交者:
GitHub
12月 16, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine (#29622)
上级
d0b789d2
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
12 addition
and
6 deletion
+12
-6
paddle/fluid/operators/elementwise/elementwise_add_op.h
paddle/fluid/operators/elementwise/elementwise_add_op.h
+12
-6
未找到文件。
paddle/fluid/operators/elementwise/elementwise_add_op.h
浏览文件 @
eab44e1f
...
@@ -148,6 +148,8 @@ __global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out,
...
@@ -148,6 +148,8 @@ __global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out,
size_t
width_stride
=
gridDim
.
x
*
blockDim
.
x
;
size_t
width_stride
=
gridDim
.
x
*
blockDim
.
x
;
size_t
full_width
=
(
width
&
(
~
((
uint64_t
)(
BLOCK_W
-
1
))))
+
size_t
full_width
=
(
width
&
(
~
((
uint64_t
)(
BLOCK_W
-
1
))))
+
((
width
&
(
BLOCK_W
-
1
))
?
BLOCK_W
:
0
);
((
width
&
(
BLOCK_W
-
1
))
?
BLOCK_W
:
0
);
size_t
full_height
=
(
height
&
(
~
((
uint64_t
)(
BLOCK_H
-
1
))))
+
((
height
&
(
BLOCK_H
-
1
))
?
BLOCK_H
:
0
);
#pragma unroll
#pragma unroll
for
(
size_t
w
=
idx
;
w
<
full_width
;
w
+=
width_stride
)
{
for
(
size_t
w
=
idx
;
w
<
full_width
;
w
+=
width_stride
)
{
...
@@ -155,10 +157,10 @@ __global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out,
...
@@ -155,10 +157,10 @@ __global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out,
__syncthreads
();
__syncthreads
();
size_t
offset
=
w
+
threadIdx
.
y
*
width
;
size_t
offset
=
w
+
threadIdx
.
y
*
width
;
#pragma unroll
#pragma unroll
for
(
size_t
h
=
threadIdx
.
y
;
h
<
height
;
for
(
size_t
h
=
threadIdx
.
y
;
h
<
full_
height
;
h
+=
BLOCK_H
)
{
// block-stride loop across matrix height
h
+=
BLOCK_H
)
{
// block-stride loop across matrix height
sdata
[
threadIdx
.
y
][
threadIdx
.
x
]
+=
sdata
[
threadIdx
.
y
][
threadIdx
.
x
]
+=
(
w
<
width
)
?
in
[
offset
]
:
(
static_cast
<
T
>
(
0
));
(
w
<
width
&&
h
<
height
)
?
in
[
offset
]
:
(
static_cast
<
T
>
(
0
));
offset
+=
width
*
BLOCK_H
;
offset
+=
width
*
BLOCK_H
;
}
}
__syncthreads
();
__syncthreads
();
...
@@ -184,21 +186,24 @@ __global__ void FP16MatrixColReduce(
...
@@ -184,21 +186,24 @@ __global__ void FP16MatrixColReduce(
size_t
width_stride
=
gridDim
.
x
*
blockDim
.
x
;
size_t
width_stride
=
gridDim
.
x
*
blockDim
.
x
;
size_t
full_width
=
(
width
&
(
~
((
uint64_t
)(
BLOCK_W
-
1
))))
+
size_t
full_width
=
(
width
&
(
~
((
uint64_t
)(
BLOCK_W
-
1
))))
+
((
width
&
(
BLOCK_W
-
1
))
?
BLOCK_W
:
0
);
((
width
&
(
BLOCK_W
-
1
))
?
BLOCK_W
:
0
);
size_t
full_height
=
(
height
&
(
~
((
uint64_t
)(
BLOCK_H
-
1
))))
+
((
height
&
(
BLOCK_H
-
1
))
?
BLOCK_H
:
0
);
#pragma unroll
#pragma unroll
for
(
size_t
w
=
idx
;
w
<
full_width
;
w
+=
width_stride
)
{
for
(
size_t
w
=
idx
;
w
<
full_width
;
w
+=
width_stride
)
{
for
(
int
r
=
0
;
r
<
repeats
;
r
++
)
{
for
(
int
r
=
0
;
r
<
repeats
;
r
++
)
{
sdata
[
threadIdx
.
y
+
r
*
BLOCK_W
][
threadIdx
.
x
]
=
0
;
sdata
[
threadIdx
.
y
+
r
*
BLOCK_W
][
threadIdx
.
x
]
=
0
;
}
}
__syncthreads
();
__syncthreads
();
#pragma unroll
for
(
int
r
=
0
;
r
<
repeats
;
r
++
)
{
for
(
int
r
=
0
;
r
<
repeats
;
r
++
)
{
size_t
offset
=
w
+
(
r
*
BLOCK_W
+
threadIdx
.
y
)
*
width
;
size_t
offset
=
w
+
(
r
*
BLOCK_W
+
threadIdx
.
y
)
*
width
;
#pragma unroll
#pragma unroll
for
(
size_t
h
=
r
*
BLOCK_H
+
threadIdx
.
y
;
h
<
height
;
for
(
size_t
h
=
threadIdx
.
y
+
r
*
BLOCK_W
;
h
<
full_
height
;
h
+=
BLOCK_H
)
{
// block-stride loop across matrix height
h
+=
BLOCK_H
)
{
// block-stride loop across matrix height
sdata
[
r
*
BLOCK_W
+
threadIdx
.
y
][
threadIdx
.
x
]
+=
sdata
[
r
*
BLOCK_W
+
threadIdx
.
y
][
threadIdx
.
x
]
+=
(
w
<
width
)
?
in
[
offset
+
r
*
BLOCK_W
*
width
]
(
w
<
width
&&
h
<
height
)
:
(
static_cast
<
paddle
::
platform
::
float16
>
(
0
));
?
in
[
offset
]
:
(
static_cast
<
paddle
::
platform
::
float16
>
(
0
));
offset
+=
width
*
BLOCK_H
;
offset
+=
width
*
BLOCK_H
;
}
}
}
}
...
@@ -373,6 +378,7 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
...
@@ -373,6 +378,7 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
err
=
cub
::
DeviceReduce
::
Sum
(
temp_storage
,
temp_storage_bytes
,
err
=
cub
::
DeviceReduce
::
Sum
(
temp_storage
,
temp_storage_bytes
,
dout_data
,
out_data
,
nums
,
stream
);
dout_data
,
out_data
,
nums
,
stream
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
err
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
err
);
return
;
}
}
constexpr
int
block_x
=
32
;
constexpr
int
block_x
=
32
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录