Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
2673798d
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2673798d
编写于
8月 17, 2018
作者:
D
dzhwinter
提交者:
GitHub
8月 17, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"fix float16 ShuffleDownSync Bug" (#12756)
* "fix bug" * "add test case"
上级
6fe5547d
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
93 addition
and
5 deletion
+93
-5
paddle/fluid/platform/cuda_device_function.h
paddle/fluid/platform/cuda_device_function.h
+11
-4
paddle/fluid/platform/cuda_helper_test.cu
paddle/fluid/platform/cuda_helper_test.cu
+82
-1
未找到文件。
paddle/fluid/platform/cuda_device_function.h
浏览文件 @
2673798d
...
...
@@ -36,7 +36,7 @@ __forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
#if CUDA_VERSION < 9000
return
__shfl_down
(
val
,
delta
,
width
);
#else
return
__shfl_down_sync
(
mask
,
val
,
delta
,
width
);
return
__shfl_down_sync
(
mask
,
val
,
static_cast
<
unsigned
>
(
delta
)
,
width
);
#endif
}
...
...
@@ -46,9 +46,16 @@ template <>
__forceinline__
__device__
float16
CudaShuffleDownSync
(
unsigned
mask
,
float16
val
,
int
delta
,
int
width
)
{
half
tmp
=
static_cast
<
half
>
(
val
);
__shfl_down
(
tmp
,
static_cast
<
unsigned
>
(
delta
),
width
);
return
float16
(
tmp
);
return
float16
(
__shfl_down
(
static_cast
<
half
>
(
val
),
static_cast
<
unsigned
>
(
delta
),
width
));
}
#else
template
<
>
__forceinline__
__device__
float16
CudaShuffleDownSync
(
unsigned
mask
,
float16
val
,
int
delta
,
int
width
)
{
return
float16
(
__shfl_down_sync
(
mask
,
static_cast
<
half
>
(
val
),
static_cast
<
unsigned
>
(
delta
),
width
));
}
#endif
...
...
paddle/fluid/platform/cuda_helper_test.cu
浏览文件 @
2673798d
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <iostream>
#include <random>
...
...
@@ -123,7 +124,7 @@ void TestUnalign(size_t num, const int shift_bit) {
cudaMemcpy
(
out
,
d_in2
,
array_size
,
cudaMemcpyDeviceToHost
);
cudaDeviceSynchronize
();
for
(
size_t
i
=
0
;
i
<
num
/
2
;
++
i
)
{
// NOTE(dzhwinter): the float16 add has small
underflow/overflow
// NOTE(dzhwinter): the float16 add has small
truncate error.
// so we use EXPECT_NEAR to check the result.
EXPECT_NEAR
(
static_cast
<
float
>
(
out
[
i
]),
static_cast
<
float
>
(
AddFunctor
<
float16
>
()(
r_in1
[
i
],
r_in2
[
i
])),
...
...
@@ -151,3 +152,83 @@ TEST(CudaAtomic, float16Unalign) {
TestUnalign
(
static_cast
<
size_t
>
(
1024
),
/*shift_bit*/
3
);
TestUnalign
(
static_cast
<
size_t
>
(
1024
*
1024
),
/*shift_bit*/
3
);
}
// https://devblogs.nvidia.com/faster-parallel-reductions-kepler/
template
<
typename
T
>
static
__forceinline__
__device__
T
WarpReduceSum
(
T
val
)
{
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
true
);
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
{
val
+=
paddle
::
platform
::
CudaShuffleDownSync
(
mask
,
val
,
offset
);
}
return
val
;
}
template
<
typename
T
>
__forceinline__
__device__
T
BlockReduce
(
T
val
)
{
static
__shared__
T
shared
[
32
];
// Shared mem for 32 partial sums
int
lane
=
threadIdx
.
x
%
warpSize
;
int
wid
=
threadIdx
.
x
/
warpSize
;
val
=
WarpReduceSum
(
val
);
// Each warp performs partial reduction
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
// Write reduced value to shared memory
__syncthreads
();
// Wait for all partial reductions
// read from shared memory only if that warp existed
val
=
(
threadIdx
.
x
<
blockDim
.
x
/
warpSize
)
?
shared
[
lane
]
:
static_cast
<
T
>
(
0
);
if
(
wid
==
0
)
val
=
WarpReduceSum
(
val
);
// Final reduce within first warp
return
val
;
}
template
<
typename
T
>
__global__
void
DeviceReduceSum
(
T
*
in
,
T
*
out
,
size_t
N
)
{
T
sum
(
0
);
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
sum
+=
in
[
i
];
}
sum
=
BlockReduce
<
T
>
(
sum
);
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
out
[
blockIdx
.
x
]
=
sum
;
}
template
<
typename
T
>
void
TestReduce
(
size_t
num
,
float
atol
=
0.01
)
{
T
*
in1
;
T
*
d_in1
,
*
d_in2
;
size_t
size
=
sizeof
(
T
)
*
num
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
d_in1
),
size
);
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
d_in2
),
sizeof
(
T
));
in1
=
reinterpret_cast
<
T
*>
(
malloc
(
size
));
std
::
minstd_rand
engine
;
std
::
uniform_real_distribution
<
double
>
dist
(
0.0
,
1.0
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
in1
[
i
]
=
static_cast
<
T
>
(
dist
(
engine
));
}
auto
out
=
std
::
accumulate
(
in1
,
in1
+
num
,
static_cast
<
T
>
(
0
));
cudaMemcpy
(
d_in1
,
in1
,
size
,
cudaMemcpyHostToDevice
);
cudaDeviceSynchronize
();
DeviceReduceSum
<
T
><<<
1
,
PADDLE_CUDA_NUM_THREADS
>>>
(
d_in1
,
d_in2
,
num
);
cudaMemcpy
(
in1
,
d_in2
,
sizeof
(
T
),
cudaMemcpyDeviceToHost
);
cudaDeviceSynchronize
();
// NOTE(dzhwinter): the float16 add has small underflow/overflow
// so we use EXPECT_NEAR to check the result.
EXPECT_NEAR
(
static_cast
<
float
>
(
in1
[
0
]),
static_cast
<
float
>
(
out
),
atol
);
free
(
in1
);
cudaFree
(
d_in1
);
cudaFree
(
d_in2
);
}
TEST
(
CudaShuffleSync
,
float16
)
{
TestReduce
<
float
>
(
10
);
TestReduce
<
float
>
(
1000
);
// float16 will overflow or accumulate truncate errors in big size.
TestReduce
<
float16
>
(
10
);
TestReduce
<
float16
>
(
100
,
/*atol error*/
1.0
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录