Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e4de26d5
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e4de26d5
编写于
8月 04, 2020
作者:
Y
yao_yf
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
embeddinglookup wrap
上级
5adba834
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
98 addition
and
42 deletion
+98
-42
mindspore/nn/layer/embedding.py
mindspore/nn/layer/embedding.py
+75
-14
model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py
...zoo/official/recommend/wide_and_deep/src/wide_and_deep.py
+13
-10
tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py
...s/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py
+4
-8
tests/ut/python/ir/test_row_tensor.py
tests/ut/python/ir/test_row_tensor.py
+4
-6
tests/ut/python/parallel/test_sparse_feature_bprop.py
tests/ut/python/parallel/test_sparse_feature_bprop.py
+2
-4
未找到文件。
mindspore/nn/layer/embedding.py
浏览文件 @
e4de26d5
...
...
@@ -18,10 +18,14 @@ from mindspore.common.tensor import Tensor
from
mindspore.ops
import
operations
as
P
from
mindspore.common.parameter
import
Parameter
from
mindspore.common.initializer
import
initializer
from
mindspore._checkparam
import
Validator
from
mindspore.communication.management
import
get_group_size
from
mindspore.train.parallel_utils
import
ParallelMode
from
mindspore.parallel._utils
import
_get_parallel_mode
from
..cell
import
Cell
from
..._checkparam
import
Validator
as
validator
from
..._checkparam
import
Validator
as
validator
,
Rel
__all__
=
[
'Embedding'
,
'EmbeddingLookup'
]
__all__
=
[
'Embedding'
,
'EmbeddingLookup'
,
'EmbeddingLookUpSplitMode'
]
class
Embedding
(
Cell
):
r
"""
...
...
@@ -114,29 +118,36 @@ class EmbeddingLookup(Cell):
When 'target' is set to 'CPU', this module will use
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
specified 'offset = 0' to lookup table.
w
hen 'target' is set to 'DEVICE', this module will use P.GatherV2() which
W
hen 'target' is set to 'DEVICE', this module will use P.GatherV2() which
specified 'axis = 0' to lookup table.
In field slice mode, the manual_shapes should be given. It is a tuple ,where
the element is (vocab[i], offset[i]), vocab[i] is the row numbers for i-th
part and offset[i] is the feature id offset for i-th part. The feature id in
i-th part will be subtracted by offset[i] to ensure the id start from 0.
Args:
vocab_size (int): Size of the dictionary of embeddings.
embedding_size (int): The size of each embedding vector.
param_init (str): The initialize way of embedding table. Default: 'normal'.
target (str): Specify the target where the op is executed. Default: 'CPU'.
slice_mode (str): The slicing way in semi auto parallel/auto parallel. Default: 'batch_slice'.
manual_shapes (tuple): The accompaniment array in field slice mode.
Inputs:
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
The Tensor slice, instead of the entire Tensor.
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
and the exceeding part will be filled with 0 in the output.
Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table,
and the exceeding part will be filled with 0 in the output. Input_indices should only be a 2d tensor in
this interface.
Outputs:
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
Examples:
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
>>> out = nn.EmbeddingLookup()(input_params, input_indices)
[[[10, 11], [8 ,9]], [[14, 15], [12, 13]]]
>>> out = nn.EmbeddingLookup(4,2)(input_indices)
"""
def
__init__
(
self
,
target
=
'CPU'
):
def
__init__
(
self
,
vocab_size
,
embedding_size
,
param_init
=
'normal'
,
target
=
'CPU'
,
slice_mode
=
'batch_slice'
,
manual_shapes
=
None
):
super
(
EmbeddingLookup
,
self
).
__init__
()
self
.
target
=
target
if
target
not
in
(
'CPU'
,
'DEVICE'
):
...
...
@@ -144,10 +155,60 @@ class EmbeddingLookup(Cell):
+
str
(
target
)
+
', should be one of values in
\'
CPU
\'
,
\'
DEVICE
\'
.'
)
self
.
gatherv2
=
P
.
GatherV2
()
self
.
embeddinglookup
=
P
.
EmbeddingLookup
().
add_prim_attr
(
'primitive_target'
,
'CPU'
)
self
.
embedding_table
=
Parameter
(
initializer
(
param_init
,
[
vocab_size
,
embedding_size
]),
name
=
'embedding_table'
)
parallel_mode
=
_get_parallel_mode
()
is_auto_parallel
=
parallel_mode
in
(
ParallelMode
.
SEMI_AUTO_PARALLEL
,
ParallelMode
.
AUTO_PARALLEL
)
if
slice_mode
==
EmbeddingLookUpSplitMode
.
FIELD_SLICE
and
is_auto_parallel
:
if
not
manual_shapes
:
raise
ValueError
(
"in slice field mode, the manual_shapes should not be none"
)
if
not
isinstance
(
manual_shapes
,
tuple
):
raise
TypeError
(
"manual_shapes type must be tuple(int) cannot be {}!"
.
format
(
type
(
manual_shapes
)))
for
dim
in
manual_shapes
:
Validator
.
check_integer
(
'manul shape dim'
,
dim
,
0
,
Rel
.
GT
,
self
.
cls_name
)
self
.
gatherv2
.
add_prim_attr
(
"manual_split"
,
manual_shapes
)
self
.
embeddinglookup
.
add_prim_attr
(
"manual_split"
,
manual_shapes
)
self
.
gatherv2
.
set_strategy
(((
get_group_size
(),
1
),
(
1
,
get_group_size
())))
self
.
embeddinglookup
.
set_strategy
(((
get_group_size
(),
1
),
(
1
,
get_group_size
())))
elif
slice_mode
==
EmbeddingLookUpSplitMode
.
TABLE_ROW_SLICE
and
is_auto_parallel
:
self
.
gatherv2
.
set_strategy
(((
get_group_size
(),
1
),
(
1
,
1
)))
self
.
embeddinglookup
.
set_strategy
(((
get_group_size
(),
1
),
(
1
,
1
)))
elif
slice_mode
==
EmbeddingLookUpSplitMode
.
TABLE_COLUMN_SLICE
and
is_auto_parallel
:
self
.
gatherv2
.
set_strategy
(((
1
,
get_group_size
()),
(
1
,
1
)))
self
.
embeddinglookup
.
set_strategy
(((
1
,
get_group_size
()),
(
1
,
1
)))
elif
slice_mode
==
EmbeddingLookUpSplitMode
.
BATCH_SLICE
and
is_auto_parallel
:
self
.
gatherv2
.
set_strategy
(((
1
,
1
),
(
get_group_size
(),
1
)))
self
.
embeddinglookup
.
set_strategy
(((
1
,
1
),
(
get_group_size
(),
1
)))
else
:
if
is_auto_parallel
:
raise
ValueError
(
"slice_mode should support mode in nn.EmbeddingLookUpSplitMode, but get "
+
str
(
slice_mode
))
def
construct
(
self
,
params
,
indices
):
def
construct
(
self
,
indices
):
if
self
.
target
==
"CPU"
:
out
=
self
.
embeddinglookup
(
params
,
indices
,
0
)
out
=
self
.
embeddinglookup
(
self
.
embedding_table
,
indices
,
0
)
else
:
out
=
self
.
gatherv2
(
params
,
indices
,
0
)
out
=
self
.
gatherv2
(
self
.
embedding_table
,
indices
,
0
)
return
out
class
EmbeddingLookUpSplitMode
:
"""
EmbeddingLookUp slice options in auto parallel and semi auto parallel mode.
There are five kinds of slice options, "BATCH_SLICE", "FIELD_SLICE",
"TABLE_ROW_SLICE" and "TABLE_COLUMN_SLICE". Default: "BATCH_SLICE".
- BATCH_SLICE: Slicing batch dimensions of indices.
- FIELD_SLICE: Slicing field dimensions of indices.
- TABLE_ROW_SLICE: Slicing row of table.
- TABLE_COLUMN_SLICE: Slicing column of table.
MODE_LIST: The list for all supported parallel modes.
"""
BATCH_SLICE
=
"batch_slice"
FIELD_SLICE
=
"field_slice"
TABLE_ROW_SLICE
=
"table_row_slice"
TABLE_COLUMN_SLICE
=
"table_column_slice"
MODE_LIST
=
[
BATCH_SLICE
,
FIELD_SLICE
,
TABLE_ROW_SLICE
,
TABLE_COLUMN_SLICE
]
model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py
浏览文件 @
e4de26d5
...
...
@@ -209,19 +209,22 @@ class WideDeepModel(nn.Cell):
if
is_auto_parallel
and
host_device_mix
:
self
.
dense_layer_1
.
dropout
.
dropout_do_mask
.
set_strategy
(((
1
,
get_group_size
()),))
self
.
dense_layer_1
.
matmul
.
set_strategy
(((
1
,
get_group_size
()),
(
get_group_size
(),
1
)))
self
.
deep_embeddinglookup
=
nn
.
EmbeddingLookup
(
)
self
.
deep_embeddinglookup
.
embeddinglookup
.
set_strategy
(((
1
,
get_group_size
()),
(
1
,
1
))
)
self
.
wide_embeddinglookup
=
nn
.
EmbeddingLookup
(
)
self
.
wide_embeddinglookup
.
embeddinglookup
.
set_strategy
(((
get_group_size
(),
1
),
(
1
,
1
))
)
self
.
deep_embeddinglookup
=
nn
.
EmbeddingLookup
(
self
.
vocab_size
,
self
.
emb_dim
,
slice_mode
=
nn
.
EmbeddingLookUpSplitMode
.
TABLE_COLUMN_SLICE
)
self
.
wide_embeddinglookup
=
nn
.
EmbeddingLookup
(
self
.
vocab_size
,
1
,
slice_mode
=
nn
.
EmbeddingLookUpSplitMode
.
TABLE_ROW_SLICE
)
self
.
deep_mul
.
set_strategy
(((
1
,
1
,
get_group_size
()),
(
1
,
1
,
1
)))
self
.
deep_reshape
.
add_prim_attr
(
"skip_redistribution"
,
True
)
self
.
reduce_sum
.
add_prim_attr
(
"cross_batch"
,
True
)
self
.
embedding_table
=
self
.
deep_embeddinglookup
.
embedding_table
elif
parameter_server
:
self
.
deep_embeddinglookup
=
nn
.
EmbeddingLookup
()
self
.
wide_embeddinglookup
=
nn
.
EmbeddingLookup
()
self
.
deep_embeddinglookup
=
nn
.
EmbeddingLookup
(
self
.
vocab_size
,
self
.
emb_dim
)
self
.
wide_embeddinglookup
=
nn
.
EmbeddingLookup
(
self
.
vocab_size
,
1
)
self
.
embedding_table
=
self
.
deep_embeddinglookup
.
embedding_table
else
:
self
.
deep_embeddinglookup
=
nn
.
EmbeddingLookup
(
target
=
'DEVICE'
)
self
.
wide_embeddinglookup
=
nn
.
EmbeddingLookup
(
target
=
'DEVICE'
)
self
.
deep_embeddinglookup
=
nn
.
EmbeddingLookup
(
self
.
vocab_size
,
self
.
emb_dim
,
target
=
'DEVICE'
)
self
.
wide_embeddinglookup
=
nn
.
EmbeddingLookup
(
self
.
vocab_size
,
1
,
target
=
'DEVICE'
)
self
.
embedding_table
=
self
.
deep_embeddinglookup
.
embedding_table
def
construct
(
self
,
id_hldr
,
wt_hldr
):
"""
...
...
@@ -231,11 +234,11 @@ class WideDeepModel(nn.Cell):
"""
mask
=
self
.
reshape
(
wt_hldr
,
(
self
.
batch_size
,
self
.
field_size
,
1
))
# Wide layer
wide_id_weight
=
self
.
wide_embeddinglookup
(
self
.
wide_w
,
id_hldr
)
wide_id_weight
=
self
.
wide_embeddinglookup
(
id_hldr
)
wx
=
self
.
wide_mul
(
wide_id_weight
,
mask
)
wide_out
=
self
.
reshape
(
self
.
reduce_sum
(
wx
,
1
)
+
self
.
wide_b
,
(
-
1
,
1
))
# Deep layer
deep_id_embs
=
self
.
deep_embeddinglookup
(
self
.
embedding_table
,
id_hldr
)
deep_id_embs
=
self
.
deep_embeddinglookup
(
id_hldr
)
vx
=
self
.
deep_mul
(
deep_id_embs
,
mask
)
deep_in
=
self
.
deep_reshape
(
vx
,
(
-
1
,
self
.
field_size
*
self
.
emb_dim
))
deep_in
=
self
.
dense_layer_1
(
deep_in
)
...
...
tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py
浏览文件 @
e4de26d5
...
...
@@ -24,8 +24,7 @@ from mindspore.common import dtype as mstype
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.nn.optim
import
Adam
from
mindspore.ops
import
operations
as
P
from
mindspore.common.initializer
import
TruncatedNormal
,
initializer
from
mindspore
import
Parameter
from
mindspore.common.initializer
import
TruncatedNormal
parser
=
argparse
.
ArgumentParser
(
description
=
"test_sparse_embedding"
)
parser
.
add_argument
(
"--device_target"
,
type
=
str
,
default
=
"Ascend"
)
...
...
@@ -53,16 +52,13 @@ class LeNet5(nn.Cell):
super
(
LeNet5
,
self
).
__init__
()
self
.
cast
=
P
.
Cast
()
self
.
flatten
=
nn
.
Flatten
()
self
.
embedding_table
=
Parameter
(
initializer
(
"normal"
,
(
16
,
4
),
mstype
.
float32
),
name
=
"embedding_table"
)
self
.
embedding
=
nn
.
EmbeddingLookup
()
self
.
embedding
=
nn
.
EmbeddingLookup
(
16
,
4
)
self
.
relu
=
nn
.
ReLU
()
self
.
fc
=
fc_with_initialize
(
12
,
num_class
)
def
construct
(
self
,
x
):
x
=
self
.
cast
(
x
,
mstype
.
int32
)
x
=
self
.
embedding
(
self
.
embedding_table
,
x
)
x
=
self
.
embedding
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc
(
x
)
return
x
...
...
@@ -72,7 +68,7 @@ def do_sparse_embedding(ps=False):
epoch
=
10
net
=
LeNet5
(
10
)
if
ps
:
net
.
embedding_table
.
set_param_ps
()
net
.
embedding
.
embedding
_table
.
set_param_ps
()
optimizer
=
Adam
(
filter
(
lambda
x
:
x
.
requires_grad
,
net
.
get_parameters
()))
optimizer
.
sparse_opt
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
...
...
tests/ut/python/ir/test_row_tensor.py
浏览文件 @
e4de26d5
...
...
@@ -421,17 +421,16 @@ def test_row_tensor_with_control_flow_if():
class
EmbeddingLookUpBnNet
(
nn
.
Cell
):
def
__init__
(
self
,
param_np
,
target
=
'CPU'
):
def
__init__
(
self
,
vocab_size
,
embedding_size
,
target
=
'CPU'
):
super
().
__init__
()
self
.
param
=
Parameter
(
Tensor
(
param_np
),
name
=
"w1"
)
self
.
embedding_lookup
=
nn
.
EmbeddingLookup
(
target
=
target
)
self
.
embedding_lookup
=
nn
.
EmbeddingLookup
(
vocab_size
,
embedding_size
,
param_init
=
'ones'
,
target
=
target
)
self
.
bn
=
nn
.
BatchNorm2d
(
num_features
=
3
)
self
.
mul
=
P
.
Mul
()
self
.
reshape
=
P
.
Reshape
()
self
.
relu
=
nn
.
PReLU
()
def
construct
(
self
,
indices
):
x
=
self
.
embedding_lookup
(
self
.
param
,
indices
)
x
=
self
.
embedding_lookup
(
indices
)
x
=
self
.
reshape
(
x
,
(
2
,
3
,
2
,
2
))
x
=
self
.
relu
(
x
)
x
=
self
.
bn
(
x
)
...
...
@@ -439,10 +438,9 @@ class EmbeddingLookUpBnNet(nn.Cell):
def
test_embedding_lookup_with_mix_precision
():
param_np
=
np
.
ones
([
8
,
8
]).
astype
(
np
.
float32
)
data
=
Tensor
(
np
.
array
([
0
,
1
,
2
]).
astype
(
np
.
int32
))
label
=
Tensor
(
np
.
random
.
randn
(
*
(
2
,
3
,
2
,
2
)).
astype
(
np
.
float32
))
net
=
EmbeddingLookUpBnNet
(
param_np
,
target
=
'CPU'
)
net
=
EmbeddingLookUpBnNet
(
8
,
8
,
target
=
'CPU'
)
criterion
=
nn
.
SoftmaxCrossEntropyWithLogits
(
reduction
=
'mean'
)
optimizer
=
nn
.
Adam
(
params
=
net
.
trainable_params
(),
learning_rate
=
0.1
)
...
...
tests/ut/python/parallel/test_sparse_feature_bprop.py
浏览文件 @
e4de26d5
...
...
@@ -69,14 +69,12 @@ def test_bprop_with_sparse_feature_mirror():
super
(
Net
,
self
).
__init__
()
if
shape
is
None
:
shape
=
[
8
,
8
]
weight
=
Tensor
(
np
.
ones
([
64
,
64
]),
dtype
=
ms
.
float32
)
self
.
weight
=
Parameter
(
weight
,
"w"
)
self
.
index
=
Tensor
(
np
.
ones
(
shape
),
dtype
=
ms
.
int32
)
self
.
embeddinglookup
=
nn
.
EmbeddingLookup
()
self
.
embeddinglookup
=
nn
.
EmbeddingLookup
(
64
,
64
,
param_init
=
'ones'
)
self
.
embeddinglookup
.
embeddinglookup
.
set_strategy
(((
1
,
1
),
(
8
,
1
)))
def
construct
(
self
,
x
,
b
):
out
=
self
.
embeddinglookup
(
self
.
weight
,
self
.
index
)
out
=
self
.
embeddinglookup
(
self
.
index
)
return
out
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录