Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
755f3814
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
755f3814
编写于
9月 02, 2020
作者:
Y
yao_yf
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix auto parallel reshape strategy set when it is first operator
上级
03093778
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
67 addition
and
11 deletion
+67
-11
mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc
.../ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc
+20
-11
tests/ut/python/parallel/test_auto_parallel_reshape.py
tests/ut/python/parallel/test_auto_parallel_reshape.py
+47
-0
未找到文件。
mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc
浏览文件 @
755f3814
...
...
@@ -1565,24 +1565,33 @@ Status CostGraph::InitSelectedStrategy() {
auto
next_iter
=
std
::
find_if
(
out_edges
.
begin
(),
out_edges
.
end
(),
[
&
](
std
::
shared_ptr
<
Edge
>
edge
)
{
return
edge
->
next_operator
()
->
name
()
==
reshape_info
->
next_operator_name
();
});
if
(
pre_iter
!=
in_edges
.
end
())
{
bool
reshape_is_first_op
=
reshape_info
->
pre_operator_name
()
==
reshape_info
->
name
();
if
(
reshape_is_first_op
)
{
reshape_info
->
InitSelectedStrategy
(
reshape_info
->
selected_strategy
());
}
if
(
pre_iter
!=
in_edges
.
end
()
||
reshape_is_first_op
)
{
MS_LOG
(
DEBUG
)
<<
"Set reshape input layout by "
<<
reshape_info
->
pre_operator_name
();
int32_t
pre_index
=
reshape_info
->
pre_operator_index
();
TensorInfo
pre_info
;
if
(
ops_
[
i
]
->
name
()
==
(
*
pre_iter
)
->
prev_operator
()
->
name
())
{
pre_info
=
(
*
pre_iter
)
->
prev_operator
()
->
inputs_tensor_info
()[
pre_index
];
std
::
shared_ptr
<
OperatorInfo
>
pre_op_info
;
if
(
reshape_is_first_op
)
{
pre_op_info
=
reshape_info
;
pre_info
=
pre_op_info
->
inputs_tensor_info
()[
pre_index
];
}
else
{
pre_info
=
(
*
pre_iter
)
->
prev_operator
()
->
outputs_tensor_info
()[
pre_index
];
pre_op_info
=
(
*
pre_iter
)
->
prev_operator
();
pre_info
=
pre_op_info
->
outputs_tensor_info
()[
pre_index
];
}
reshape_info
->
SetInputLayout
(
pre_info
.
tensor_layout
());
Dimensions
stra
=
pre_info
.
InferStrategy
();
if
(
stra
.
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"Infer strategy by tensor_info failed"
;
if
(
pre_iter
!=
in_edges
.
end
())
{
Dimensions
stra
=
pre_info
.
InferStrategy
();
if
(
stra
.
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"Infer strategy by tensor_info failed"
;
}
Strategys
stra_inputs
=
{
stra
};
StrategyPtr
reshape_stra
=
std
::
make_shared
<
Strategy
>
((
*
pre_iter
)
->
prev_operator
()
->
strategy
()
->
GetInputStage
(),
stra_inputs
);
reshape_info
->
set_strategy
(
reshape_stra
);
}
Strategys
stra_inputs
=
{
stra
};
StrategyPtr
reshape_stra
=
std
::
make_shared
<
Strategy
>
((
*
pre_iter
)
->
prev_operator
()
->
strategy
()
->
GetInputStage
(),
stra_inputs
);
reshape_info
->
set_strategy
(
reshape_stra
);
}
if
(
next_iter
!=
out_edges
.
end
())
{
MS_LOG
(
DEBUG
)
<<
"Set reshape output layout by "
<<
reshape_info
->
next_operator_name
();
...
...
tests/ut/python/parallel/test_auto_parallel_reshape.py
浏览文件 @
755f3814
...
...
@@ -245,3 +245,50 @@ def test_reshape_auto_5():
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
)
net
.
set_auto_parallel
()
_executor
.
compile
(
net
,
x
,
y
)
def
test_reshape_auto_6
():
class
NetWithLoss6
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
NetWithLoss6
,
self
).
__init__
()
self
.
loss
=
VirtualLoss
()
self
.
network
=
network
def
construct
(
self
,
x
,
y
):
predict
=
self
.
network
(
x
,
y
)
return
self
.
loss
(
predict
)
class
GradWrap6
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
GradWrap6
,
self
).
__init__
()
self
.
network
=
network
def
construct
(
self
,
x
,
y
):
return
grad_all
(
self
.
network
)(
x
,
y
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
relu
=
P
.
ReLU
()
self
.
mul
=
P
.
Mul
()
self
.
reshape
=
P
.
Reshape
()
self
.
reduce_mean
=
P
.
ReduceMean
()
self
.
wide_w
=
Parameter
(
Tensor
(
np
.
ones
([
4
,
1024
,
1
]),
dtype
=
ms
.
float32
),
name
=
"weight"
)
def
construct
(
self
,
x
,
y
):
out1
=
x
+
self
.
wide_w
w
=
self
.
reshape
(
self
.
wide_w
,
(
4
,
1024
))
out1
=
self
.
reduce_mean
(
out1
,
1
)
out1
=
out1
-
w
out2
=
self
.
mul
(
y
,
w
)
out
=
out1
+
out2
return
out
size
=
8
context
.
set_auto_parallel_context
(
device_num
=
size
,
global_rank
=
0
)
x
=
Tensor
(
np
.
ones
([
4
,
1024
,
1
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
4
,
1024
,]),
dtype
=
ms
.
float32
)
net
=
GradWrap6
(
NetWithLoss6
(
Net
()))
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
)
net
.
set_auto_parallel
()
_executor
.
compile
(
net
,
x
,
y
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录