From e4de26d5bc1451fcdc955f90c449b2a6e35ba493 Mon Sep 17 00:00:00 2001 From: yao_yf Date: Tue, 4 Aug 2020 15:49:25 +0800 Subject: [PATCH] embeddinglookup wrap --- mindspore/nn/layer/embedding.py | 89 ++++++++++++++++--- .../wide_and_deep/src/wide_and_deep.py | 23 ++--- .../test_cmp_sparse_embedding.py | 12 +-- tests/ut/python/ir/test_row_tensor.py | 10 +-- .../parallel/test_sparse_feature_bprop.py | 6 +- 5 files changed, 98 insertions(+), 42 deletions(-) diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index 3c4245d70..581ef9a72 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -18,10 +18,14 @@ from mindspore.common.tensor import Tensor from mindspore.ops import operations as P from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer +from mindspore._checkparam import Validator +from mindspore.communication.management import get_group_size +from mindspore.train.parallel_utils import ParallelMode +from mindspore.parallel._utils import _get_parallel_mode from ..cell import Cell -from ..._checkparam import Validator as validator +from ..._checkparam import Validator as validator, Rel -__all__ = ['Embedding', 'EmbeddingLookup'] +__all__ = ['Embedding', 'EmbeddingLookup', 'EmbeddingLookUpSplitMode'] class Embedding(Cell): r""" @@ -114,29 +118,36 @@ class EmbeddingLookup(Cell): When 'target' is set to 'CPU', this module will use P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which specified 'offset = 0' to lookup table. - when 'target' is set to 'DEVICE', this module will use P.GatherV2() which + When 'target' is set to 'DEVICE', this module will use P.GatherV2() which specified 'axis = 0' to lookup table. + In field slice mode, the manual_shapes should be given. It is a tuple ,where + the element is (vocab[i], offset[i]), vocab[i] is the row numbers for i-th + part and offset[i] is the feature id offset for i-th part. The feature id in + i-th part will be subtracted by offset[i] to ensure the id start from 0. Args: + vocab_size (int): Size of the dictionary of embeddings. + embedding_size (int): The size of each embedding vector. + param_init (str): The initialize way of embedding table. Default: 'normal'. target (str): Specify the target where the op is executed. Default: 'CPU'. + slice_mode (str): The slicing way in semi auto parallel/auto parallel. Default: 'batch_slice'. + manual_shapes (tuple): The accompaniment array in field slice mode. Inputs: - - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. - The Tensor slice, instead of the entire Tensor. - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. - Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`, - and the exceeding part will be filled with 0 in the output. + Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table, + and the exceeding part will be filled with 0 in the output. Input_indices should only be a 2d tensor in + this interface. Outputs: Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. Examples: - >>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32) - >>> out = nn.EmbeddingLookup()(input_params, input_indices) - [[[10, 11], [8 ,9]], [[14, 15], [12, 13]]] + >>> out = nn.EmbeddingLookup(4,2)(input_indices) """ - def __init__(self, target='CPU'): + def __init__(self, vocab_size, embedding_size, param_init='normal', + target='CPU', slice_mode='batch_slice', manual_shapes=None): super(EmbeddingLookup, self).__init__() self.target = target if target not in ('CPU', 'DEVICE'): @@ -144,10 +155,60 @@ class EmbeddingLookup(Cell): + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') self.gatherv2 = P.GatherV2() self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') + self.embedding_table = Parameter(initializer(param_init, [vocab_size, embedding_size]), + name='embedding_table') + parallel_mode = _get_parallel_mode() + is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) + if slice_mode == EmbeddingLookUpSplitMode.FIELD_SLICE and is_auto_parallel: + if not manual_shapes: + raise ValueError("in slice field mode, the manual_shapes should not be none") + if not isinstance(manual_shapes, tuple): + raise TypeError("manual_shapes type must be tuple(int) cannot be {}!".format(type(manual_shapes))) + for dim in manual_shapes: + Validator.check_integer('manul shape dim', dim, 0, Rel.GT, self.cls_name) + self.gatherv2.add_prim_attr("manual_split", manual_shapes) + self.embeddinglookup.add_prim_attr("manual_split", manual_shapes) + self.gatherv2.set_strategy(((get_group_size(), 1), (1, get_group_size()))) + self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, get_group_size()))) + elif slice_mode == EmbeddingLookUpSplitMode.TABLE_ROW_SLICE and is_auto_parallel: + self.gatherv2.set_strategy(((get_group_size(), 1), (1, 1))) + self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1))) + elif slice_mode == EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE and is_auto_parallel: + self.gatherv2.set_strategy(((1, get_group_size()), (1, 1))) + self.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1))) + elif slice_mode == EmbeddingLookUpSplitMode.BATCH_SLICE and is_auto_parallel: + self.gatherv2.set_strategy(((1, 1), (get_group_size(), 1))) + self.embeddinglookup.set_strategy(((1, 1), (get_group_size(), 1))) + else: + if is_auto_parallel: + raise ValueError("slice_mode should support mode in nn.EmbeddingLookUpSplitMode, but get " + + str(slice_mode)) - def construct(self, params, indices): + def construct(self, indices): if self.target == "CPU": - out = self.embeddinglookup(params, indices, 0) + out = self.embeddinglookup(self.embedding_table, indices, 0) else: - out = self.gatherv2(params, indices, 0) + out = self.gatherv2(self.embedding_table, indices, 0) return out + + +class EmbeddingLookUpSplitMode: + """ + EmbeddingLookUp slice options in auto parallel and semi auto parallel mode. + + There are five kinds of slice options, "BATCH_SLICE", "FIELD_SLICE", + "TABLE_ROW_SLICE" and "TABLE_COLUMN_SLICE". Default: "BATCH_SLICE". + + - BATCH_SLICE: Slicing batch dimensions of indices. + - FIELD_SLICE: Slicing field dimensions of indices. + - TABLE_ROW_SLICE: Slicing row of table. + - TABLE_COLUMN_SLICE: Slicing column of table. + + MODE_LIST: The list for all supported parallel modes. + """ + + BATCH_SLICE = "batch_slice" + FIELD_SLICE = "field_slice" + TABLE_ROW_SLICE = "table_row_slice" + TABLE_COLUMN_SLICE = "table_column_slice" + MODE_LIST = [BATCH_SLICE, FIELD_SLICE, TABLE_ROW_SLICE, TABLE_COLUMN_SLICE] diff --git a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py index b00b40905..151ecbdeb 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py +++ b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py @@ -209,19 +209,22 @@ class WideDeepModel(nn.Cell): if is_auto_parallel and host_device_mix: self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),)) self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1))) - self.deep_embeddinglookup = nn.EmbeddingLookup() - self.deep_embeddinglookup.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1))) - self.wide_embeddinglookup = nn.EmbeddingLookup() - self.wide_embeddinglookup.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1))) + self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, + slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE) + self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, + slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_ROW_SLICE) self.deep_mul.set_strategy(((1, 1, get_group_size()), (1, 1, 1))) self.deep_reshape.add_prim_attr("skip_redistribution", True) self.reduce_sum.add_prim_attr("cross_batch", True) + self.embedding_table = self.deep_embeddinglookup.embedding_table elif parameter_server: - self.deep_embeddinglookup = nn.EmbeddingLookup() - self.wide_embeddinglookup = nn.EmbeddingLookup() + self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim) + self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1) + self.embedding_table = self.deep_embeddinglookup.embedding_table else: - self.deep_embeddinglookup = nn.EmbeddingLookup(target='DEVICE') - self.wide_embeddinglookup = nn.EmbeddingLookup(target='DEVICE') + self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target='DEVICE') + self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target='DEVICE') + self.embedding_table = self.deep_embeddinglookup.embedding_table def construct(self, id_hldr, wt_hldr): """ @@ -231,11 +234,11 @@ class WideDeepModel(nn.Cell): """ mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) # Wide layer - wide_id_weight = self.wide_embeddinglookup(self.wide_w, id_hldr) + wide_id_weight = self.wide_embeddinglookup(id_hldr) wx = self.wide_mul(wide_id_weight, mask) wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) # Deep layer - deep_id_embs = self.deep_embeddinglookup(self.embedding_table, id_hldr) + deep_id_embs = self.deep_embeddinglookup(id_hldr) vx = self.deep_mul(deep_id_embs, mask) deep_in = self.deep_reshape(vx, (-1, self.field_size * self.emb_dim)) deep_in = self.dense_layer_1(deep_in) diff --git a/tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py b/tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py index c08b5b993..0a7034c71 100644 --- a/tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py +++ b/tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py @@ -24,8 +24,7 @@ from mindspore.common import dtype as mstype from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import Adam from mindspore.ops import operations as P -from mindspore.common.initializer import TruncatedNormal, initializer -from mindspore import Parameter +from mindspore.common.initializer import TruncatedNormal parser = argparse.ArgumentParser(description="test_sparse_embedding") parser.add_argument("--device_target", type=str, default="Ascend") @@ -53,16 +52,13 @@ class LeNet5(nn.Cell): super(LeNet5, self).__init__() self.cast = P.Cast() self.flatten = nn.Flatten() - self.embedding_table = Parameter( - initializer("normal", (16, 4), mstype.float32), name="embedding_table" - ) - self.embedding = nn.EmbeddingLookup() + self.embedding = nn.EmbeddingLookup(16, 4) self.relu = nn.ReLU() self.fc = fc_with_initialize(12, num_class) def construct(self, x): x = self.cast(x, mstype.int32) - x = self.embedding(self.embedding_table, x) + x = self.embedding(x) x = self.flatten(x) x = self.fc(x) return x @@ -72,7 +68,7 @@ def do_sparse_embedding(ps=False): epoch = 10 net = LeNet5(10) if ps: - net.embedding_table.set_param_ps() + net.embedding.embedding_table.set_param_ps() optimizer = Adam(filter(lambda x: x.requires_grad, net.get_parameters())) optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU") diff --git a/tests/ut/python/ir/test_row_tensor.py b/tests/ut/python/ir/test_row_tensor.py index c2097a831..cdfcf55bd 100644 --- a/tests/ut/python/ir/test_row_tensor.py +++ b/tests/ut/python/ir/test_row_tensor.py @@ -421,17 +421,16 @@ def test_row_tensor_with_control_flow_if(): class EmbeddingLookUpBnNet(nn.Cell): - def __init__(self, param_np, target='CPU'): + def __init__(self, vocab_size, embedding_size, target='CPU'): super().__init__() - self.param = Parameter(Tensor(param_np), name="w1") - self.embedding_lookup = nn.EmbeddingLookup(target=target) + self.embedding_lookup = nn.EmbeddingLookup(vocab_size, embedding_size, param_init='ones', target=target) self.bn = nn.BatchNorm2d(num_features=3) self.mul = P.Mul() self.reshape = P.Reshape() self.relu = nn.PReLU() def construct(self, indices): - x = self.embedding_lookup(self.param, indices) + x = self.embedding_lookup(indices) x = self.reshape(x, (2, 3, 2, 2)) x = self.relu(x) x = self.bn(x) @@ -439,10 +438,9 @@ class EmbeddingLookUpBnNet(nn.Cell): def test_embedding_lookup_with_mix_precision(): - param_np = np.ones([8, 8]).astype(np.float32) data = Tensor(np.array([0, 1, 2]).astype(np.int32)) label = Tensor(np.random.randn(*(2, 3, 2, 2)).astype(np.float32)) - net = EmbeddingLookUpBnNet(param_np, target='CPU') + net = EmbeddingLookUpBnNet(8, 8, target='CPU') criterion = nn.SoftmaxCrossEntropyWithLogits(reduction='mean') optimizer = nn.Adam(params=net.trainable_params(), learning_rate=0.1) diff --git a/tests/ut/python/parallel/test_sparse_feature_bprop.py b/tests/ut/python/parallel/test_sparse_feature_bprop.py index 407044202..32ebe2859 100644 --- a/tests/ut/python/parallel/test_sparse_feature_bprop.py +++ b/tests/ut/python/parallel/test_sparse_feature_bprop.py @@ -69,14 +69,12 @@ def test_bprop_with_sparse_feature_mirror(): super(Net, self).__init__() if shape is None: shape = [8, 8] - weight = Tensor(np.ones([64, 64]), dtype=ms.float32) - self.weight = Parameter(weight, "w") self.index = Tensor(np.ones(shape), dtype=ms.int32) - self.embeddinglookup = nn.EmbeddingLookup() + self.embeddinglookup = nn.EmbeddingLookup(64, 64, param_init='ones') self.embeddinglookup.embeddinglookup.set_strategy(((1, 1), (8, 1))) def construct(self, x, b): - out = self.embeddinglookup(self.weight, self.index) + out = self.embeddinglookup(self.index) return out -- GitLab