Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
pingzhuyan
mindspore
提交
05c003ae
M
mindspore
项目概览
pingzhuyan
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
05c003ae
编写于
9月 03, 2020
作者:
Y
yao_yf
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
origin/semi_auto_parallel_reshape_parameter_has_another_user
上级
a9f4a24e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
52 addition
and
0 deletion
+52
-0
mindspore/ccsrc/frontend/parallel/step_parallel.cc
mindspore/ccsrc/frontend/parallel/step_parallel.cc
+28
-0
mindspore/ccsrc/frontend/parallel/step_parallel.h
mindspore/ccsrc/frontend/parallel/step_parallel.h
+2
-0
tests/ut/python/parallel/test_auto_parallel_reshape.py
tests/ut/python/parallel/test_auto_parallel_reshape.py
+22
-0
未找到文件。
mindspore/ccsrc/frontend/parallel/step_parallel.cc
浏览文件 @
05c003ae
...
...
@@ -1645,8 +1645,36 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n
return
nullptr
;
}
std
::
shared_ptr
<
TensorLayout
>
FindParameterNextLayout
(
const
AnfNodePtr
&
node
)
{
FuncGraphManagerPtr
manager
=
node
->
func_graph
()
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
AnfNodeIndexSet
node_set
=
manager
->
node_users
()[
node
];
for
(
auto
&
node_pair
:
node_set
)
{
CNodePtr
use_apply
=
node_pair
.
first
->
cast
<
CNodePtr
>
();
if
(
use_apply
==
nullptr
||
!
IsValueNode
<
Primitive
>
(
use_apply
->
input
(
0
)))
{
continue
;
}
ValueNodePtr
prim_anf_node
=
use_apply
->
input
(
0
)
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
prim_anf_node
);
PrimitivePtr
node_prim
=
prim_anf_node
->
value
()
->
cast
<
PrimitivePtr
>
();
MS_EXCEPTION_IF_NULL
(
node_prim
);
if
((
node_prim
->
name
()
==
DEPEND
&&
node_pair
.
second
!=
1
)
||
node_prim
->
name
()
==
RESHAPE
)
{
continue
;
}
if
(
IsParallelCareNode
(
use_apply
)
&&
use_apply
->
has_user_data
<
OperatorInfo
>
())
{
auto
layout
=
GetInputLayoutFromCNode
(
node_pair
);
return
std
::
make_shared
<
TensorLayout
>
(
layout
);
}
}
return
nullptr
;
}
std
::
shared_ptr
<
TensorLayout
>
CreateParameterLayout
(
const
AnfNodePtr
&
node
)
{
// Create DataParallel tensor layout for parameter(support WideDeep).
auto
next_layout
=
FindParameterNextLayout
(
node
);
if
(
next_layout
!=
nullptr
)
{
return
next_layout
;
}
CheckGlobalDeviceManager
();
int32_t
dev_num
=
SizeToInt
(
g_device_manager
->
GetDeviceListByStageId
(
0
).
size
());
TensorLayout
input_tensor_layout
;
...
...
mindspore/ccsrc/frontend/parallel/step_parallel.h
浏览文件 @
05c003ae
...
...
@@ -156,6 +156,8 @@ using ParameterUsersInfo = std::pair<std::string, std::pair<AnfNodePtr, AnfNodeI
RefKeyPair
CNodeWithRefKeys
(
const
AnfNodePtr
&
cnode
);
std
::
shared_ptr
<
TensorLayout
>
FindParameterNextLayout
(
const
AnfNodePtr
&
node
);
ParameterUsersInfo
FindParameterUsers
(
const
AnfNodePtr
&
node
,
bool
(
*
IsCareNode
)(
const
CNodePtr
&
));
}
// namespace parallel
}
// namespace mindspore
...
...
tests/ut/python/parallel/test_auto_parallel_reshape.py
浏览文件 @
05c003ae
...
...
@@ -292,3 +292,25 @@ def test_reshape_auto_6():
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
)
net
.
set_auto_parallel
()
_executor
.
compile
(
net
,
x
,
y
)
def
test_reshape_auto_7
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
reshape
=
P
.
Reshape
()
self
.
mul
=
P
.
Mul
().
set_strategy
(((
1
,
2
,
4
),
(
2
,
4
)))
self
.
mul_weight
=
Parameter
(
Tensor
(
np
.
ones
([
128
,
96
]),
dtype
=
ms
.
float32
),
name
=
"weight"
)
def
construct
(
self
,
x
):
weight
=
self
.
reshape
(
self
.
mul_weight
,
(
1
,
128
,
96
))
out
=
self
.
mul
(
weight
,
self
.
mul_weight
)
return
out
size
=
8
context
.
set_auto_parallel_context
(
device_num
=
size
,
global_rank
=
0
)
x
=
Tensor
(
np
.
ones
([
128
,
28
]),
dtype
=
ms
.
float32
)
net
=
GradWrap
(
NetWithLoss
(
Net
()))
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
)
net
.
set_auto_parallel
()
_executor
.
compile
(
net
,
x
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录