Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
343889cd
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
343889cd
编写于
6月 23, 2020
作者:
C
chenjianping
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
building _ms_mpi with mpi_interface
上级
c9b8a8da
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
161 addition
and
125 deletion
+161
-125
cmake/package.cmake
cmake/package.cmake
+5
-0
mindspore/ccsrc/CMakeLists.txt
mindspore/ccsrc/CMakeLists.txt
+3
-2
mindspore/ccsrc/device/CMakeLists.txt
mindspore/ccsrc/device/CMakeLists.txt
+11
-6
mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc
mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc
+72
-51
mindspore/ccsrc/device/cpu/mpi/mpi_adapter.h
mindspore/ccsrc/device/cpu/mpi/mpi_adapter.h
+28
-12
mindspore/ccsrc/device/cpu/mpi/mpi_interface.cc
mindspore/ccsrc/device/cpu/mpi/mpi_interface.cc
+33
-0
mindspore/ccsrc/device/gpu/mpi/mpi_initializer.cc
mindspore/ccsrc/device/gpu/mpi/mpi_initializer.cc
+0
-7
mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.cc
mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.cc
+1
-1
mindspore/ccsrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.cc
...csrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.cc
+2
-2
mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc
mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc
+3
-3
mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.cc
mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.cc
+2
-2
mindspore/context.py
mindspore/context.py
+0
-38
mindspore/parallel/mpi/_mpi_config.py
mindspore/parallel/mpi/_mpi_config.py
+1
-1
未找到文件。
cmake/package.cmake
浏览文件 @
343889cd
...
...
@@ -128,6 +128,11 @@ if (ENABLE_MPI)
DESTINATION
${
INSTALL_BASE_DIR
}
COMPONENT mindspore
)
install
(
TARGETS mpi_adapter
DESTINATION
${
INSTALL_LIB_DIR
}
COMPONENT mindspore
)
endif
()
if
(
ENABLE_GPU
)
...
...
mindspore/ccsrc/CMakeLists.txt
浏览文件 @
343889cd
...
...
@@ -126,11 +126,12 @@ endforeach ()
set_property
(
SOURCE
${
SUB_OBJECTS_SRC
}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ME
)
add_library
(
mindspore STATIC
${
SUB_OBJECTS_SRC
}
)
target_link_libraries
(
mindspore proto_input
)
if
(
ENABLE_
CPU AND ENABLE_
MPI
)
target_link_libraries
(
mindspore securec mindspore::flatbuffers m
indspore::ompi
)
if
(
ENABLE_MPI
)
target_link_libraries
(
mindspore securec mindspore::flatbuffers m
pi_adapter
)
else
()
target_link_libraries
(
mindspore securec mindspore::flatbuffers
)
endif
()
if
(
NOT WIN32
)
target_link_libraries
(
mindspore dl
)
endif
()
...
...
mindspore/ccsrc/device/CMakeLists.txt
浏览文件 @
343889cd
...
...
@@ -14,17 +14,22 @@ endif ()
if
(
ENABLE_CPU
)
file
(
GLOB_RECURSE CPU_SRC_LIST RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
"cpu/*.cc"
)
if
(
NOT ENABLE_MPI
)
list
(
REMOVE_ITEM CPU_SRC_LIST
"cpu/mpi/mpi_adapter.cc"
)
endif
()
list
(
REMOVE_ITEM CPU_SRC_LIST
"cpu/mpi/mpi_adapter.cc"
)
list
(
REMOVE_ITEM CPU_SRC_LIST
"cpu/mpi/mpi_interface.cc"
)
endif
()
if
(
ENABLE_MPI
)
# _ms_mpi
set_property
(
SOURCE
"gpu/mpi/mpi_initializer.cc"
file
(
GLOB_RECURSE MPI_SRC_LIST
"cpu/mpi/mpi_adapter.cc"
)
set_property
(
SOURCE
${
MPI_SRC_LIST
}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE
)
add_library
(
mpi_adapter SHARED
${
MPI_SRC_LIST
}
)
target_link_libraries
(
mpi_adapter PRIVATE mindspore::ompi
)
set_property
(
SOURCE
"cpu/mpi/mpi_interface.cc"
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE
)
pybind11_add_module
(
_ms_mpi
"
gpu/mpi/mpi_initializer
.cc"
)
target_link_libraries
(
_ms_mpi PRIVATE mindspore::pybind11_module m
indspore::ompi
)
pybind11_add_module
(
_ms_mpi
"
cpu/mpi/mpi_interface
.cc"
)
target_link_libraries
(
_ms_mpi PRIVATE mindspore::pybind11_module m
pi_adapter
)
endif
()
# gpu
...
...
mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc
浏览文件 @
343889cd
...
...
@@ -15,13 +15,41 @@
*/
#include "device/cpu/mpi/mpi_adapter.h"
#ifdef ENABLE_MPI
#include <algorithm>
#include "utils/mpi/mpi_config.h"
#include <sstream>
#include "pybind11/pybind11.h"
#endif // ENABLE_MPI
#include "utils/log_adapter.h"
namespace
mindspore
{
namespace
device
{
namespace
cpu
{
std
::
shared_ptr
<
MPIAdapter
>
MPIAdapter
::
instance_
=
nullptr
;
std
::
shared_ptr
<
MPIAdapter
>
MPIAdapter
::
Instance
()
{
if
(
instance_
==
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"Create new mpi adapter instance."
;
instance_
.
reset
(
new
(
std
::
nothrow
)
MPIAdapter
());
}
return
instance_
;
}
#ifdef ENABLE_MPI
#define RAISE_EXCEPTION(message) \
{ \
std::ostringstream oss; \
oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message; \
pybind11::pybind11_fail(oss.str()); \
}
#define RAISE_EXCEPTION_WITH_PARAM(message, param) \
{ \
std::ostringstream oss; \
oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message << param; \
pybind11::pybind11_fail(oss.str()); \
}
namespace
{
MPI_Op
GetMpiOp
(
const
std
::
string
&
op_type
)
{
if
(
op_type
==
"sum"
)
{
...
...
@@ -33,7 +61,8 @@ MPI_Op GetMpiOp(const std::string &op_type) {
}
else
if
(
op_type
==
"prod"
)
{
return
MPI_PROD
;
}
MS_LOG
(
EXCEPTION
)
<<
"unsupport op_type:"
<<
op_type
;
RAISE_EXCEPTION_WITH_PARAM
(
"unsupport op_type: "
,
op_type
);
return
MPI_SUM
;
}
...
...
@@ -46,80 +75,72 @@ int GetScatterIndex(int rankid, const std::vector<int> &ranks_group) {
}
}
if
(
scatter_index
==
-
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"process rankid "
<<
rankid
<<
" does not in the input rank group!"
;
RAISE_EXCEPTION_WITH_PARAM
(
"local rankid does not in the input rank group!local rank id:"
,
rankid
)
;
}
return
scatter_index
;
}
}
// namespace
MPIAdapter
::
MPIAdapter
()
:
rank_id_
(
0
),
rank_size_
(
0
),
comm_group_world_
(
MPI_GROUP_NULL
)
{
Init
();
}
MPIAdapter
::
MPIAdapter
()
:
comm_group_world_
(
MPI_GROUP_NULL
)
{
Init
();
}
MPIAdapter
::~
MPIAdapter
()
{
int
finalized
;
MPI_Finalized
(
&
finalized
);
if
(
finalized
!=
0
)
{
return
;
}
for
(
auto
iter
=
ranks_group_
.
begin
();
iter
!=
ranks_group_
.
end
();
++
iter
)
{
MPI_Group_free
(
&
iter
->
second
);
}
ranks_group_
.
clear
();
if
(
comm_group_world_
!=
MPI_GROUP_NULL
)
{
MPI_Group_free
(
&
comm_group_world_
);
comm_group_world_
=
MPI_GROUP_NULL
;
}
int
finalized
;
MPI_Finalized
(
&
finalized
);
if
(
finalized
==
0
)
{
MPI_Finalize
();
}
MPI_Finalize
();
}
MPIAdapter
&
MPIAdapter
::
Instance
()
{
static
MPIAdapter
instance
;
return
instance
;
}
int
MPIAdapter
::
GetRankId
()
const
{
return
rank_id_
;
}
void
MPIAdapter
::
Init
()
{
static
bool
init
=
false
;
if
(
init
)
{
return
;
}
auto
mpi_config_ptr
=
MpiConfig
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
mpi_config_ptr
);
if
(
!
mpi_config_ptr
->
enable_mpi
())
{
MS_LOG
(
EXCEPTION
)
<<
"MPI is disabled now!Please enable mpi with mpi config first."
;
}
int
init_flag
=
0
;
if
(
MPI_Initialized
(
&
init_flag
)
!=
MPI_SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"Check mpi initialized fail!"
;
RAISE_EXCEPTION
(
"Check mpi initialized fail!"
)
;
}
if
(
init_flag
==
0
)
{
auto
ret
=
MPI_Init
(
nullptr
,
nullptr
);
if
(
ret
!=
MPI_SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failed to init mpi!"
;
RAISE_EXCEPTION
(
"Failed to init mpi!"
)
;
}
}
MPI_Comm_group
(
MPI_COMM_WORLD
,
&
comm_group_world_
);
if
(
comm_group_world_
==
MPI_GROUP_NULL
)
{
MS_LOG
(
EXCEPTION
)
<<
"comm_group_world_ init fail!"
;
RAISE_EXCEPTION
(
"comm_group_world_ init fail!"
)
;
}
auto
ret
=
MPI_Comm_rank
(
MPI_COMM_WORLD
,
&
rank_id_
);
if
(
ret
!=
MPI_SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failed to init mpi rank id!"
;
RAISE_EXCEPTION
(
"Failed to init mpi rank id!"
)
;
}
ret
=
MPI_Comm_size
(
MPI_COMM_WORLD
,
&
rank_size_
);
if
(
ret
!=
MPI_SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failed to init mpi rank size!rankid:"
<<
rank_id_
;
RAISE_EXCEPTION_WITH_PARAM
(
"Failed to init mpi rank size!rankid:"
,
rank_id_
)
}
init
=
true
;
}
MPI_Group
MPIAdapter
::
AddGroup
(
const
std
::
vector
<
int
>
&
ranks
)
{
if
(
ranks
.
size
()
>
static_cast
<
size_t
>
(
rank_size_
)
||
ranks
.
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"input rank size: "
<<
ranks
.
size
()
<<
", max rank size: "
<<
rank_size_
;
RAISE_EXCEPTION_WITH_PARAM
(
"input rank size:"
,
ranks
.
size
())
;
}
if
(
std
::
find
(
ranks
.
begin
(),
ranks
.
end
(),
rank_id_
)
==
ranks
.
end
())
{
MS_LOG
(
ERROR
)
<<
"rankid:"
<<
rank_id_
<<
" is not in the input group."
;
return
MPI_GROUP_NULL
;
RAISE_EXCEPTION_WITH_PARAM
(
"local rankid does not in the input rank group!local rank id:"
,
rank_id_
);
}
std
::
lock_guard
<
std
::
mutex
>
lock
(
group_mutex_
);
auto
iter
=
ranks_group_
.
find
(
ranks
);
...
...
@@ -135,29 +156,28 @@ MPI_Group MPIAdapter::AddGroup(const std::vector<int> &ranks) {
MPI_Group
group
=
MPI_GROUP_NULL
;
MPI_Group_incl
(
comm_group_world_
,
ranks
.
size
(),
ranks_input
.
data
(),
&
group
);
if
(
group
==
MPI_GROUP_NULL
)
{
MS_LOG
(
EXCEPTION
)
<<
"create mpi group fail!rankid:"
<<
rank_id_
;
RAISE_EXCEPTION_WITH_PARAM
(
"create mpi group fail!rankid:"
,
rank_id_
)
}
ranks_group_
[
ranks
]
=
group
;
MS_LOG
(
INFO
)
<<
"rank:"
<<
rank_id_
<<
" add group:"
<<
group
;
return
group
;
}
bool
MPIAdapter
::
ReduceScatter
(
const
float
*
input
,
float
*
output
,
const
std
::
vector
<
int
>
&
ranks_group
,
size_t
data_num
,
const
std
::
string
&
op_type
)
{
if
(
ranks_group
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"input rank group is empty!"
;
RAISE_EXCEPTION
(
"input rank group is empty!"
)
;
return
false
;
}
auto
group
=
AddGroup
(
ranks_group
);
if
(
group
==
MPI_GROUP_NULL
)
{
MS_LOG
(
EXCEPTION
)
<<
"Get mpi group fail!rankid:"
<<
rank_id_
;
RAISE_EXCEPTION_WITH_PARAM
(
"Get mpi group fail!rankid:"
,
rank_id_
)
}
MPI_Comm
comm
;
MPI_Comm_create_group
(
MPI_COMM_WORLD
,
group
,
0
,
&
comm
);
if
(
comm
==
MPI_COMM_NULL
)
{
MS_LOG
(
EXCEPTION
)
<<
"create mpi comm fail!rankid:"
<<
rank_id_
;
RAISE_EXCEPTION_WITH_PARAM
(
"create mpi comm fail!rankid:"
,
rank_id_
)
;
}
std
::
vector
<
int
>
receive_count
(
ranks_group
.
size
(),
0
);
for
(
size_t
i
=
0
;
i
<
ranks_group
.
size
();
++
i
)
{
...
...
@@ -168,13 +188,13 @@ bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vec
auto
ret
=
MPI_Reduce_scatter
(
input
,
output
,
receive_count
.
data
(),
MPI_FLOAT
,
op
,
comm
);
bool
result
=
true
;
if
(
ret
!=
MPI_SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"mpi reduce_scatter fail!ret = "
<<
ret
<<
", rankid:"
<<
rank_id_
;
RAISE_EXCEPTION_WITH_PARAM
(
"mpi reduce_scatter fail!ret = "
,
ret
)
;
result
=
false
;
}
ret
=
MPI_Comm_free
(
&
comm
);
if
(
ret
!=
MPI_SUCCESS
)
{
MS_LOG
(
WARNING
)
<<
"mpi comm free fail! ret = "
<<
ret
<<
", rankid:"
<<
rank_id_
;
RAISE_EXCEPTION_WITH_PARAM
(
"mpi comm free fail! ret = "
,
ret
)
;
}
return
result
;
}
...
...
@@ -184,19 +204,18 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int
int
scatter_index
=
GetScatterIndex
(
rank_id_
,
ranks_group
);
auto
group
=
AddGroup
(
ranks_group
);
if
(
group
==
MPI_GROUP_NULL
)
{
MS_LOG
(
EXCEPTION
)
<<
"Get mpi group fail!rankid:"
<<
rank_id_
;
RAISE_EXCEPTION_WITH_PARAM
(
"Get mpi group fail!rankid:"
,
rank_id_
)
;
}
MPI_Comm
comm
;
MPI_Comm_create_group
(
MPI_COMM_WORLD
,
group
,
0
,
&
comm
);
if
(
comm
==
MPI_COMM_NULL
)
{
MS_LOG
(
EXCEPTION
)
<<
"create mpi comm fail!rankid:"
<<
rank_id_
;
RAISE_EXCEPTION_WITH_PARAM
(
"create mpi comm fail!rankid:"
,
rank_id_
)
;
}
MPI_Win
window
;
auto
ret
=
MPI_Win_create
(
input
,
input_data_num
*
sizeof
(
float
),
sizeof
(
float
),
MPI_INFO_NULL
,
comm
,
&
window
);
if
(
ret
!=
MPI_SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"mpi window create fail! ret = "
<<
ret
;
return
false
;
RAISE_EXCEPTION_WITH_PARAM
(
"mpi window create fail! ret = "
,
ret
);
}
MPI_Win_fence
(
0
,
window
);
for
(
size_t
i
=
0
;
i
<
ranks_group
.
size
();
++
i
)
{
...
...
@@ -208,18 +227,20 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int
ret
=
MPI_Accumulate
(
input
+
i
*
input_data_num
,
input_data_num
,
MPI_FLOAT
,
remote_rank
,
i
*
input_data_num
,
input_data_num
,
MPI_FLOAT
,
op
,
window
);
if
(
ret
!=
MPI_SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"mpi accumulate "
<<
op_type
<<
" fail!ret = "
<<
ret
;
RAISE_EXCEPTION_WITH_PARAM
(
"mpi accumulate fail!ret = "
,
ret
)
;
}
}
MPI_Win_fence
(
0
,
window
);
if
(
output
!=
nullptr
)
{
auto
data_size
=
input_data_num
*
sizeof
(
float
);
if
(
output_size
<
data_size
)
{
MS_LOG
(
EXCEPTION
)
<<
"output buffer size "
<<
output_size
<<
" < input size "
<<
data_size
;
std
::
ostringstream
exception_msg
;
exception_msg
<<
"output buffer size "
<<
output_size
<<
" < input size "
<<
data_size
;
RAISE_EXCEPTION
(
exception_msg
.
str
())
}
auto
copy_ret
=
memcpy_s
(
output
,
output_size
,
input
+
scatter_index
*
input_data_num
,
data_size
);
if
(
copy_ret
!=
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"copy output memory fail!ret = "
<<
copy_ret
;
RAISE_EXCEPTION_WITH_PARAM
(
"copy output memory fail!ret = "
,
copy_ret
)
;
}
}
MPI_Win_free
(
&
window
);
...
...
@@ -229,31 +250,31 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int
bool
MPIAdapter
::
AllGather
(
const
float
*
input
,
float
*
output
,
const
std
::
vector
<
int
>
&
ranks_group
,
size_t
data_num
)
{
if
(
ranks_group
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"input rank group is empty!"
;
RAISE_EXCEPTION
(
"input rank group is empty!"
)
;
return
false
;
}
auto
group
=
AddGroup
(
ranks_group
);
if
(
group
==
MPI_GROUP_NULL
)
{
MS_LOG
(
EXCEPTION
)
<<
"Get mpi group fail! rankid:"
<<
rank_id_
;
RAISE_EXCEPTION_WITH_PARAM
(
"Get mpi group fail! rankid:"
,
rank_id_
)
;
}
MPI_Comm
comm
;
MPI_Comm_create_group
(
MPI_COMM_WORLD
,
group
,
0
,
&
comm
);
if
(
comm
==
MPI_COMM_NULL
)
{
MS_LOG
(
EXCEPTION
)
<<
"create mpi comm fail! rankid:"
<<
rank_id_
;
RAISE_EXCEPTION_WITH_PARAM
(
"create mpi comm fail! rankid:"
,
rank_id_
)
;
}
auto
ret
=
MPI_Allgather
(
input
,
data_num
,
MPI_FLOAT
,
output
,
data_num
,
MPI_FLOAT
,
comm
);
bool
result
=
true
;
if
(
ret
!=
MPI_SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"mpi allgater fail!ret = "
<<
ret
<<
", rankid:"
<<
rank_id_
;
result
=
false
;
RAISE_EXCEPTION_WITH_PARAM
(
"mpi allgater fail!ret = "
,
ret
);
}
ret
=
MPI_Comm_free
(
&
comm
);
if
(
ret
!=
MPI_SUCCESS
)
{
MS_LOG
(
WARNING
)
<<
"mpi comm free fail!ret = "
<<
ret
<<
",rankid:"
<<
rank_id_
;
RAISE_EXCEPTION_WITH_PARAM
(
"mpi comm free fail!ret = "
,
ret
)
;
}
return
result
;
return
true
;
}
#endif // ENABLE_MPI
}
// namespace cpu
}
// namespace device
}
// namespace mindspore
mindspore/ccsrc/device/cpu/mpi/mpi_adapter.h
浏览文件 @
343889cd
...
...
@@ -22,37 +22,53 @@
#include <map>
#include <string>
#include <mutex>
#endif // ENABLE_MPI
#include <memory>
namespace
mindspore
{
namespace
device
{
namespace
cpu
{
#ifndef FUNC_EXPORT
#define FUNC_EXPORT __attribute__((visibility("default")))
#endif
constexpr
auto
kOpTypeSum
=
"sum"
;
class
MPIAdapter
{
public:
~
MPIAdapter
();
static
MPIAdapter
&
Instance
();
int
GetRankId
()
const
;
bool
ReduceScatter
(
const
float
*
input
,
float
*
output
,
const
std
::
vector
<
int
>
&
ranks_group
,
size_t
data_num
,
const
std
::
string
&
op_type
=
kOpTypeSum
);
bool
ReduceScatterOverwriteInput
(
float
*
input
,
const
std
::
vector
<
int
>
&
ranks_group
,
size_t
input_data_num
,
size_t
output_size
,
const
std
::
string
&
op_type
=
kOpTypeSum
,
float
*
output
=
nullptr
);
bool
AllGather
(
const
float
*
input
,
float
*
output
,
const
std
::
vector
<
int
>
&
ranks_group
,
size_t
data_num
);
FUNC_EXPORT
static
std
::
shared_ptr
<
MPIAdapter
>
Instance
();
FUNC_EXPORT
int
GetRankId
()
const
{
return
rank_id_
;
}
FUNC_EXPORT
int
GetRankSize
()
const
{
return
rank_size_
;
}
#ifdef ENABLE_MPI
FUNC_EXPORT
~
MPIAdapter
();
FUNC_EXPORT
bool
ReduceScatter
(
const
float
*
input
,
float
*
output
,
const
std
::
vector
<
int
>
&
ranks_group
,
size_t
data_num
,
const
std
::
string
&
op_type
=
kOpTypeSum
);
FUNC_EXPORT
bool
ReduceScatterOverwriteInput
(
float
*
input
,
const
std
::
vector
<
int
>
&
ranks_group
,
size_t
in_data_num
,
size_t
output_size
,
const
std
::
string
&
op_type
=
kOpTypeSum
,
float
*
output
=
nullptr
);
FUNC_EXPORT
bool
AllGather
(
const
float
*
input
,
float
*
output
,
const
std
::
vector
<
int
>
&
ranks_group
,
size_t
data_num
);
#else
FUNC_EXPORT
~
MPIAdapter
()
=
default
;
#endif // ENABLE_MPI
private:
#ifdef ENABLE_MPI
MPIAdapter
();
void
Init
();
MPI_Group
AddGroup
(
const
std
::
vector
<
int
>
&
ranks
);
int
rank_id_
;
int
rank_size_
;
MPI_Group
comm_group_world_
;
// key:ranks group, value: mpi group
std
::
map
<
std
::
vector
<
int
>
,
MPI_Group
>
ranks_group_
;
std
::
mutex
group_mutex_
;
#else
MPIAdapter
()
=
default
;
#endif // ENABLE_MPI
int
rank_id_
{
-
1
};
int
rank_size_
{
0
};
static
std
::
shared_ptr
<
MPIAdapter
>
instance_
;
};
}
// namespace cpu
}
// namespace device
}
// namespace mindspore
#endif // ENABLE_MPI
#endif // MINDSPORE_CCSRC_DEVICE_CPU_MPI_MPI_ADAPTER_H_
mindspore/ccsrc/device/cpu/mpi/mpi_interface.cc
0 → 100644
浏览文件 @
343889cd
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <pybind11/operators.h>
#include "device/cpu/mpi/mpi_adapter.h"
namespace
mindspore
{
namespace
device
{
namespace
cpu
{
int
get_rank_id
()
{
return
MPIAdapter
::
Instance
()
->
GetRankId
();
}
int
get_rank_size
()
{
return
MPIAdapter
::
Instance
()
->
GetRankSize
();
}
PYBIND11_MODULE
(
_ms_mpi
,
mpi_interface
)
{
mpi_interface
.
doc
()
=
"mindspore mpi python wrapper"
;
mpi_interface
.
def
(
"get_rank_id"
,
&
get_rank_id
,
"get rank id"
);
mpi_interface
.
def
(
"get_rank_size"
,
&
get_rank_size
,
"get rank size"
);
}
}
// namespace cpu
}
// namespace device
}
// namespace mindspore
mindspore/ccsrc/device/gpu/mpi/mpi_initializer.cc
浏览文件 @
343889cd
...
...
@@ -17,7 +17,6 @@
#include "device/gpu/mpi/mpi_initializer.h"
#include <mpi.h>
#include <pybind11/operators.h>
#include <iostream>
namespace
mindspore
{
...
...
@@ -54,12 +53,6 @@ MPIInitializer &MPIInitializer::GetInstance() {
int
MPIInitializer
::
get_rank_id
()
{
return
MPIInitializer
::
GetInstance
().
rank_id_
;
}
int
MPIInitializer
::
get_rank_size
()
{
return
MPIInitializer
::
GetInstance
().
rank_size_
;
}
PYBIND11_MODULE
(
_ms_mpi
,
mpi_initializer
)
{
mpi_initializer
.
doc
()
=
"mindspore mpi python wrapper"
;
mpi_initializer
.
def
(
"get_rank_id"
,
&
MPIInitializer
::
get_rank_id
,
"get rank id"
);
mpi_initializer
.
def
(
"get_rank_size"
,
&
MPIInitializer
::
get_rank_size
,
"get rank size"
);
}
}
// namespace gpu
}
// namespace device
}
// namespace mindspore
mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.cc
浏览文件 @
343889cd
...
...
@@ -47,7 +47,7 @@ bool AllGatherCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
auto
output_addr
=
reinterpret_cast
<
float
*>
(
outputs
[
0
]
->
addr
);
auto
input_data_num
=
inputs
[
0
]
->
size
/
sizeof
(
float
);
return
device
::
cpu
::
MPIAdapter
::
Instance
()
.
AllGather
(
input_addr
,
output_addr
,
ranks_group_
,
input_data_num
);
return
device
::
cpu
::
MPIAdapter
::
Instance
()
->
AllGather
(
input_addr
,
output_addr
,
ranks_group_
,
input_data_num
);
}
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.cc
浏览文件 @
343889cd
...
...
@@ -51,8 +51,8 @@ bool EmbeddingLookUpCommGradCPUKernel::Launch(const std::vector<kernel::AddressP
size_t
input_split_lens
=
input_size
/
split_num_
/
sizeof
(
float_t
);
size_t
output_split_lens
=
output_size
/
split_num_
/
sizeof
(
float_t
);
for
(
int
i
=
0
;
i
<
split_num_
;
i
++
)
{
device
::
cpu
::
MPIAdapter
::
Instance
()
.
AllGather
(
input_addr
+
i
*
input_split_lens
,
output_addr
+
i
*
output_split_lens
,
rank_group
,
input_split_lens
);
device
::
cpu
::
MPIAdapter
::
Instance
()
->
AllGather
(
input_addr
+
i
*
input_split_lens
,
output_addr
+
i
*
output_split_lens
,
rank_group
,
input_split_lens
);
}
#if defined(_WIN32) || defined(_WIN64)
auto
end_time
=
std
::
chrono
::
steady_clock
::
now
();
...
...
mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc
浏览文件 @
343889cd
...
...
@@ -105,9 +105,9 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
size_t
reduce_scatter_out_lens
=
one_split_lens
/
8
;
const
std
::
vector
<
int
>
&
group
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
};
for
(
int
i
=
0
;
i
<
split_num_
;
i
++
)
{
device
::
cpu
::
MPIAdapter
::
Instance
()
.
ReduceScatter
(
reinterpret_cast
<
float
*>
(
gather_v2_out_
)
+
i
*
one_split_lens
,
output_addr
+
i
*
reduce_scatter_out_lens
,
group
,
one_split_lens
/
8
,
"sum"
);
device
::
cpu
::
MPIAdapter
::
Instance
()
->
ReduceScatter
(
reinterpret_cast
<
float
*>
(
gather_v2_out_
)
+
i
*
one_split_lens
,
output_addr
+
i
*
reduce_scatter_out_lens
,
group
,
one_split_lens
/
8
,
"sum"
);
}
}
#endif
...
...
mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.cc
浏览文件 @
343889cd
...
...
@@ -47,8 +47,8 @@ bool ReduceScatterCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input
auto
output_addr
=
reinterpret_cast
<
float
*>
(
outputs
[
0
]
->
addr
);
auto
output_data_num
=
outputs
[
0
]
->
size
/
sizeof
(
float
);
return
device
::
cpu
::
MPIAdapter
::
Instance
()
.
ReduceScatter
(
input_addr
,
output_addr
,
ranks_group_
,
output_data_num
,
op_type_
);
return
device
::
cpu
::
MPIAdapter
::
Instance
()
->
ReduceScatter
(
input_addr
,
output_addr
,
ranks_group_
,
output_data_num
,
op_type_
);
}
}
// namespace kernel
}
// namespace mindspore
mindspore/context.py
浏览文件 @
343889cd
...
...
@@ -25,7 +25,6 @@ from mindspore._c_expression import MSContext
from
mindspore._checkparam
import
args_type_check
from
mindspore.parallel._auto_parallel_context
import
_set_auto_parallel_context
,
_get_auto_parallel_context
,
\
_reset_auto_parallel_context
from
mindspore.parallel.mpi._mpi_config
import
_set_mpi_config
,
_get_mpi_config
__all__
=
[
'GRAPH_MODE'
,
'PYNATIVE_MODE'
,
'set_context'
,
'get_context'
,
'set_auto_parallel_context'
,
'get_auto_parallel_context'
,
'reset_auto_parallel_context'
]
...
...
@@ -608,40 +607,3 @@ def get_context(attr_key):
raise
ValueError
(
"Get context keyword %s is not recognized!"
%
attr_key
)
return
getattr
(
_context
(),
attr_key
)
@
args_type_check
(
enable_mpi
=
bool
)
def
set_mpi_config
(
**
kwargs
):
"""
Sets mpi config for running environment.
mpi config should be configured before running your program. If there is no configuration,
mpi moudle will be disabled by default.
Note:
Attribute name is required for setting attributes.
Args:
enable_mpi (bool): Whether to enable mpi. Default: False.
Raises:
ValueError: If input key is not an attribute in mpi config.
Examples:
>>> mpiconfig.set_mpi_config(enable_mpi=True)
"""
_set_mpi_config
(
**
kwargs
)
def
get_mpi_config
(
attr_key
):
"""
Gets mpi config attribute value according to the input key.
Args:
attr_key (str): The key of the attribute.
Returns:
Object, The value of given attribute key.
Raises:
ValueError: If input key is not an attribute in context.
"""
return
_get_mpi_config
(
attr_key
)
mindspore/parallel/mpi/_mpi_config.py
浏览文件 @
343889cd
...
...
@@ -104,7 +104,7 @@ def _get_mpi_config(attr_key):
Object, The value of given attribute key.
Raises:
ValueError: If input key is not an attribute in con
text
.
ValueError: If input key is not an attribute in con
fig
.
"""
if
not
hasattr
(
_mpi_config
(),
attr_key
):
raise
ValueError
(
"Get context keyword %s is not recognized!"
%
attr_key
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录