Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
033ebe7e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
033ebe7e
编写于
12月 09, 2021
作者:
S
sneaxiy
提交者:
GitHub
12月 09, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine CUDA atomicAdd for FP16 by CUDA primitive methods (#37895)
* fix cuda atomicAdd for FP16 * try to fix ci
上级
491d4f01
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
15 addition
and
0 deletion
+15
-0
paddle/fluid/platform/device/gpu/gpu_primitives.h
paddle/fluid/platform/device/gpu/gpu_primitives.h
+15
-0
未找到文件。
paddle/fluid/platform/device/gpu/gpu_primitives.h
浏览文件 @
033ebe7e
...
@@ -101,6 +101,20 @@ inline static __device__ uint32_t add_to_high_half(uint32_t val, float x) {
...
@@ -101,6 +101,20 @@ inline static __device__ uint32_t add_to_high_half(uint32_t val, float x) {
return
(
val
&
0xFFFFu
)
|
(
static_cast
<
uint32_t
>
(
high_half
.
x
)
<<
16
);
return
(
val
&
0xFFFFu
)
|
(
static_cast
<
uint32_t
>
(
high_half
.
x
)
<<
16
);
}
}
#if CUDA_VERSION >= 10000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
static
__device__
__forceinline__
float16
CUDAFP16ToPDFP16
(
__half
x
)
{
return
*
reinterpret_cast
<
float16
*>
(
&
x
);
}
static
__device__
__forceinline__
__half
PDFP16ToCUDAFP16
(
float16
x
)
{
return
*
reinterpret_cast
<
__half
*>
(
&
x
);
}
CUDA_ATOMIC_WRAPPER
(
Add
,
float16
)
{
return
CUDAFP16ToPDFP16
(
atomicAdd
(
reinterpret_cast
<
__half
*>
(
address
),
PDFP16ToCUDAFP16
(
val
)));
}
#else
CUDA_ATOMIC_WRAPPER
(
Add
,
float16
)
{
CUDA_ATOMIC_WRAPPER
(
Add
,
float16
)
{
// concrete packed float16 value may exsits in lower or higher 16bits
// concrete packed float16 value may exsits in lower or higher 16bits
// of the 32bits address.
// of the 32bits address.
...
@@ -133,6 +147,7 @@ CUDA_ATOMIC_WRAPPER(Add, float16) {
...
@@ -133,6 +147,7 @@ CUDA_ATOMIC_WRAPPER(Add, float16) {
}
}
}
}
#endif
#endif
#endif
CUDA_ATOMIC_WRAPPER
(
Add
,
complex
<
float
>
)
{
CUDA_ATOMIC_WRAPPER
(
Add
,
complex
<
float
>
)
{
float
*
real
=
reinterpret_cast
<
float
*>
(
address
);
float
*
real
=
reinterpret_cast
<
float
*>
(
address
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录