Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0251de84
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0251de84
编写于
9月 09, 2020
作者:
Y
yao_yf
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
EmbeddingLookupSplitMode modify
上级
174de814
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
16 addition
and
33 deletion
+16
-33
mindspore/nn/layer/embedding.py
mindspore/nn/layer/embedding.py
+12
-29
model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py
...zoo/official/recommend/wide_and_deep/src/wide_and_deep.py
+4
-4
未找到文件。
mindspore/nn/layer/embedding.py
浏览文件 @
0251de84
...
...
@@ -25,7 +25,7 @@ from mindspore.parallel._utils import _get_parallel_mode
from
..cell
import
Cell
from
..._checkparam
import
Validator
as
validator
,
Rel
__all__
=
[
'Embedding'
,
'EmbeddingLookup'
,
'EmbeddingLookUpSplitMode'
]
__all__
=
[
'Embedding'
,
'EmbeddingLookup'
]
class
Embedding
(
Cell
):
r
"""
...
...
@@ -131,7 +131,7 @@ class EmbeddingLookup(Cell):
target (str): Specify the target where the op is executed. The value should in
['DEVICE', 'CPU']. Default: 'CPU'.
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value should get through
nn.EmbeddingLook
UpSplitMode. Default: nn.EmbeddingLookUpSplitMode
.BATCH_SLICE.
nn.EmbeddingLook
up. Default: nn.EmbeddingLookup
.BATCH_SLICE.
manual_shapes (tuple): The accompaniment array in field slice mode.
Inputs:
...
...
@@ -147,6 +147,11 @@ class EmbeddingLookup(Cell):
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
>>> out = nn.EmbeddingLookup(4,2)(input_indices)
"""
BATCH_SLICE
=
"batch_slice"
FIELD_SLICE
=
"field_slice"
TABLE_ROW_SLICE
=
"table_row_slice"
TABLE_COLUMN_SLICE
=
"table_column_slice"
def
__init__
(
self
,
vocab_size
,
embedding_size
,
param_init
=
'normal'
,
target
=
'CPU'
,
slice_mode
=
'batch_slice'
,
manual_shapes
=
None
):
super
(
EmbeddingLookup
,
self
).
__init__
()
...
...
@@ -160,7 +165,7 @@ class EmbeddingLookup(Cell):
name
=
'embedding_table'
)
parallel_mode
=
_get_parallel_mode
()
is_auto_parallel
=
parallel_mode
in
(
ParallelMode
.
SEMI_AUTO_PARALLEL
,
ParallelMode
.
AUTO_PARALLEL
)
if
slice_mode
==
EmbeddingLookUpSplitMode
.
FIELD_SLICE
and
is_auto_parallel
:
if
slice_mode
==
"field_slice"
and
is_auto_parallel
:
if
not
manual_shapes
:
raise
ValueError
(
"in slice field mode, the manual_shapes should not be none"
)
if
not
isinstance
(
manual_shapes
,
tuple
):
...
...
@@ -171,18 +176,18 @@ class EmbeddingLookup(Cell):
self
.
embeddinglookup
.
add_prim_attr
(
"manual_split"
,
manual_shapes
)
self
.
gatherv2
.
set_strategy
(((
get_group_size
(),
1
),
(
1
,
get_group_size
())))
self
.
embeddinglookup
.
set_strategy
(((
get_group_size
(),
1
),
(
1
,
get_group_size
())))
elif
slice_mode
==
EmbeddingLookUpSplitMode
.
TABLE_ROW_SLICE
and
is_auto_parallel
:
elif
slice_mode
==
"table_row_slice"
and
is_auto_parallel
:
self
.
gatherv2
.
set_strategy
(((
get_group_size
(),
1
),
(
1
,
1
)))
self
.
embeddinglookup
.
set_strategy
(((
get_group_size
(),
1
),
(
1
,
1
)))
elif
slice_mode
==
EmbeddingLookUpSplitMode
.
TABLE_COLUMN_SLICE
and
is_auto_parallel
:
elif
slice_mode
==
"table_column_slice"
and
is_auto_parallel
:
self
.
gatherv2
.
set_strategy
(((
1
,
get_group_size
()),
(
1
,
1
)))
self
.
embeddinglookup
.
set_strategy
(((
1
,
get_group_size
()),
(
1
,
1
)))
elif
slice_mode
==
EmbeddingLookUpSplitMode
.
BATCH_SLICE
and
is_auto_parallel
:
elif
slice_mode
==
"batch_slice"
and
is_auto_parallel
:
self
.
gatherv2
.
set_strategy
(((
1
,
1
),
(
get_group_size
(),
1
)))
self
.
embeddinglookup
.
set_strategy
(((
1
,
1
),
(
get_group_size
(),
1
)))
else
:
if
is_auto_parallel
:
raise
ValueError
(
"slice_mode should support mode in nn.EmbeddingLook
UpSplitMode
, but get "
raise
ValueError
(
"slice_mode should support mode in nn.EmbeddingLook
up
, but get "
+
str
(
slice_mode
))
def
construct
(
self
,
indices
):
...
...
@@ -191,25 +196,3 @@ class EmbeddingLookup(Cell):
else
:
out
=
self
.
gatherv2
(
self
.
embedding_table
,
indices
,
0
)
return
out
class
EmbeddingLookUpSplitMode
:
"""
EmbeddingLookUp slice options in auto parallel and semi auto parallel mode.
There are five kinds of slice options, "BATCH_SLICE", "FIELD_SLICE",
"TABLE_ROW_SLICE" and "TABLE_COLUMN_SLICE". Default: "BATCH_SLICE".
- BATCH_SLICE: Slicing batch dimensions of indices.
- FIELD_SLICE: Slicing field dimensions of indices.
- TABLE_ROW_SLICE: Slicing row of table.
- TABLE_COLUMN_SLICE: Slicing column of table.
MODE_LIST: The list for all supported parallel modes.
"""
BATCH_SLICE
=
"batch_slice"
FIELD_SLICE
=
"field_slice"
TABLE_ROW_SLICE
=
"table_row_slice"
TABLE_COLUMN_SLICE
=
"table_column_slice"
MODE_LIST
=
[
BATCH_SLICE
,
FIELD_SLICE
,
TABLE_ROW_SLICE
,
TABLE_COLUMN_SLICE
]
model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py
浏览文件 @
0251de84
...
...
@@ -202,9 +202,9 @@ class WideDeepModel(nn.Cell):
self
.
dense_layer_1
.
dropout
.
dropout
.
set_strategy
(((
1
,
get_group_size
()),))
self
.
dense_layer_1
.
matmul
.
set_strategy
(((
1
,
get_group_size
()),
(
get_group_size
(),
1
)))
self
.
deep_embeddinglookup
=
nn
.
EmbeddingLookup
(
self
.
vocab_size
,
self
.
emb_dim
,
slice_mode
=
nn
.
EmbeddingLook
UpSplitMode
.
TABLE_COLUMN_SLICE
)
slice_mode
=
nn
.
EmbeddingLook
up
.
TABLE_COLUMN_SLICE
)
self
.
wide_embeddinglookup
=
nn
.
EmbeddingLookup
(
self
.
vocab_size
,
1
,
slice_mode
=
nn
.
EmbeddingLook
UpSplitMode
.
TABLE_ROW_SLICE
)
slice_mode
=
nn
.
EmbeddingLook
up
.
TABLE_ROW_SLICE
)
self
.
deep_mul
.
set_strategy
(((
1
,
1
,
get_group_size
()),
(
1
,
1
,
1
)))
self
.
deep_reshape
.
add_prim_attr
(
"skip_redistribution"
,
True
)
self
.
reduce_sum
.
add_prim_attr
(
"cross_batch"
,
True
)
...
...
@@ -212,10 +212,10 @@ class WideDeepModel(nn.Cell):
elif
is_auto_parallel
and
host_device_mix
and
is_field_slice
and
config
.
full_batch
and
config
.
manual_shape
:
manual_shapes
=
tuple
((
s
[
0
]
for
s
in
config
.
manual_shape
))
self
.
deep_embeddinglookup
=
nn
.
EmbeddingLookup
(
self
.
vocab_size
,
self
.
emb_dim
,
slice_mode
=
nn
.
EmbeddingLook
UpSplitMode
.
FIELD_SLICE
,
slice_mode
=
nn
.
EmbeddingLook
up
.
FIELD_SLICE
,
manual_shapes
=
manual_shapes
)
self
.
wide_embeddinglookup
=
nn
.
EmbeddingLookup
(
self
.
vocab_size
,
1
,
slice_mode
=
nn
.
EmbeddingLook
UpSplitMode
.
FIELD_SLICE
,
slice_mode
=
nn
.
EmbeddingLook
up
.
FIELD_SLICE
,
manual_shapes
=
manual_shapes
)
self
.
deep_mul
.
set_strategy
(((
1
,
get_group_size
(),
1
),
(
1
,
get_group_size
(),
1
)))
self
.
wide_mul
.
set_strategy
(((
1
,
get_group_size
(),
1
),
(
1
,
get_group_size
(),
1
)))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录