提交 fc906f7f 编写于 作者: X Xiaoda Zhang

move embeddinglookup to external

上级 a7fc7e50
......@@ -191,13 +191,12 @@ def get_bprop_tile(self):
return bprop
@bprop_getters.register(inner.EmbeddingLookup)
@bprop_getters.register(P.EmbeddingLookup)
def get_bprop_embedding_lookup(self):
"""Generate bprop for EmbeddingLookup"""
sub_op = P.Sub()
reshape_op = P.Reshape()
host_reshape = P.Reshape().add_prim_attr('primitive_target', 'CPU')
def bprop_sparse(x, indices, offset, reduce_scatter_flag, split_num, out, dout):
def bprop_sparse(x, indices, offset, out, dout):
x_shp = shape_op(x)
new_indices = sub_op(indices, offset)
# Reshape the 'new_indices'
......@@ -205,17 +204,9 @@ def get_bprop_embedding_lookup(self):
new_indices = reshape_op(new_indices, new_indices_shape_changed)
x_shp_tail = x_shp[1:]
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
if reduce_scatter_flag is True:
# On host
elu_grad = G.EmbeddingLookupCommGrad()
actual_dout = elu_grad(dout, split_num)
# Reshape the 'actual_dout' on host
actual_dout = host_reshape(actual_dout, actual_dout_shape_changed)
else:
# Reshape the 'actual_dout' on device
actual_dout = reshape_op(dout, actual_dout_shape_changed)
return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset), \
zeros_like(reduce_scatter_flag), zeros_like(split_num)
# Reshape the 'actual_dout' on device
actual_dout = reshape_op(dout, actual_dout_shape_changed)
return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset)
return bprop_sparse
......
......@@ -32,7 +32,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence)
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice,
......@@ -333,6 +333,7 @@ __all__ = [
"Mod",
"PopulationCount",
"ParallelConcat",
"EmbeddingLookup"
]
__all__.sort()
......@@ -263,76 +263,6 @@ class AscendDequant(PrimitiveWithInfer):
return mstype.float16
class EmbeddingLookup(PrimitiveWithInfer):
"""
Returns a slice of input tensor based on the specified indices.
This Primitive has the similar functionality as GatherV2 operating on `axis = 0`, but has three more inputs:
`offset`, `reduce_scatter_flag` and `split_num`. This primitive runs on the host instead of devices.
Inputs:
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
The Tensor slice, instead of the entire Tensor.
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
and the exceeding part will be filled with 0 in the output.
- **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices
are equal to `input_indices` minus `offset`.
- **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not.
Only constant value is allowed.
- **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable
is used only if `reduce_scatter_flag` is True. Only constant value is allowed.
Outputs:
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
Examples:
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
>>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
>>> offset = 4
>>> reduce_scatter_flag = False
>>> split_num = 1
>>> out = P.EmbeddingLookup()(input_params, input_indices, offset, reduce_scatter_flag, split_num)
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
"""
@prim_attr_register
def __init__(self):
"""init index_select"""
self.__setattr_flag__ = True
self.init_prim_io_names(inputs=['params', 'indices', 'offset', 'reduce_scatter_flag', 'split_num'],
outputs=['output'])
self.add_prim_attr('primitive_target', 'CPU')
def __infer__(self, params, indices, offset, reduce_scatter_flag=False, split_num=2):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
validator.check_subclass("split_num", split_num['dtype'], mstype.int_, self.name)
if split_num['value'] < 1:
raise ValueError("The parameter 'split_num' must be positive, but got %d." % split_num)
params_shp = params['shape']
out_shape = indices['shape'] + params_shp[1:]
if reduce_scatter_flag is None:
raise ValueError("The value of 'reduce_scatter_flag' is None.")
reduce_scatter_flag_value = reduce_scatter_flag['value']
if split_num is None:
raise ValueError("The value of 'split_num_value' is None.")
split_num_value = split_num['value']
if reduce_scatter_flag_value is True:
# Partition the tensor along the dimension 0. The shape size of dimension 0 should be divisible by
# (split_num * 8)
if out_shape[0] % (split_num_value * 8) != 0:
raise ValueError("The dimension 0 of the shape: %d, is not divisible by: %d." %
(out_shape[0], (split_num_value * 8)))
# After 'Concat' on host, the shape size of dimension 0 is: out_shape[0] // 8
out_shape[0] = out_shape[0] // 8
out = {'shape': out_shape,
'dtype': params['dtype'],
'value': None}
return out
class SparseApplyFtrlNoReturn(PrimitiveWithInfer):
"""
Update relevant entries according to the FTRL-proximal scheme.
......
......@@ -3236,3 +3236,50 @@ class TransShape(PrimitiveWithInfer):
return {'shape': shp,
'dtype': dtype,
'value': None}
class EmbeddingLookup(PrimitiveWithInfer):
"""
Returns a slice of input tensor based on the specified indices.
This Primitive has the similar functionality as GatherV2 operating on `axis = 0`, but has one more inputs:
`offset`.
Inputs:
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
The Tensor slice, instead of the entire Tensor.
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
and the exceeding part will be filled with 0 in the output.
- **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices
are equal to `input_indices` minus `offset`.
Outputs:
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
Examples:
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
>>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
>>> offset = 4
>>> out = P.EmbeddingLookup()(input_params, input_indices, offset)
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
"""
@prim_attr_register
def __init__(self):
"""init index_select"""
self.__setattr_flag__ = True
self.init_prim_io_names(inputs=['params', 'indices', 'offset'],
outputs=['output'])
def __infer__(self, params, indices, offset):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
params_shp = params['shape']
if len(params_shp) != 2:
raise ValueError("The dimension of 'params' in EmbeddingLookup must be 2, but got %d." % len(params_shp))
out_shape = indices['shape'] + params_shp[1:]
out = {'shape': out_shape,
'dtype': params['dtype'],
'value': None}
return out
......@@ -19,7 +19,6 @@ import mindspore.nn as nn
from mindspore.common.api import _executor
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops.operations import _inner_ops as inner
from mindspore import Tensor, context
from tests.ut.python.ops.test_math_ops import VirtualLoss
......@@ -42,17 +41,15 @@ class NetWithLoss(nn.Cell):
return self.loss(predict)
class Net(nn.Cell):
def __init__(self, shape, offset, reduce_scatter_flag, split_num):
def __init__(self, shape, offset):
super().__init__()
self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.offset = offset
self.reduce_scatter_flag = reduce_scatter_flag
self.split_num = split_num
self.elu = inner.EmbeddingLookup()
self.elu = P.EmbeddingLookup()
self.mm = P.BatchMatMul()
def construct(self, x, y):
out = self.elu(x, self.index, self.offset, self.reduce_scatter_flag, self.split_num)
out = self.elu(x, self.index, self.offset)
out = self.mm(out, y)
return out
......@@ -60,9 +57,7 @@ class Net(nn.Cell):
def test_embeddinglookup_reducescatter_false():
shape = [8, 8]
offset = 8
reduce_scatter_flag = False
split_num = 1
net = NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num))
net = NetWithLoss(Net(shape, offset))
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
......@@ -71,11 +66,9 @@ def test_embeddinglookup_reducescatter_false():
def test_embeddinglookup_reducescatter_true():
shape = [64, 8]
shape = [8, 8]
offset = 8
reduce_scatter_flag = True
split_num = 8
net = NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num))
net = NetWithLoss(Net(shape, offset))
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
......@@ -86,9 +79,7 @@ def test_embeddinglookup_reducescatter_true():
def test_embeddinglookup_reducescatter_false_grad():
shape = [8, 8]
offset = 8
reduce_scatter_flag = False
split_num = 1
net = GradWrap(NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num)))
net = GradWrap(NetWithLoss(Net(shape, offset)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
......@@ -98,11 +89,9 @@ def test_embeddinglookup_reducescatter_false_grad():
def test_embeddinglookup_reducescatter_true_grad():
context.set_context(save_graphs=True)
shape = [64, 8]
shape = [8, 8]
offset = 8
reduce_scatter_flag = True
split_num = 8
net = GradWrap(NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num)))
net = GradWrap(NetWithLoss(Net(shape, offset)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
......
......@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
......@@ -184,6 +185,7 @@ def test_gatherv2_auto1():
_executor.compile(net, x, y)
@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen")
def test_gatherv2_cpu0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1), (1, 1))
......@@ -196,6 +198,7 @@ def test_gatherv2_cpu0():
_executor.compile(net, x, y)
@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen")
def test_gatherv2_cpu1():
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((16, 1), (1, 1))
......@@ -208,6 +211,7 @@ def test_gatherv2_cpu1():
_executor.compile(net, x, y)
@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen")
def test_gatherv2_cpu2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((1, 8), (1, 1))
......
......@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
......@@ -184,6 +185,7 @@ def test_gatherv2_auto1():
_executor.compile(net, x, y)
@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen")
def test_gatherv2_cpu0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1), (1, 1))
......@@ -196,6 +198,7 @@ def test_gatherv2_cpu0():
_executor.compile(net, x, y)
@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen")
def test_gatherv2_cpu1():
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((16, 1), (1, 1))
......@@ -208,6 +211,7 @@ def test_gatherv2_cpu1():
_executor.compile(net, x, y)
@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen")
def test_gatherv2_cpu2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((1, 8), (1, 1))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册