diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/parallel/ops_info/ops_utils.h index 3c55a15146bfb4284e483b3cfc124408f25fc1fb..5c074ef1583bd43c2c2265ff2b7f56976281fff7 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_utils.h @@ -76,7 +76,7 @@ constexpr char DEPEND[] = "depend"; constexpr char BATCH_PARALLEL[] = "BatchParallel"; constexpr char ACTIVATION_TYPE[] = "activation_type"; -constexpr char TARGET[] = "target"; +constexpr char TARGET[] = "primitive_target"; constexpr char CPU[] = "CPU"; constexpr char TRANSPOSE_A[] = "transpose_a"; constexpr char TRANSPOSE_B[] = "transpose_b"; diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 008f5f0edb999afd75b776c237077bb818e82902..55157216a835cce5d67ada6358b2e9be7e37800d 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -21,6 +21,7 @@ from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register from ..._checkparam import Validator as validator, Rel from .._utils import get_concat_offset from ...common import dtype as mstype +from .. import functional as F class AbsGrad(PrimitiveWithInfer): @@ -1121,6 +1122,37 @@ class MirrorPadGrad(PrimitiveWithInfer): 'value': None} +class EmbeddingLookupCommGrad(PrimitiveWithInfer): + """ + Perform the gradient for the communication part of EmbeddingLookup operator. + + This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking, + this primitive is implemented by StridedSlice --> HostAllGather --> Concat. This primitive runs on host. + """ + @prim_attr_register + def __init__(self): + self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output']) + self.add_prim_attr('primitive_target', 'CPU') + + def __infer__(self, dy, split_num): + """ + This primitive is implemented by three steps: + 1) Split the 'dy' along dimension 0 into 'split_num' parts. + 2) For each part, perform HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host. + 3) After HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them + along dimension 0. + + The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8 + """ + dy_shape = tuple(dy['shape']) + split_num_value = split_num['value'] + validator.check_value_type("split_num_value", split_num_value, [int], self.name) + dy_shape_all = F.tuple_setitem(dy_shape, 0, dy_shape[0] * 8) + return {'shape': dy_shape_all, + 'dtype': dy['dtype'], + 'value': None} + + class RefToEmbed(Primitive): r""" Make a key from Ref. diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index a9c856b7c54097ef57fdbfd3ca2a426c3abe2569..ba5c20e06de08cf33a5b6cd6bdaf6717112a8d3c 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -614,7 +614,7 @@ class EmbeddingLookup(PrimitiveWithInfer): self.__setattr_flag__ = True self.init_prim_io_names(inputs=['params', 'indices', 'axis', 'offset', 'reduce_scatter_flag', 'split_num'], outputs=['output']) - self.add_prim_attr('target', 'CPU') + self.add_prim_attr('primitive_target', 'CPU') def __infer__(self, params, indices, axis, offset, reduce_scatter_flag=False, split_num=2): validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) diff --git a/tests/ut/python/parallel/test_gather_v2.py b/tests/ut/python/parallel/test_gather_v2.py index c295bf93abe9537ba7e249136911f2a96d4f0668..5d52089cbec101988cd719ad05af05cf4174257c 100644 --- a/tests/ut/python/parallel/test_gather_v2.py +++ b/tests/ut/python/parallel/test_gather_v2.py @@ -45,11 +45,11 @@ class GradWrap(nn.Cell): class Net(nn.Cell): - def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None): + def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None, target=""): super().__init__() if shape is None: shape = [64, 64] - self.gatherv2 = P.GatherV2().set_strategy(strategy1) + self.gatherv2 = P.GatherV2().set_strategy(strategy1).add_prim_attr("primitive_target", target) self.mul = P.Mul().set_strategy(strategy2) self.index = Tensor(np.ones(shape), dtype=ms.int32) self.axis = axis @@ -188,7 +188,7 @@ def test_gatherv2_cpu0(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") strategy1 = ((8, 1), (1, 1)) strategy2 = ((4, 2, 1), (4, 2, 1)) - net = NetWithLoss(Net(0, strategy1, strategy2)) + net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) net.set_auto_parallel() x = Tensor(np.ones([64, 64]), dtype=ms.float32) @@ -200,7 +200,7 @@ def test_gatherv2_cpu1(): context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") strategy1 = ((16, 1), (1, 1)) strategy2 = ((4, 2, 1), (4, 2, 1)) - net = NetWithLoss(Net(0, strategy1, strategy2)) + net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) net.set_auto_parallel() x = Tensor(np.ones([64, 64]), dtype=ms.float32) @@ -212,7 +212,7 @@ def test_gatherv2_cpu2(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") strategy1 = ((1, 8), (1, 1)) strategy2 = ((4, 2, 1), (4, 2, 1)) - net = NetWithLoss(Net(0, strategy1, strategy2)) + net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) net.set_auto_parallel() x = Tensor(np.ones([64, 64]), dtype=ms.float32)