提交 69574f38 编写于 作者: X Xiaoda Zhang

fix the bprob error of embeddinglookup

上级 373832d0
......@@ -203,7 +203,7 @@ def get_bprop_embedding_lookup(self):
actual_dout = elu_grad(dout, split_num)
else:
actual_dout = dout
new_indices = host_sub(indices - offset)
new_indices = host_sub(indices, offset)
# Reshape the 'new_indices'
new_indices_shape_changed = (size_op(new_indices),)
new_indices = host_reshape(new_indices, new_indices_shape_changed)
......@@ -211,7 +211,7 @@ def get_bprop_embedding_lookup(self):
x_shp_tail = x_shp[1:]
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
actual_dout = host_reshape(actual_dout, actual_dout_shape_changed)
return (new_indices, actual_dout, x_shp), zeros_like(new_indices), zeros_like(axis), \
return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset), \
zeros_like(reduce_scatter_flag), zeros_like(split_num)
return bprop_sparse
......
......@@ -16,12 +16,20 @@ import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import _executor
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops.operations import _inner_ops as inner
from mindspore import Tensor, context
from tests.ut.python.ops.test_math_ops import VirtualLoss
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x, y):
return C.grad_all(self.network)(x, y)
class NetWithLoss(nn.Cell):
def __init__(self, network):
......@@ -73,3 +81,30 @@ def test_embeddinglookup_reducescatter_true():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_embeddinglookup_reducescatter_false_grad():
shape = [8, 8]
offset = 8
reduce_scatter_flag = False
split_num = 1
net = GradWrap(NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_embeddinglookup_reducescatter_true_grad():
context.set_context(save_graphs=True)
shape = [64, 8]
offset = 8
reduce_scatter_flag = True
split_num = 8
net = GradWrap(NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
_executor.compile(net, x, y)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册