Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
1cfb52bc
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1cfb52bc
编写于
6月 01, 2020
作者:
X
Xiaoda Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add the reshape part of the embeddinglookup backward operator
上级
5c4731b7
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
39 addition
and
7 deletion
+39
-7
mindspore/ccsrc/parallel/ops_info/ops_utils.h
mindspore/ccsrc/parallel/ops_info/ops_utils.h
+1
-1
mindspore/ops/operations/_grad_ops.py
mindspore/ops/operations/_grad_ops.py
+32
-0
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+1
-1
tests/ut/python/parallel/test_gather_v2.py
tests/ut/python/parallel/test_gather_v2.py
+5
-5
未找到文件。
mindspore/ccsrc/parallel/ops_info/ops_utils.h
浏览文件 @
1cfb52bc
...
...
@@ -76,7 +76,7 @@ constexpr char DEPEND[] = "depend";
constexpr
char
BATCH_PARALLEL
[]
=
"BatchParallel"
;
constexpr
char
ACTIVATION_TYPE
[]
=
"activation_type"
;
constexpr
char
TARGET
[]
=
"target"
;
constexpr
char
TARGET
[]
=
"
primitive_
target"
;
constexpr
char
CPU
[]
=
"CPU"
;
constexpr
char
TRANSPOSE_A
[]
=
"transpose_a"
;
constexpr
char
TRANSPOSE_B
[]
=
"transpose_b"
;
...
...
mindspore/ops/operations/_grad_ops.py
浏览文件 @
1cfb52bc
...
...
@@ -21,6 +21,7 @@ from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from
..._checkparam
import
Validator
as
validator
,
Rel
from
.._utils
import
get_concat_offset
from
...common
import
dtype
as
mstype
from
..
import
functional
as
F
class
AbsGrad
(
PrimitiveWithInfer
):
...
...
@@ -1121,6 +1122,37 @@ class MirrorPadGrad(PrimitiveWithInfer):
'value'
:
None
}
class
EmbeddingLookupCommGrad
(
PrimitiveWithInfer
):
"""
Perform the gradient for the communication part of EmbeddingLookup operator.
This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking,
this primitive is implemented by StridedSlice --> HostAllGather --> Concat. This primitive runs on host.
"""
@
prim_attr_register
def
__init__
(
self
):
self
.
init_prim_io_names
(
inputs
=
[
'dy'
,
'split_num'
],
outputs
=
[
'output'
])
self
.
add_prim_attr
(
'primitive_target'
,
'CPU'
)
def
__infer__
(
self
,
dy
,
split_num
):
"""
This primitive is implemented by three steps:
1) Split the 'dy' along dimension 0 into 'split_num' parts.
2) For each part, perform HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host.
3) After HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them
along dimension 0.
The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8
"""
dy_shape
=
tuple
(
dy
[
'shape'
])
split_num_value
=
split_num
[
'value'
]
validator
.
check_value_type
(
"split_num_value"
,
split_num_value
,
[
int
],
self
.
name
)
dy_shape_all
=
F
.
tuple_setitem
(
dy_shape
,
0
,
dy_shape
[
0
]
*
8
)
return
{
'shape'
:
dy_shape_all
,
'dtype'
:
dy
[
'dtype'
],
'value'
:
None
}
class
RefToEmbed
(
Primitive
):
r
"""
Make a key from Ref.
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
1cfb52bc
...
...
@@ -614,7 +614,7 @@ class EmbeddingLookup(PrimitiveWithInfer):
self
.
__setattr_flag__
=
True
self
.
init_prim_io_names
(
inputs
=
[
'params'
,
'indices'
,
'axis'
,
'offset'
,
'reduce_scatter_flag'
,
'split_num'
],
outputs
=
[
'output'
])
self
.
add_prim_attr
(
'target'
,
'CPU'
)
self
.
add_prim_attr
(
'
primitive_
target'
,
'CPU'
)
def
__infer__
(
self
,
params
,
indices
,
axis
,
offset
,
reduce_scatter_flag
=
False
,
split_num
=
2
):
validator
.
check_subclass
(
"params"
,
params
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
...
...
tests/ut/python/parallel/test_gather_v2.py
浏览文件 @
1cfb52bc
...
...
@@ -45,11 +45,11 @@ class GradWrap(nn.Cell):
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
axis
=
0
,
strategy1
=
None
,
strategy2
=
None
,
shape
=
None
):
def
__init__
(
self
,
axis
=
0
,
strategy1
=
None
,
strategy2
=
None
,
shape
=
None
,
target
=
""
):
super
().
__init__
()
if
shape
is
None
:
shape
=
[
64
,
64
]
self
.
gatherv2
=
P
.
GatherV2
().
set_strategy
(
strategy1
)
self
.
gatherv2
=
P
.
GatherV2
().
set_strategy
(
strategy1
)
.
add_prim_attr
(
"primitive_target"
,
target
)
self
.
mul
=
P
.
Mul
().
set_strategy
(
strategy2
)
self
.
index
=
Tensor
(
np
.
ones
(
shape
),
dtype
=
ms
.
int32
)
self
.
axis
=
axis
...
...
@@ -188,7 +188,7 @@ def test_gatherv2_cpu0():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy1
=
((
8
,
1
),
(
1
,
1
))
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
net
=
NetWithLoss
(
Net
(
0
,
strategy1
,
strategy2
))
net
=
NetWithLoss
(
Net
(
0
,
strategy1
,
strategy2
,
None
,
"CPU"
))
net
.
set_auto_parallel
()
x
=
Tensor
(
np
.
ones
([
64
,
64
]),
dtype
=
ms
.
float32
)
...
...
@@ -200,7 +200,7 @@ def test_gatherv2_cpu1():
context
.
set_auto_parallel_context
(
device_num
=
16
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy1
=
((
16
,
1
),
(
1
,
1
))
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
net
=
NetWithLoss
(
Net
(
0
,
strategy1
,
strategy2
))
net
=
NetWithLoss
(
Net
(
0
,
strategy1
,
strategy2
,
None
,
"CPU"
))
net
.
set_auto_parallel
()
x
=
Tensor
(
np
.
ones
([
64
,
64
]),
dtype
=
ms
.
float32
)
...
...
@@ -212,7 +212,7 @@ def test_gatherv2_cpu2():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy1
=
((
1
,
8
),
(
1
,
1
))
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
net
=
NetWithLoss
(
Net
(
0
,
strategy1
,
strategy2
))
net
=
NetWithLoss
(
Net
(
0
,
strategy1
,
strategy2
,
None
,
"CPU"
))
net
.
set_auto_parallel
()
x
=
Tensor
(
np
.
ones
([
64
,
64
]),
dtype
=
ms
.
float32
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录