Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
597345d1
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
597345d1
编写于
9月 25, 2020
作者:
Z
Zhong Hui
提交者:
GitHub
9月 25, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix cuda atomic for ARCH<350 for the automic_max
fix cuda atomic for ARCH<350 for the automic_max
上级
dd04b160
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
38 addition
and
0 deletion
+38
-0
paddle/fluid/platform/cuda_primitives.h
paddle/fluid/platform/cuda_primitives.h
+38
-0
未找到文件。
paddle/fluid/platform/cuda_primitives.h
浏览文件 @
597345d1
...
@@ -134,7 +134,26 @@ USE_CUDA_ATOMIC(Max, int);
...
@@ -134,7 +134,26 @@ USE_CUDA_ATOMIC(Max, int);
USE_CUDA_ATOMIC
(
Max
,
unsigned
int
);
USE_CUDA_ATOMIC
(
Max
,
unsigned
int
);
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
// It because unsigned long long int is not necessarily uint64_t
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
USE_CUDA_ATOMIC
(
Max
,
unsigned
long
long
int
);
// NOLINT
USE_CUDA_ATOMIC
(
Max
,
unsigned
long
long
int
);
// NOLINT
#else
CUDA_ATOMIC_WRAPPER
(
Max
,
unsigned
long
long
int
)
{
if
(
*
address
>=
val
)
{
return
;
}
unsigned
long
long
int
old
=
*
address
,
assumed
;
do
{
assumed
=
old
;
if
(
assumed
>=
val
)
{
break
;
}
old
=
atomicCAS
(
address
,
assumed
,
val
);
}
while
(
assumed
!=
old
);
}
#endif
CUDA_ATOMIC_WRAPPER
(
Max
,
int64_t
)
{
CUDA_ATOMIC_WRAPPER
(
Max
,
int64_t
)
{
// Here, we check long long int must be int64_t.
// Here, we check long long int must be int64_t.
...
@@ -187,7 +206,26 @@ USE_CUDA_ATOMIC(Min, int);
...
@@ -187,7 +206,26 @@ USE_CUDA_ATOMIC(Min, int);
USE_CUDA_ATOMIC
(
Min
,
unsigned
int
);
USE_CUDA_ATOMIC
(
Min
,
unsigned
int
);
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
// It because unsigned long long int is not necessarily uint64_t
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
USE_CUDA_ATOMIC
(
Min
,
unsigned
long
long
int
);
// NOLINT
USE_CUDA_ATOMIC
(
Min
,
unsigned
long
long
int
);
// NOLINT
#else
CUDA_ATOMIC_WRAPPER
(
Min
,
unsigned
long
long
int
)
{
if
(
*
address
<=
val
)
{
return
;
}
unsigned
long
long
int
old
=
*
address
,
assumed
;
do
{
assumed
=
old
;
if
(
assumed
<=
val
)
{
break
;
}
old
=
atomicCAS
(
address
,
assumed
,
val
);
}
while
(
assumed
!=
old
);
}
#endif
CUDA_ATOMIC_WRAPPER
(
Min
,
int64_t
)
{
CUDA_ATOMIC_WRAPPER
(
Min
,
int64_t
)
{
// Here, we check long long int must be int64_t.
// Here, we check long long int must be int64_t.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录