diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index 4bc2dc5a6bc5e613ba107e00479ff0edf5ca546d..f6fe6e376ae4e3d2cff9ed385ea76ba498887105 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -203,7 +203,7 @@ def get_bprop_embedding_lookup(self): actual_dout = elu_grad(dout, split_num) else: actual_dout = dout - new_indices = host_sub(indices - offset) + new_indices = host_sub(indices, offset) # Reshape the 'new_indices' new_indices_shape_changed = (size_op(new_indices),) new_indices = host_reshape(new_indices, new_indices_shape_changed) @@ -211,7 +211,7 @@ def get_bprop_embedding_lookup(self): x_shp_tail = x_shp[1:] actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail actual_dout = host_reshape(actual_dout, actual_dout_shape_changed) - return (new_indices, actual_dout, x_shp), zeros_like(new_indices), zeros_like(axis), \ + return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset), \ zeros_like(reduce_scatter_flag), zeros_like(split_num) return bprop_sparse diff --git a/tests/ut/python/parallel/test_embeddinglookup.py b/tests/ut/python/parallel/test_embeddinglookup.py index b3060619813b86bf556cbd97fcbba39a6c66725c..4ab5f5f878881770cf222ab10d47c4b843359e00 100644 --- a/tests/ut/python/parallel/test_embeddinglookup.py +++ b/tests/ut/python/parallel/test_embeddinglookup.py @@ -16,12 +16,20 @@ import numpy as np import mindspore as ms import mindspore.nn as nn -from mindspore import Tensor from mindspore.common.api import _executor from mindspore.ops import operations as P +from mindspore.ops import composite as C from mindspore.ops.operations import _inner_ops as inner +from mindspore import Tensor, context from tests.ut.python.ops.test_math_ops import VirtualLoss +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x, y): + return C.grad_all(self.network)(x, y) class NetWithLoss(nn.Cell): def __init__(self, network): @@ -73,3 +81,30 @@ def test_embeddinglookup_reducescatter_true(): x = Tensor(np.ones([64, 32]), dtype=ms.float32) y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32) _executor.compile(net, x, y) + + +def test_embeddinglookup_reducescatter_false_grad(): + shape = [8, 8] + offset = 8 + reduce_scatter_flag = False + split_num = 1 + net = GradWrap(NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_embeddinglookup_reducescatter_true_grad(): + context.set_context(save_graphs=True) + shape = [64, 8] + offset = 8 + reduce_scatter_flag = True + split_num = 8 + net = GradWrap(NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32) + _executor.compile(net, x, y)