diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index b721df626f9e26d3c0c09f559bd4dd4da1144582..3d3f622f1480d4afbb230739221852ab8432d65d 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -25,7 +25,7 @@ from mindspore.parallel._utils import _get_parallel_mode from ..cell import Cell from ..._checkparam import Validator as validator, Rel -__all__ = ['Embedding', 'EmbeddingLookup', 'EmbeddingLookUpSplitMode'] +__all__ = ['Embedding', 'EmbeddingLookup'] class Embedding(Cell): r""" @@ -131,7 +131,7 @@ class EmbeddingLookup(Cell): target (str): Specify the target where the op is executed. The value should in ['DEVICE', 'CPU']. Default: 'CPU'. slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value should get through - nn.EmbeddingLookUpSplitMode. Default: nn.EmbeddingLookUpSplitMode.BATCH_SLICE. + nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE. manual_shapes (tuple): The accompaniment array in field slice mode. Inputs: @@ -147,6 +147,11 @@ class EmbeddingLookup(Cell): >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32) >>> out = nn.EmbeddingLookup(4,2)(input_indices) """ + BATCH_SLICE = "batch_slice" + FIELD_SLICE = "field_slice" + TABLE_ROW_SLICE = "table_row_slice" + TABLE_COLUMN_SLICE = "table_column_slice" + def __init__(self, vocab_size, embedding_size, param_init='normal', target='CPU', slice_mode='batch_slice', manual_shapes=None): super(EmbeddingLookup, self).__init__() @@ -160,7 +165,7 @@ class EmbeddingLookup(Cell): name='embedding_table') parallel_mode = _get_parallel_mode() is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) - if slice_mode == EmbeddingLookUpSplitMode.FIELD_SLICE and is_auto_parallel: + if slice_mode == "field_slice" and is_auto_parallel: if not manual_shapes: raise ValueError("in slice field mode, the manual_shapes should not be none") if not isinstance(manual_shapes, tuple): @@ -171,18 +176,18 @@ class EmbeddingLookup(Cell): self.embeddinglookup.add_prim_attr("manual_split", manual_shapes) self.gatherv2.set_strategy(((get_group_size(), 1), (1, get_group_size()))) self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, get_group_size()))) - elif slice_mode == EmbeddingLookUpSplitMode.TABLE_ROW_SLICE and is_auto_parallel: + elif slice_mode == "table_row_slice" and is_auto_parallel: self.gatherv2.set_strategy(((get_group_size(), 1), (1, 1))) self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1))) - elif slice_mode == EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE and is_auto_parallel: + elif slice_mode == "table_column_slice" and is_auto_parallel: self.gatherv2.set_strategy(((1, get_group_size()), (1, 1))) self.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1))) - elif slice_mode == EmbeddingLookUpSplitMode.BATCH_SLICE and is_auto_parallel: + elif slice_mode == "batch_slice" and is_auto_parallel: self.gatherv2.set_strategy(((1, 1), (get_group_size(), 1))) self.embeddinglookup.set_strategy(((1, 1), (get_group_size(), 1))) else: if is_auto_parallel: - raise ValueError("slice_mode should support mode in nn.EmbeddingLookUpSplitMode, but get " + raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get " + str(slice_mode)) def construct(self, indices): @@ -191,25 +196,3 @@ class EmbeddingLookup(Cell): else: out = self.gatherv2(self.embedding_table, indices, 0) return out - - -class EmbeddingLookUpSplitMode: - """ - EmbeddingLookUp slice options in auto parallel and semi auto parallel mode. - - There are five kinds of slice options, "BATCH_SLICE", "FIELD_SLICE", - "TABLE_ROW_SLICE" and "TABLE_COLUMN_SLICE". Default: "BATCH_SLICE". - - - BATCH_SLICE: Slicing batch dimensions of indices. - - FIELD_SLICE: Slicing field dimensions of indices. - - TABLE_ROW_SLICE: Slicing row of table. - - TABLE_COLUMN_SLICE: Slicing column of table. - - MODE_LIST: The list for all supported parallel modes. - """ - - BATCH_SLICE = "batch_slice" - FIELD_SLICE = "field_slice" - TABLE_ROW_SLICE = "table_row_slice" - TABLE_COLUMN_SLICE = "table_column_slice" - MODE_LIST = [BATCH_SLICE, FIELD_SLICE, TABLE_ROW_SLICE, TABLE_COLUMN_SLICE] diff --git a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py index e579b26e1c8e105d919c0c02cdb7f62c5f53cf6d..68cdcb869548c73c1f29c3a48457d694b4cfde96 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py +++ b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py @@ -202,9 +202,9 @@ class WideDeepModel(nn.Cell): self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),)) self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1))) self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, - slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE) + slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE) self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, - slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_ROW_SLICE) + slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE) self.deep_mul.set_strategy(((1, 1, get_group_size()), (1, 1, 1))) self.deep_reshape.add_prim_attr("skip_redistribution", True) self.reduce_sum.add_prim_attr("cross_batch", True) @@ -212,10 +212,10 @@ class WideDeepModel(nn.Cell): elif is_auto_parallel and host_device_mix and is_field_slice and config.full_batch and config.manual_shape: manual_shapes = tuple((s[0] for s in config.manual_shape)) self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, - slice_mode=nn.EmbeddingLookUpSplitMode.FIELD_SLICE, + slice_mode=nn.EmbeddingLookup.FIELD_SLICE, manual_shapes=manual_shapes) self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, - slice_mode=nn.EmbeddingLookUpSplitMode.FIELD_SLICE, + slice_mode=nn.EmbeddingLookup.FIELD_SLICE, manual_shapes=manual_shapes) self.deep_mul.set_strategy(((1, get_group_size(), 1), (1, get_group_size(), 1))) self.wide_mul.set_strategy(((1, get_group_size(), 1), (1, get_group_size(), 1)))