Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2eedd321
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2eedd321
编写于
5月 27, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 27, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1472 add operator HostAllGather and HostReduceScatter
Merge pull request !1472 from yihuaijie/master
上级
b94949ea
2f8e7ff6
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
232 addition
and
3 deletion
+232
-3
mindspore/ccsrc/transform/convert.cc
mindspore/ccsrc/transform/convert.cc
+2
-0
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+2
-0
mindspore/ops/_grad/grad_comm_ops.py
mindspore/ops/_grad/grad_comm_ops.py
+35
-2
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+4
-1
mindspore/ops/operations/comm_ops.py
mindspore/ops/operations/comm_ops.py
+128
-0
tests/ut/python/communication/test_comm.py
tests/ut/python/communication/test_comm.py
+61
-0
未找到文件。
mindspore/ccsrc/transform/convert.cc
浏览文件 @
2eedd321
...
...
@@ -55,7 +55,9 @@ const char kNameSimpleMeanGrad[] = "SimpleMeanGrad";
const
char
kNameAllReduce
[]
=
"AllReduce"
;
const
char
kNameBroadcast
[]
=
"Broadcast"
;
const
char
kNameAllgather
[]
=
"AllGather"
;
const
char
kNameHostAllgather
[]
=
"HostAllGather"
;
const
char
kNameReduceScatter
[]
=
"ReduceScatter"
;
const
char
kNameHostReduceScatter
[]
=
"HostReduceScatter"
;
const
char
kNameReduceSum
[]
=
"ReduceSum"
;
const
char
kNameIsFinite
[]
=
"isFinite"
;
const
char
kNameReciprocal
[]
=
"Reciprocal"
;
...
...
mindspore/ccsrc/utils/utils.h
浏览文件 @
2eedd321
...
...
@@ -45,8 +45,10 @@ constexpr auto kAtomicAddrCleanOpName = "AtomicAddrClean";
constexpr
auto
kGetNextOpName
=
"GetNext"
;
constexpr
auto
kAllReduceOpName
=
"AllReduce"
;
constexpr
auto
kAllGatherOpName
=
"AllGather"
;
constexpr
auto
kHostAllGatherOpName
=
"HostAllGather"
;
constexpr
auto
kBroadcastOpName
=
"Broadcast"
;
constexpr
auto
kReduceScatterOpName
=
"ReduceScatter"
;
constexpr
auto
kHostReduceScatterOpName
=
"HostReduceScatter"
;
constexpr
auto
kMemCpyAsyncOpName
=
"memcpy_async"
;
constexpr
auto
kTopKOpName
=
"TopK"
;
constexpr
auto
kExtractImagePatchesOpName
=
"ExtractImagePatches"
;
...
...
mindspore/ops/_grad/grad_comm_ops.py
浏览文件 @
2eedd321
...
...
@@ -18,9 +18,9 @@ import mindspore.common.dtype as mstype
from
mindspore.ops
import
functional
as
F
from
..
import
operations
as
P
from
..composite.multitype_ops.zeros_like_impl
import
zeros_like
from
..operations.comm_ops
import
(
AllGather
,
AllReduce
,
_AlltoAll
,
Broadcast
,
from
..operations.comm_ops
import
(
AllGather
,
HostAllGather
,
AllReduce
,
_AlltoAll
,
Broadcast
,
_GetTensorSlice
,
_MirrorOperator
,
ReduceOp
,
ReduceScatter
,
_VirtualDiv
)
ReduceScatter
,
HostReduceScatter
,
_VirtualDiv
)
from
.grad_base
import
bprop_getters
...
...
@@ -79,6 +79,21 @@ def get_bprop_all_gather(self):
return
bprop
@
bprop_getters
.
register
(
HostAllGather
)
def
get_bprop_host_all_gather
(
self
):
"""Generate bprop for HostAllGather"""
host_all_gather_grad
=
HostReduceScatter
(
ReduceOp
.
SUM
,
self
.
group
)
if
self
.
instance_name
:
instance_name
=
"grad"
+
self
.
instance_name
host_all_gather_grad
.
set_prim_instance_name
(
instance_name
)
def
bprop
(
x
,
out
,
dout
):
dx
=
host_all_gather_grad
(
dout
)
return
(
dx
,)
return
bprop
@
bprop_getters
.
register
(
ReduceScatter
)
def
get_bprop_reduce_scatter
(
self
):
"""Generate bprop for ReduceScatter"""
...
...
@@ -97,6 +112,24 @@ def get_bprop_reduce_scatter(self):
return
bprop
@
bprop_getters
.
register
(
HostReduceScatter
)
def
get_bprop_host_reduce_scatter
(
self
):
"""Generate bprop for HostReduceScatter"""
host_reduce_scatter_grad
=
HostAllGather
(
self
.
group
)
if
self
.
instance_name
:
instance_name
=
"grad"
+
self
.
instance_name
host_reduce_scatter_grad
.
set_prim_instance_name
(
instance_name
)
if
self
.
op
!=
ReduceOp
.
SUM
:
raise
RuntimeError
(
"The hostreducescatter bprop only support ReduceOp.SUM until now."
)
def
bprop
(
x
,
out
,
dout
):
dx
=
host_reduce_scatter_grad
(
dout
)
return
(
dx
,)
return
bprop
@
bprop_getters
.
register
(
_AlltoAll
)
def
get_bprop_all_to_all
(
self
):
"""Generate bprop for AlltoAll."""
...
...
mindspore/ops/operations/__init__.py
浏览文件 @
2eedd321
...
...
@@ -33,7 +33,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
SpaceToBatchND
,
BatchToSpaceND
)
from
.comm_ops
import
(
AllGather
,
AllReduce
,
_AlltoAll
,
ReduceScatter
,
Broadcast
,
_MirrorOperator
,
ReduceOp
,
_VirtualDataset
,
_VirtualDiv
,
_GetTensorSlice
)
_VirtualDiv
,
_GetTensorSlice
,
HostAllGather
,
HostReduceScatter
)
from
.debug_ops
import
(
ImageSummary
,
InsertGradientOf
,
HookBackward
,
ScalarSummary
,
TensorSummary
,
HistogramSummary
,
Print
)
from
.control_ops
import
ControlDepend
,
GeSwitch
,
Merge
...
...
@@ -220,8 +221,10 @@ __all__ = [
'UnsortedSegmentSum'
,
'UnsortedSegmentMin'
,
"AllGather"
,
"HostAllGather"
,
"AllReduce"
,
"ReduceScatter"
,
"HostReduceScatter"
,
"Broadcast"
,
"ReduceOp"
,
'ScalarCast'
,
...
...
mindspore/ops/operations/comm_ops.py
浏览文件 @
2eedd321
...
...
@@ -169,6 +169,72 @@ class AllGather(PrimitiveWithInfer):
raise
NotImplementedError
class
HostAllGather
(
PrimitiveWithInfer
):
"""
Gathers tensors from the specified communication group on host.
Note:
Tensor must have the same shape and format in all processes participating in the collective.
Args:
group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on.
Raises:
TypeError: If group is not a list nor tuple, or elements of group are not int.
ValueError: If the local rank id of the calling process not in group,
or rank_id from group not in [0, 7].
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Outputs:
Tensor. If the number of devices in the group is N,
then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`.
Examples:
>>> from mindspore.communication import init
>>> import mindspore.ops.operations as P
>>> init('nccl')
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.hostallgather = P.HostAllGather(group=(0, 1, 2, 3))
>>>
>>> def construct(self, x):
>>> return self.hostallgather(x)
>>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
"""
@
prim_attr_register
def
__init__
(
self
,
group
=
None
):
if
group
is
None
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' group must be set."
)
validator
.
check_value_type
(
'group'
,
group
,
(
tuple
,
list
),
self
.
name
)
validator
.
check_integer
(
"group size"
,
len
(
group
),
2
,
Rel
.
GE
,
self
.
name
)
for
r
in
group
:
validator
.
check_int_range
(
"rank_id"
,
r
,
0
,
7
,
Rel
.
INC_BOTH
,
self
.
name
)
validator
.
check_value_type
(
"rank_id"
,
r
,
(
int
,),
self
.
name
)
self
.
group_size
=
len
(
group
)
self
.
rank
=
get_rank
()
validator
.
check
(
'rank'
,
self
.
rank
,
'group'
,
self
.
group
,
Rel
.
IN
,
self
.
name
)
self
.
add_prim_attr
(
'group'
,
group
)
def
infer_shape
(
self
,
x_shape
):
validator
.
check_integer
(
"x shape"
,
len
(
x_shape
),
0
,
Rel
.
GT
,
self
.
name
)
x_shape
[
0
]
=
x_shape
[
0
]
*
self
.
group_size
return
x_shape
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_tensor_type_same
({
'x'
:
x_dtype
},
target_dtypes
,
self
.
name
)
return
x_dtype
def
__call__
(
self
,
tensor
):
raise
NotImplementedError
class
ReduceScatter
(
PrimitiveWithInfer
):
"""
Reduces and scatters tensors from the specified communication group.
...
...
@@ -226,6 +292,68 @@ class ReduceScatter(PrimitiveWithInfer):
raise
NotImplementedError
class
HostReduceScatter
(
PrimitiveWithInfer
):
"""
Reduces and scatters tensors from the specified communication group on host.
Note:
Tensor must have the same shape and format in all processes participating in the collective.
Args:
op (str): Specifies an operation used for element-wise reductions,
like sum, max, avg. Default: ReduceOp.SUM.
group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on.
Raise:
TypeError: If op is not a string and group is not a list nor tuple,
or elements of group are not int.
ValueError: If the first dimension of input can not be divided by rank size,
or group is not set, or rank_id not in [1, 7].
Examples:
>>> from mindspore.communication import init
>>> import mindspore.ops.operations as P
>>> init('nccl')
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.hostreducescatter = P.HostReduceScatter(ReduceOp.SUM, group=[0, 1, 2, 3])
>>>
>>> def construct(self, x):
>>> return self.hostreducescatter(x)
>>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
"""
@
prim_attr_register
def
__init__
(
self
,
op
=
ReduceOp
.
SUM
,
group
=
None
):
if
group
is
None
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' group must be set."
)
validator
.
check_value_type
(
'op'
,
op
,
(
type
(
ReduceOp
.
SUM
),),
self
.
name
)
validator
.
check_value_type
(
'group'
,
group
,
(
tuple
,
list
),
self
.
name
)
validator
.
check_integer
(
"group size"
,
len
(
group
),
2
,
Rel
.
GE
,
self
.
name
)
for
r
in
group
:
validator
.
check_int_range
(
"rank_id"
,
r
,
0
,
7
,
Rel
.
INC_BOTH
,
self
.
name
)
validator
.
check_value_type
(
"rank_id"
,
r
,
(
int
,),
self
.
name
)
self
.
op
=
op
self
.
group_size
=
len
(
group
)
self
.
add_prim_attr
(
'group'
,
group
)
def
infer_shape
(
self
,
x_shape
):
if
x_shape
[
0
]
%
self
.
group_size
!=
0
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' the first dimension of x should be divided by group_size."
)
x_shape
[
0
]
=
int
(
x_shape
[
0
]
/
self
.
group_size
)
return
x_shape
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_tensor_type_same
({
'x'
:
x_dtype
},
target_dtypes
,
self
.
name
)
return
x_dtype
def
__call__
(
self
,
tensor
):
raise
NotImplementedError
class
Broadcast
(
PrimitiveWithInfer
):
"""
Broadcasts the tensor to the whole group.
...
...
tests/ut/python/communication/test_comm.py
浏览文件 @
2eedd321
...
...
@@ -26,6 +26,7 @@ from mindspore.nn import Momentum
from
mindspore.nn
import
ReLU
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.ops.operations.comm_ops
import
AllReduce
,
AllGather
,
_AlltoAll
,
ReduceOp
,
ReduceScatter
from
mindspore.ops.operations.comm_ops
import
HostAllGather
,
HostReduceScatter
from
mindspore.ops.operations.comm_ops
import
Broadcast
# pylint: disable=W0212
...
...
@@ -86,6 +87,21 @@ class AllGatherNet(nn.Cell):
return
self
.
relu
(
x
)
class
HostAllGatherNet
(
nn
.
Cell
):
"""HostAllGatherNet definition"""
def
__init__
(
self
,
input_channel
,
output_channel
):
super
(
HostAllGatherNet
,
self
).
__init__
()
self
.
dense
=
Dense
(
input_channel
,
output_channel
)
self
.
hostallgather
=
HostAllGather
((
0
,
1
))
self
.
relu
=
ReLU
()
def
construct
(
self
,
x
):
x
=
self
.
dense
(
x
)
x
=
self
.
hostallgather
(
x
)
return
self
.
relu
(
x
)
class
ReduceScatterNet
(
nn
.
Cell
):
"""ReduceScatterNet definition"""
...
...
@@ -101,6 +117,21 @@ class ReduceScatterNet(nn.Cell):
return
self
.
relu
(
x
)
class
HostReduceScatterNet
(
nn
.
Cell
):
"""HostReduceScatterNet definition"""
def
__init__
(
self
,
input_channel
,
out_channel
,
op
):
super
(
HostReduceScatterNet
,
self
).
__init__
()
self
.
dense
=
Dense
(
input_channel
,
out_channel
)
self
.
hostreducescatter
=
HostReduceScatter
(
op
,
(
0
,
1
))
self
.
relu
=
ReLU
()
def
construct
(
self
,
x
):
x
=
self
.
dense
(
x
)
x
=
self
.
hostreducescatter
(
x
)
return
self
.
relu
(
x
)
class
AlltoAllNet
(
nn
.
Cell
):
"""AlltoAllNet definition"""
...
...
@@ -154,6 +185,21 @@ def test_allgather():
_executor
.
compile
(
network
,
input_tensor
,
label_tensor
)
def
test_hostallgather
():
"""test_hostallgather"""
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
input_tensor
=
Tensor
(
np
.
array
([[
1.2
,
2.1
],
[
2.2
,
3.2
]],
dtype
=
np
.
float32
))
label_tensor
=
Tensor
(
np
.
array
([[
1.2
],
[
2.2
],
[
3.2
],
[
4.2
]],
dtype
=
np
.
float32
))
network
=
HostAllGatherNet
(
2
,
1
)
loss_fn
=
nn
.
SoftmaxCrossEntropyWithLogits
()
optimizer
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
network
.
get_parameters
()),
learning_rate
=
0.1
,
momentum
=
0.9
)
network
=
WithLossCell
(
network
,
loss_fn
)
network
=
TrainOneStepCell
(
network
,
optimizer
)
_executor
.
compile
(
network
,
input_tensor
,
label_tensor
)
def
run_reducescatter
(
op
):
"""run_reducescatter"""
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
...
...
@@ -175,6 +221,21 @@ def test_reducescatter():
run_reducescatter
(
ReduceOp
.
SUM
)
def
test_hostreducescatter
():
"""test_hostreducescatter"""
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
input_tensor
=
Tensor
(
np
.
array
([[
1.2
,
2.1
],
[
2.2
,
3.2
]],
dtype
=
np
.
float32
))
label_tensor
=
Tensor
(
np
.
array
([[
1.2
]],
dtype
=
np
.
float32
))
network
=
HostReduceScatterNet
(
2
,
1
,
ReduceOp
.
SUM
)
loss_fn
=
nn
.
SoftmaxCrossEntropyWithLogits
()
optimizer
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
network
.
get_parameters
()),
learning_rate
=
0.1
,
momentum
=
0.9
)
network
=
WithLossCell
(
network
,
loss_fn
)
network
=
TrainOneStepCell
(
network
,
optimizer
)
_executor
.
compile
(
network
,
input_tensor
,
label_tensor
)
def
test_broadcast
():
"""test_broadcast"""
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录