Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
51fd48a2
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
51fd48a2
编写于
10月 25, 2021
作者:
J
Juncheng
提交者:
GitHub
10月 25, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
CudaDriverGetPrimaryCtxActive (#6604)
上级
f7b8bb8a
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
43 addition
and
0 deletion
+43
-0
oneflow/core/device/cuda_util.cpp
oneflow/core/device/cuda_util.cpp
+41
-0
oneflow/core/device/cuda_util.h
oneflow/core/device/cuda_util.h
+2
-0
未找到文件。
oneflow/core/device/cuda_util.cpp
浏览文件 @
51fd48a2
...
...
@@ -22,6 +22,12 @@ limitations under the License.
#include "oneflow/core/job/env_global_objects_scope.h"
#include "oneflow/core/job/lazy_mode.h"
#ifdef WITH_CUDA
#include <cuda.h>
#endif // WITH_CUDA
namespace
oneflow
{
#ifdef WITH_CUDA
...
...
@@ -200,6 +206,41 @@ void InitCudaContextOnce(int device_id) {
});
}
cudaError_t
CudaDriverGetPrimaryCtxActive
(
int
dev
,
int
*
active
)
{
#if CUDA_VERSION >= 11030
CUdevice
cu_device
{};
{
CUresult
(
*
fnCuDeviceGet
)(
CUdevice
*
,
int
)
=
nullptr
;
cudaError_t
err
=
cudaGetDriverEntryPoint
(
"cuDeviceGet"
,
(
void
**
)
&
fnCuDeviceGet
,
cudaEnableDefault
);
if
(
err
!=
cudaSuccess
)
{
return
err
;
}
CUresult
result
=
fnCuDeviceGet
(
&
cu_device
,
dev
);
if
(
result
==
CUDA_SUCCESS
)
{
// do nothing
}
else
if
(
result
==
CUresult
::
CUDA_ERROR_INVALID_DEVICE
)
{
return
cudaErrorInvalidDevice
;
}
else
{
return
cudaErrorUnknown
;
}
}
{
CUresult
(
*
fnCuDevicePrimaryCtxGetState
)(
CUdevice
,
unsigned
int
*
,
int
*
)
=
nullptr
;
cudaError_t
err
=
cudaGetDriverEntryPoint
(
"cuDevicePrimaryCtxGetState"
,
(
void
**
)
&
fnCuDevicePrimaryCtxGetState
,
cudaEnableDefault
);
if
(
err
!=
cudaSuccess
)
{
return
err
;
}
unsigned
int
flags
{};
CUresult
result
=
fnCuDevicePrimaryCtxGetState
(
cu_device
,
&
flags
,
active
);
if
(
result
==
CUDA_SUCCESS
)
{
return
cudaSuccess
;
}
else
{
return
cudaErrorUnknown
;
}
}
#else
return
cudaErrorNotSupported
;
#endif // CUDA_VERSION < 11030
}
#endif // WITH_CUDA
}
// namespace oneflow
oneflow/core/device/cuda_util.h
浏览文件 @
51fd48a2
...
...
@@ -168,6 +168,8 @@ int GetCudaDeviceCount();
void
InitCudaContextOnce
(
int
device_id
);
cudaError_t
CudaDriverGetPrimaryCtxActive
(
int
dev
,
int
*
active
);
}
// namespace oneflow
#endif // WITH_CUDA
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录