Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
36a62576
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
36a62576
编写于
4月 26, 2020
作者:
Y
yangzhenzhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support forward graph
上级
00191223
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
168 addition
and
90 deletion
+168
-90
mindspore/ccsrc/parallel/step_parallel.cc
mindspore/ccsrc/parallel/step_parallel.cc
+84
-89
mindspore/ccsrc/parallel/step_parallel.h
mindspore/ccsrc/parallel/step_parallel.h
+2
-1
tests/ut/python/parallel/test_forward_graph.py
tests/ut/python/parallel/test_forward_graph.py
+82
-0
未找到文件。
mindspore/ccsrc/parallel/step_parallel.cc
浏览文件 @
36a62576
...
...
@@ -345,7 +345,6 @@ bool FindCommunicationOp(const std::vector<AnfNodePtr> &all_nodes) {
continue
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)))
{
continue
;
}
...
...
@@ -903,9 +902,15 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
}
}
void
BackwardCommunication
(
const
OperatorInfoPtr
&
distribute_operator
,
const
CNodePtr
&
node
,
bool
is_loss_node
)
{
void
BackwardCommunication
(
const
OperatorInfoPtr
&
distribute_operator
,
const
CNodePtr
&
node
,
const
std
::
vector
<
std
::
pair
<
CNodePtr
,
CNodePtr
>>
&
sens_loss_pairs
)
{
MS_EXCEPTION_IF_NULL
(
distribute_operator
);
MS_EXCEPTION_IF_NULL
(
node
);
bool
is_loss_cnode
=
std
::
any_of
(
sens_loss_pairs
.
begin
(),
sens_loss_pairs
.
end
(),
[
node
](
const
std
::
pair
<
CNodePtr
,
CNodePtr
>
&
element
)
{
return
element
.
second
==
node
;
});
MirrorOps
mirror_ops
=
distribute_operator
->
mirror_ops
();
VirtualDivOp
virtual_div_op
=
distribute_operator
->
virtual_div_op
();
// insert mirror op
...
...
@@ -914,7 +919,7 @@ void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNo
InsertMirrorOps
(
mirror_ops
,
node
);
}
// insert virtual div op
if
(
!
virtual_div_op
.
empty
()
&&
is_loss_node
)
{
if
(
!
virtual_div_op
.
empty
()
&&
is_loss_
c
node
)
{
MS_LOG
(
INFO
)
<<
"insert virtual div op for "
<<
distribute_operator
->
name
();
InsertVirtualDivOp
(
virtual_div_op
,
node
);
}
...
...
@@ -986,10 +991,6 @@ StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs) {
Dimensions
dim
;
if
(
elements
[
index
]
->
isa
<
ValueSequeue
>
())
{
ValueTuplePtr
value_tuple
=
elements
[
index
]
->
cast
<
ValueTuplePtr
>
();
if
(
value_tuple
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failure:value_tuple is nullptr"
;
}
std
::
vector
<
ValuePtr
>
value_vector
=
value_tuple
->
value
();
(
void
)
std
::
transform
(
value_vector
.
begin
(),
value_vector
.
end
(),
std
::
back_inserter
(
dim
),
[](
const
ValuePtr
&
value
)
{
return
static_cast
<
int32_t
>
(
GetValue
<
int
>
(
value
));
});
...
...
@@ -1013,7 +1014,6 @@ Shapes GetNodeShape(const AnfNodePtr &node) {
BaseShapePtr
base_shape_ptr
=
node
->
Shape
();
if
(
node
->
isa
<
CNode
>
())
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)))
{
PrimitivePtr
prim
=
GetValueNode
<
PrimitivePtr
>
(
cnode
->
input
(
0
));
MS_EXCEPTION_IF_NULL
(
prim
);
...
...
@@ -1190,7 +1190,7 @@ std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &graph, const AnfNode
continue
;
}
CNodePtr
graph_cnode_inp0
=
graph_cnode
->
input
(
0
)
->
cast
<
CNodePtr
>
();
if
(
(
graph_cnode_inp0
==
nullptr
)
||
!
IsValueNode
<
FuncGraph
>
(
graph_cnode_inp0
->
input
(
1
)))
{
if
(
!
IsValueNode
<
FuncGraph
>
(
graph_cnode_inp0
->
input
(
1
)))
{
continue
;
}
FuncGraphPtr
graph_sub
=
GetValueNode
<
FuncGraphPtr
>
(
graph_cnode_inp0
->
input
(
1
));
...
...
@@ -1692,14 +1692,8 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
return
pre_cnode
;
}
TensorLayouts
GetLossNodeGradOutputLayout
(
const
CNodePtr
&
cnode
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
TensorLayouts
GetLossNodeGradOutputLayout
(
const
CNodePtr
&
loss_cnode
)
{
TensorLayouts
ret
;
if
(
!
IsValueNode
<
FuncGraph
>
(
cnode
->
input
(
1
)))
{
MS_LOG
(
EXCEPTION
)
<<
"Sens can't find the corresponding graph."
;
}
auto
func_graph
=
GetValueNode
<
FuncGraphPtr
>
(
cnode
->
input
(
1
));
auto
loss_cnode
=
FindLossCNode
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
loss_cnode
);
AnfNodePtr
node
=
loss_cnode
->
cast
<
AnfNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
node
);
...
...
@@ -1735,16 +1729,16 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &cnode) {
return
ret
;
}
void
SplitSens
(
const
Anf
NodePtr
&
grad_sens_node
,
const
TensorLayout
&
loss_grad_layout
)
{
void
SplitSens
(
const
C
NodePtr
&
grad_sens_node
,
const
TensorLayout
&
loss_grad_layout
)
{
MS_EXCEPTION_IF_NULL
(
grad_sens_node
);
auto
cnode
=
grad_sens_node
->
cast
<
CNodePtr
>
()
;
MS_EXCEPTION_IF_NULL
(
cnode
);
AnfNodePtr
sens_tensor_node
=
c
node
->
input
(
1
);
if
(
grad_sens_node
->
size
()
<=
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"The size of grad sens node is smaller than 2"
;
}
AnfNodePtr
sens_tensor_node
=
grad_sens_
node
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
sens_tensor_node
);
Shapes
sens_shapes
=
GetNodeShape
(
sens_tensor_node
);
if
(
sens_shapes
.
size
()
!=
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"
SplitSens:
GetNodeShape for sens_tensor_node, output size is not 1"
;
MS_LOG
(
EXCEPTION
)
<<
"GetNodeShape for sens_tensor_node, output size is not 1"
;
}
// If the shape of sens tensor is [] or [1], no need to split it.
Shape
sens_shape
=
sens_shapes
[
0
];
...
...
@@ -1780,14 +1774,14 @@ void SplitSens(const AnfNodePtr &grad_sens_node, const TensorLayout &loss_grad_l
sens_tensor_param
->
set_tensor_layout
(
std
::
make_shared
<
TensorLayout
>
(
loss_grad_layout
));
return
;
}
MS_LOG
(
EXCEPTION
)
<<
"
SplitSens: t
he type of sens node is not Tensor or Parameter, it is unsupported now."
;
MS_LOG
(
EXCEPTION
)
<<
"
T
he type of sens node is not Tensor or Parameter, it is unsupported now."
;
}
// Use _GetTensorSlice operator to split the sens tensor
FuncGraphPtr
func_graph
=
c
node
->
func_graph
();
// only cnode can get the graph
FuncGraphPtr
func_graph
=
grad_sens_
node
->
func_graph
();
// only cnode can get the graph
MS_EXCEPTION_IF_NULL
(
func_graph
);
Operator
op
=
CreateGetTensorSliceOp
(
loss_grad_layout
);
InsertGetTensorSliceOp
(
op
,
c
node
,
func_graph
,
1
,
SPLIT_SENS
);
InsertGetTensorSliceOp
(
op
,
grad_sens_
node
,
func_graph
,
1
,
SPLIT_SENS
);
}
void
InsertForwardOps
(
const
OperatorInfoPtr
&
distribute_operator
,
const
CNodePtr
&
cnode
)
{
...
...
@@ -1853,7 +1847,6 @@ std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
((
cnode
->
size
()
<
2
)
||
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)))
{
continue
;
}
...
...
@@ -1870,55 +1863,12 @@ std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no
return
graph_set
;
}
// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
void
StepSplitSens
(
const
AnfNodePtr
&
node
)
{
if
(
!
node
->
isa
<
CNode
>
())
{
return
;
}
// cnode(sens)-->cnode(tuple_getitem)
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
AnfNodePtr
expect_tuple_getitem
=
cnode
->
input
(
0
);
MS_EXCEPTION_IF_NULL
(
expect_tuple_getitem
);
if
(
!
expect_tuple_getitem
->
isa
<
CNode
>
())
{
return
;
}
auto
expect_tuple_getitem_cnode
=
expect_tuple_getitem
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
expect_tuple_getitem_cnode
);
if
(
!
IsValueNode
<
Primitive
>
(
expect_tuple_getitem_cnode
->
input
(
0
)))
{
return
;
}
auto
expect_tuple_getitem_prim
=
GetValueNode
<
PrimitivePtr
>
(
expect_tuple_getitem_cnode
->
input
(
0
));
if
(
expect_tuple_getitem_prim
->
name
()
!=
TUPLE_GETITEM
)
{
return
;
}
// cnode(sens)-->cnode(tuple_getitem)-->cnode
AnfNodePtr
expect_anonymous
=
expect_tuple_getitem_cnode
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
expect_anonymous
);
if
(
!
expect_anonymous
->
isa
<
CNode
>
())
{
return
;
}
// cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
auto
expect_anonymous_cnode
=
expect_anonymous
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
expect_anonymous_cnode
);
AnfNodePtr
expect_j
=
expect_anonymous_cnode
->
input
(
0
);
MS_EXCEPTION_IF_NULL
(
expect_j
);
if
(
!
expect_j
->
isa
<
CNode
>
())
{
return
;
}
auto
expect_j_cnode
=
expect_j
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
expect_j_cnode
);
if
(
!
IsValueNode
<
Primitive
>
(
expect_j_cnode
->
input
(
0
)))
{
return
;
}
auto
expect_j_prim
=
GetValueNode
<
PrimitivePtr
>
(
expect_j_cnode
->
input
(
0
));
if
(
expect_j_prim
->
name
()
==
J
)
{
auto
loss_grad_layout
=
GetLossNodeGradOutputLayout
(
expect_j_cnode
);
if
(
!
loss_grad_layout
.
empty
())
{
SplitSens
(
node
,
loss_grad_layout
[
0
]);
}
void
StepSplitSens
(
const
std
::
pair
<
CNodePtr
,
CNodePtr
>
&
sens_loss_pair
)
{
CNodePtr
sens_node
=
sens_loss_pair
.
first
;
CNodePtr
loss_node
=
sens_loss_pair
.
second
;
auto
loss_grad_layout
=
GetLossNodeGradOutputLayout
(
loss_node
);
if
(
!
loss_grad_layout
.
empty
())
{
SplitSens
(
sens_node
,
loss_grad_layout
[
0
]);
}
}
...
...
@@ -1937,26 +1887,77 @@ std::vector<CNodePtr> FindLossCNodeFromRoot(const FuncGraphPtr &root) {
return
loss_node
;
}
// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
std
::
vector
<
std
::
pair
<
CNodePtr
,
CNodePtr
>>
GetSensLossPairs
(
const
FuncGraphPtr
&
root
)
{
MS_EXCEPTION_IF_NULL
(
root
);
std
::
vector
<
std
::
pair
<
CNodePtr
,
CNodePtr
>>
sens_loss_pairs
;
for
(
auto
&
node
:
root
->
nodes
())
{
if
(
!
node
->
isa
<
CNode
>
())
{
continue
;
}
// cnode(sens)-->cnode(tuple_getitem)
auto
sens_cnode
=
node
->
cast
<
CNodePtr
>
();
AnfNodePtr
expect_tuple_getitem
=
sens_cnode
->
input
(
0
);
MS_EXCEPTION_IF_NULL
(
expect_tuple_getitem
);
if
(
!
expect_tuple_getitem
->
isa
<
CNode
>
())
{
continue
;
}
auto
expect_tuple_getitem_cnode
=
expect_tuple_getitem
->
cast
<
CNodePtr
>
();
if
(
!
IsSomePrimitive
(
expect_tuple_getitem_cnode
,
TUPLE_GETITEM
))
{
continue
;
}
// cnode(sens)-->cnode(tuple_getitem)-->cnode
AnfNodePtr
expect_anonymous
=
expect_tuple_getitem_cnode
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
expect_anonymous
);
if
(
!
expect_anonymous
->
isa
<
CNode
>
())
{
continue
;
}
// cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
auto
expect_anonymous_cnode
=
expect_anonymous
->
cast
<
CNodePtr
>
();
AnfNodePtr
expect_j
=
expect_anonymous_cnode
->
input
(
0
);
MS_EXCEPTION_IF_NULL
(
expect_j
);
if
(
!
expect_j
->
isa
<
CNode
>
())
{
continue
;
}
auto
expect_j_cnode
=
expect_j
->
cast
<
CNodePtr
>
();
if
(
!
IsSomePrimitive
(
expect_j_cnode
,
J
))
{
continue
;
}
if
(
!
IsValueNode
<
FuncGraph
>
(
expect_j_cnode
->
input
(
1
)))
{
MS_LOG
(
EXCEPTION
)
<<
"Sens can't find the corresponding graph."
;
}
auto
func_graph
=
GetValueNode
<
FuncGraphPtr
>
(
expect_j_cnode
->
input
(
1
));
auto
loss_cnode
=
FindLossCNode
(
func_graph
);
std
::
pair
<
CNodePtr
,
CNodePtr
>
sens_loss_pair
=
std
::
make_pair
(
sens_cnode
,
loss_cnode
);
sens_loss_pairs
.
push_back
(
sens_loss_pair
);
}
return
sens_loss_pairs
;
}
void
ParallelCommunication
(
const
FuncGraphPtr
&
root
,
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
,
const
FuncGraphManagerPtr
&
manager
)
{
MS_EXCEPTION_IF_NULL
(
root
);
MS_EXCEPTION_IF_NULL
(
manager
);
TensorRedistribution
tensor_redistribution
;
AnfNodePtr
grad_sens_node
=
nullptr
;
std
::
vector
<
CNodePtr
>
loss_cnode
=
FindLossCNodeFromRoot
(
root
);
std
::
vector
<
std
::
pair
<
CNodePtr
,
CNodePtr
>>
sens_loss_pairs
=
GetSensLossPairs
(
root
);
bool
has_backward
=
!
sens_loss_pairs
.
empty
();
// split sens must before inserting the operators.
for
(
auto
&
node
:
all_node
s
)
{
for
(
auto
&
pair
:
sens_loss_pair
s
)
{
// If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it.
// If the type of sens node is not Tensor, it is unsupported now, do nothing default.
StepSplitSens
(
node
);
StepSplitSens
(
pair
);
}
for
(
auto
&
node
:
all_nodes
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
node
->
isa
<
CNode
>
())
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)))
{
continue
;
}
...
...
@@ -1965,11 +1966,6 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
continue
;
}
bool
is_loss_cnode
=
false
;
auto
iter
=
std
::
find
(
loss_cnode
.
begin
(),
loss_cnode
.
end
(),
cnode
);
if
(
iter
!=
loss_cnode
.
end
())
{
is_loss_cnode
=
true
;
}
// insert forward ops
InsertForwardOps
(
distribute_operator
,
cnode
);
...
...
@@ -1977,7 +1973,9 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
StepRedistribution
(
cnode
,
distribute_operator
,
cnode
,
tensor_redistribution
,
cnode
);
// insert backward ops
BackwardCommunication
(
distribute_operator
,
cnode
,
is_loss_cnode
);
if
(
has_backward
)
{
BackwardCommunication
(
distribute_operator
,
cnode
,
sens_loss_pairs
);
}
// StepReplace
StepReplace
(
distribute_operator
,
cnode
);
...
...
@@ -2099,7 +2097,6 @@ void SetForwardFlag(const std::vector<AnfNodePtr> &all_nodes) {
continue
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)))
{
continue
;
}
...
...
@@ -2117,7 +2114,6 @@ void SetForwardFlag(const AnfNodeSet &all_nodes) {
continue
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)))
{
continue
;
}
...
...
@@ -2146,7 +2142,6 @@ std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const An
continue
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
root_node_id
=
node
->
UniqueIdThroughCopy
();
if
(
loss_cnode_id
==
root_node_id
)
{
root_forward_nodes
=
DeepLinkedGraphSearch
(
cnode
);
...
...
mindspore/ccsrc/parallel/step_parallel.h
浏览文件 @
36a62576
...
...
@@ -82,7 +82,8 @@ std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &
void
InsertMirrorOps
(
const
MirrorOps
&
mirror_ops
,
const
CNodePtr
&
node
);
void
BackwardCommunication
(
const
OperatorInfoPtr
&
distribute_operator
,
const
CNodePtr
&
node
,
bool
is_loss_node
);
void
BackwardCommunication
(
const
OperatorInfoPtr
&
distribute_operator
,
const
CNodePtr
&
node
,
const
std
::
vector
<
std
::
pair
<
CNodePtr
,
CNodePtr
>>
&
sens_loss_pairs
);
// Generate and init parallel operator
OperatorInfoPtr
OperatorInstance
(
const
PrimitivePtr
&
prim
,
const
PrimitiveAttrs
&
attrs
,
...
...
tests/ut/python/parallel/test_forward_graph.py
0 → 100644
浏览文件 @
36a62576
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
mindspore
as
ms
from
mindspore
import
context
,
Tensor
,
Parameter
from
mindspore.nn
import
Cell
from
mindspore.ops
import
operations
as
P
from
mindspore.common.api
import
_executor
class
Net
(
Cell
):
def
__init__
(
self
,
mul_weight
,
strategy1
=
None
,
strategy2
=
None
):
super
().
__init__
()
self
.
mul
=
P
.
Mul
().
set_strategy
(
strategy1
)
self
.
neg
=
P
.
Neg
().
set_strategy
(
strategy2
)
self
.
mul_weight
=
Parameter
(
mul_weight
,
"w1"
)
def
construct
(
self
,
x
,
b
):
out
=
self
.
mul
(
x
,
self
.
mul_weight
)
out
=
self
.
neg
(
out
)
return
out
,
b
_x
=
Tensor
(
np
.
ones
([
128
,
64
,
32
]),
dtype
=
ms
.
float32
)
_w1
=
Tensor
(
np
.
ones
([
128
,
64
,
32
]),
dtype
=
ms
.
float32
)
_b
=
Tensor
(
np
.
ones
([
128
,
64
,
32
]),
dtype
=
ms
.
float32
)
def
compile
(
net
):
_executor
.
compile
(
net
,
_x
,
_b
)
context
.
reset_auto_parallel_context
()
def
test_forward_graph_data_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
16
,
1
,
1
),
(
16
,
1
,
1
))
strategy2
=
((
16
,
1
,
1
),
)
net
=
Net
(
_w1
,
strategy1
,
strategy2
)
compile
(
net
)
def
test_forward_graph_model_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
1
,
1
,
16
),
(
1
,
1
,
16
))
strategy2
=
((
1
,
1
,
16
),
)
net
=
Net
(
_w1
,
strategy1
,
strategy2
)
compile
(
net
)
def
test_forward_graph_hybrid_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
2
,
2
,
4
),
(
2
,
2
,
4
))
strategy2
=
((
2
,
2
,
4
),
)
net
=
Net
(
_w1
,
strategy1
,
strategy2
)
compile
(
net
)
def
test_forward_graph_auto_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
net
=
Net
(
_w1
)
compile
(
net
)
def
test_forward_graph_repeat_calc
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
2
,
2
,
4
),
(
2
,
2
,
4
))
strategy2
=
((
1
,
2
,
2
),
)
net
=
Net
(
_w1
,
strategy1
,
strategy2
)
compile
(
net
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录