Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b9ba99bb
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b9ba99bb
编写于
6月 01, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 01, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1685 [Auto parallel] Fix the bugs in Embeddinglookup forward operator
Merge pull request !1685 from Xiaoda/fix-the-embeddinglookup-bug
上级
b250b087
55ef468a
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
21 addition
and
11 deletion
+21
-11
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+19
-9
tests/ut/python/parallel/test_embeddinglookup.py
tests/ut/python/parallel/test_embeddinglookup.py
+2
-2
未找到文件。
mindspore/ops/operations/array_ops.py
浏览文件 @
b9ba99bb
...
...
@@ -576,19 +576,21 @@ class EmbeddingLookup(PrimitiveWithInfer):
"""
Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar
functionality as GatherV2, but has three more inputs: `offset`, `reduce_scatter_flag` and `split_num`.
This primitive runs on the host instead of devices.
Inputs:
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
The Tensor slice, instead of the entire Tensor.
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor.
Must be in the range
`[0, input_param.shape()[axis])`
.
Specifies the indices of elements of the original Tensor.
Values can be out of range of `input_params`,
and the exceeding part will be filled with 0 in the output
.
- **axis** (int) - Specifies the dimension index to gather indices.
- **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices
are equal to `input_indices` minus `offset`.
- **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not.
Only constant value is allowed.
- **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable
is used only if `reduce_scatter_flag` is True.
is used only if `reduce_scatter_flag` is True.
Only constant value is allowed.
Outputs:
...
...
@@ -627,12 +629,20 @@ class EmbeddingLookup(PrimitiveWithInfer):
if
axis_v
<
0
:
axis_v
+=
rank
out_shape
=
params_shp
[:
axis_v
]
+
indices
[
'shape'
]
+
params_shp
[
axis_v
+
1
:]
if
reduce_scatter_flag
:
# partition the tensor along the dimension 0.
if
out_shape
[
0
]
%
split_num
[
'value'
]
!=
0
:
raise
ValueError
(
"The dimension 0 of the shape: %d, is not divisible by split_num: %d."
%
(
out_shape
[
0
],
split_num
[
'value'
]))
out_shape
[
0
]
=
out_shape
[
0
]
//
split_num
[
'value'
]
if
reduce_scatter_flag
is
None
:
raise
ValueError
(
"The value of 'reduce_scatter_flag' is None."
)
reduce_scatter_flag_value
=
reduce_scatter_flag
[
'value'
]
if
split_num
is
None
:
raise
ValueError
(
"The value of 'split_num_value' is None."
)
split_num_value
=
split_num
[
'value'
]
if
reduce_scatter_flag_value
is
True
:
# Partition the tensor along the dimension 0. The shape size of dimension 0 should be divisible by
# (split_num * 8)
if
out_shape
[
0
]
%
(
split_num_value
*
8
)
!=
0
:
raise
ValueError
(
"The dimension 0 of the shape: %d, is not divisible by: %d."
%
(
out_shape
[
0
],
(
split_num_value
*
8
)))
# After 'Concat' on host, the shape size of dimension 0 is: out_shape[0] // 8
out_shape
[
0
]
=
out_shape
[
0
]
//
8
out
=
{
'shape'
:
out_shape
,
'dtype'
:
params
[
'dtype'
],
'value'
:
None
}
...
...
tests/ut/python/parallel/test_embeddinglookup.py
浏览文件 @
b9ba99bb
...
...
@@ -64,7 +64,7 @@ def test_embeddinglookup_reducescatter_false():
def
test_embeddinglookup_reducescatter_true
():
shape
=
[
8
,
8
]
shape
=
[
64
,
8
]
axis
=
0
offset
=
8
reduce_scatter_flag
=
True
...
...
@@ -73,5 +73,5 @@ def test_embeddinglookup_reducescatter_true():
net
.
set_auto_parallel
()
x
=
Tensor
(
np
.
ones
([
64
,
32
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
1
,
32
,
8
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
8
,
32
,
8
]),
dtype
=
ms
.
float32
)
_executor
.
compile
(
net
,
x
,
y
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录