Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
20d2012a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
20d2012a
编写于
6月 03, 2020
作者:
X
Xiaoda Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
implementing the backward of embeddinglookup
上级
e33ecf31
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
43 addition
and
28 deletion
+43
-28
mindspore/ops/_grad/grad_array_ops.py
mindspore/ops/_grad/grad_array_ops.py
+25
-0
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+8
-15
tests/ut/python/parallel/test_embeddinglookup.py
tests/ut/python/parallel/test_embeddinglookup.py
+4
-7
tests/ut/python/parallel/test_gather_v2.py
tests/ut/python/parallel/test_gather_v2.py
+3
-3
tests/ut/python/parallel/test_sparse_gather_v2.py
tests/ut/python/parallel/test_sparse_gather_v2.py
+3
-3
未找到文件。
mindspore/ops/_grad/grad_array_ops.py
浏览文件 @
20d2012a
...
...
@@ -190,6 +190,31 @@ def get_bprop_tile(self):
return
bprop
@
bprop_getters
.
register
(
P
.
EmbeddingLookup
)
def
get_bprop_embedding_lookup
(
self
):
"""Generate bprop for EmbeddingLookup"""
host_sub
=
P
.
Sub
().
add_prim_attr
(
'primitive_target'
,
'CPU'
)
host_reshape
=
P
.
Reshape
().
add_prim_attr
(
'primitive_target'
,
'CPU'
)
def
bprop_sparse
(
x
,
indices
,
offset
,
reduce_scatter_flag
,
split_num
,
out
,
dout
):
x_shp
=
shape_op
(
x
)
if
reduce_scatter_flag
is
True
:
elu_grad
=
G
.
EmbeddingLookupCommGrad
()
actual_dout
=
elu_grad
(
dout
,
split_num
)
else
:
actual_dout
=
dout
new_indices
=
host_sub
(
indices
-
offset
)
# Reshape the 'new_indices'
new_indices_shape_changed
=
(
size_op
(
new_indices
),)
new_indices
=
host_reshape
(
new_indices
,
new_indices_shape_changed
)
# Reshape the 'actual_dout'
x_shp_tail
=
x_shp
[
1
:]
actual_dout_shape_changed
=
new_indices_shape_changed
+
x_shp_tail
actual_dout
=
host_reshape
(
actual_dout
,
actual_dout_shape_changed
)
return
(
new_indices
,
actual_dout
,
x_shp
),
zeros_like
(
new_indices
),
zeros_like
(
axis
),
\
zeros_like
(
reduce_scatter_flag
),
zeros_like
(
split_num
)
return
bprop_sparse
@
bprop_getters
.
register
(
P
.
Transpose
)
def
get_bprop_transpose
(
self
):
"""Generate bprop for Transpose"""
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
20d2012a
...
...
@@ -616,9 +616,10 @@ class Range(PrimitiveWithInfer):
class
EmbeddingLookup
(
PrimitiveWithInfer
):
"""
Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar
functionality as GatherV2, but has three more inputs: `offset`, `reduce_scatter_flag` and `split_num`.
This primitive runs on the host instead of devices.
Returns a slice of input tensor based on the specified indices.
This Primitive has the similar functionality as GatherV2 operating on `axis = 0`, but has three more inputs:
`offset`, `reduce_scatter_flag` and `split_num`. This primitive runs on the host instead of devices.
Inputs:
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
...
...
@@ -626,7 +627,6 @@ class EmbeddingLookup(PrimitiveWithInfer):
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
and the exceeding part will be filled with 0 in the output.
- **axis** (int) - Specifies the dimension index to gather indices.
- **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices
are equal to `input_indices` minus `offset`.
- **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not.
...
...
@@ -641,36 +641,29 @@ class EmbeddingLookup(PrimitiveWithInfer):
Examples:
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
>>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
>>> axis = 0
>>> offset = 4
>>> reduce_scatter_flag = False
>>> split_num = 1
>>> out = P.EmbeddingLookup()(input_params, input_indices,
axis,
offset, reduce_scatter_flag, split_num)
>>> out = P.EmbeddingLookup()(input_params, input_indices, offset, reduce_scatter_flag, split_num)
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
"""
@
prim_attr_register
def
__init__
(
self
):
"""init index_select"""
self
.
__setattr_flag__
=
True
self
.
init_prim_io_names
(
inputs
=
[
'params'
,
'indices'
,
'
axis'
,
'
offset'
,
'reduce_scatter_flag'
,
'split_num'
],
self
.
init_prim_io_names
(
inputs
=
[
'params'
,
'indices'
,
'offset'
,
'reduce_scatter_flag'
,
'split_num'
],
outputs
=
[
'output'
])
self
.
add_prim_attr
(
'primitive_target'
,
'CPU'
)
def
__infer__
(
self
,
params
,
indices
,
axis
,
offset
,
reduce_scatter_flag
=
False
,
split_num
=
2
):
def
__infer__
(
self
,
params
,
indices
,
offset
,
reduce_scatter_flag
=
False
,
split_num
=
2
):
validator
.
check_subclass
(
"params"
,
params
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
validator
.
check_tensor_type_same
({
"indices"
:
indices
[
'dtype'
]},
mstype
.
int_type
,
self
.
name
)
validator
.
check_subclass
(
"axis"
,
axis
[
'dtype'
],
mstype
.
int_
,
self
.
name
)
validator
.
check_subclass
(
"offset"
,
offset
[
'dtype'
],
mstype
.
int_
,
self
.
name
)
validator
.
check_subclass
(
"split_num"
,
split_num
[
'dtype'
],
mstype
.
int_
,
self
.
name
)
if
split_num
[
'value'
]
<
1
:
raise
ValueError
(
"The parameter 'split_num' must be positive, but got %d."
%
split_num
)
axis_v
=
axis
[
'value'
]
params_shp
=
params
[
'shape'
]
rank
=
len
(
params_shp
)
validator
.
check_int_range
(
"axis"
,
axis_v
,
-
rank
,
rank
,
Rel
.
INC_LEFT
,
self
.
name
)
if
axis_v
<
0
:
axis_v
+=
rank
out_shape
=
params_shp
[:
axis_v
]
+
indices
[
'shape'
]
+
params_shp
[
axis_v
+
1
:]
out_shape
=
indices
[
'shape'
]
+
params_shp
[
1
:]
if
reduce_scatter_flag
is
None
:
raise
ValueError
(
"The value of 'reduce_scatter_flag' is None."
)
reduce_scatter_flag_value
=
reduce_scatter_flag
[
'value'
]
...
...
tests/ut/python/parallel/test_embeddinglookup.py
浏览文件 @
20d2012a
...
...
@@ -33,10 +33,9 @@ class NetWithLoss(nn.Cell):
return
self
.
loss
(
predict
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
shape
,
axis
,
offset
,
reduce_scatter_flag
,
split_num
):
def
__init__
(
self
,
shape
,
offset
,
reduce_scatter_flag
,
split_num
):
super
().
__init__
()
self
.
index
=
Tensor
(
np
.
ones
(
shape
),
dtype
=
ms
.
int32
)
self
.
axis
=
axis
self
.
offset
=
offset
self
.
reduce_scatter_flag
=
reduce_scatter_flag
self
.
split_num
=
split_num
...
...
@@ -44,18 +43,17 @@ class Net(nn.Cell):
self
.
mm
=
P
.
BatchMatMul
()
def
construct
(
self
,
x
,
y
):
out
=
self
.
elu
(
x
,
self
.
index
,
self
.
axis
,
self
.
offset
,
self
.
reduce_scatter_flag
,
self
.
split_num
)
out
=
self
.
elu
(
x
,
self
.
index
,
self
.
offset
,
self
.
reduce_scatter_flag
,
self
.
split_num
)
out
=
self
.
mm
(
out
,
y
)
return
out
def
test_embeddinglookup_reducescatter_false
():
shape
=
[
8
,
8
]
axis
=
0
offset
=
8
reduce_scatter_flag
=
False
split_num
=
1
net
=
NetWithLoss
(
Net
(
shape
,
axis
,
offset
,
reduce_scatter_flag
,
split_num
))
net
=
NetWithLoss
(
Net
(
shape
,
offset
,
reduce_scatter_flag
,
split_num
))
net
.
set_auto_parallel
()
x
=
Tensor
(
np
.
ones
([
64
,
32
]),
dtype
=
ms
.
float32
)
...
...
@@ -65,11 +63,10 @@ def test_embeddinglookup_reducescatter_false():
def
test_embeddinglookup_reducescatter_true
():
shape
=
[
64
,
8
]
axis
=
0
offset
=
8
reduce_scatter_flag
=
True
split_num
=
8
net
=
NetWithLoss
(
Net
(
shape
,
axis
,
offset
,
reduce_scatter_flag
,
split_num
))
net
=
NetWithLoss
(
Net
(
shape
,
offset
,
reduce_scatter_flag
,
split_num
))
net
.
set_auto_parallel
()
x
=
Tensor
(
np
.
ones
([
64
,
32
]),
dtype
=
ms
.
float32
)
...
...
tests/ut/python/parallel/test_gather_v2.py
浏览文件 @
20d2012a
...
...
@@ -184,7 +184,7 @@ def test_gatherv2_auto1():
_executor
.
compile
(
net
,
x
,
y
)
def
test_gatherv2_cpu0
():
def
need_fix_
test_gatherv2_cpu0
():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy1
=
((
8
,
1
),
(
1
,
1
))
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
...
...
@@ -196,7 +196,7 @@ def test_gatherv2_cpu0():
_executor
.
compile
(
net
,
x
,
y
)
def
test_gatherv2_cpu1
():
def
need_fix_
test_gatherv2_cpu1
():
context
.
set_auto_parallel_context
(
device_num
=
16
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy1
=
((
16
,
1
),
(
1
,
1
))
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
...
...
@@ -208,7 +208,7 @@ def test_gatherv2_cpu1():
_executor
.
compile
(
net
,
x
,
y
)
def
test_gatherv2_cpu2
():
def
need_fix_
test_gatherv2_cpu2
():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy1
=
((
1
,
8
),
(
1
,
1
))
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
...
...
tests/ut/python/parallel/test_sparse_gather_v2.py
浏览文件 @
20d2012a
...
...
@@ -184,7 +184,7 @@ def test_gatherv2_auto1():
_executor
.
compile
(
net
,
x
,
y
)
def
test_gatherv2_cpu0
():
def
need_fix_
test_gatherv2_cpu0
():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy1
=
((
8
,
1
),
(
1
,
1
))
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
...
...
@@ -196,7 +196,7 @@ def test_gatherv2_cpu0():
_executor
.
compile
(
net
,
x
,
y
)
def
test_gatherv2_cpu1
():
def
need_fix_
test_gatherv2_cpu1
():
context
.
set_auto_parallel_context
(
device_num
=
16
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy1
=
((
16
,
1
),
(
1
,
1
))
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
...
...
@@ -208,7 +208,7 @@ def test_gatherv2_cpu1():
_executor
.
compile
(
net
,
x
,
y
)
def
test_gatherv2_cpu2
():
def
need_fix_
test_gatherv2_cpu2
():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy1
=
((
1
,
8
),
(
1
,
1
))
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录