提交 2eb739de 编写于 作者: Y Yi Huaijie

change HostAllGather and HostReduceScatter to internal interface

上级 5b14292f
...@@ -36,7 +36,7 @@ class AllGatherCPUKernel : public CPUKernel { ...@@ -36,7 +36,7 @@ class AllGatherCPUKernel : public CPUKernel {
std::vector<int> ranks_group_; std::vector<int> ranks_group_;
}; };
MS_REG_CPU_KERNEL(HostAllGather, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_CPU_KERNEL(_HostAllGather, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
AllGatherCPUKernel); AllGatherCPUKernel);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
......
...@@ -37,7 +37,7 @@ class ReduceScatterCPUKernel : public CPUKernel { ...@@ -37,7 +37,7 @@ class ReduceScatterCPUKernel : public CPUKernel {
std::vector<int> ranks_group_; std::vector<int> ranks_group_;
}; };
MS_REG_CPU_KERNEL(HostReduceScatter, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_CPU_KERNEL(_HostReduceScatter, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReduceScatterCPUKernel); ReduceScatterCPUKernel);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
......
...@@ -145,7 +145,7 @@ constexpr char MIRROR_OPERATOR[] = "_MirrorOperator"; ...@@ -145,7 +145,7 @@ constexpr char MIRROR_OPERATOR[] = "_MirrorOperator";
constexpr char STRIDED_SLICE[] = "StridedSlice"; constexpr char STRIDED_SLICE[] = "StridedSlice";
constexpr char ALL_GATHER[] = "AllGather"; constexpr char ALL_GATHER[] = "AllGather";
constexpr char REDUCE_SCATTER[] = "ReduceScatter"; constexpr char REDUCE_SCATTER[] = "ReduceScatter";
constexpr char HOST_REDUCE_SCATTER[] = "HostReduceScatter"; constexpr char HOST_REDUCE_SCATTER[] = "_HostReduceScatter";
constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup"; constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup";
constexpr char CONCAT[] = "Concat"; constexpr char CONCAT[] = "Concat";
constexpr char SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SoftmaxCrossEntropyWithLogits"; constexpr char SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SoftmaxCrossEntropyWithLogits";
......
...@@ -55,9 +55,7 @@ const char kNameSimpleMeanGrad[] = "SimpleMeanGrad"; ...@@ -55,9 +55,7 @@ const char kNameSimpleMeanGrad[] = "SimpleMeanGrad";
const char kNameAllReduce[] = "AllReduce"; const char kNameAllReduce[] = "AllReduce";
const char kNameBroadcast[] = "Broadcast"; const char kNameBroadcast[] = "Broadcast";
const char kNameAllgather[] = "AllGather"; const char kNameAllgather[] = "AllGather";
const char kNameHostAllgather[] = "HostAllGather";
const char kNameReduceScatter[] = "ReduceScatter"; const char kNameReduceScatter[] = "ReduceScatter";
const char kNameHostReduceScatter[] = "HostReduceScatter";
const char kNameReduceSum[] = "ReduceSum"; const char kNameReduceSum[] = "ReduceSum";
const char kNameIsFinite[] = "isFinite"; const char kNameIsFinite[] = "isFinite";
const char kNameReciprocal[] = "Reciprocal"; const char kNameReciprocal[] = "Reciprocal";
......
...@@ -18,9 +18,9 @@ import mindspore.common.dtype as mstype ...@@ -18,9 +18,9 @@ import mindspore.common.dtype as mstype
from mindspore.ops import functional as F from mindspore.ops import functional as F
from .. import operations as P from .. import operations as P
from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.comm_ops import (AllGather, HostAllGather, AllReduce, _AlltoAll, Broadcast, from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
_GetTensorSlice, _MirrorOperator, ReduceOp, _GetTensorSlice, _MirrorOperator, ReduceOp,
ReduceScatter, HostReduceScatter, _VirtualDiv) ReduceScatter, _HostReduceScatter, _VirtualDiv)
from .grad_base import bprop_getters from .grad_base import bprop_getters
...@@ -93,10 +93,10 @@ def get_bprop_all_gather(self): ...@@ -93,10 +93,10 @@ def get_bprop_all_gather(self):
return bprop return bprop
@bprop_getters.register(HostAllGather) @bprop_getters.register(_HostAllGather)
def get_bprop_host_all_gather(self): def get_bprop_host_all_gather(self):
"""Generate bprop for HostAllGather""" """Generate bprop for _HostAllGather"""
host_all_gather_grad = HostReduceScatter(ReduceOp.SUM, self.group) host_all_gather_grad = _HostReduceScatter(ReduceOp.SUM, self.group)
if self.instance_name: if self.instance_name:
instance_name = "grad" + self.instance_name instance_name = "grad" + self.instance_name
host_all_gather_grad.set_prim_instance_name(instance_name) host_all_gather_grad.set_prim_instance_name(instance_name)
...@@ -126,10 +126,10 @@ def get_bprop_reduce_scatter(self): ...@@ -126,10 +126,10 @@ def get_bprop_reduce_scatter(self):
return bprop return bprop
@bprop_getters.register(HostReduceScatter) @bprop_getters.register(_HostReduceScatter)
def get_bprop_host_reduce_scatter(self): def get_bprop_host_reduce_scatter(self):
"""Generate bprop for HostReduceScatter""" """Generate bprop for _HostReduceScatter"""
host_reduce_scatter_grad = HostAllGather(self.group) host_reduce_scatter_grad = _HostAllGather(self.group)
if self.instance_name: if self.instance_name:
instance_name = "grad" + self.instance_name instance_name = "grad" + self.instance_name
host_reduce_scatter_grad.set_prim_instance_name(instance_name) host_reduce_scatter_grad.set_prim_instance_name(instance_name)
......
...@@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, ...@@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset, _MirrorOperator, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice, _VirtualDiv, _GetTensorSlice,
HostAllGather, HostReduceScatter) _HostAllGather, _HostReduceScatter)
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
TensorSummary, HistogramSummary, Debug, Print) TensorSummary, HistogramSummary, Debug, Print)
from .control_ops import ControlDepend, GeSwitch, Merge from .control_ops import ControlDepend, GeSwitch, Merge
...@@ -244,10 +244,8 @@ __all__ = [ ...@@ -244,10 +244,8 @@ __all__ = [
'UnsortedSegmentSum', 'UnsortedSegmentSum',
'UnsortedSegmentMin', 'UnsortedSegmentMin',
"AllGather", "AllGather",
"HostAllGather",
"AllReduce", "AllReduce",
"ReduceScatter", "ReduceScatter",
"HostReduceScatter",
"Broadcast", "Broadcast",
"ReduceOp", "ReduceOp",
'ScalarCast', 'ScalarCast',
......
...@@ -1166,7 +1166,7 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer): ...@@ -1166,7 +1166,7 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer):
Perform the gradient for the communication part of EmbeddingLookup operator. Perform the gradient for the communication part of EmbeddingLookup operator.
This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking, This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking,
this primitive is implemented by StridedSlice --> HostAllGather --> Concat. This primitive runs on host. this primitive is implemented by StridedSlice --> _HostAllGather --> Concat. This primitive runs on host.
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
...@@ -1177,8 +1177,8 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer): ...@@ -1177,8 +1177,8 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer):
""" """
This primitive is implemented by three steps: This primitive is implemented by three steps:
1) Split the 'dy' along dimension 0 into 'split_num' parts. 1) Split the 'dy' along dimension 0 into 'split_num' parts.
2) For each part, perform HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host. 2) For each part, perform _HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host.
3) After HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them 3) After _HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them
along dimension 0. along dimension 0.
The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8 The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8
......
...@@ -176,13 +176,13 @@ class AllGather(PrimitiveWithInfer): ...@@ -176,13 +176,13 @@ class AllGather(PrimitiveWithInfer):
raise NotImplementedError raise NotImplementedError
class HostAllGather(PrimitiveWithInfer): class _HostAllGather(PrimitiveWithInfer):
""" """
Gathers tensors from the specified communication group on host. Gathers tensors from the specified communication group on host.
Note: Note:
Tensor must have the same shape and format in all processes participating in the collective. Tensor must have the same shape and format in all processes participating in the collective.
HostAllGather is a host-side operator, it depends on OpenMPI and must use build option -M on _HostAllGather is a host-side operator, it depends on OpenMPI and must use build option -M on
to enable it. Using mpirun command to run it: to enable it. Using mpirun command to run it:
mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_all_gather.py mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_all_gather.py
...@@ -199,27 +199,6 @@ class HostAllGather(PrimitiveWithInfer): ...@@ -199,27 +199,6 @@ class HostAllGather(PrimitiveWithInfer):
Outputs: Outputs:
Tensor. If the number of devices in the group is N, Tensor. If the number of devices in the group is N,
then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`. then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`.
Examples:
>>> import mindspore.nn as nn
>>> import mindspore.context as context
>>> import mindspore.ops.operations as P
>>> from mindspore import Tensor
>>>
>>> context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
>>> context.set_mpi_config(enable_mpi=True)
>>>
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.hostallgather = P.HostAllGather(group=(0, 1, 2, 3))
>>>
>>> def construct(self, x):
>>> return self.hostallgather(x)
>>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
""" """
@prim_attr_register @prim_attr_register
...@@ -308,13 +287,13 @@ class ReduceScatter(PrimitiveWithInfer): ...@@ -308,13 +287,13 @@ class ReduceScatter(PrimitiveWithInfer):
raise NotImplementedError raise NotImplementedError
class HostReduceScatter(PrimitiveWithInfer): class _HostReduceScatter(PrimitiveWithInfer):
""" """
Reduces and scatters tensors from the specified communication group on host. Reduces and scatters tensors from the specified communication group on host.
Note: Note:
Tensor must have the same shape and format in all processes participating in the collective. Tensor must have the same shape and format in all processes participating in the collective.
HostReduceScatter is a host-side operator, it depends on OpenMPI and must use build option _HostReduceScatter is a host-side operator, it depends on OpenMPI and must use build option
-M on to enable it. Using mpirun command to run it: -M on to enable it. Using mpirun command to run it:
mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_reduce_scatter.py mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_reduce_scatter.py
...@@ -328,28 +307,6 @@ class HostReduceScatter(PrimitiveWithInfer): ...@@ -328,28 +307,6 @@ class HostReduceScatter(PrimitiveWithInfer):
or elements of group are not int. or elements of group are not int.
ValueError: If the first dimension of input can not be divided by group size, ValueError: If the first dimension of input can not be divided by group size,
or group is not set, or rank_id not in [0, 7]. or group is not set, or rank_id not in [0, 7].
Examples:
>>> import mindspore.nn as nn
>>> import mindspore.context as context
>>> import mindspore.ops.operations as P
>>> from mindspore import Tensor
>>> from mindspore.ops.operations.comm_ops import ReduceOp
>>>
>>> context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
>>> context.set_mpi_config(enable_mpi=True)
>>>
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.hostreducescatter = P.HostReduceScatter(ReduceOp.SUM, group=[0, 1, 2, 3])
>>>
>>> def construct(self, x):
>>> return self.hostreducescatter(x)
>>>
>>> input_ = Tensor(np.ones([8, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
""" """
@prim_attr_register @prim_attr_register
def __init__(self, op=ReduceOp.SUM, group=None): def __init__(self, op=ReduceOp.SUM, group=None):
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
import mindspore._ms_mpi as mpi
# run comand:
# mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_reduce_scatter.py
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
context.set_mpi_config(enable_mpi=True)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.op = "sum"
self.reducescatter = P.HostReduceScatter(op=self.op, group=[0,1,2])
def construct(self, x):
return self.reducescatter(x)
class AllGatherNet(nn.Cell):
def __init__(self):
super(AllGatherNet, self).__init__()
self.hostallgather = P.HostAllGather(group=(0, 1, 2))
def construct(self, x):
return self.hostallgather(x)
def test_net_reduce_scatter():
x = np.arange(12).astype(np.float32) * 0.1
reducescatter = Net()
rankid = mpi.get_rank_id()
print("self rankid:", rankid)
output = reducescatter(Tensor(x, mstype.float32))
print("output:\n", output)
if rankid == 0:
expect_result = np.arange(4).astype(np.float32) * 0.3
if rankid == 1:
expect_result = np.arange(4, 8).astype(np.float32) * 0.3
if rankid == 2:
expect_result = np.arange(8, 12).astype(np.float32) * 0.3
diff = abs(output.asnumpy() - expect_result)
error = np.ones(shape=expect_result.shape) * 1.0e-6
assert np.all(diff < error)
allgather = AllGatherNet()
allgather_output = allgather(output)
print("allgather result:\n", allgather_output)
expect_allgather_result = np.arange(12).astype(np.float32) * 0.3
diff = abs(allgather_output.asnumpy() - expect_allgather_result)
error = np.ones(shape=expect_allgather_result.shape) * 1.0e-6
assert np.all(diff < error)
if __name__ == '__main__':
test_net_reduce_scatter()
...@@ -26,7 +26,6 @@ from mindspore.nn import Momentum ...@@ -26,7 +26,6 @@ from mindspore.nn import Momentum
from mindspore.nn import ReLU from mindspore.nn import ReLU
from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter
from mindspore.ops.operations.comm_ops import HostAllGather, HostReduceScatter
from mindspore.ops.operations.comm_ops import Broadcast from mindspore.ops.operations.comm_ops import Broadcast
# pylint: disable=W0212 # pylint: disable=W0212
...@@ -87,21 +86,6 @@ class AllGatherNet(nn.Cell): ...@@ -87,21 +86,6 @@ class AllGatherNet(nn.Cell):
return self.relu(x) return self.relu(x)
class HostAllGatherNet(nn.Cell):
"""HostAllGatherNet definition"""
def __init__(self, input_channel, output_channel):
super(HostAllGatherNet, self).__init__()
self.dense = Dense(input_channel, output_channel)
self.hostallgather = HostAllGather((0, 1))
self.relu = ReLU()
def construct(self, x):
x = self.dense(x)
x = self.hostallgather(x)
return self.relu(x)
class ReduceScatterNet(nn.Cell): class ReduceScatterNet(nn.Cell):
"""ReduceScatterNet definition""" """ReduceScatterNet definition"""
...@@ -117,21 +101,6 @@ class ReduceScatterNet(nn.Cell): ...@@ -117,21 +101,6 @@ class ReduceScatterNet(nn.Cell):
return self.relu(x) return self.relu(x)
class HostReduceScatterNet(nn.Cell):
"""HostReduceScatterNet definition"""
def __init__(self, input_channel, out_channel, op):
super(HostReduceScatterNet, self).__init__()
self.dense = Dense(input_channel, out_channel)
self.hostreducescatter = HostReduceScatter(op, (0, 1))
self.relu = ReLU()
def construct(self, x):
x = self.dense(x)
x = self.hostreducescatter(x)
return self.relu(x)
class AlltoAllNet(nn.Cell): class AlltoAllNet(nn.Cell):
"""AlltoAllNet definition""" """AlltoAllNet definition"""
...@@ -185,21 +154,6 @@ def test_allgather(): ...@@ -185,21 +154,6 @@ def test_allgather():
_executor.compile(network, input_tensor, label_tensor) _executor.compile(network, input_tensor, label_tensor)
def test_hostallgather():
"""test_hostallgather"""
context.set_context(mode=context.GRAPH_MODE)
input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32))
label_tensor = Tensor(np.array([[1.2], [2.2], [3.2], [4.2]], dtype=np.float32))
network = HostAllGatherNet(2, 1)
loss_fn = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()),
learning_rate=0.1,
momentum=0.9)
network = WithLossCell(network, loss_fn)
network = TrainOneStepCell(network, optimizer)
_executor.compile(network, input_tensor, label_tensor)
def run_reducescatter(op): def run_reducescatter(op):
"""run_reducescatter""" """run_reducescatter"""
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
...@@ -221,21 +175,6 @@ def test_reducescatter(): ...@@ -221,21 +175,6 @@ def test_reducescatter():
run_reducescatter(ReduceOp.SUM) run_reducescatter(ReduceOp.SUM)
def test_hostreducescatter():
"""test_hostreducescatter"""
context.set_context(mode=context.GRAPH_MODE)
input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32))
label_tensor = Tensor(np.array([[1.2]], dtype=np.float32))
network = HostReduceScatterNet(2, 1, ReduceOp.SUM)
loss_fn = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()),
learning_rate=0.1,
momentum=0.9)
network = WithLossCell(network, loss_fn)
network = TrainOneStepCell(network, optimizer)
_executor.compile(network, input_tensor, label_tensor)
def test_broadcast(): def test_broadcast():
"""test_broadcast""" """test_broadcast"""
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册