Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e9f20331
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e9f20331
编写于
2月 28, 2018
作者:
C
chengduo
提交者:
GitHub
2月 28, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #8539 from chengduoZH/feature/refine_elementwise_op_function.h
Refine Sum in elementwise_op_function
上级
fee90b50
90dc33b5
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
62 addition
and
34 deletion
+62
-34
paddle/fluid/operators/elementwise_op_function.h
paddle/fluid/operators/elementwise_op_function.h
+14
-34
paddle/fluid/platform/cuda_helper.h
paddle/fluid/platform/cuda_helper.h
+48
-0
未找到文件。
paddle/fluid/operators/elementwise_op_function.h
浏览文件 @
e9f20331
...
...
@@ -20,6 +20,7 @@ limitations under the License. */
#ifdef __NVCC__
#include <thrust/iterator/iterator_adaptor.h>
#include "paddle/fluid/platform/cuda_helper.h"
constexpr
int
ELEMWISE_MAX_BLOCK_DIM
=
1024
;
#endif
...
...
@@ -361,13 +362,10 @@ template <typename T, typename DX_OP, typename DY_OP>
static
__global__
void
ElemwiseGradBroadcast1CUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
extern
__shared__
char
shm_buffer
[];
T
*
shm
=
reinterpret_cast
<
T
*>
(
shm_buffer
);
int
j
=
blockIdx
.
x
;
int
i
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
shm
[
tid
]
=
0
;
T
val
=
0
;
do
{
int
x_offset
=
i
*
w
+
j
;
...
...
@@ -375,22 +373,16 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
if
(
dy
)
{
shm
[
tid
]
+=
dy_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
val
+=
dy_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
i
+=
ELEMWISE_MAX_BLOCK_DIM
;
}
while
(
i
<
h
);
if
(
dy
)
{
__syncthreads
();
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
// Sum, could be optimized
val
=
platform
::
reduceSum
(
val
,
tid
,
h
);
if
(
threadIdx
.
x
==
0
)
{
for
(
int
k
=
1
;
k
<
h
;
++
k
)
{
shm
[
0
]
+=
shm
[
k
];
}
dy
[
j
]
=
shm
[
0
];
dy
[
j
]
=
val
;
}
}
}
...
...
@@ -402,10 +394,8 @@ static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T* x,
T
*
dx
,
T
*
dy
)
{
int
block_size
=
std
::
min
(
ELEMWISE_MAX_BLOCK_DIM
,
h
);
int
gird_size
=
w
;
int
shared_mem_size
=
block_size
*
sizeof
(
T
);
ElemwiseGradBroadcast1CUDAKernel
<<<
gird_size
,
block_size
,
shared_mem_size
,
stream
>>>
(
x
,
y
,
out
,
dout
,
h
,
w
,
dx_op
,
dy_op
,
dx
,
dy
);
ElemwiseGradBroadcast1CUDAKernel
<<<
gird_size
,
block_size
,
0
,
stream
>>>
(
x
,
y
,
out
,
dout
,
h
,
w
,
dx_op
,
dy_op
,
dx
,
dy
);
}
#endif
...
...
@@ -436,7 +426,6 @@ static void ElemwiseGradBroadcast2CPU(const T* x, const T* y, const T* out,
}
#ifdef __NVCC__
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
__global__
void
ElemwiseGradBroadcast2CUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
pre
,
int
n
,
...
...
@@ -444,9 +433,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
int
tid
=
threadIdx
.
x
;
int
j
=
blockIdx
.
x
;
extern
__shared__
char
shm_buffer
[];
T
*
shm
=
reinterpret_cast
<
T
*>
(
shm_buffer
);
shm
[
tid
]
=
0
;
T
val
=
0
;
int
ttid
=
tid
;
while
(
true
)
{
...
...
@@ -461,23 +448,18 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
}
if
(
dy
!=
nullptr
)
{
shm
[
tid
]
+=
dy_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
val
+=
dy_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
ttid
+=
ELEMWISE_MAX_BLOCK_DIM
;
}
if
(
dy
)
{
__syncthreads
();
int
h
=
pre
*
post
;
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
// Sum, could be optimized
if
(
tid
==
0
)
{
for
(
int
i
=
1
;
i
<
h
;
++
i
)
{
shm
[
0
]
+=
shm
[
i
];
}
dy
[
j
]
=
shm
[
0
];
val
=
platform
::
reduceSum
(
val
,
tid
,
h
);
if
(
threadIdx
.
x
==
0
)
{
dy
[
j
]
=
val
;
}
}
}
...
...
@@ -489,10 +471,8 @@ static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
int
block_size
=
std
::
min
(
ELEMWISE_MAX_BLOCK_DIM
,
pre
*
post
);
int
gird_size
=
n
;
int
shared_mem_size
=
block_size
*
sizeof
(
T
);
ElemwiseGradBroadcast2CUDAKernel
<<<
gird_size
,
block_size
,
shared_mem_size
,
stream
>>>
(
x
,
y
,
out
,
dout
,
pre
,
n
,
post
,
dx_op
,
dy_op
,
dx
,
dy
);
ElemwiseGradBroadcast2CUDAKernel
<<<
gird_size
,
block_size
,
0
,
stream
>>>
(
x
,
y
,
out
,
dout
,
pre
,
n
,
post
,
dx_op
,
dy_op
,
dx
,
dy
);
}
#endif
...
...
paddle/fluid/platform/cuda_helper.h
浏览文件 @
e9f20331
...
...
@@ -62,5 +62,53 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
}
#endif
// __shfl_down has been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template
<
typename
T
>
__forceinline__
__device__
T
__shfl_down_sync
(
unsigned
,
T
val
,
int
delta
)
{
return
__shfl_down
(
val
,
delta
);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
template
<
typename
T
>
__device__
T
reduceSum
(
T
val
,
int
tid
,
int
len
)
{
// TODO(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
__shared__
T
shm
[
32
];
const
int
warpSize
=
32
;
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
tid
<
len
);
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
val
+=
__shfl_down_sync
(
mask
,
val
,
offset
);
if
(
tid
<
warpSize
)
shm
[
tid
]
=
0
;
__syncthreads
();
if
(
tid
%
warpSize
==
0
)
{
shm
[
tid
/
warpSize
]
=
val
;
}
CREATE_SHFL_MASK
(
mask
,
tid
<
warpSize
);
if
(
tid
<
warpSize
)
{
val
=
shm
[
tid
];
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
val
+=
__shfl_down_sync
(
mask
,
val
,
offset
);
}
return
val
;
}
}
// namespace platform
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录