Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ec5363ad
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ec5363ad
编写于
4年前
作者:
M
mindspore-ci-bot
提交者:
Gitee
4年前
浏览文件
操作
浏览文件
下载
差异文件
!1593 Fix the bug that there is only return node in the forward graph
Merge pull request !1593 from yangzhenzhang/reshape-optimized
上级
29eacb0f
1413f520
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
68 addition
and
20 deletion
+68
-20
mindspore/ccsrc/parallel/step_parallel.cc
mindspore/ccsrc/parallel/step_parallel.cc
+14
-18
mindspore/ccsrc/parallel/step_parallel.h
mindspore/ccsrc/parallel/step_parallel.h
+0
-2
tests/ut/python/parallel/test_reshape_optimized.py
tests/ut/python/parallel/test_reshape_optimized.py
+54
-0
未找到文件。
mindspore/ccsrc/parallel/step_parallel.cc
浏览文件 @
ec5363ad
...
...
@@ -1683,7 +1683,10 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL
(
pre_node
);
auto
pre_cnode
=
pre_node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
pre_cnode
);
if
(
pre_cnode
==
nullptr
)
{
return
nullptr
;
}
auto
current_prim
=
GetValueNode
<
PrimitivePtr
>
(
pre_cnode
->
input
(
0
));
// return -> cast
if
(
current_prim
->
name
()
==
CAST
&&
pre_cnode
->
operator_info
()
==
nullptr
)
{
...
...
@@ -1907,21 +1910,6 @@ void StepSplitSens(const std::pair<CNodePtr, CNodePtr> &sens_loss_pair) {
}
}
std
::
vector
<
CNodePtr
>
FindLossCNodeFromRoot
(
const
FuncGraphPtr
&
root
)
{
MS_EXCEPTION_IF_NULL
(
root
);
AnfNodePtr
root_return_node
=
root
->
get_return
();
MS_EXCEPTION_IF_NULL
(
root_return_node
);
std
::
vector
<
CNodePtr
>
loss_node
;
const
auto
&
all_nodes
=
root
->
nodes
();
std
::
set
<
FuncGraphPtr
>
graph_set
=
FindForwardGraphByRootNodes
(
all_nodes
);
if
(
graph_set
.
empty
())
{
loss_node
.
push_back
(
FindLossCNode
(
root
));
}
(
void
)
std
::
transform
(
graph_set
.
begin
(),
graph_set
.
end
(),
std
::
back_inserter
(
loss_node
),
[](
const
FuncGraphPtr
&
graph
)
{
return
FindLossCNode
(
graph
);
});
return
loss_node
;
}
// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
std
::
vector
<
std
::
pair
<
CNodePtr
,
CNodePtr
>>
GetSensLossPairs
(
const
FuncGraphPtr
&
root
)
{
MS_EXCEPTION_IF_NULL
(
root
);
...
...
@@ -1968,6 +1956,10 @@ std::vector<std::pair<CNodePtr, CNodePtr>> GetSensLossPairs(const FuncGraphPtr &
}
auto
func_graph
=
GetValueNode
<
FuncGraphPtr
>
(
expect_j_cnode
->
input
(
1
));
auto
loss_cnode
=
FindLossCNode
(
func_graph
);
if
(
loss_cnode
==
nullptr
)
{
MS_LOG
(
WARNING
)
<<
"Can not find the loss cnode"
;
continue
;
}
std
::
pair
<
CNodePtr
,
CNodePtr
>
sens_loss_pair
=
std
::
make_pair
(
sens_cnode
,
loss_cnode
);
sens_loss_pairs
.
push_back
(
sens_loss_pair
);
}
...
...
@@ -2158,10 +2150,14 @@ std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root) {
std
::
vector
<
AnfNodePtr
>
FindRootForwardCNode
(
const
FuncGraphPtr
&
graph
,
const
AnfNodeSet
&
all_nodes
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
std
::
vector
<
AnfNodePtr
>
root_forward_nodes
;
auto
loss_cnode
=
FindLossCNode
(
graph
);
MS_EXCEPTION_IF_NULL
(
loss_cnode
);
if
(
loss_cnode
==
nullptr
)
{
MS_LOG
(
WARNING
)
<<
"Can not find the loss cnode"
;
return
root_forward_nodes
;
}
auto
loss_cnode_id
=
loss_cnode
->
UniqueIdThroughCopy
();
std
::
vector
<
AnfNodePtr
>
root_forward_nodes
;
for
(
auto
&
node
:
all_nodes
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
!
node
->
isa
<
CNode
>
())
{
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/parallel/step_parallel.h
浏览文件 @
ec5363ad
...
...
@@ -144,8 +144,6 @@ bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optim
int32_t
GetTupleGetItemIndex
(
const
CNodePtr
&
cnode
);
std
::
vector
<
CNodePtr
>
FindLossCNodeFromRoot
(
const
FuncGraphPtr
&
root
);
Status
ParallelInit
();
std
::
vector
<
std
::
string
>
ExtractInputsTensorName
(
const
CNodePtr
&
node
);
...
...
This diff is collapsed.
Click to expand it.
tests/ut/python/parallel/test_reshape_optimized.py
0 → 100644
浏览文件 @
ec5363ad
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
mindspore
as
ms
from
mindspore
import
context
,
Tensor
,
Parameter
from
mindspore.common.api
import
_executor
from
mindspore.nn
import
Cell
,
TrainOneStepCell
,
Momentum
from
mindspore.ops
import
operations
as
P
class
Net
(
Cell
):
def
__init__
(
self
,
mul_weight
):
super
().
__init__
()
self
.
reshape1
=
P
.
Reshape
()
self
.
reshape2
=
P
.
Reshape
()
self
.
mul_weight
=
Parameter
(
mul_weight
,
"w1"
)
def
construct
(
self
,
x
,
b
):
out
=
self
.
reshape1
(
self
.
mul_weight
,
(
128
,
64
,
32
))
out
=
self
.
reshape2
(
out
,
(
128
,
64
,
32
))
return
out
_x
=
Tensor
(
np
.
ones
([
128
,
64
,
32
]),
dtype
=
ms
.
float32
)
_w1
=
Tensor
(
np
.
ones
([
128
,
64
,
32
]),
dtype
=
ms
.
float32
)
_b
=
Tensor
(
np
.
ones
([
128
,
64
,
32
]),
dtype
=
ms
.
float32
)
def
compile_net
(
net
):
context
.
set_context
(
save_graphs
=
True
)
optimizer
=
Momentum
(
net
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
train_net
=
TrainOneStepCell
(
net
,
optimizer
)
train_net
.
set_auto_parallel
()
_executor
.
compile
(
train_net
,
_x
,
_b
)
context
.
reset_auto_parallel_context
()
def
test_reshape_optimized
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
net
=
Net
(
_w1
)
compile_net
(
net
)
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
新手
引导
客服
返回
顶部