Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4fbde42c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4fbde42c
编写于
5月 03, 2018
作者:
C
chengduo
提交者:
dzhwinter
5月 03, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix __shfl_down_sync_ of cross_entropy (#10345)
* fix __shfl_down_sync_ of cross_entropy * use reduceSum * "fix ci"
上级
6d5e582d
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
88 addition
and
108 deletion
+88
-108
paddle/fluid/operators/elementwise_op_function.h
paddle/fluid/operators/elementwise_op_function.h
+3
-39
paddle/fluid/operators/math/cross_entropy.cu
paddle/fluid/operators/math/cross_entropy.cu
+10
-55
paddle/fluid/operators/row_conv_op.cu
paddle/fluid/operators/row_conv_op.cu
+1
-1
paddle/fluid/platform/cuda_device_function.h
paddle/fluid/platform/cuda_device_function.h
+74
-0
paddle/fluid/platform/cuda_primitives.h
paddle/fluid/platform/cuda_primitives.h
+0
-13
未找到文件。
paddle/fluid/operators/elementwise_op_function.h
浏览文件 @
4fbde42c
...
@@ -22,6 +22,7 @@ limitations under the License. */
...
@@ -22,6 +22,7 @@ limitations under the License. */
#ifdef __NVCC__
#ifdef __NVCC__
#include <cuda.h>
#include <cuda.h>
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/iterator_adaptor.h>
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/cuda_primitives.h"
constexpr
int
ELEMWISE_MAX_BLOCK_DIM
=
1024
;
constexpr
int
ELEMWISE_MAX_BLOCK_DIM
=
1024
;
#endif
#endif
...
@@ -336,43 +337,6 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
...
@@ -336,43 +337,6 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
}
}
#ifdef __NVCC__
#ifdef __NVCC__
template
<
typename
T
>
__device__
T
reduceSum
(
T
val
,
int
tid
,
int
len
)
{
// NOTE(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
const
int
warpSize
=
32
;
__shared__
T
shm
[
warpSize
];
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
tid
<
len
);
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
val
+=
platform
::
__shfl_down_sync
(
mask
,
val
,
offset
);
if
(
tid
<
warpSize
)
shm
[
tid
]
=
0
;
__syncthreads
();
if
(
tid
%
warpSize
==
0
)
{
shm
[
tid
/
warpSize
]
=
val
;
}
__syncthreads
();
CREATE_SHFL_MASK
(
mask
,
tid
<
warpSize
);
if
(
tid
<
warpSize
)
{
val
=
shm
[
tid
];
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
val
+=
platform
::
__shfl_down_sync
(
mask
,
val
,
offset
);
}
return
val
;
}
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
__global__
void
ElemwiseGradBroadcast1CUDAKernel
(
static
__global__
void
ElemwiseGradBroadcast1CUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
...
@@ -395,7 +359,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
...
@@ -395,7 +359,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
if
(
dy
)
{
if
(
dy
)
{
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
val
=
reduceSum
(
val
,
tid
,
h
);
val
=
paddle
::
platform
::
reduceSum
(
val
,
tid
,
h
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
dy
[
j
]
=
val
;
dy
[
j
]
=
val
;
}
}
...
@@ -472,7 +436,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
...
@@ -472,7 +436,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
if
(
dy
)
{
if
(
dy
)
{
int
h
=
pre
*
post
;
int
h
=
pre
*
post
;
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
val
=
reduceSum
(
val
,
tid
,
h
);
val
=
paddle
::
platform
::
reduceSum
(
val
,
tid
,
h
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
dy
[
j
]
=
val
;
dy
[
j
]
=
val
;
}
}
...
...
paddle/fluid/operators/math/cross_entropy.cu
浏览文件 @
4fbde42c
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -30,66 +31,22 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
...
@@ -30,66 +31,22 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
}
}
}
}
template
<
typename
T
>
__device__
__forceinline__
T
sum_single_warp
(
T
val
)
{
val
+=
platform
::
__shfl_down_sync
(
0
,
val
,
16
);
val
+=
platform
::
__shfl_down_sync
(
0
,
val
,
8
);
val
+=
platform
::
__shfl_down_sync
(
0
,
val
,
4
);
val
+=
platform
::
__shfl_down_sync
(
0
,
val
,
2
);
val
+=
platform
::
__shfl_down_sync
(
0
,
val
,
1
);
return
val
;
}
// CUDA do not support dynamic arrary in template
// https://stackoverflow.com/questions/20497209
template
<
typename
T
>
struct
SharedMemory
{
// Ensure that we won't compile any un-specialized types
__device__
T
*
GetPointer
()
{
return
NULL
;
}
};
template
<
>
struct
SharedMemory
<
float
>
{
__device__
float
*
GetPointer
()
{
extern
__shared__
float
s_float
[];
return
s_float
;
}
};
template
<
>
struct
SharedMemory
<
double
>
{
__device__
double
*
GetPointer
()
{
extern
__shared__
double
s_double
[];
return
s_double
;
}
};
template
<
typename
T
>
template
<
typename
T
>
__global__
void
SoftCrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
T
*
label
,
__global__
void
SoftCrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
T
*
label
,
const
int
class_num
)
{
const
int
class_num
)
{
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
SharedMemory
<
T
>
d_sum_shared
;
T
val
=
0
;
T
*
d_sum
=
d_sum_shared
.
GetPointer
();
d_sum
[
tid
]
=
0
;
int
cur_idx
=
tid
;
int
idx
=
blockIdx
.
x
*
class_num
+
tid
;
int
next_idx
=
blockIdx
.
x
*
class_num
+
tid
;
int
end
=
blockIdx
.
x
*
class_num
+
class_num
;
while
(
cur_idx
<
class_num
)
{
for
(;
idx
<
end
;
idx
+=
blockDim
.
x
)
{
d_sum
[
tid
]
+=
val
+=
math
::
TolerableValue
<
T
>
()(
std
::
log
(
X
[
idx
]))
*
label
[
idx
];
math
::
TolerableValue
<
T
>
()(
std
::
log
(
X
[
next_idx
]))
*
label
[
next_idx
];
next_idx
+=
blockDim
.
x
;
cur_idx
+=
blockDim
.
x
;
}
}
__syncthreads
();
for
(
unsigned
int
stride
=
blockDim
.
x
>>
1
;
stride
>=
32
;
stride
>>=
1
)
{
val
=
paddle
::
platform
::
reduceSum
(
val
,
tid
,
blockDim
.
x
);
if
(
tid
<
stride
)
d_sum
[
tid
]
+=
d_sum
[
tid
+
stride
];
if
(
threadIdx
.
x
==
0
)
{
__syncthreads
()
;
Y
[
blockIdx
.
x
]
=
-
val
;
}
}
T
val
=
d_sum
[
tid
];
val
=
sum_single_warp
<
T
>
(
val
);
if
(
tid
==
0
)
Y
[
blockIdx
.
x
]
=
-
val
;
}
}
}
// namespace
}
// namespace
...
@@ -113,9 +70,7 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
...
@@ -113,9 +70,7 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
?
512
?
512
:
pow
(
2
,
static_cast
<
int
>
(
std
::
log2
(
class_num
)));
:
pow
(
2
,
static_cast
<
int
>
(
std
::
log2
(
class_num
)));
SoftCrossEntropyKernel
<
T
><<<
SoftCrossEntropyKernel
<
T
><<<
batch_size
,
block
,
0
,
ctx
.
stream
()
>>>
(
batch_size
,
block
,
block
*
sizeof
(
T
),
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
()
>>>
(
loss_data
,
prob_data
,
label_data
,
class_num
);
loss_data
,
prob_data
,
label_data
,
class_num
);
}
else
{
}
else
{
const
int64_t
*
label_data
=
labels
->
data
<
int64_t
>
();
const
int64_t
*
label_data
=
labels
->
data
<
int64_t
>
();
...
...
paddle/fluid/operators/row_conv_op.cu
浏览文件 @
4fbde42c
...
@@ -14,7 +14,7 @@ limitations under the License. */
...
@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/row_conv_op.h"
#include "paddle/fluid/operators/row_conv_op.h"
#include "paddle/fluid/platform/cuda_
primitives
.h"
#include "paddle/fluid/platform/cuda_
device_function
.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
...
paddle/fluid/platform/cuda_device_function.h
0 → 100644
浏览文件 @
4fbde42c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda.h>
namespace
paddle
{
namespace
platform
{
// __shfl_down and __shfl have been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template
<
typename
T
>
__forceinline__
__device__
T
__shfl_down_sync
(
unsigned
,
T
val
,
int
delta
)
{
return
__shfl_down
(
val
,
delta
);
}
template
<
typename
T
>
__forceinline__
__device__
T
__shfl_sync
(
unsigned
,
T
val
,
int
src_line
,
int
width
)
{
return
__shfl
(
val
,
src_line
,
width
);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
template
<
typename
T
>
__device__
T
reduceSum
(
T
val
,
int
tid
,
int
len
)
{
// NOTE(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
const
int
warpSize
=
32
;
__shared__
T
shm
[
warpSize
];
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
tid
<
len
);
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
val
+=
platform
::
__shfl_down_sync
(
mask
,
val
,
offset
);
if
(
tid
<
warpSize
)
shm
[
tid
]
=
0
;
if
(
tid
%
warpSize
==
0
)
{
shm
[
tid
/
warpSize
]
=
val
;
}
__syncthreads
();
CREATE_SHFL_MASK
(
mask
,
tid
<
warpSize
);
if
(
tid
<
warpSize
)
{
val
=
shm
[
tid
];
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
val
+=
platform
::
__shfl_down_sync
(
mask
,
val
,
offset
);
}
return
val
;
}
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/cuda_primitives.h
浏览文件 @
4fbde42c
...
@@ -66,18 +66,5 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
...
@@ -66,18 +66,5 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
}
}
#endif
#endif
// __shfl_down has been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template
<
typename
T
>
__forceinline__
__device__
T
__shfl_down_sync
(
unsigned
,
T
val
,
int
delta
)
{
return
__shfl_down
(
val
,
delta
);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
}
// namespace platform
}
// namespace platform
}
// namespace paddle
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录