提交 2eedd321 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1472 add operator HostAllGather and HostReduceScatter

Merge pull request !1472 from yihuaijie/master
......@@ -55,7 +55,9 @@ const char kNameSimpleMeanGrad[] = "SimpleMeanGrad";
const char kNameAllReduce[] = "AllReduce";
const char kNameBroadcast[] = "Broadcast";
const char kNameAllgather[] = "AllGather";
const char kNameHostAllgather[] = "HostAllGather";
const char kNameReduceScatter[] = "ReduceScatter";
const char kNameHostReduceScatter[] = "HostReduceScatter";
const char kNameReduceSum[] = "ReduceSum";
const char kNameIsFinite[] = "isFinite";
const char kNameReciprocal[] = "Reciprocal";
......
......@@ -45,8 +45,10 @@ constexpr auto kAtomicAddrCleanOpName = "AtomicAddrClean";
constexpr auto kGetNextOpName = "GetNext";
constexpr auto kAllReduceOpName = "AllReduce";
constexpr auto kAllGatherOpName = "AllGather";
constexpr auto kHostAllGatherOpName = "HostAllGather";
constexpr auto kBroadcastOpName = "Broadcast";
constexpr auto kReduceScatterOpName = "ReduceScatter";
constexpr auto kHostReduceScatterOpName = "HostReduceScatter";
constexpr auto kMemCpyAsyncOpName = "memcpy_async";
constexpr auto kTopKOpName = "TopK";
constexpr auto kExtractImagePatchesOpName = "ExtractImagePatches";
......
......@@ -18,9 +18,9 @@ import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from .. import operations as P
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.comm_ops import (AllGather, AllReduce, _AlltoAll, Broadcast,
from ..operations.comm_ops import (AllGather, HostAllGather, AllReduce, _AlltoAll, Broadcast,
_GetTensorSlice, _MirrorOperator, ReduceOp,
ReduceScatter, _VirtualDiv)
ReduceScatter, HostReduceScatter, _VirtualDiv)
from .grad_base import bprop_getters
......@@ -79,6 +79,21 @@ def get_bprop_all_gather(self):
return bprop
@bprop_getters.register(HostAllGather)
def get_bprop_host_all_gather(self):
"""Generate bprop for HostAllGather"""
host_all_gather_grad = HostReduceScatter(ReduceOp.SUM, self.group)
if self.instance_name:
instance_name = "grad" + self.instance_name
host_all_gather_grad.set_prim_instance_name(instance_name)
def bprop(x, out, dout):
dx = host_all_gather_grad(dout)
return (dx,)
return bprop
@bprop_getters.register(ReduceScatter)
def get_bprop_reduce_scatter(self):
"""Generate bprop for ReduceScatter"""
......@@ -97,6 +112,24 @@ def get_bprop_reduce_scatter(self):
return bprop
@bprop_getters.register(HostReduceScatter)
def get_bprop_host_reduce_scatter(self):
"""Generate bprop for HostReduceScatter"""
host_reduce_scatter_grad = HostAllGather(self.group)
if self.instance_name:
instance_name = "grad" + self.instance_name
host_reduce_scatter_grad.set_prim_instance_name(instance_name)
if self.op != ReduceOp.SUM:
raise RuntimeError("The hostreducescatter bprop only support ReduceOp.SUM until now.")
def bprop(x, out, dout):
dx = host_reduce_scatter_grad(dout)
return (dx,)
return bprop
@bprop_getters.register(_AlltoAll)
def get_bprop_all_to_all(self):
"""Generate bprop for AlltoAll."""
......
......@@ -33,7 +33,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
SpaceToBatchND, BatchToSpaceND)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice)
_VirtualDiv, _GetTensorSlice,
HostAllGather, HostReduceScatter)
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
TensorSummary, HistogramSummary, Print)
from .control_ops import ControlDepend, GeSwitch, Merge
......@@ -220,8 +221,10 @@ __all__ = [
'UnsortedSegmentSum',
'UnsortedSegmentMin',
"AllGather",
"HostAllGather",
"AllReduce",
"ReduceScatter",
"HostReduceScatter",
"Broadcast",
"ReduceOp",
'ScalarCast',
......
......@@ -169,6 +169,72 @@ class AllGather(PrimitiveWithInfer):
raise NotImplementedError
class HostAllGather(PrimitiveWithInfer):
"""
Gathers tensors from the specified communication group on host.
Note:
Tensor must have the same shape and format in all processes participating in the collective.
Args:
group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on.
Raises:
TypeError: If group is not a list nor tuple, or elements of group are not int.
ValueError: If the local rank id of the calling process not in group,
or rank_id from group not in [0, 7].
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Outputs:
Tensor. If the number of devices in the group is N,
then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`.
Examples:
>>> from mindspore.communication import init
>>> import mindspore.ops.operations as P
>>> init('nccl')
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.hostallgather = P.HostAllGather(group=(0, 1, 2, 3))
>>>
>>> def construct(self, x):
>>> return self.hostallgather(x)
>>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
"""
@prim_attr_register
def __init__(self, group=None):
if group is None:
raise ValueError(f"For '{self.name}' group must be set.")
validator.check_value_type('group', group, (tuple, list), self.name)
validator.check_integer("group size", len(group), 2, Rel.GE, self.name)
for r in group:
validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name)
validator.check_value_type("rank_id", r, (int,), self.name)
self.group_size = len(group)
self.rank = get_rank()
validator.check('rank', self.rank, 'group', self.group, Rel.IN, self.name)
self.add_prim_attr('group', group)
def infer_shape(self, x_shape):
validator.check_integer("x shape", len(x_shape), 0, Rel.GT, self.name)
x_shape[0] = x_shape[0] * self.group_size
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
return x_dtype
def __call__(self, tensor):
raise NotImplementedError
class ReduceScatter(PrimitiveWithInfer):
"""
Reduces and scatters tensors from the specified communication group.
......@@ -226,6 +292,68 @@ class ReduceScatter(PrimitiveWithInfer):
raise NotImplementedError
class HostReduceScatter(PrimitiveWithInfer):
"""
Reduces and scatters tensors from the specified communication group on host.
Note:
Tensor must have the same shape and format in all processes participating in the collective.
Args:
op (str): Specifies an operation used for element-wise reductions,
like sum, max, avg. Default: ReduceOp.SUM.
group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on.
Raise:
TypeError: If op is not a string and group is not a list nor tuple,
or elements of group are not int.
ValueError: If the first dimension of input can not be divided by rank size,
or group is not set, or rank_id not in [1, 7].
Examples:
>>> from mindspore.communication import init
>>> import mindspore.ops.operations as P
>>> init('nccl')
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.hostreducescatter = P.HostReduceScatter(ReduceOp.SUM, group=[0, 1, 2, 3])
>>>
>>> def construct(self, x):
>>> return self.hostreducescatter(x)
>>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
"""
@prim_attr_register
def __init__(self, op=ReduceOp.SUM, group=None):
if group is None:
raise ValueError(f"For '{self.name}' group must be set.")
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
validator.check_value_type('group', group, (tuple, list), self.name)
validator.check_integer("group size", len(group), 2, Rel.GE, self.name)
for r in group:
validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name)
validator.check_value_type("rank_id", r, (int,), self.name)
self.op = op
self.group_size = len(group)
self.add_prim_attr('group', group)
def infer_shape(self, x_shape):
if x_shape[0] % self.group_size != 0:
raise ValueError(f"For '{self.name}' the first dimension of x should be divided by group_size.")
x_shape[0] = int(x_shape[0]/self.group_size)
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
return x_dtype
def __call__(self, tensor):
raise NotImplementedError
class Broadcast(PrimitiveWithInfer):
"""
Broadcasts the tensor to the whole group.
......
......@@ -26,6 +26,7 @@ from mindspore.nn import Momentum
from mindspore.nn import ReLU
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter
from mindspore.ops.operations.comm_ops import HostAllGather, HostReduceScatter
from mindspore.ops.operations.comm_ops import Broadcast
# pylint: disable=W0212
......@@ -86,6 +87,21 @@ class AllGatherNet(nn.Cell):
return self.relu(x)
class HostAllGatherNet(nn.Cell):
"""HostAllGatherNet definition"""
def __init__(self, input_channel, output_channel):
super(HostAllGatherNet, self).__init__()
self.dense = Dense(input_channel, output_channel)
self.hostallgather = HostAllGather((0, 1))
self.relu = ReLU()
def construct(self, x):
x = self.dense(x)
x = self.hostallgather(x)
return self.relu(x)
class ReduceScatterNet(nn.Cell):
"""ReduceScatterNet definition"""
......@@ -101,6 +117,21 @@ class ReduceScatterNet(nn.Cell):
return self.relu(x)
class HostReduceScatterNet(nn.Cell):
"""HostReduceScatterNet definition"""
def __init__(self, input_channel, out_channel, op):
super(HostReduceScatterNet, self).__init__()
self.dense = Dense(input_channel, out_channel)
self.hostreducescatter = HostReduceScatter(op, (0, 1))
self.relu = ReLU()
def construct(self, x):
x = self.dense(x)
x = self.hostreducescatter(x)
return self.relu(x)
class AlltoAllNet(nn.Cell):
"""AlltoAllNet definition"""
......@@ -154,6 +185,21 @@ def test_allgather():
_executor.compile(network, input_tensor, label_tensor)
def test_hostallgather():
"""test_hostallgather"""
context.set_context(mode=context.GRAPH_MODE)
input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32))
label_tensor = Tensor(np.array([[1.2], [2.2], [3.2], [4.2]], dtype=np.float32))
network = HostAllGatherNet(2, 1)
loss_fn = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()),
learning_rate=0.1,
momentum=0.9)
network = WithLossCell(network, loss_fn)
network = TrainOneStepCell(network, optimizer)
_executor.compile(network, input_tensor, label_tensor)
def run_reducescatter(op):
"""run_reducescatter"""
context.set_context(mode=context.GRAPH_MODE)
......@@ -175,6 +221,21 @@ def test_reducescatter():
run_reducescatter(ReduceOp.SUM)
def test_hostreducescatter():
"""test_hostreducescatter"""
context.set_context(mode=context.GRAPH_MODE)
input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32))
label_tensor = Tensor(np.array([[1.2]], dtype=np.float32))
network = HostReduceScatterNet(2, 1, ReduceOp.SUM)
loss_fn = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()),
learning_rate=0.1,
momentum=0.9)
network = WithLossCell(network, loss_fn)
network = TrainOneStepCell(network, optimizer)
_executor.compile(network, input_tensor, label_tensor)
def test_broadcast():
"""test_broadcast"""
context.set_context(mode=context.GRAPH_MODE)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册