Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2eb739de
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2eb739de
编写于
6月 23, 2020
作者:
Y
Yi Huaijie
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change HostAllGather and HostReduceScatter to internal interface
上级
5b14292f
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
19 addition
and
203 deletion
+19
-203
mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.h
mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.h
+1
-1
mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.h
mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.h
+1
-1
mindspore/ccsrc/parallel/ops_info/ops_utils.h
mindspore/ccsrc/parallel/ops_info/ops_utils.h
+1
-1
mindspore/ccsrc/transform/convert.cc
mindspore/ccsrc/transform/convert.cc
+0
-2
mindspore/ops/_grad/grad_comm_ops.py
mindspore/ops/_grad/grad_comm_ops.py
+8
-8
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+1
-3
mindspore/ops/operations/_grad_ops.py
mindspore/ops/operations/_grad_ops.py
+3
-3
mindspore/ops/operations/comm_ops.py
mindspore/ops/operations/comm_ops.py
+4
-47
tests/st/ops/cpu/test_reduce_scatter.py
tests/st/ops/cpu/test_reduce_scatter.py
+0
-76
tests/ut/python/communication/test_comm.py
tests/ut/python/communication/test_comm.py
+0
-61
未找到文件。
mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.h
浏览文件 @
2eb739de
...
@@ -36,7 +36,7 @@ class AllGatherCPUKernel : public CPUKernel {
...
@@ -36,7 +36,7 @@ class AllGatherCPUKernel : public CPUKernel {
std
::
vector
<
int
>
ranks_group_
;
std
::
vector
<
int
>
ranks_group_
;
};
};
MS_REG_CPU_KERNEL
(
HostAllGather
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
MS_REG_CPU_KERNEL
(
_
HostAllGather
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
AllGatherCPUKernel
);
AllGatherCPUKernel
);
}
// namespace kernel
}
// namespace kernel
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.h
浏览文件 @
2eb739de
...
@@ -37,7 +37,7 @@ class ReduceScatterCPUKernel : public CPUKernel {
...
@@ -37,7 +37,7 @@ class ReduceScatterCPUKernel : public CPUKernel {
std
::
vector
<
int
>
ranks_group_
;
std
::
vector
<
int
>
ranks_group_
;
};
};
MS_REG_CPU_KERNEL
(
HostReduceScatter
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
MS_REG_CPU_KERNEL
(
_
HostReduceScatter
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
ReduceScatterCPUKernel
);
ReduceScatterCPUKernel
);
}
// namespace kernel
}
// namespace kernel
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/parallel/ops_info/ops_utils.h
浏览文件 @
2eb739de
...
@@ -145,7 +145,7 @@ constexpr char MIRROR_OPERATOR[] = "_MirrorOperator";
...
@@ -145,7 +145,7 @@ constexpr char MIRROR_OPERATOR[] = "_MirrorOperator";
constexpr
char
STRIDED_SLICE
[]
=
"StridedSlice"
;
constexpr
char
STRIDED_SLICE
[]
=
"StridedSlice"
;
constexpr
char
ALL_GATHER
[]
=
"AllGather"
;
constexpr
char
ALL_GATHER
[]
=
"AllGather"
;
constexpr
char
REDUCE_SCATTER
[]
=
"ReduceScatter"
;
constexpr
char
REDUCE_SCATTER
[]
=
"ReduceScatter"
;
constexpr
char
HOST_REDUCE_SCATTER
[]
=
"HostReduceScatter"
;
constexpr
char
HOST_REDUCE_SCATTER
[]
=
"
_
HostReduceScatter"
;
constexpr
char
EMBEDDING_LOOKUP
[]
=
"EmbeddingLookup"
;
constexpr
char
EMBEDDING_LOOKUP
[]
=
"EmbeddingLookup"
;
constexpr
char
CONCAT
[]
=
"Concat"
;
constexpr
char
CONCAT
[]
=
"Concat"
;
constexpr
char
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS
[]
=
"SoftmaxCrossEntropyWithLogits"
;
constexpr
char
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS
[]
=
"SoftmaxCrossEntropyWithLogits"
;
...
...
mindspore/ccsrc/transform/convert.cc
浏览文件 @
2eb739de
...
@@ -55,9 +55,7 @@ const char kNameSimpleMeanGrad[] = "SimpleMeanGrad";
...
@@ -55,9 +55,7 @@ const char kNameSimpleMeanGrad[] = "SimpleMeanGrad";
const
char
kNameAllReduce
[]
=
"AllReduce"
;
const
char
kNameAllReduce
[]
=
"AllReduce"
;
const
char
kNameBroadcast
[]
=
"Broadcast"
;
const
char
kNameBroadcast
[]
=
"Broadcast"
;
const
char
kNameAllgather
[]
=
"AllGather"
;
const
char
kNameAllgather
[]
=
"AllGather"
;
const
char
kNameHostAllgather
[]
=
"HostAllGather"
;
const
char
kNameReduceScatter
[]
=
"ReduceScatter"
;
const
char
kNameReduceScatter
[]
=
"ReduceScatter"
;
const
char
kNameHostReduceScatter
[]
=
"HostReduceScatter"
;
const
char
kNameReduceSum
[]
=
"ReduceSum"
;
const
char
kNameReduceSum
[]
=
"ReduceSum"
;
const
char
kNameIsFinite
[]
=
"isFinite"
;
const
char
kNameIsFinite
[]
=
"isFinite"
;
const
char
kNameReciprocal
[]
=
"Reciprocal"
;
const
char
kNameReciprocal
[]
=
"Reciprocal"
;
...
...
mindspore/ops/_grad/grad_comm_ops.py
浏览文件 @
2eb739de
...
@@ -18,9 +18,9 @@ import mindspore.common.dtype as mstype
...
@@ -18,9 +18,9 @@ import mindspore.common.dtype as mstype
from
mindspore.ops
import
functional
as
F
from
mindspore.ops
import
functional
as
F
from
..
import
operations
as
P
from
..
import
operations
as
P
from
..composite.multitype_ops.zeros_like_impl
import
zeros_like
from
..composite.multitype_ops.zeros_like_impl
import
zeros_like
from
..operations.comm_ops
import
(
AllGather
,
HostAllGather
,
AllReduce
,
_AlltoAll
,
Broadcast
,
from
..operations.comm_ops
import
(
AllGather
,
_
HostAllGather
,
AllReduce
,
_AlltoAll
,
Broadcast
,
_GetTensorSlice
,
_MirrorOperator
,
ReduceOp
,
_GetTensorSlice
,
_MirrorOperator
,
ReduceOp
,
ReduceScatter
,
HostReduceScatter
,
_VirtualDiv
)
ReduceScatter
,
_
HostReduceScatter
,
_VirtualDiv
)
from
.grad_base
import
bprop_getters
from
.grad_base
import
bprop_getters
...
@@ -93,10 +93,10 @@ def get_bprop_all_gather(self):
...
@@ -93,10 +93,10 @@ def get_bprop_all_gather(self):
return
bprop
return
bprop
@
bprop_getters
.
register
(
HostAllGather
)
@
bprop_getters
.
register
(
_
HostAllGather
)
def
get_bprop_host_all_gather
(
self
):
def
get_bprop_host_all_gather
(
self
):
"""Generate bprop for HostAllGather"""
"""Generate bprop for
_
HostAllGather"""
host_all_gather_grad
=
HostReduceScatter
(
ReduceOp
.
SUM
,
self
.
group
)
host_all_gather_grad
=
_
HostReduceScatter
(
ReduceOp
.
SUM
,
self
.
group
)
if
self
.
instance_name
:
if
self
.
instance_name
:
instance_name
=
"grad"
+
self
.
instance_name
instance_name
=
"grad"
+
self
.
instance_name
host_all_gather_grad
.
set_prim_instance_name
(
instance_name
)
host_all_gather_grad
.
set_prim_instance_name
(
instance_name
)
...
@@ -126,10 +126,10 @@ def get_bprop_reduce_scatter(self):
...
@@ -126,10 +126,10 @@ def get_bprop_reduce_scatter(self):
return
bprop
return
bprop
@
bprop_getters
.
register
(
HostReduceScatter
)
@
bprop_getters
.
register
(
_
HostReduceScatter
)
def
get_bprop_host_reduce_scatter
(
self
):
def
get_bprop_host_reduce_scatter
(
self
):
"""Generate bprop for HostReduceScatter"""
"""Generate bprop for
_
HostReduceScatter"""
host_reduce_scatter_grad
=
HostAllGather
(
self
.
group
)
host_reduce_scatter_grad
=
_
HostAllGather
(
self
.
group
)
if
self
.
instance_name
:
if
self
.
instance_name
:
instance_name
=
"grad"
+
self
.
instance_name
instance_name
=
"grad"
+
self
.
instance_name
host_reduce_scatter_grad
.
set_prim_instance_name
(
instance_name
)
host_reduce_scatter_grad
.
set_prim_instance_name
(
instance_name
)
...
...
mindspore/ops/operations/__init__.py
浏览文件 @
2eb739de
...
@@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
...
@@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
from
.comm_ops
import
(
AllGather
,
AllReduce
,
_AlltoAll
,
ReduceScatter
,
Broadcast
,
from
.comm_ops
import
(
AllGather
,
AllReduce
,
_AlltoAll
,
ReduceScatter
,
Broadcast
,
_MirrorOperator
,
ReduceOp
,
_VirtualDataset
,
_MirrorOperator
,
ReduceOp
,
_VirtualDataset
,
_VirtualDiv
,
_GetTensorSlice
,
_VirtualDiv
,
_GetTensorSlice
,
HostAllGather
,
HostReduceScatter
)
_HostAllGather
,
_
HostReduceScatter
)
from
.debug_ops
import
(
ImageSummary
,
InsertGradientOf
,
HookBackward
,
ScalarSummary
,
from
.debug_ops
import
(
ImageSummary
,
InsertGradientOf
,
HookBackward
,
ScalarSummary
,
TensorSummary
,
HistogramSummary
,
Debug
,
Print
)
TensorSummary
,
HistogramSummary
,
Debug
,
Print
)
from
.control_ops
import
ControlDepend
,
GeSwitch
,
Merge
from
.control_ops
import
ControlDepend
,
GeSwitch
,
Merge
...
@@ -244,10 +244,8 @@ __all__ = [
...
@@ -244,10 +244,8 @@ __all__ = [
'UnsortedSegmentSum'
,
'UnsortedSegmentSum'
,
'UnsortedSegmentMin'
,
'UnsortedSegmentMin'
,
"AllGather"
,
"AllGather"
,
"HostAllGather"
,
"AllReduce"
,
"AllReduce"
,
"ReduceScatter"
,
"ReduceScatter"
,
"HostReduceScatter"
,
"Broadcast"
,
"Broadcast"
,
"ReduceOp"
,
"ReduceOp"
,
'ScalarCast'
,
'ScalarCast'
,
...
...
mindspore/ops/operations/_grad_ops.py
浏览文件 @
2eb739de
...
@@ -1166,7 +1166,7 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer):
...
@@ -1166,7 +1166,7 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer):
Perform the gradient for the communication part of EmbeddingLookup operator.
Perform the gradient for the communication part of EmbeddingLookup operator.
This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking,
This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking,
this primitive is implemented by StridedSlice --> HostAllGather --> Concat. This primitive runs on host.
this primitive is implemented by StridedSlice -->
_
HostAllGather --> Concat. This primitive runs on host.
"""
"""
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -1177,8 +1177,8 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer):
...
@@ -1177,8 +1177,8 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer):
"""
"""
This primitive is implemented by three steps:
This primitive is implemented by three steps:
1) Split the 'dy' along dimension 0 into 'split_num' parts.
1) Split the 'dy' along dimension 0 into 'split_num' parts.
2) For each part, perform HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host.
2) For each part, perform
_
HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host.
3) After HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them
3) After
_
HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them
along dimension 0.
along dimension 0.
The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8
The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8
...
...
mindspore/ops/operations/comm_ops.py
浏览文件 @
2eb739de
...
@@ -176,13 +176,13 @@ class AllGather(PrimitiveWithInfer):
...
@@ -176,13 +176,13 @@ class AllGather(PrimitiveWithInfer):
raise
NotImplementedError
raise
NotImplementedError
class
HostAllGather
(
PrimitiveWithInfer
):
class
_
HostAllGather
(
PrimitiveWithInfer
):
"""
"""
Gathers tensors from the specified communication group on host.
Gathers tensors from the specified communication group on host.
Note:
Note:
Tensor must have the same shape and format in all processes participating in the collective.
Tensor must have the same shape and format in all processes participating in the collective.
HostAllGather is a host-side operator, it depends on OpenMPI and must use build option -M on
_
HostAllGather is a host-side operator, it depends on OpenMPI and must use build option -M on
to enable it. Using mpirun command to run it:
to enable it. Using mpirun command to run it:
mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_all_gather.py
mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_all_gather.py
...
@@ -199,27 +199,6 @@ class HostAllGather(PrimitiveWithInfer):
...
@@ -199,27 +199,6 @@ class HostAllGather(PrimitiveWithInfer):
Outputs:
Outputs:
Tensor. If the number of devices in the group is N,
Tensor. If the number of devices in the group is N,
then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`.
then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`.
Examples:
>>> import mindspore.nn as nn
>>> import mindspore.context as context
>>> import mindspore.ops.operations as P
>>> from mindspore import Tensor
>>>
>>> context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
>>> context.set_mpi_config(enable_mpi=True)
>>>
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.hostallgather = P.HostAllGather(group=(0, 1, 2, 3))
>>>
>>> def construct(self, x):
>>> return self.hostallgather(x)
>>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
"""
"""
@
prim_attr_register
@
prim_attr_register
...
@@ -308,13 +287,13 @@ class ReduceScatter(PrimitiveWithInfer):
...
@@ -308,13 +287,13 @@ class ReduceScatter(PrimitiveWithInfer):
raise
NotImplementedError
raise
NotImplementedError
class
HostReduceScatter
(
PrimitiveWithInfer
):
class
_
HostReduceScatter
(
PrimitiveWithInfer
):
"""
"""
Reduces and scatters tensors from the specified communication group on host.
Reduces and scatters tensors from the specified communication group on host.
Note:
Note:
Tensor must have the same shape and format in all processes participating in the collective.
Tensor must have the same shape and format in all processes participating in the collective.
HostReduceScatter is a host-side operator, it depends on OpenMPI and must use build option
_
HostReduceScatter is a host-side operator, it depends on OpenMPI and must use build option
-M on to enable it. Using mpirun command to run it:
-M on to enable it. Using mpirun command to run it:
mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_reduce_scatter.py
mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_reduce_scatter.py
...
@@ -328,28 +307,6 @@ class HostReduceScatter(PrimitiveWithInfer):
...
@@ -328,28 +307,6 @@ class HostReduceScatter(PrimitiveWithInfer):
or elements of group are not int.
or elements of group are not int.
ValueError: If the first dimension of input can not be divided by group size,
ValueError: If the first dimension of input can not be divided by group size,
or group is not set, or rank_id not in [0, 7].
or group is not set, or rank_id not in [0, 7].
Examples:
>>> import mindspore.nn as nn
>>> import mindspore.context as context
>>> import mindspore.ops.operations as P
>>> from mindspore import Tensor
>>> from mindspore.ops.operations.comm_ops import ReduceOp
>>>
>>> context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
>>> context.set_mpi_config(enable_mpi=True)
>>>
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.hostreducescatter = P.HostReduceScatter(ReduceOp.SUM, group=[0, 1, 2, 3])
>>>
>>> def construct(self, x):
>>> return self.hostreducescatter(x)
>>>
>>> input_ = Tensor(np.ones([8, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
"""
"""
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
op
=
ReduceOp
.
SUM
,
group
=
None
):
def
__init__
(
self
,
op
=
ReduceOp
.
SUM
,
group
=
None
):
...
...
tests/st/ops/cpu/test_reduce_scatter.py
已删除
100644 → 0
浏览文件 @
5b14292f
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.common
import
dtype
as
mstype
from
mindspore.ops
import
operations
as
P
import
mindspore._ms_mpi
as
mpi
# run comand:
# mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_reduce_scatter.py
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
'CPU'
)
context
.
set_mpi_config
(
enable_mpi
=
True
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
op
=
"sum"
self
.
reducescatter
=
P
.
HostReduceScatter
(
op
=
self
.
op
,
group
=
[
0
,
1
,
2
])
def
construct
(
self
,
x
):
return
self
.
reducescatter
(
x
)
class
AllGatherNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
AllGatherNet
,
self
).
__init__
()
self
.
hostallgather
=
P
.
HostAllGather
(
group
=
(
0
,
1
,
2
))
def
construct
(
self
,
x
):
return
self
.
hostallgather
(
x
)
def
test_net_reduce_scatter
():
x
=
np
.
arange
(
12
).
astype
(
np
.
float32
)
*
0.1
reducescatter
=
Net
()
rankid
=
mpi
.
get_rank_id
()
print
(
"self rankid:"
,
rankid
)
output
=
reducescatter
(
Tensor
(
x
,
mstype
.
float32
))
print
(
"output:
\n
"
,
output
)
if
rankid
==
0
:
expect_result
=
np
.
arange
(
4
).
astype
(
np
.
float32
)
*
0.3
if
rankid
==
1
:
expect_result
=
np
.
arange
(
4
,
8
).
astype
(
np
.
float32
)
*
0.3
if
rankid
==
2
:
expect_result
=
np
.
arange
(
8
,
12
).
astype
(
np
.
float32
)
*
0.3
diff
=
abs
(
output
.
asnumpy
()
-
expect_result
)
error
=
np
.
ones
(
shape
=
expect_result
.
shape
)
*
1.0e-6
assert
np
.
all
(
diff
<
error
)
allgather
=
AllGatherNet
()
allgather_output
=
allgather
(
output
)
print
(
"allgather result:
\n
"
,
allgather_output
)
expect_allgather_result
=
np
.
arange
(
12
).
astype
(
np
.
float32
)
*
0.3
diff
=
abs
(
allgather_output
.
asnumpy
()
-
expect_allgather_result
)
error
=
np
.
ones
(
shape
=
expect_allgather_result
.
shape
)
*
1.0e-6
assert
np
.
all
(
diff
<
error
)
if
__name__
==
'__main__'
:
test_net_reduce_scatter
()
tests/ut/python/communication/test_comm.py
浏览文件 @
2eb739de
...
@@ -26,7 +26,6 @@ from mindspore.nn import Momentum
...
@@ -26,7 +26,6 @@ from mindspore.nn import Momentum
from
mindspore.nn
import
ReLU
from
mindspore.nn
import
ReLU
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.ops.operations.comm_ops
import
AllReduce
,
AllGather
,
_AlltoAll
,
ReduceOp
,
ReduceScatter
from
mindspore.ops.operations.comm_ops
import
AllReduce
,
AllGather
,
_AlltoAll
,
ReduceOp
,
ReduceScatter
from
mindspore.ops.operations.comm_ops
import
HostAllGather
,
HostReduceScatter
from
mindspore.ops.operations.comm_ops
import
Broadcast
from
mindspore.ops.operations.comm_ops
import
Broadcast
# pylint: disable=W0212
# pylint: disable=W0212
...
@@ -87,21 +86,6 @@ class AllGatherNet(nn.Cell):
...
@@ -87,21 +86,6 @@ class AllGatherNet(nn.Cell):
return
self
.
relu
(
x
)
return
self
.
relu
(
x
)
class
HostAllGatherNet
(
nn
.
Cell
):
"""HostAllGatherNet definition"""
def
__init__
(
self
,
input_channel
,
output_channel
):
super
(
HostAllGatherNet
,
self
).
__init__
()
self
.
dense
=
Dense
(
input_channel
,
output_channel
)
self
.
hostallgather
=
HostAllGather
((
0
,
1
))
self
.
relu
=
ReLU
()
def
construct
(
self
,
x
):
x
=
self
.
dense
(
x
)
x
=
self
.
hostallgather
(
x
)
return
self
.
relu
(
x
)
class
ReduceScatterNet
(
nn
.
Cell
):
class
ReduceScatterNet
(
nn
.
Cell
):
"""ReduceScatterNet definition"""
"""ReduceScatterNet definition"""
...
@@ -117,21 +101,6 @@ class ReduceScatterNet(nn.Cell):
...
@@ -117,21 +101,6 @@ class ReduceScatterNet(nn.Cell):
return
self
.
relu
(
x
)
return
self
.
relu
(
x
)
class
HostReduceScatterNet
(
nn
.
Cell
):
"""HostReduceScatterNet definition"""
def
__init__
(
self
,
input_channel
,
out_channel
,
op
):
super
(
HostReduceScatterNet
,
self
).
__init__
()
self
.
dense
=
Dense
(
input_channel
,
out_channel
)
self
.
hostreducescatter
=
HostReduceScatter
(
op
,
(
0
,
1
))
self
.
relu
=
ReLU
()
def
construct
(
self
,
x
):
x
=
self
.
dense
(
x
)
x
=
self
.
hostreducescatter
(
x
)
return
self
.
relu
(
x
)
class
AlltoAllNet
(
nn
.
Cell
):
class
AlltoAllNet
(
nn
.
Cell
):
"""AlltoAllNet definition"""
"""AlltoAllNet definition"""
...
@@ -185,21 +154,6 @@ def test_allgather():
...
@@ -185,21 +154,6 @@ def test_allgather():
_executor
.
compile
(
network
,
input_tensor
,
label_tensor
)
_executor
.
compile
(
network
,
input_tensor
,
label_tensor
)
def
test_hostallgather
():
"""test_hostallgather"""
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
input_tensor
=
Tensor
(
np
.
array
([[
1.2
,
2.1
],
[
2.2
,
3.2
]],
dtype
=
np
.
float32
))
label_tensor
=
Tensor
(
np
.
array
([[
1.2
],
[
2.2
],
[
3.2
],
[
4.2
]],
dtype
=
np
.
float32
))
network
=
HostAllGatherNet
(
2
,
1
)
loss_fn
=
nn
.
SoftmaxCrossEntropyWithLogits
()
optimizer
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
network
.
get_parameters
()),
learning_rate
=
0.1
,
momentum
=
0.9
)
network
=
WithLossCell
(
network
,
loss_fn
)
network
=
TrainOneStepCell
(
network
,
optimizer
)
_executor
.
compile
(
network
,
input_tensor
,
label_tensor
)
def
run_reducescatter
(
op
):
def
run_reducescatter
(
op
):
"""run_reducescatter"""
"""run_reducescatter"""
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
...
@@ -221,21 +175,6 @@ def test_reducescatter():
...
@@ -221,21 +175,6 @@ def test_reducescatter():
run_reducescatter
(
ReduceOp
.
SUM
)
run_reducescatter
(
ReduceOp
.
SUM
)
def
test_hostreducescatter
():
"""test_hostreducescatter"""
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
input_tensor
=
Tensor
(
np
.
array
([[
1.2
,
2.1
],
[
2.2
,
3.2
]],
dtype
=
np
.
float32
))
label_tensor
=
Tensor
(
np
.
array
([[
1.2
]],
dtype
=
np
.
float32
))
network
=
HostReduceScatterNet
(
2
,
1
,
ReduceOp
.
SUM
)
loss_fn
=
nn
.
SoftmaxCrossEntropyWithLogits
()
optimizer
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
network
.
get_parameters
()),
learning_rate
=
0.1
,
momentum
=
0.9
)
network
=
WithLossCell
(
network
,
loss_fn
)
network
=
TrainOneStepCell
(
network
,
optimizer
)
_executor
.
compile
(
network
,
input_tensor
,
label_tensor
)
def
test_broadcast
():
def
test_broadcast
():
"""test_broadcast"""
"""test_broadcast"""
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录