Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
57a82568
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
57a82568
编写于
9月 25, 2018
作者:
J
Juncheng
提交者:
Jinhui Yuan
9月 25, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cmake for nccl (#1262)
上级
4b2c4ef0
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
60 addition
and
5 deletion
+60
-5
CMakeLists.txt
CMakeLists.txt
+1
-0
cmake/third_party.cmake
cmake/third_party.cmake
+8
-1
oneflow/core/actor/nccl_actor.cpp
oneflow/core/actor/nccl_actor.cpp
+4
-0
oneflow/core/device/cuda_util.h
oneflow/core/device/cuda_util.h
+0
-1
oneflow/core/device/device_context.h
oneflow/core/device/device_context.h
+6
-0
oneflow/core/device/nccl_device_context.h
oneflow/core/device/nccl_device_context.h
+2
-2
oneflow/core/device/nccl_util.cpp
oneflow/core/device/nccl_util.cpp
+2
-0
oneflow/core/device/nccl_util.h
oneflow/core/device/nccl_util.h
+21
-1
oneflow/core/job/job_desc.cpp
oneflow/core/job/job_desc.cpp
+3
-0
oneflow/core/job/nccl_comm_manager.cpp
oneflow/core/job/nccl_comm_manager.cpp
+4
-0
oneflow/core/job/nccl_comm_manager.h
oneflow/core/job/nccl_comm_manager.h
+5
-0
oneflow/core/job/runtime.cpp
oneflow/core/job/runtime.cpp
+4
-0
未找到文件。
CMakeLists.txt
浏览文件 @
57a82568
...
...
@@ -4,6 +4,7 @@ cmake_minimum_required(VERSION 3.5)
option
(
BUILD_THIRD_PARTY
"Build third party or oneflow"
OFF
)
option
(
BUILD_RDMA
""
ON
)
option
(
BUILD_CUDA
""
ON
)
option
(
BUILD_NCCL
""
ON
)
option
(
RELEASE_VERSION
""
OFF
)
# Project
...
...
cmake/third_party.cmake
浏览文件 @
57a82568
...
...
@@ -31,7 +31,14 @@ if (BUILD_CUDA)
list
(
APPEND CUDA_LIBRARIES
${
cuda_lib_dir
}
/
${
extra_cuda_lib
}
)
endforeach
()
find_package
(
CuDNN REQUIRED
)
find_package
(
NCCL REQUIRED
)
if
(
BUILD_NCCL
)
find_package
(
NCCL REQUIRED
)
if
(
NCCL_VERSION VERSION_LESS 2.0
)
message
(
FATAL_ERROR
"minimum nccl version required is 2.0"
)
else
()
add_definitions
(
-DWITH_NCCL
)
endif
()
endif
()
endif
()
if
(
NOT WIN32
)
...
...
oneflow/core/actor/nccl_actor.cpp
浏览文件 @
57a82568
...
...
@@ -6,10 +6,14 @@
namespace
oneflow
{
void
NcclActor
::
InitDeviceCtx
(
const
ThreadCtx
&
thread_ctx
)
{
#ifdef WITH_NCCL
CHECK_EQ
(
GetDeviceType
(),
DeviceType
::
kGPU
);
// CHECK_EQ(GetLocalWorkStreamId(), 0);
mut_device_ctx
().
reset
(
new
NcclDeviceCtx
(
thread_ctx
.
g_cuda_stream
.
get
(),
Global
<
NcclCommMgr
>::
Get
()
->
NcclComm4ActorId
(
actor_id
())));
#else
UNIMPLEMENTED
();
#endif // WITH_NCCL
}
REGISTER_ACTOR
(
TaskType
::
kNcclAllReduce
,
NcclActor
);
...
...
oneflow/core/device/cuda_util.h
浏览文件 @
57a82568
...
...
@@ -11,7 +11,6 @@
#include <cuda_runtime.h>
#include <cudnn.h>
#include <curand.h>
#include <nccl.h>
namespace
oneflow
{
...
...
oneflow/core/device/device_context.h
浏览文件 @
57a82568
...
...
@@ -3,6 +3,10 @@
#include "oneflow/core/device/cuda_util.h"
#ifdef WITH_NCCL
#include <nccl.h>
#endif // WITH_NCCL
namespace
oneflow
{
class
DeviceCtx
{
...
...
@@ -15,7 +19,9 @@ class DeviceCtx {
virtual
const
cublasHandle_t
&
cublas_pmh_handle
()
const
{
UNIMPLEMENTED
();
}
virtual
const
cublasHandle_t
&
cublas_pmd_handle
()
const
{
UNIMPLEMENTED
();
}
virtual
const
cudnnHandle_t
&
cudnn_handle
()
const
{
UNIMPLEMENTED
();
}
#ifdef WITH_NCCL
virtual
const
ncclComm_t
&
nccl_handle
()
const
{
UNIMPLEMENTED
();
}
#endif // WITH_NCCL
#endif
virtual
void
AddCallBack
(
std
::
function
<
void
()
>
)
const
=
0
;
...
...
oneflow/core/device/nccl_device_context.h
浏览文件 @
57a82568
...
...
@@ -6,7 +6,7 @@
namespace
oneflow
{
#ifdef WITH_
CUDA
#ifdef WITH_
NCCL
class
NcclDeviceCtx
final
:
public
CudaDeviceCtx
{
public:
...
...
@@ -21,7 +21,7 @@ class NcclDeviceCtx final : public CudaDeviceCtx {
ncclComm_t
nccl_handler_
;
};
#endif // WITH_
CUDA
#endif // WITH_
NCCL
}
// namespace oneflow
...
...
oneflow/core/device/nccl_util.cpp
浏览文件 @
57a82568
...
...
@@ -2,6 +2,8 @@
namespace
oneflow
{
#ifdef WITH_NCCL
void
NcclCheck
(
ncclResult_t
error
)
{
CHECK_EQ
(
error
,
ncclSuccess
)
<<
ncclGetErrorString
(
error
);
}
#endif // WITH_NCCL
}
// namespace oneflow
oneflow/core/device/nccl_util.h
浏览文件 @
57a82568
#ifndef ONEFLOW_CORE_DEVICE_NCCL_UTIL_H_
#define ONEFLOW_CORE_DEVICE_NCCL_UTIL_H_
#include <nccl.h>
#include "oneflow/core/register/blob.h"
#include "oneflow/core/common/data_type.pb.h"
#include "oneflow/core/common/util.h"
#ifdef WITH_NCCL
#include <nccl.h>
#endif // WITH_NCCL
namespace
oneflow
{
#ifdef WITH_NCCL
inline
ncclDataType_t
GetNcclDataType
(
const
DataType
&
dt
)
{
switch
(
dt
)
{
#define NCCL_DATA_TYPE_CASE(dtype) \
...
...
@@ -22,29 +26,45 @@ inline ncclDataType_t GetNcclDataType(const DataType& dt) {
}
void
NcclCheck
(
ncclResult_t
error
);
#endif // WITH_NCCL
class
NcclUtil
final
{
public:
using
NcclReduceMthd
=
void
(
DeviceCtx
*
,
Blob
*
,
Blob
*
);
static
void
AllReduce
(
DeviceCtx
*
ctx
,
Blob
*
send_blob
,
Blob
*
recv_blob
)
{
#ifdef WITH_NCCL
auto
elem_cnt
=
(
size_t
)
send_blob
->
shape
().
elem_cnt
();
NcclCheck
(
ncclAllReduce
(
send_blob
->
dptr
(),
recv_blob
->
mut_dptr
(),
elem_cnt
,
GetNcclDataType
(
send_blob
->
data_type
()),
ncclSum
,
ctx
->
nccl_handle
(),
ctx
->
cuda_stream
()));
#else
UNIMPLEMENTED
();
#endif // WITH_NCCL
}
static
void
ReduceScatter
(
DeviceCtx
*
ctx
,
Blob
*
send_blob
,
Blob
*
recv_blob
)
{
#ifdef WITH_NCCL
auto
elem_cnt
=
(
size_t
)
recv_blob
->
shape
().
elem_cnt
();
NcclCheck
(
ncclReduceScatter
(
send_blob
->
dptr
(),
recv_blob
->
mut_dptr
(),
elem_cnt
,
GetNcclDataType
(
send_blob
->
data_type
()),
ncclSum
,
ctx
->
nccl_handle
(),
ctx
->
cuda_stream
()));
#else
UNIMPLEMENTED
();
#endif // WITH_NCCL
}
static
void
AllGather
(
DeviceCtx
*
ctx
,
Blob
*
send_blob
,
Blob
*
recv_blob
)
{
#ifdef WITH_NCCL
auto
elem_cnt
=
(
size_t
)
send_blob
->
shape
().
elem_cnt
();
NcclCheck
(
ncclAllGather
(
send_blob
->
dptr
(),
recv_blob
->
mut_dptr
(),
elem_cnt
,
GetNcclDataType
(
send_blob
->
data_type
()),
ctx
->
nccl_handle
(),
ctx
->
cuda_stream
()));
#else
UNIMPLEMENTED
();
#endif // WITH_NCCL
}
};
...
...
oneflow/core/job/job_desc.cpp
浏览文件 @
57a82568
...
...
@@ -113,6 +113,9 @@ JobDesc::JobDesc(const std::string& job_conf_filepath) {
#ifndef WITH_RDMA
CHECK_EQ
(
job_conf_
.
other
().
use_rdma
(),
false
)
<<
"Please compile ONEFLOW with RDMA"
;
#endif
#ifndef WITH_NCCL
CHECK_EQ
(
job_conf_
.
other
().
enable_nccl
(),
false
)
<<
"Please compile ONEFLOW with NCCL"
;
#endif // WITH_NCCL
int64_t
piece_exp
=
job_conf_
.
other
().
piece_num_of_experiment_phase
();
if
(
job_conf_
.
other
().
has_train_conf
())
{
TrainConf
*
train_conf
=
job_conf_
.
mutable_other
()
->
mutable_train_conf
();
...
...
oneflow/core/job/nccl_comm_manager.cpp
浏览文件 @
57a82568
...
...
@@ -5,6 +5,8 @@
#include "oneflow/core/device/nccl_util.h"
#include "nccl_comm_manager.h"
#ifdef WITH_NCCL
namespace
oneflow
{
NcclCommMgr
::
NcclCommMgr
(
const
Plan
&
plan
)
{
...
...
@@ -108,3 +110,5 @@ void NcclCommMgr::NcclGetUniqueId4Tasks(const std::vector<TaskProto>& tasks,
}
}
// namespace oneflow
#endif // WITH_NCCL
oneflow/core/job/nccl_comm_manager.h
浏览文件 @
57a82568
...
...
@@ -3,6 +3,9 @@
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/plan.pb.h"
#ifdef WITH_NCCL
#include <nccl.h>
namespace
oneflow
{
...
...
@@ -29,4 +32,6 @@ class NcclCommMgr final {
}
// namespace oneflow
#endif // WITH_NCCL
#endif // ONEFLOW_CORE_JOB_NCCL_COMM_MANAGER_H_
oneflow/core/job/runtime.cpp
浏览文件 @
57a82568
...
...
@@ -112,11 +112,15 @@ void Runtime::NewAllGlobal(const Plan& plan, bool is_experiment_phase) {
Global
<
RegstMgr
>::
New
(
plan
);
Global
<
ActorMsgBus
>::
New
();
Global
<
ThreadMgr
>::
New
(
plan
);
#ifdef WITH_NCCL
Global
<
NcclCommMgr
>::
New
(
plan
);
#endif // WITH_NCCL
}
void
Runtime
::
DeleteAllGlobal
()
{
#ifdef WITH_NCCL
Global
<
NcclCommMgr
>::
Delete
();
#endif // WITH_NCCL
Global
<
ThreadMgr
>::
Delete
();
Global
<
ActorMsgBus
>::
Delete
();
Global
<
RegstMgr
>::
Delete
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录