Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
9948ddfe
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
9948ddfe
编写于
6月 13, 2017
作者:
W
willzhang4a58
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add cudnn
上级
d2649862
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
54 addition
and
6 deletion
+54
-6
cmake/third_party.cmake
cmake/third_party.cmake
+3
-0
oneflow/core/actor/copy_hd_actor.cpp
oneflow/core/actor/copy_hd_actor.cpp
+1
-1
oneflow/core/actor/fw_data_comp_actor.cpp
oneflow/core/actor/fw_data_comp_actor.cpp
+1
-1
oneflow/core/actor/model_update_comp_actor.cpp
oneflow/core/actor/model_update_comp_actor.cpp
+1
-1
oneflow/core/common/unique_cudnn_handle.h
oneflow/core/common/unique_cudnn_handle.h
+31
-0
oneflow/core/common/util.h
oneflow/core/common/util.h
+1
-0
oneflow/core/kernel/cuda_kernel_context.h
oneflow/core/kernel/cuda_kernel_context.h
+3
-1
oneflow/core/kernel/kernel_context.h
oneflow/core/kernel/kernel_context.h
+7
-1
oneflow/core/thread/gpu_thread.cpp
oneflow/core/thread/gpu_thread.cpp
+3
-0
oneflow/core/thread/thread_context.h
oneflow/core/thread/thread_context.h
+3
-1
未找到文件。
cmake/third_party.cmake
浏览文件 @
9948ddfe
...
...
@@ -11,6 +11,7 @@ include(grpc)
include
(
tensorflow
)
find_package
(
CUDA REQUIRED
)
find_package
(
CuDNN REQUIRED
)
set
(
oneflow_third_party_libs
${
tensorflow_STATIC_LIBRARIES
}
...
...
@@ -28,6 +29,7 @@ set(oneflow_third_party_libs
${
PNG_STATIC_LIBRARIES
}
${
JSONCPP_STATIC_LIBRARIES
}
${
CUDA_CUBLAS_LIBRARIES
}
${
CUDNN_LIBRARIES
}
)
if
(
WIN32
)
...
...
@@ -81,4 +83,5 @@ include_directories(
${
PNG_INCLUDE_DIR
}
${
JSONCPP_INCLUDE_DIR
}
${
EIGEN_INCLUDE_DIRS
}
${
CUDNN_INCLUDE_DIRS
}
)
oneflow/core/actor/copy_hd_actor.cpp
浏览文件 @
9948ddfe
...
...
@@ -7,7 +7,7 @@ namespace oneflow {
// need review
void
CopyHdActor
::
ProcessMsg
(
const
ActorMsg
&
msg
,
const
ThreadContext
&
thread_ctx
)
{
CudaKernelCtx
kernel_ctx
(
thread_ctx
.
copy_hd_cuda_stream
,
nullptr
);
CudaKernelCtx
kernel_ctx
(
thread_ctx
.
copy_hd_cuda_stream
,
nullptr
,
nullptr
);
ProcessMsgWithKernelCtx
(
msg
,
kernel_ctx
);
}
...
...
oneflow/core/actor/fw_data_comp_actor.cpp
浏览文件 @
9948ddfe
...
...
@@ -26,7 +26,7 @@ bool FwDataCompActor::IsReadReady() {
void
FwDataCompActor
::
ProcessMsg
(
const
ActorMsg
&
msg
,
const
ThreadContext
&
thread_ctx
)
{
CudaKernelCtx
kernel_ctx
(
thread_ctx
.
compute_cuda_stream
,
nullptr
);
CudaKernelCtx
kernel_ctx
(
thread_ctx
.
compute_cuda_stream
,
nullptr
,
nullptr
);
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kCmdMsg
)
{
TODO
();
}
...
...
oneflow/core/actor/model_update_comp_actor.cpp
浏览文件 @
9948ddfe
...
...
@@ -13,7 +13,7 @@ void MdUpdtCompActor::Init(const TaskProto& task_proto) {
void
MdUpdtCompActor
::
ProcessMsg
(
const
ActorMsg
&
actor_msg
,
const
ThreadContext
&
thread_ctx
)
{
CudaKernelCtx
kernel_ctx
(
thread_ctx
.
compute_cuda_stream
,
nullptr
);
CudaKernelCtx
kernel_ctx
(
thread_ctx
.
compute_cuda_stream
,
nullptr
,
nullptr
);
(
this
->*
cur_handle_
)(
actor_msg
,
kernel_ctx
);
}
...
...
oneflow/core/common/unique_cudnn_handle.h
0 → 100644
浏览文件 @
9948ddfe
#ifndef ONEFLOW_CORE_UNIQUE_CUDNN_HANDLE_H_
#define ONEFLOW_CORE_UNIQUE_CUDNN_HANDLE_H_
#include "oneflow/core/common/util.h"
namespace
oneflow
{
class
UniqueCudnnHandle
final
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
UniqueCudnnHandle
);
UniqueCudnnHandle
()
=
delete
;
UniqueCudnnHandle
(
const
cudaStream_t
*
cuda_stream
)
{
CHECK_EQ
(
cudnnCreate
(
&
handle_
),
CUDNN_STATUS_SUCCESS
);
CHECK_EQ
(
cudnnSetStream
(
handle_
,
*
cuda_stream
),
CUDNN_STATUS_SUCCESS
);
}
~
UniqueCudnnHandle
()
{
cudnnDestroy
(
handle_
);
}
const
cudnnHandle_t
*
get
()
const
{
return
&
handle_
;
}
private:
cudnnHandle_t
handle_
;
};
}
// namespace oneflow
#endif // ONEFLOW_CORE_UNIQUE_CUDNN_HANDLE_H_
oneflow/core/common/util.h
浏览文件 @
9948ddfe
...
...
@@ -12,6 +12,7 @@
#include "cuda.h"
#include "cuda_runtime.h"
#include "cublas_v2.h"
#include "cudnn.h"
namespace
oneflow
{
...
...
oneflow/core/kernel/cuda_kernel_context.h
浏览文件 @
9948ddfe
...
...
@@ -12,9 +12,11 @@ class CudaKernelCtx final : public KernelCtx {
~
CudaKernelCtx
()
=
default
;
CudaKernelCtx
(
const
cudaStream_t
*
cuda_stream
,
const
cublasHandle_t
*
cublas_handle
)
{
const
cublasHandle_t
*
cublas_handle
,
const
cudnnHandle_t
*
cudnn_handle
)
{
set_cuda_stream
(
cuda_stream
);
set_cublas_handle
(
cublas_handle
);
set_cudnn_handle
(
cudnn_handle
);
}
void
AddCallBack
(
std
::
function
<
void
()
>
callback
)
const
override
;
...
...
oneflow/core/kernel/kernel_context.h
浏览文件 @
9948ddfe
...
...
@@ -14,13 +14,15 @@ class KernelCtx {
Channel
<
std
::
function
<
void
()
>>*
cpu_stream
()
const
{
return
cpu_stream_
;
}
const
cudaStream_t
&
cuda_stream
()
const
{
return
*
cuda_stream_
;
}
const
cublasHandle_t
&
cublas_handle
()
const
{
return
*
cublas_handle_
;
}
const
cudnnHandle_t
&
cudnn_handle
()
const
{
return
*
cudnn_handle_
;
}
virtual
void
AddCallBack
(
std
::
function
<
void
()
>
)
const
=
0
;
protected:
KernelCtx
()
:
cpu_stream_
(
nullptr
),
cuda_stream_
(
nullptr
),
cublas_handle_
(
nullptr
)
{}
cublas_handle_
(
nullptr
),
cudnn_handle_
(
nullptr
)
{}
void
set_cpu_stream
(
Channel
<
std
::
function
<
void
()
>>*
val
)
{
cpu_stream_
=
val
;
...
...
@@ -31,11 +33,15 @@ class KernelCtx {
void
set_cublas_handle
(
const
cublasHandle_t
*
val
)
{
cublas_handle_
=
val
;
}
void
set_cudnn_handle
(
const
cudnnHandle_t
*
val
)
{
cudnn_handle_
=
val
;
}
private:
Channel
<
std
::
function
<
void
()
>>*
cpu_stream_
;
const
cudaStream_t
*
cuda_stream_
;
const
cublasHandle_t
*
cublas_handle_
;
const
cudnnHandle_t
*
cudnn_handle_
;
};
...
...
oneflow/core/thread/gpu_thread.cpp
浏览文件 @
9948ddfe
...
...
@@ -2,6 +2,7 @@
#include "cuda_runtime.h"
#include "oneflow/core/common/unique_cuda_stream.h"
#include "oneflow/core/common/unique_cublas_handle.h"
#include "oneflow/core/common/unique_cudnn_handle.h"
namespace
oneflow
{
...
...
@@ -12,10 +13,12 @@ GpuThread::GpuThread(int device_phy_id) {
UniqueCudaStream
compute_cuda_stream
;
{
UniqueCublasHandle
cublas_handle
(
compute_cuda_stream
.
get
());
UniqueCudnnHandle
cudnn_handle
(
compute_cuda_stream
.
get
());
ThreadContext
ctx
;
ctx
.
copy_hd_cuda_stream
=
copy_hd_cuda_stream
.
get
();
ctx
.
compute_cuda_stream
=
compute_cuda_stream
.
get
();
ctx
.
cublas_handle
=
cublas_handle
.
get
();
ctx
.
cudnn_handle
=
cudnn_handle
.
get
();
PollMsgChannel
(
ctx
);
}
});
...
...
oneflow/core/thread/thread_context.h
浏览文件 @
9948ddfe
...
...
@@ -9,12 +9,14 @@ struct ThreadContext {
ThreadContext
()
:
cpu_stream
(
nullptr
),
copy_hd_cuda_stream
(
nullptr
),
compute_cuda_stream
(
nullptr
),
cublas_handle
(
nullptr
)
{}
cublas_handle
(
nullptr
),
cudnn_handle
(
nullptr
)
{}
Channel
<
std
::
function
<
void
()
>>*
cpu_stream
;
const
cudaStream_t
*
copy_hd_cuda_stream
;
const
cudaStream_t
*
compute_cuda_stream
;
const
cublasHandle_t
*
cublas_handle
;
const
cudnnHandle_t
*
cudnn_handle
;
};
}
// namespace oneflow
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录