Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
b1224da8
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b1224da8
编写于
4月 10, 2018
作者:
C
chengduo
提交者:
Yi Wang
4月 09, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move reduceSum to elementwise_op_function.h (#9773)
* add cuda_device_functions.h * move reduceSum to elementwise_op_function.h
上级
b44b6a4f
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
75 addition
and
73 deletion
+75
-73
paddle/fluid/operators/elementwise_op_function.h
paddle/fluid/operators/elementwise_op_function.h
+75
-25
paddle/fluid/platform/cuda_helper.h
paddle/fluid/platform/cuda_helper.h
+0
-48
未找到文件。
paddle/fluid/operators/elementwise_op_function.h
浏览文件 @
b1224da8
...
@@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
...
@@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/fluid/platform/transform.h"
#ifdef __NVCC__
#ifdef __NVCC__
#include <cuda.h>
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/iterator_adaptor.h>
#include "paddle/fluid/platform/cuda_helper.h"
constexpr
int
ELEMWISE_MAX_BLOCK_DIM
=
1024
;
constexpr
int
ELEMWISE_MAX_BLOCK_DIM
=
1024
;
#endif
#endif
...
@@ -43,35 +44,35 @@ namespace operators {
...
@@ -43,35 +44,35 @@ namespace operators {
*/
*/
inline
void
get_mid_dims
(
const
framework
::
DDim
&
x_dims
,
inline
void
get_mid_dims
(
const
framework
::
DDim
&
x_dims
,
const
framework
::
DDim
&
y_dims
,
const
int
axis
,
const
framework
::
DDim
&
y_dims
,
const
int
axis
,
int
&
pre
,
int
&
n
,
int
&
post
)
{
int
*
pre
,
int
*
n
,
int
*
post
)
{
pre
=
1
;
*
pre
=
1
;
n
=
1
;
*
n
=
1
;
post
=
1
;
*
post
=
1
;
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
pre
*=
x_dims
[
i
];
(
*
pre
)
*=
x_dims
[
i
];
}
}
for
(
int
i
=
0
;
i
<
y_dims
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
y_dims
.
size
();
++
i
)
{
PADDLE_ENFORCE_EQ
(
x_dims
[
i
+
axis
],
y_dims
[
i
],
PADDLE_ENFORCE_EQ
(
x_dims
[
i
+
axis
],
y_dims
[
i
],
"Broadcast dimension mismatch."
);
"Broadcast dimension mismatch."
);
n
*=
y_dims
[
i
];
(
*
n
)
*=
y_dims
[
i
];
}
}
for
(
int
i
=
axis
+
y_dims
.
size
();
i
<
x_dims
.
size
();
++
i
)
{
for
(
int
i
=
axis
+
y_dims
.
size
();
i
<
x_dims
.
size
();
++
i
)
{
post
*=
x_dims
[
i
];
(
*
post
)
*=
x_dims
[
i
];
}
}
}
}
inline
void
trim_trailing_singular_dims
(
framework
::
DDim
&
dims
)
{
inline
void
trim_trailing_singular_dims
(
framework
::
DDim
*
dims
)
{
// Remove trailing dimensions of size 1 for y
// Remove trailing dimensions of size 1 for y
auto
actual_dims_size
=
dims
.
size
();
auto
actual_dims_size
=
dims
->
size
();
for
(;
actual_dims_size
!=
0
;
--
actual_dims_size
)
{
for
(;
actual_dims_size
!=
0
;
--
actual_dims_size
)
{
if
(
dims
[
actual_dims_size
-
1
]
!=
1
)
break
;
if
(
(
*
dims
)
[
actual_dims_size
-
1
]
!=
1
)
break
;
}
}
if
(
actual_dims_size
!=
dims
.
size
())
{
if
(
actual_dims_size
!=
dims
->
size
())
{
auto
actual_dims
=
framework
::
vectorize
(
dims
);
auto
actual_dims
=
framework
::
vectorize
(
*
dims
);
actual_dims
.
resize
(
actual_dims_size
);
actual_dims
.
resize
(
actual_dims_size
);
dims
=
framework
::
make_ddim
(
actual_dims
);
*
dims
=
framework
::
make_ddim
(
actual_dims
);
}
}
}
}
...
@@ -159,7 +160,7 @@ class RowwiseTransformIterator<T, platform::CUDADeviceContext>
...
@@ -159,7 +160,7 @@ class RowwiseTransformIterator<T, platform::CUDADeviceContext>
RowwiseTransformIterator
<
T
,
platform
::
CUDADeviceContext
>
,
const
T
*>
RowwiseTransformIterator
<
T
,
platform
::
CUDADeviceContext
>
,
const
T
*>
super_t
;
super_t
;
HOSTDEVICE
RowwiseTransformIterator
(
const
T
*
x
,
int
n
)
HOSTDEVICE
RowwiseTransformIterator
(
const
T
*
x
,
int
n
)
:
super_t
(
x
),
begin_
(
x
),
n_
(
n
)
{};
:
super_t
(
x
),
begin_
(
x
),
n_
(
n
)
{}
friend
class
thrust
::
iterator_core_access
;
friend
class
thrust
::
iterator_core_access
;
private:
private:
...
@@ -179,7 +180,7 @@ class MidWiseTransformIterator<T, platform::CUDADeviceContext>
...
@@ -179,7 +180,7 @@ class MidWiseTransformIterator<T, platform::CUDADeviceContext>
MidWiseTransformIterator
<
T
,
platform
::
CUDADeviceContext
>
,
const
T
*>
MidWiseTransformIterator
<
T
,
platform
::
CUDADeviceContext
>
,
const
T
*>
super_t
;
super_t
;
HOSTDEVICE
MidWiseTransformIterator
(
const
T
*
x
,
int
n
,
int
post
)
HOSTDEVICE
MidWiseTransformIterator
(
const
T
*
x
,
int
n
,
int
post
)
:
super_t
(
x
),
begin_
(
x
),
n_
(
n
),
post_
(
post
)
{};
:
super_t
(
x
),
begin_
(
x
),
n_
(
n
),
post_
(
post
)
{}
friend
class
thrust
::
iterator_core_access
;
friend
class
thrust
::
iterator_core_access
;
private:
private:
...
@@ -333,6 +334,55 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
...
@@ -333,6 +334,55 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
}
}
}
}
#ifdef __NVCC__
#ifdef __NVCC__
// __shfl_down has been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template
<
typename
T
>
__forceinline__
__device__
T
__shfl_down_sync
(
unsigned
,
T
val
,
int
delta
)
{
return
__shfl_down
(
val
,
delta
);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
template
<
typename
T
>
__device__
T
reduceSum
(
T
val
,
int
tid
,
int
len
)
{
// TODO(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
__shared__
T
shm
[
32
];
const
int
warpSize
=
32
;
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
tid
<
len
);
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
val
+=
__shfl_down_sync
(
mask
,
val
,
offset
);
if
(
tid
<
warpSize
)
shm
[
tid
]
=
0
;
__syncthreads
();
if
(
tid
%
warpSize
==
0
)
{
shm
[
tid
/
warpSize
]
=
val
;
}
CREATE_SHFL_MASK
(
mask
,
tid
<
warpSize
);
if
(
tid
<
warpSize
)
{
val
=
shm
[
tid
];
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
val
+=
__shfl_down_sync
(
mask
,
val
,
offset
);
}
return
val
;
}
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
__global__
void
ElemwiseGradBroadcast1CUDAKernel
(
static
__global__
void
ElemwiseGradBroadcast1CUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
...
@@ -355,7 +405,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
...
@@ -355,7 +405,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
if
(
dy
)
{
if
(
dy
)
{
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
val
=
platform
::
reduceSum
(
val
,
tid
,
h
);
val
=
reduceSum
(
val
,
tid
,
h
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
dy
[
j
]
=
val
;
dy
[
j
]
=
val
;
}
}
...
@@ -432,7 +482,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
...
@@ -432,7 +482,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
if
(
dy
)
{
if
(
dy
)
{
int
h
=
pre
*
post
;
int
h
=
pre
*
post
;
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
val
=
platform
::
reduceSum
(
val
,
tid
,
h
);
val
=
reduceSum
(
val
,
tid
,
h
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
dy
[
j
]
=
val
;
dy
[
j
]
=
val
;
}
}
...
@@ -472,11 +522,11 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
...
@@ -472,11 +522,11 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
auto
y_dim
=
y
.
dims
();
auto
y_dim
=
y
.
dims
();
axis
=
(
axis
==
-
1
?
x_dim
.
size
()
-
y_dim
.
size
()
:
axis
);
axis
=
(
axis
==
-
1
?
x_dim
.
size
()
-
y_dim
.
size
()
:
axis
);
trim_trailing_singular_dims
(
y_dim
);
trim_trailing_singular_dims
(
&
y_dim
);
axis
=
(
y_dim
.
size
()
==
0
)
?
x_dim
.
size
()
:
axis
;
axis
=
(
y_dim
.
size
()
==
0
)
?
x_dim
.
size
()
:
axis
;
int
pre
,
n
,
post
;
int
pre
,
n
,
post
;
get_mid_dims
(
x_dim
,
y_dim
,
axis
,
pre
,
n
,
post
);
get_mid_dims
(
x_dim
,
y_dim
,
axis
,
&
pre
,
&
n
,
&
post
);
if
(
post
==
1
)
{
if
(
post
==
1
)
{
int
h
=
pre
;
int
h
=
pre
;
int
w
=
n
;
int
w
=
n
;
...
@@ -514,7 +564,7 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
...
@@ -514,7 +564,7 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
}
}
}
}
}
}
}
;
}
template
<
typename
DeviceContext
,
typename
T
,
typename
functor
,
template
<
typename
DeviceContext
,
typename
T
,
typename
functor
,
typename
broadcastfunctor
,
typename
broadcast2functor
>
typename
broadcastfunctor
,
typename
broadcast2functor
>
...
@@ -543,11 +593,11 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
...
@@ -543,11 +593,11 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
}
}
axis
=
(
axis
==
-
1
?
x_dims
.
size
()
-
y_dims
.
size
()
:
axis
);
axis
=
(
axis
==
-
1
?
x_dims
.
size
()
-
y_dims
.
size
()
:
axis
);
trim_trailing_singular_dims
(
y_dims
);
trim_trailing_singular_dims
(
&
y_dims
);
axis
=
(
y_dims
.
size
()
==
0
)
?
x_dims
.
size
()
:
axis
;
axis
=
(
y_dims
.
size
()
==
0
)
?
x_dims
.
size
()
:
axis
;
int
pre
,
n
,
post
;
int
pre
,
n
,
post
;
get_mid_dims
(
x_dims
,
y_dims
,
axis
,
pre
,
n
,
post
);
get_mid_dims
(
x_dims
,
y_dims
,
axis
,
&
pre
,
&
n
,
&
post
);
if
(
post
==
1
)
{
if
(
post
==
1
)
{
broadcastfunctor
f
;
broadcastfunctor
f
;
...
@@ -582,11 +632,11 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
...
@@ -582,11 +632,11 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
axis
=
(
axis
==
-
1
?
x_dims
.
size
()
-
y_dims
.
size
()
:
axis
);
axis
=
(
axis
==
-
1
?
x_dims
.
size
()
-
y_dims
.
size
()
:
axis
);
PADDLE_ENFORCE
(
axis
>=
0
&&
axis
<
x_dims
.
size
(),
PADDLE_ENFORCE
(
axis
>=
0
&&
axis
<
x_dims
.
size
(),
"Axis should be in range [0, x_dims)"
);
"Axis should be in range [0, x_dims)"
);
trim_trailing_singular_dims
(
y_dims
);
trim_trailing_singular_dims
(
&
y_dims
);
axis
=
(
y_dims
.
size
()
==
0
)
?
x_dims
.
size
()
:
axis
;
axis
=
(
y_dims
.
size
()
==
0
)
?
x_dims
.
size
()
:
axis
;
int
pre
,
n
,
post
;
int
pre
,
n
,
post
;
get_mid_dims
(
x_dims
,
y_dims
,
axis
,
pre
,
n
,
post
);
get_mid_dims
(
x_dims
,
y_dims
,
axis
,
&
pre
,
&
n
,
&
post
);
if
(
post
==
1
)
{
if
(
post
==
1
)
{
functor
.
RunRowWise
(
n
,
pre
);
functor
.
RunRowWise
(
n
,
pre
);
return
;
return
;
...
...
paddle/fluid/platform/cuda_helper.h
浏览文件 @
b1224da8
...
@@ -62,53 +62,5 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
...
@@ -62,53 +62,5 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
}
}
#endif
#endif
// __shfl_down has been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template
<
typename
T
>
__forceinline__
__device__
T
__shfl_down_sync
(
unsigned
,
T
val
,
int
delta
)
{
return
__shfl_down
(
val
,
delta
);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
template
<
typename
T
>
__device__
T
reduceSum
(
T
val
,
int
tid
,
int
len
)
{
// TODO(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
__shared__
T
shm
[
32
];
const
int
warpSize
=
32
;
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
tid
<
len
);
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
val
+=
__shfl_down_sync
(
mask
,
val
,
offset
);
if
(
tid
<
warpSize
)
shm
[
tid
]
=
0
;
__syncthreads
();
if
(
tid
%
warpSize
==
0
)
{
shm
[
tid
/
warpSize
]
=
val
;
}
CREATE_SHFL_MASK
(
mask
,
tid
<
warpSize
);
if
(
tid
<
warpSize
)
{
val
=
shm
[
tid
];
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
val
+=
__shfl_down_sync
(
mask
,
val
,
offset
);
}
return
val
;
}
}
// namespace platform
}
// namespace platform
}
// namespace paddle
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录