Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8593c4d8
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8593c4d8
编写于
6月 15, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 15, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2051 support host mpi
Merge pull request !2051 from chenjianping/host_reduce
上级
842f8231
6034f9c1
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
355 addition
and
21 deletion
+355
-21
build.sh
build.sh
+1
-1
mindspore/ccsrc/device/CMakeLists.txt
mindspore/ccsrc/device/CMakeLists.txt
+9
-7
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
+24
-7
mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc
mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc
+66
-1
mindspore/ccsrc/device/cpu/mpi/mpi_adapter.h
mindspore/ccsrc/device/cpu/mpi/mpi_adapter.h
+3
-1
mindspore/ccsrc/pipeline/init.cc
mindspore/ccsrc/pipeline/init.cc
+6
-0
mindspore/ccsrc/utils/mpi/mpi_config.cc
mindspore/ccsrc/utils/mpi/mpi_config.cc
+31
-0
mindspore/ccsrc/utils/mpi/mpi_config.h
mindspore/ccsrc/utils/mpi/mpi_config.h
+42
-0
mindspore/context.py
mindspore/context.py
+38
-0
mindspore/parallel/mpi/__init__.py
mindspore/parallel/mpi/__init__.py
+14
-0
mindspore/parallel/mpi/_mpi_config.py
mindspore/parallel/mpi/_mpi_config.py
+111
-0
tests/st/ops/cpu/test_reduce_scatter.py
tests/st/ops/cpu/test_reduce_scatter.py
+10
-4
未找到文件。
build.sh
浏览文件 @
8593c4d8
...
...
@@ -49,7 +49,7 @@ usage()
echo
" -Q Enable dump memory, default off"
echo
" -D Enable dumping of function graph ir, default on"
echo
" -z Compile dataset & mindrecord, default on"
echo
" -M Enable MPI and NCCL for GPU training, default on"
echo
" -M Enable MPI and NCCL for GPU training,
gpu
default on"
echo
" -V Specify the minimum required cuda version, default CUDA 9.2"
echo
" -I Compile predict, default off"
echo
" -K Compile with AKG, default off"
...
...
mindspore/ccsrc/device/CMakeLists.txt
浏览文件 @
8593c4d8
...
...
@@ -14,17 +14,19 @@ endif ()
if
(
ENABLE_CPU
)
file
(
GLOB_RECURSE CPU_SRC_LIST RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
"cpu/*.cc"
)
if
(
ENABLE_MPI
)
# _ms_mpi
set_property
(
SOURCE
"gpu/mpi/mpi_initializer.cc"
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE
)
pybind11_add_module
(
_ms_mpi
"gpu/mpi/mpi_initializer.cc"
)
target_link_libraries
(
_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi
)
else
()
if
(
NOT ENABLE_MPI
)
list
(
REMOVE_ITEM CPU_SRC_LIST
"cpu/mpi/mpi_adapter.cc"
)
endif
()
endif
()
if
(
ENABLE_MPI
)
# _ms_mpi
set_property
(
SOURCE
"gpu/mpi/mpi_initializer.cc"
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE
)
pybind11_add_module
(
_ms_mpi
"gpu/mpi/mpi_initializer.cc"
)
target_link_libraries
(
_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi
)
endif
()
# gpu
if
(
ENABLE_GPU
)
file
(
GLOB_RECURSE CUDA_SRC_LIST RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
"gpu/*.cc"
"gpu/*.cu"
)
...
...
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
浏览文件 @
8593c4d8
...
...
@@ -25,6 +25,7 @@
#include "device/ascend/ascend_device_address.h"
#include "device/cpu/mpi/mpi_adapter.h"
#include "utils/context/ms_context.h"
#include "utils/mpi/mpi_config.h"
#include "device/ascend/profiling/profiling_manager.h"
#include "hccl/hcom.h"
#include "common/trans.h"
...
...
@@ -510,19 +511,35 @@ bool AscendKernelRuntime::HcclInit() {
MS_LOG
(
ERROR
)
<<
"file path "
<<
config_path_str
<<
" does not exist"
;
return
false
;
}
const
char
*
identify
=
nullptr
;
#ifdef ENABLE_MPI
int
rank_id
=
device
::
cpu
::
MPIAdapter
::
Instance
().
GetRankId
();
const
char
*
offset
=
std
::
getenv
(
"RANK_OFFSET"
);
if
(
offset
!=
nullptr
)
{
int
rank_offset
=
std
::
stoi
(
offset
);
rank_id
+=
rank_offset
;
std
::
string
rank_id_tmp
;
auto
mpi_config_ptr
=
MpiConfig
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
mpi_config_ptr
);
if
(
mpi_config_ptr
->
enable_mpi
())
{
int
rank_id
=
device
::
cpu
::
MPIAdapter
::
Instance
().
GetRankId
();
const
char
*
offset
=
std
::
getenv
(
"RANK_OFFSET"
);
if
(
offset
!=
nullptr
)
{
try
{
int
rank_offset
=
std
::
stoi
(
offset
);
rank_id
+=
rank_offset
;
}
catch
(
std
::
invalid_argument
)
{
MS_LOG
(
EXCEPTION
)
<<
"stoi invalid argument:"
<<
offset
;
}
catch
(
std
::
out_of_range
)
{
MS_LOG
(
EXCEPTION
)
<<
"stoi out_of_range:"
<<
offset
;
}
}
rank_id_tmp
=
std
::
to_string
(
rank_id
);
identify
=
rank_id_tmp
.
c_str
();
}
else
{
identify
=
std
::
getenv
(
"RANK_ID"
);
}
const
char
*
identify
=
reinterpret_cast
<
const
char
*>
(
std
::
to_string
(
rank_id
).
c_str
());
#else
const
char
*
identify
=
std
::
getenv
(
"RANK_ID"
);
identify
=
std
::
getenv
(
"RANK_ID"
);
#endif
if
(
identify
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"get hccl rankid failed, please set env RANK_ID"
;
free
(
full_path
);
return
false
;
}
MS_LOG
(
INFO
)
<<
"MINDSPORE_HCCL_CONFIG_PATH : "
<<
full_path
<<
", RANK_ID: "
<<
identify
;
...
...
mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc
浏览文件 @
8593c4d8
...
...
@@ -16,6 +16,7 @@
#include "device/cpu/mpi/mpi_adapter.h"
#include <algorithm>
#include "utils/mpi/mpi_config.h"
#include "utils/log_adapter.h"
namespace
mindspore
{
...
...
@@ -35,6 +36,20 @@ MPI_Op GetMpiOp(const std::string &op_type) {
MS_LOG
(
EXCEPTION
)
<<
"unsupport op_type:"
<<
op_type
;
return
MPI_SUM
;
}
int
GetScatterIndex
(
int
rankid
,
const
std
::
vector
<
int
>
&
ranks_group
)
{
int
scatter_index
=
-
1
;
for
(
size_t
i
=
0
;
i
<
ranks_group
.
size
();
++
i
)
{
if
(
ranks_group
[
i
]
==
rankid
)
{
scatter_index
=
static_cast
<
int
>
(
i
);
break
;
}
}
if
(
scatter_index
==
-
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"process rankid "
<<
rankid
<<
" does not in the input rank group!"
;
}
return
scatter_index
;
}
}
// namespace
MPIAdapter
::
MPIAdapter
()
:
rank_id_
(
0
),
rank_size_
(
0
),
comm_group_world_
(
MPI_GROUP_NULL
)
{
Init
();
}
...
...
@@ -65,6 +80,11 @@ void MPIAdapter::Init() {
if
(
init
)
{
return
;
}
auto
mpi_config_ptr
=
MpiConfig
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
mpi_config_ptr
);
if
(
!
mpi_config_ptr
->
enable_mpi
())
{
MS_LOG
(
EXCEPTION
)
<<
"MPI is disabled now!Please enable mpi with mpi config first."
;
}
int
init_flag
=
0
;
if
(
MPI_Initialized
(
&
init_flag
)
!=
MPI_SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"Check mpi initialized fail!"
;
...
...
@@ -123,7 +143,7 @@ MPI_Group MPIAdapter::AddGroup(const std::vector<int> &ranks) {
return
group
;
}
bool
MPIAdapter
::
ReduceScatter
(
float
*
input
,
float
*
output
,
const
std
::
vector
<
int
>
&
ranks_group
,
size_t
data_num
,
bool
MPIAdapter
::
ReduceScatter
(
const
float
*
input
,
float
*
output
,
const
std
::
vector
<
int
>
&
ranks_group
,
size_t
data_num
,
const
std
::
string
&
op_type
)
{
if
(
ranks_group
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"input rank group is empty!"
;
...
...
@@ -159,6 +179,51 @@ bool MPIAdapter::ReduceScatter(float *input, float *output, const std::vector<in
return
result
;
}
bool
MPIAdapter
::
ReduceScatterOverwriteInput
(
float
*
input
,
const
std
::
vector
<
int
>
&
ranks_group
,
size_t
data_num
,
const
std
::
string
&
op_type
,
float
*
output
)
{
int
scatter_index
=
GetScatterIndex
(
rank_id_
,
ranks_group
);
auto
group
=
AddGroup
(
ranks_group
);
if
(
group
==
MPI_GROUP_NULL
)
{
MS_LOG
(
EXCEPTION
)
<<
"Get mpi group fail!rankid:"
<<
rank_id_
;
}
MPI_Comm
comm
;
MPI_Comm_create_group
(
MPI_COMM_WORLD
,
group
,
0
,
&
comm
);
if
(
comm
==
MPI_COMM_NULL
)
{
MS_LOG
(
EXCEPTION
)
<<
"create mpi comm fail!rankid:"
<<
rank_id_
;
}
MPI_Win
window
;
auto
ret
=
MPI_Win_create
(
input
,
data_num
*
sizeof
(
float
),
sizeof
(
float
),
MPI_INFO_NULL
,
comm
,
&
window
);
if
(
ret
!=
MPI_SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"mpi window create fail! ret = "
<<
ret
;
return
false
;
}
MPI_Win_fence
(
0
,
window
);
for
(
size_t
i
=
0
;
i
<
ranks_group
.
size
();
++
i
)
{
int
remote_rank
=
ranks_group
[
i
];
if
(
rank_id_
==
remote_rank
)
{
continue
;
}
auto
op
=
GetMpiOp
(
op_type
);
ret
=
MPI_Accumulate
(
input
+
i
*
data_num
,
data_num
,
MPI_FLOAT
,
remote_rank
,
i
*
data_num
,
data_num
,
MPI_FLOAT
,
op
,
window
);
if
(
ret
!=
MPI_SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"mpi accumulate "
<<
op_type
<<
" fail!ret = "
<<
ret
;
}
}
MPI_Win_fence
(
0
,
window
);
if
(
output
!=
nullptr
)
{
auto
data_size
=
data_num
*
sizeof
(
float
);
auto
copy_ret
=
memcpy_s
(
output
,
data_size
,
input
+
scatter_index
*
data_num
,
data_size
);
if
(
copy_ret
!=
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"copy output memory fail!"
;
}
}
MPI_Win_free
(
&
window
);
MPI_Comm_free
(
&
comm
);
return
true
;
}
bool
MPIAdapter
::
AllGather
(
float
*
input
,
float
*
output
,
const
std
::
vector
<
int
>
&
ranks_group
,
size_t
data_num
)
{
if
(
ranks_group
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"input rank group is empty!"
;
...
...
mindspore/ccsrc/device/cpu/mpi/mpi_adapter.h
浏览文件 @
8593c4d8
...
...
@@ -32,8 +32,10 @@ class MPIAdapter {
~
MPIAdapter
();
static
MPIAdapter
&
Instance
();
int
GetRankId
()
const
;
bool
ReduceScatter
(
float
*
input
,
float
*
output
,
const
std
::
vector
<
int
>
&
ranks_group
,
size_t
data_num
,
bool
ReduceScatter
(
const
float
*
input
,
float
*
output
,
const
std
::
vector
<
int
>
&
ranks_group
,
size_t
data_num
,
const
std
::
string
&
op_type
=
kOpTypeSum
);
bool
ReduceScatterOverwriteInput
(
float
*
input
,
const
std
::
vector
<
int
>
&
ranks_group
,
size_t
data_num
,
const
std
::
string
&
op_type
=
kOpTypeSum
,
float
*
output
=
nullptr
);
bool
AllGather
(
float
*
input
,
float
*
output
,
const
std
::
vector
<
int
>
&
ranks_group
,
size_t
data_num
);
private:
...
...
mindspore/ccsrc/pipeline/init.cc
浏览文件 @
8593c4d8
...
...
@@ -26,6 +26,7 @@
#include "pipeline/parse/python_adapter.h"
#include "utils/summary/event_writer.h"
#include "utils/config_manager.h"
#include "utils/mpi/mpi_config.h"
#include "parallel/context.h"
#include "parallel/device_manager.h"
#include "parallel/costmodel_context.h"
...
...
@@ -147,6 +148,11 @@ PYBIND11_MODULE(_c_expression, m) {
.
def
(
"get_max_device_memory"
,
&
mindspore
::
MsContext
::
max_device_memory
,
"Get deivce memory max size."
)
.
def
(
"set_max_device_memory"
,
&
mindspore
::
MsContext
::
set_max_device_memory
,
"Set deivce memory max size."
);
(
void
)
py
::
class_
<
mindspore
::
MpiConfig
,
std
::
shared_ptr
<
mindspore
::
MpiConfig
>>
(
m
,
"MpiConfig"
)
.
def_static
(
"get_instance"
,
&
mindspore
::
MpiConfig
::
GetInstance
,
"Get mpi config instance."
)
.
def
(
"get_enable_mpi"
,
&
mindspore
::
MpiConfig
::
enable_mpi
,
"Get whether enable mpi."
)
.
def
(
"set_enable_mpi"
,
&
mindspore
::
MpiConfig
::
set_enable_mpi
,
"Set whether to enable mpi."
);
(
void
)
py
::
class_
<
ParallelContext
,
std
::
shared_ptr
<
ParallelContext
>>
(
m
,
"AutoParallelContext"
)
.
def_static
(
"get_instance"
,
&
ParallelContext
::
GetInstance
,
"Get auto parallel context instance."
)
.
def
(
"get_device_num"
,
&
ParallelContext
::
device_num
,
"Get device num."
)
...
...
mindspore/ccsrc/utils/mpi/mpi_config.cc
0 → 100644
浏览文件 @
8593c4d8
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/mpi/mpi_config.h"
namespace
mindspore
{
std
::
shared_ptr
<
MpiConfig
>
MpiConfig
::
instance_
=
nullptr
;
std
::
shared_ptr
<
MpiConfig
>
MpiConfig
::
GetInstance
()
{
if
(
instance_
==
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"Create new mpi config instance."
;
instance_
.
reset
(
new
(
std
::
nothrow
)
MpiConfig
());
}
return
instance_
;
}
}
// namespace mindspore
mindspore/ccsrc/utils/mpi/mpi_config.h
0 → 100644
浏览文件 @
8593c4d8
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_
#define MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_
#include <memory>
#include "utils/log_adapter.h"
namespace
mindspore
{
class
MpiConfig
{
public:
~
MpiConfig
()
=
default
;
MpiConfig
(
const
MpiConfig
&
)
=
delete
;
MpiConfig
&
operator
=
(
const
MpiConfig
&
)
=
delete
;
static
std
::
shared_ptr
<
MpiConfig
>
GetInstance
();
void
set_enable_mpi
(
bool
flag
)
{
enable_mpi_
=
flag
;
}
bool
enable_mpi
()
const
{
return
enable_mpi_
;
}
private:
MpiConfig
()
:
enable_mpi_
(
false
)
{}
static
std
::
shared_ptr
<
MpiConfig
>
instance_
;
bool
enable_mpi_
;
};
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_
mindspore/context.py
浏览文件 @
8593c4d8
...
...
@@ -25,6 +25,7 @@ from mindspore._c_expression import MSContext
from
mindspore._checkparam
import
args_type_check
from
mindspore.parallel._auto_parallel_context
import
_set_auto_parallel_context
,
_get_auto_parallel_context
,
\
_reset_auto_parallel_context
from
mindspore.parallel.mpi._mpi_config
import
_set_mpi_config
,
_get_mpi_config
__all__
=
[
'GRAPH_MODE'
,
'PYNATIVE_MODE'
,
'set_context'
,
'get_context'
,
'set_auto_parallel_context'
,
'get_auto_parallel_context'
,
'reset_auto_parallel_context'
]
...
...
@@ -566,3 +567,40 @@ def get_context(attr_key):
if
not
hasattr
(
_context
(),
attr_key
):
raise
ValueError
(
"Get context keyword %s is not recognized!"
%
attr_key
)
return
getattr
(
_context
(),
attr_key
)
@
args_type_check
(
enable_mpi
=
bool
)
def
set_mpi_config
(
**
kwargs
):
"""
Sets mpi config for running environment.
mpi config should be configured before running your program. If there is no configuration,
mpi moudle will be disabled by default.
Note:
Attribute name is required for setting attributes.
Args:
enable_mpi (bool): Whether to enable mpi. Default: False.
Raises:
ValueError: If input key is not an attribute in mpi config.
Examples:
>>> mpiconfig.set_mpi_config(enable_mpi=True)
"""
_set_mpi_config
(
**
kwargs
)
def
get_mpi_config
(
attr_key
):
"""
Gets mpi config attribute value according to the input key.
Args:
attr_key (str): The key of the attribute.
Returns:
Object, The value of given attribute key.
Raises:
ValueError: If input key is not an attribute in context.
"""
return
_get_mpi_config
(
attr_key
)
mindspore/parallel/mpi/__init__.py
0 → 100644
浏览文件 @
8593c4d8
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
mindspore/parallel/mpi/_mpi_config.py
0 → 100644
浏览文件 @
8593c4d8
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
The MPI config, used to configure the MPI environment.
"""
import
threading
from
mindspore._c_expression
import
MpiConfig
from
mindspore._checkparam
import
args_type_check
class
_MpiConfig
:
"""
_MpiConfig is the config tool for controlling MPI
Note:
Create a config through instantiating MpiConfig object is not recommended.
should use MpiConfig() to get the config since MpiConfig is singleton.
"""
_instance
=
None
_instance_lock
=
threading
.
Lock
()
def
__init__
(
self
):
self
.
_mpiconfig_handle
=
MpiConfig
.
get_instance
()
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
.
_instance
is
None
:
cls
.
_instance_lock
.
acquire
()
cls
.
_instance
=
object
.
__new__
(
cls
)
cls
.
_instance_lock
.
release
()
return
cls
.
_instance
def
__getattribute__
(
self
,
attr
):
value
=
object
.
__getattribute__
(
self
,
attr
)
if
attr
==
"_mpiconfig_handle"
and
value
is
None
:
raise
ValueError
(
"mpiconfig handle is none in MpiConfig!!!"
)
return
value
@
property
def
enable_mpi
(
self
):
return
self
.
_mpiconfig_handle
.
get_enable_mpi
()
@
enable_mpi
.
setter
def
enable_mpi
(
self
,
enable_mpi
):
self
.
_mpiconfig_handle
.
set_enable_mpi
(
enable_mpi
)
_k_mpi_config
=
None
def
_mpi_config
():
"""
Get the global mpi config, if mpi config is not created, create a new one.
Returns:
_MpiConfig, the global mpi config.
"""
global
_k_mpi_config
if
_k_mpi_config
is
None
:
_k_mpi_config
=
_MpiConfig
()
return
_k_mpi_config
@
args_type_check
(
enable_mpi
=
bool
)
def
_set_mpi_config
(
**
kwargs
):
"""
Sets mpi config for running environment.
mpi config should be configured before running your program. If there is no configuration,
mpi moudle will be disabled by default.
Note:
Attribute name is required for setting attributes.
Args:
enable_mpi (bool): Whether to enable mpi. Default: False.
Raises:
ValueError: If input key is not an attribute in mpi config.
Examples:
>>> mpiconfig.set_mpi_config(enable_mpi=True)
"""
for
key
,
value
in
kwargs
.
items
():
if
not
hasattr
(
_mpi_config
(),
key
):
raise
ValueError
(
"Set mpi config keyword %s is not recognized!"
%
key
)
setattr
(
_mpi_config
(),
key
,
value
)
def
_get_mpi_config
(
attr_key
):
"""
Gets mpi config attribute value according to the input key.
Args:
attr_key (str): The key of the attribute.
Returns:
Object, The value of given attribute key.
Raises:
ValueError: If input key is not an attribute in context.
"""
if
not
hasattr
(
_mpi_config
(),
attr_key
):
raise
ValueError
(
"Get context keyword %s is not recognized!"
%
attr_key
)
return
getattr
(
_mpi_config
(),
attr_key
)
tests/st/ops/cpu/test_reduce_scatter.py
浏览文件 @
8593c4d8
...
...
@@ -23,9 +23,10 @@ from mindspore.common import dtype as mstype
from
mindspore.ops
import
operations
as
P
import
mindspore._ms_mpi
as
mpi
# run comand:
# mpirun -np 3 python test_reduce_scatter.py
# mpirun -
output-filename log -merge-stderr-to-stdout -
np 3 python test_reduce_scatter.py
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
'CPU'
)
context
.
set_mpi_config
(
enable_mpi
=
True
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
...
...
@@ -46,14 +47,19 @@ class AllGatherNet(nn.Cell):
return
self
.
hostallgather
(
x
)
def
test_net_reduce_scatter
():
x
=
np
.
ones
(
12
).
astype
(
np
.
float32
)
*
0.1
x
=
np
.
arange
(
12
).
astype
(
np
.
float32
)
*
0.1
reducescatter
=
Net
()
rankid
=
mpi
.
get_rank_id
()
print
(
"self rankid:"
,
rankid
)
output
=
reducescatter
(
Tensor
(
x
,
mstype
.
float32
))
print
(
"output:
\n
"
,
output
)
expect_result
=
np
.
ones
(
4
).
astype
(
np
.
float32
)
*
0.3
if
rankid
==
0
:
expect_result
=
np
.
arange
(
4
).
astype
(
np
.
float32
)
*
0.3
if
rankid
==
1
:
expect_result
=
np
.
arange
(
4
,
8
).
astype
(
np
.
float32
)
*
0.3
if
rankid
==
2
:
expect_result
=
np
.
arange
(
8
,
12
).
astype
(
np
.
float32
)
*
0.3
diff
=
abs
(
output
.
asnumpy
()
-
expect_result
)
error
=
np
.
ones
(
shape
=
expect_result
.
shape
)
*
1.0e-6
assert
np
.
all
(
diff
<
error
)
...
...
@@ -61,7 +67,7 @@ def test_net_reduce_scatter():
allgather
=
AllGatherNet
()
allgather_output
=
allgather
(
output
)
print
(
"allgather result:
\n
"
,
allgather_output
)
expect_allgather_result
=
np
.
ones
(
12
).
astype
(
np
.
float32
)
*
0.3
expect_allgather_result
=
np
.
arange
(
12
).
astype
(
np
.
float32
)
*
0.3
diff
=
abs
(
allgather_output
.
asnumpy
()
-
expect_allgather_result
)
error
=
np
.
ones
(
shape
=
expect_allgather_result
.
shape
)
*
1.0e-6
assert
np
.
all
(
diff
<
error
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录