Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
5b3327d1
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5b3327d1
编写于
4月 28, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 28, 2020
浏览文件
操作
浏览文件
下载
差异文件
!746 reducescatter backforward operator
Merge pull request !746 from lirongzhen1/bp_reducescatter
上级
36d9327c
0b464888
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
52 addition
and
3 deletion
+52
-3
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc
...re/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc
+0
-1
mindspore/ops/_grad/grad_comm_ops.py
mindspore/ops/_grad/grad_comm_ops.py
+19
-1
tests/ut/python/communication/test_comm.py
tests/ut/python/communication/test_comm.py
+33
-1
未找到文件。
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc
浏览文件 @
5b3327d1
...
...
@@ -29,7 +29,6 @@
namespace
mindspore
{
namespace
parallel
{
// Get the target node's weight for sorting.
double
GetWeights
(
const
Graph
::
NodeType
&
node
)
{
const
OperatorRec
&
op
=
node
.
apply
;
...
...
mindspore/ops/_grad/grad_comm_ops.py
浏览文件 @
5b3327d1
...
...
@@ -67,11 +67,29 @@ def get_bprop_broad_cast(self):
@
bprop_getters
.
register
(
AllGather
)
def
get_bprop_all_gather
(
self
):
"""Generate bprop for AllGather"""
reduce_scatter_grad
=
ReduceScatter
(
ReduceOp
.
SUM
,
self
.
group
)
all_gather_grad
=
ReduceScatter
(
ReduceOp
.
SUM
,
self
.
group
)
if
self
.
instance_name
:
instance_name
=
"grad"
+
self
.
instance_name
all_gather_grad
.
set_prim_instance_name
(
instance_name
)
def
bprop
(
x
,
out
,
dout
):
dx
=
all_gather_grad
(
dout
)
return
(
dx
,)
return
bprop
@
bprop_getters
.
register
(
ReduceScatter
)
def
get_bprop_reduce_scatter
(
self
):
"""Generate bprop for ReduceScatter"""
reduce_scatter_grad
=
AllGather
(
self
.
group
)
if
self
.
instance_name
:
instance_name
=
"grad"
+
self
.
instance_name
reduce_scatter_grad
.
set_prim_instance_name
(
instance_name
)
if
self
.
op
!=
ReduceOp
.
SUM
:
raise
RuntimeError
(
"The reducescatter bprop only support ReduceOp.SUM until now."
)
def
bprop
(
x
,
out
,
dout
):
dx
=
reduce_scatter_grad
(
dout
)
return
(
dx
,)
...
...
tests/ut/python/communication/test_comm.py
浏览文件 @
5b3327d1
...
...
@@ -14,7 +14,7 @@
""" test Communicate """
import
numpy
as
np
from
mindspore.ops.operations.comm_ops
import
AllReduce
,
AllGather
,
_AlltoAll
,
ReduceOp
from
mindspore.ops.operations.comm_ops
import
AllReduce
,
AllGather
,
_AlltoAll
,
ReduceOp
,
ReduceScatter
from
mindspore.ops.operations.comm_ops
import
Broadcast
from
mindspore.communication.management
import
HCCL_WORLD_COMM_GROUP
,
NCCL_WORLD_COMM_GROUP
,
GlobalComm
,
init
from
mindspore.communication._comm_helper
import
Backend
...
...
@@ -78,6 +78,19 @@ class AllGatherNet(nn.Cell):
x
=
self
.
allgather
(
x
)
return
self
.
relu
(
x
)
class
ReduceScatterNet
(
nn
.
Cell
):
"""ReduceScatterNet definition"""
def
__init__
(
self
,
input_channel
,
out_channel
,
op
):
super
(
ReduceScatterNet
,
self
).
__init__
()
self
.
dense
=
Dense
(
input_channel
,
out_channel
)
self
.
reducescatter
=
ReduceScatter
(
op
)
self
.
relu
=
ReLU
()
def
construct
(
self
,
x
):
x
=
self
.
dense
(
x
)
x
=
self
.
reducescatter
(
x
)
return
self
.
relu
(
x
)
class
AlltoAllNet
(
nn
.
Cell
):
"""AlltoAllNet definition"""
def
__init__
(
self
,
input_channel
,
out_channel
):
...
...
@@ -126,6 +139,25 @@ def test_allgather():
network
=
TrainOneStepCell
(
network
,
optimizer
)
_executor
.
compile
(
network
,
input_tensor
,
label_tensor
)
def
run_reducescatter
(
op
):
"""run_reducescatter"""
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
input_tensor
=
Tensor
(
np
.
array
([[
1.2
,
2.1
],
[
2.2
,
3.2
]],
dtype
=
np
.
float32
))
label_tensor
=
Tensor
(
np
.
array
([[
1.2
],
[
2.2
]],
dtype
=
np
.
float32
))
network
=
ReduceScatterNet
(
2
,
1
,
op
)
loss_fn
=
nn
.
SoftmaxCrossEntropyWithLogits
()
optimizer
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
network
.
get_parameters
()),
learning_rate
=
0.1
,
momentum
=
0.9
)
network
=
WithLossCell
(
network
,
loss_fn
)
network
=
TrainOneStepCell
(
network
,
optimizer
)
_executor
.
compile
(
network
,
input_tensor
,
label_tensor
)
def
test_reducescatter
():
"""test_reducescatter"""
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
run_reducescatter
(
ReduceOp
.
SUM
)
def
test_broadcast
():
"""test_broadcast"""
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录