Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
87040483
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
87040483
编写于
4月 01, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 01, 2020
浏览文件
操作
浏览文件
下载
差异文件
!58 fix two cast bug in auto parallel
Merge pull request !58 from lichen/fix_two_cast_bug_in_auto_parallel
上级
52166a85
2da38ad4
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
51 addition
and
11 deletion
+51
-11
mindspore/ccsrc/parallel/step_auto_parallel.cc
mindspore/ccsrc/parallel/step_auto_parallel.cc
+4
-2
mindspore/ccsrc/parallel/step_parallel.cc
mindspore/ccsrc/parallel/step_parallel.cc
+18
-5
tests/ut/python/parallel/test_element_wise_function.py
tests/ut/python/parallel/test_element_wise_function.py
+29
-4
未找到文件。
mindspore/ccsrc/parallel/step_auto_parallel.cc
浏览文件 @
87040483
...
...
@@ -346,6 +346,8 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) {
}
OperatorInfoPtr
CreateTheOperatorInfo
(
const
PrimitivePtr
&
prim
,
const
CNodePtr
&
cnode
)
{
MS_EXCEPTION_IF_NULL
(
prim
);
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
attrs
=
prim
->
attrs
();
std
::
vector
<
Shapes
>
shape_list
=
ExtractShape
(
cnode
);
if
(
shape_list
.
empty
())
{
...
...
@@ -381,8 +383,8 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
operator_info
->
set_outputs_dtype
(
cnode
->
Type
());
operator_info
->
set_cnode
(
cnode
);
// If no strategy has been configured for this operator, then candidate strategies are generated for
// auto-strategy searching
if
(
!
StrategyFound
(
attrs
))
{
// auto-strategy searching
m if this primitive is Cast, we ignore the user-specified strategy
if
(
!
StrategyFound
(
attrs
)
||
prim
->
name
()
==
CAST
)
{
// Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
// BatchParallelInfo operator
operator_info
->
ComputeBatchSplitFlagList
();
...
...
mindspore/ccsrc/parallel/step_parallel.cc
浏览文件 @
87040483
...
...
@@ -371,7 +371,6 @@ bool IsParallelCareNode(const CNodePtr& cnode) {
if
(
prim
==
nullptr
)
{
return
false
;
}
auto
attrs
=
prim
->
attrs
();
if
(
IsInBlackList
(
prim
))
{
MS_LOG
(
INFO
)
<<
"Parallel don't care node: "
<<
prim
->
name
();
return
false
;
...
...
@@ -380,11 +379,9 @@ bool IsParallelCareNode(const CNodePtr& cnode) {
if
(
prim
->
name
()
==
GET_NEXT
)
{
return
true
;
}
if
((
prim
->
name
()
==
CAST
))
{
if
((
!
attrs
.
count
(
STRATEGY
))
&&
(
cnode
->
operator_info
()
==
nullptr
))
{
if
((
prim
->
name
()
==
CAST
)
&&
(
cnode
->
operator_info
()
==
nullptr
))
{
return
false
;
}
}
return
cnode
->
in_forward_flag
();
}
...
...
@@ -654,6 +651,14 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr& loss_node) {
LossNodeInfo
node_info
;
// return -> cast
auto
pre_cnode
=
pre_node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
pre_cnode
);
auto
pre_prim
=
GetValueNode
<
PrimitivePtr
>
(
pre_cnode
->
input
(
0
));
if
(
pre_prim
->
name
()
==
CAST
&&
pre_cnode
->
operator_info
()
==
nullptr
)
{
pre_node
=
pre_cnode
->
input
(
1
);
}
// return -> loss
if
(
pre_node
==
loss_node
)
{
node_info
.
has_tuple_getitem
=
false
;
...
...
@@ -1948,6 +1953,14 @@ CNodePtr FindLossCNode(const FuncGraphPtr& func_graph) {
MS_EXCEPTION_IF_NULL
(
current_value
);
PrimitivePtr
current_prim
=
current_value
->
value
()
->
cast
<
PrimitivePtr
>
();
MS_EXCEPTION_IF_NULL
(
current_prim
);
// return -> cast
if
(
current_prim
->
name
()
==
CAST
&&
pre_cnode
->
operator_info
()
==
nullptr
)
{
pre_cnode
=
pre_cnode
->
input
(
1
)
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
pre_cnode
);
current_prim
=
GetValueNode
<
PrimitivePtr
>
(
pre_cnode
->
input
(
0
));
}
// notice: the GetNext op has not input
if
(
INVALID_LOSS_OPS
.
find
(
current_prim
->
name
())
!=
INVALID_LOSS_OPS
.
end
())
{
MS_LOG
(
INFO
)
<<
"The loss is: "
<<
current_prim
->
name
();
...
...
tests/ut/python/parallel/test_element_wise_function.py
浏览文件 @
87040483
...
...
@@ -192,7 +192,6 @@ def test_cast_before_mirror():
net
=
GradWrap
(
NetWithLoss
(
Net
(
strategy1
)))
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
)
x
=
Tensor
(
np
.
ones
([
128
,
32
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
32
,
64
]),
dtype
=
ms
.
float32
)
b
=
Tensor
(
np
.
ones
([
64
,
64
]),
dtype
=
ms
.
float16
)
...
...
@@ -217,7 +216,6 @@ def test_cast_before_mirror1():
net
=
GradWrap
(
NetWithLoss
(
Net
(
strategy1
)))
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
)
x
=
Tensor
(
np
.
ones
([
128
,
32
]),
dtype
=
ms
.
float16
)
y
=
Tensor
(
np
.
ones
([
32
,
64
]),
dtype
=
ms
.
float16
)
b
=
Tensor
(
np
.
ones
([
64
,
64
]),
dtype
=
ms
.
float32
)
...
...
@@ -242,7 +240,6 @@ def test_cast_before_mirror2():
net
=
GradWrap
(
NetWithLoss
(
Net
(
strategy1
)))
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
)
x
=
Tensor
(
np
.
ones
([
128
,
32
]),
dtype
=
ms
.
float16
)
y
=
Tensor
(
np
.
ones
([
32
,
64
]),
dtype
=
ms
.
float16
)
b
=
Tensor
(
np
.
ones
([
64
,
64
]),
dtype
=
ms
.
float32
)
...
...
@@ -267,8 +264,36 @@ def test_cast_before_mirror3():
net
=
GradWrap
(
NetWithLoss
(
Net
(
strategy1
)))
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
)
x
=
Tensor
(
np
.
ones
([
128
,
32
]),
dtype
=
ms
.
float16
)
y
=
Tensor
(
np
.
ones
([
32
,
64
]),
dtype
=
ms
.
float16
)
b
=
Tensor
(
np
.
ones
([
64
,
64
]),
dtype
=
ms
.
float32
)
_executor
.
compile
(
net
,
x
,
y
,
b
)
def
test_mul_two_cast
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
strategy1
,
strategy2
,
strategy3
):
super
().
__init__
()
self
.
mul
=
P
.
Mul
().
set_strategy
(
strategy1
)
self
.
mul2
=
P
.
Mul
().
set_strategy
(
strategy2
)
self
.
cast
=
P
.
Cast
().
set_strategy
(
strategy3
)
self
.
cast2
=
P
.
Cast
().
set_strategy
(
strategy3
)
def
construct
(
self
,
x
,
y
,
b
):
out
=
self
.
mul
(
x
,
y
)
out
=
self
.
mul2
(
out
,
b
)
out
=
self
.
cast
(
out
,
ms
.
int32
)
out
=
self
.
cast2
(
out
,
ms
.
bool_
)
return
out
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
2
,
2
),
(
2
,
2
))
strategy2
=
((
8
,
1
),
(
8
,
1
))
strategy3
=
((
8
,
1
),
)
net
=
GradWrap
(
Net
(
strategy1
,
strategy2
,
strategy3
))
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
)
x
=
Tensor
(
np
.
ones
([
128
,
32
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
128
,
32
]),
dtype
=
ms
.
float32
)
b
=
Tensor
(
np
.
ones
([
128
,
32
]),
dtype
=
ms
.
float32
)
_executor
.
compile
(
net
,
x
,
y
,
b
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录