提交 0251de84 编写于 作者: Y yao_yf

EmbeddingLookupSplitMode modify

上级 174de814
...@@ -25,7 +25,7 @@ from mindspore.parallel._utils import _get_parallel_mode ...@@ -25,7 +25,7 @@ from mindspore.parallel._utils import _get_parallel_mode
from ..cell import Cell from ..cell import Cell
from ..._checkparam import Validator as validator, Rel from ..._checkparam import Validator as validator, Rel
__all__ = ['Embedding', 'EmbeddingLookup', 'EmbeddingLookUpSplitMode'] __all__ = ['Embedding', 'EmbeddingLookup']
class Embedding(Cell): class Embedding(Cell):
r""" r"""
...@@ -131,7 +131,7 @@ class EmbeddingLookup(Cell): ...@@ -131,7 +131,7 @@ class EmbeddingLookup(Cell):
target (str): Specify the target where the op is executed. The value should in target (str): Specify the target where the op is executed. The value should in
['DEVICE', 'CPU']. Default: 'CPU'. ['DEVICE', 'CPU']. Default: 'CPU'.
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value should get through slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value should get through
nn.EmbeddingLookUpSplitMode. Default: nn.EmbeddingLookUpSplitMode.BATCH_SLICE. nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
manual_shapes (tuple): The accompaniment array in field slice mode. manual_shapes (tuple): The accompaniment array in field slice mode.
Inputs: Inputs:
...@@ -147,6 +147,11 @@ class EmbeddingLookup(Cell): ...@@ -147,6 +147,11 @@ class EmbeddingLookup(Cell):
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32) >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
>>> out = nn.EmbeddingLookup(4,2)(input_indices) >>> out = nn.EmbeddingLookup(4,2)(input_indices)
""" """
BATCH_SLICE = "batch_slice"
FIELD_SLICE = "field_slice"
TABLE_ROW_SLICE = "table_row_slice"
TABLE_COLUMN_SLICE = "table_column_slice"
def __init__(self, vocab_size, embedding_size, param_init='normal', def __init__(self, vocab_size, embedding_size, param_init='normal',
target='CPU', slice_mode='batch_slice', manual_shapes=None): target='CPU', slice_mode='batch_slice', manual_shapes=None):
super(EmbeddingLookup, self).__init__() super(EmbeddingLookup, self).__init__()
...@@ -160,7 +165,7 @@ class EmbeddingLookup(Cell): ...@@ -160,7 +165,7 @@ class EmbeddingLookup(Cell):
name='embedding_table') name='embedding_table')
parallel_mode = _get_parallel_mode() parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
if slice_mode == EmbeddingLookUpSplitMode.FIELD_SLICE and is_auto_parallel: if slice_mode == "field_slice" and is_auto_parallel:
if not manual_shapes: if not manual_shapes:
raise ValueError("in slice field mode, the manual_shapes should not be none") raise ValueError("in slice field mode, the manual_shapes should not be none")
if not isinstance(manual_shapes, tuple): if not isinstance(manual_shapes, tuple):
...@@ -171,18 +176,18 @@ class EmbeddingLookup(Cell): ...@@ -171,18 +176,18 @@ class EmbeddingLookup(Cell):
self.embeddinglookup.add_prim_attr("manual_split", manual_shapes) self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
self.gatherv2.set_strategy(((get_group_size(), 1), (1, get_group_size()))) self.gatherv2.set_strategy(((get_group_size(), 1), (1, get_group_size())))
self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, get_group_size()))) self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, get_group_size())))
elif slice_mode == EmbeddingLookUpSplitMode.TABLE_ROW_SLICE and is_auto_parallel: elif slice_mode == "table_row_slice" and is_auto_parallel:
self.gatherv2.set_strategy(((get_group_size(), 1), (1, 1))) self.gatherv2.set_strategy(((get_group_size(), 1), (1, 1)))
self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1))) self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1)))
elif slice_mode == EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE and is_auto_parallel: elif slice_mode == "table_column_slice" and is_auto_parallel:
self.gatherv2.set_strategy(((1, get_group_size()), (1, 1))) self.gatherv2.set_strategy(((1, get_group_size()), (1, 1)))
self.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1))) self.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1)))
elif slice_mode == EmbeddingLookUpSplitMode.BATCH_SLICE and is_auto_parallel: elif slice_mode == "batch_slice" and is_auto_parallel:
self.gatherv2.set_strategy(((1, 1), (get_group_size(), 1))) self.gatherv2.set_strategy(((1, 1), (get_group_size(), 1)))
self.embeddinglookup.set_strategy(((1, 1), (get_group_size(), 1))) self.embeddinglookup.set_strategy(((1, 1), (get_group_size(), 1)))
else: else:
if is_auto_parallel: if is_auto_parallel:
raise ValueError("slice_mode should support mode in nn.EmbeddingLookUpSplitMode, but get " raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get "
+ str(slice_mode)) + str(slice_mode))
def construct(self, indices): def construct(self, indices):
...@@ -191,25 +196,3 @@ class EmbeddingLookup(Cell): ...@@ -191,25 +196,3 @@ class EmbeddingLookup(Cell):
else: else:
out = self.gatherv2(self.embedding_table, indices, 0) out = self.gatherv2(self.embedding_table, indices, 0)
return out return out
class EmbeddingLookUpSplitMode:
"""
EmbeddingLookUp slice options in auto parallel and semi auto parallel mode.
There are five kinds of slice options, "BATCH_SLICE", "FIELD_SLICE",
"TABLE_ROW_SLICE" and "TABLE_COLUMN_SLICE". Default: "BATCH_SLICE".
- BATCH_SLICE: Slicing batch dimensions of indices.
- FIELD_SLICE: Slicing field dimensions of indices.
- TABLE_ROW_SLICE: Slicing row of table.
- TABLE_COLUMN_SLICE: Slicing column of table.
MODE_LIST: The list for all supported parallel modes.
"""
BATCH_SLICE = "batch_slice"
FIELD_SLICE = "field_slice"
TABLE_ROW_SLICE = "table_row_slice"
TABLE_COLUMN_SLICE = "table_column_slice"
MODE_LIST = [BATCH_SLICE, FIELD_SLICE, TABLE_ROW_SLICE, TABLE_COLUMN_SLICE]
...@@ -202,9 +202,9 @@ class WideDeepModel(nn.Cell): ...@@ -202,9 +202,9 @@ class WideDeepModel(nn.Cell):
self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),)) self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),))
self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1))) self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1)))
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim,
slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE) slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE)
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1,
slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_ROW_SLICE) slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE)
self.deep_mul.set_strategy(((1, 1, get_group_size()), (1, 1, 1))) self.deep_mul.set_strategy(((1, 1, get_group_size()), (1, 1, 1)))
self.deep_reshape.add_prim_attr("skip_redistribution", True) self.deep_reshape.add_prim_attr("skip_redistribution", True)
self.reduce_sum.add_prim_attr("cross_batch", True) self.reduce_sum.add_prim_attr("cross_batch", True)
...@@ -212,10 +212,10 @@ class WideDeepModel(nn.Cell): ...@@ -212,10 +212,10 @@ class WideDeepModel(nn.Cell):
elif is_auto_parallel and host_device_mix and is_field_slice and config.full_batch and config.manual_shape: elif is_auto_parallel and host_device_mix and is_field_slice and config.full_batch and config.manual_shape:
manual_shapes = tuple((s[0] for s in config.manual_shape)) manual_shapes = tuple((s[0] for s in config.manual_shape))
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim,
slice_mode=nn.EmbeddingLookUpSplitMode.FIELD_SLICE, slice_mode=nn.EmbeddingLookup.FIELD_SLICE,
manual_shapes=manual_shapes) manual_shapes=manual_shapes)
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1,
slice_mode=nn.EmbeddingLookUpSplitMode.FIELD_SLICE, slice_mode=nn.EmbeddingLookup.FIELD_SLICE,
manual_shapes=manual_shapes) manual_shapes=manual_shapes)
self.deep_mul.set_strategy(((1, get_group_size(), 1), (1, get_group_size(), 1))) self.deep_mul.set_strategy(((1, get_group_size(), 1), (1, get_group_size(), 1)))
self.wide_mul.set_strategy(((1, get_group_size(), 1), (1, get_group_size(), 1))) self.wide_mul.set_strategy(((1, get_group_size(), 1), (1, get_group_size(), 1)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册