Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
414184c1
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
414184c1
编写于
8月 28, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 28, 2020
浏览文件
操作
浏览文件
下载
差异文件
!5367 Check the parameter's split strategies if it has multiple users
Merge pull request !5367 from yangzhenzhang/check-parameter-split
上级
5c0b0c49
fbda03bb
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
251 addition
and
147 deletion
+251
-147
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
+4
-99
mindspore/ccsrc/frontend/parallel/step_parallel.cc
mindspore/ccsrc/frontend/parallel/step_parallel.cc
+146
-0
mindspore/ccsrc/frontend/parallel/step_parallel.h
mindspore/ccsrc/frontend/parallel/step_parallel.h
+7
-0
tests/ut/python/parallel/test_auto_parallel_reshape.py
tests/ut/python/parallel/test_auto_parallel_reshape.py
+0
-48
tests/ut/python/parallel/test_parameter_multi_users.py
tests/ut/python/parallel/test_parameter_multi_users.py
+94
-0
未找到文件。
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
浏览文件 @
414184c1
...
...
@@ -649,108 +649,13 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
MS_LOG
(
INFO
)
<<
"Constructing edges for cost graph ends."
;
}
std
::
pair
<
AnfNodePtr
,
std
::
vector
<
AnfNodePtr
>>
CNodeWithRefKeys
(
const
AnfNodePtr
&
cnode
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
std
::
vector
<
AnfNodePtr
>
refkeys
;
if
(
cnode
->
isa
<
CNode
>
())
{
auto
cnode_ptr
=
cnode
->
cast
<
CNodePtr
>
();
auto
inputs
=
cnode_ptr
->
inputs
();
for
(
auto
&
one_input
:
inputs
)
{
if
(
IsValueNode
<
RefKey
>
(
one_input
))
{
refkeys
.
push_back
(
one_input
);
}
}
if
(
refkeys
.
size
()
>=
1
)
{
return
std
::
make_pair
(
cnode
,
refkeys
);
}
}
return
{
nullptr
,
refkeys
};
}
void
AugmentCostGraph
(
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
)
{
// Step 3
for
(
auto
&
node
:
all_nodes
)
{
auto
cnode_with_refkeys
=
CNodeWithRefKeys
(
node
);
if
((
!
node
->
isa
<
Parameter
>
())
&&
(
cnode_with_refkeys
.
first
==
nullptr
))
{
continue
;
}
std
::
string
parameter_name
;
AnfNodePtr
target_parameter
=
nullptr
;
AnfNodeIndexSet
target_set
;
if
(
cnode_with_refkeys
.
first
!=
nullptr
)
{
// Dealing with the RefKey case
auto
refkeys
=
cnode_with_refkeys
.
second
;
auto
cnode
=
cnode_with_refkeys
.
first
;
auto
cnode_ptr
=
cnode
->
cast
<
CNodePtr
>
();
if
(
cnode_ptr
==
nullptr
||
!
IsValueNode
<
Primitive
>
(
cnode_ptr
->
input
(
0
)))
{
continue
;
}
if
(
!
IsAutoParallelCareNode
(
cnode_ptr
))
{
continue
;
}
if
(
refkeys
.
size
()
>
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"CNode: "
<<
cnode
->
fullname_with_scope
()
<<
" 's inputs have more than 1 RefKeys."
;
}
MS_EXCEPTION_IF_NULL
(
cnode
->
func_graph
());
auto
cnode_func_graph
=
cnode
->
func_graph
();
MS_EXCEPTION_IF_NULL
(
cnode
->
func_graph
()
->
manager
());
// Find the RefKey being used
auto
candidate_set_by_refkey
=
cnode_func_graph
->
manager
()
->
node_users
()[
refkeys
[
0
]];
for
(
auto
&
candidate
:
candidate_set_by_refkey
)
{
auto
candidate_node
=
candidate
.
first
;
auto
c
=
candidate_node
->
cast
<
CNodePtr
>
();
if
(
c
==
nullptr
||
!
IsValueNode
<
Primitive
>
(
c
->
input
(
0
)))
{
continue
;
}
if
(
!
IsAutoParallelCareNode
(
c
))
{
continue
;
}
target_set
.
add
(
candidate
);
}
// Find the corresponding Parameter being used
std
::
vector
<
AnfNodePtr
>
parameters
=
FindParameterByRefKeyNode
(
refkeys
[
0
],
cnode_func_graph
);
if
(
parameters
.
size
()
!=
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"Find parameter by ref key node failed"
;
}
parameter_name
=
parameters
[
0
]
->
cast
<
ParameterPtr
>
()
->
name
();
target_parameter
=
parameters
[
0
];
auto
candidate_set_by_para
=
cnode_func_graph
->
manager
()
->
node_users
()[
parameters
[
0
]];
for
(
auto
&
candidate
:
candidate_set_by_para
)
{
auto
candidate_node
=
candidate
.
first
;
auto
c
=
candidate_node
->
cast
<
CNodePtr
>
();
if
(
c
==
nullptr
||
!
IsValueNode
<
Primitive
>
(
c
->
input
(
0
)))
{
continue
;
}
if
(
!
IsAutoParallelCareNode
(
c
))
{
continue
;
}
(
void
)
target_set
.
insert
(
candidate
);
}
}
else
if
(
node
->
isa
<
Parameter
>
())
{
// Dealing with the Parameter case
MS_EXCEPTION_IF_NULL
(
node
->
func_graph
());
MS_EXCEPTION_IF_NULL
(
node
->
func_graph
()
->
manager
());
auto
candidate_set
=
node
->
func_graph
()
->
manager
()
->
node_users
()[
node
];
for
(
auto
&
candidate
:
candidate_set
)
{
auto
candidate_node
=
candidate
.
first
;
auto
c
=
candidate_node
->
cast
<
CNodePtr
>
();
if
(
c
==
nullptr
||
!
IsValueNode
<
Primitive
>
(
c
->
input
(
0
)))
{
continue
;
}
if
(
!
IsAutoParallelCareNode
(
c
))
{
continue
;
}
(
void
)
target_set
.
insert
(
candidate
);
}
// In this case, node is a Parameter
parameter_name
=
node
->
cast
<
ParameterPtr
>
()
->
name
();
target_parameter
=
node
;
}
ParameterUsersInfo
parameter_users_info
=
FindParameterUsers
(
node
,
IsAutoParallelCareNode
);
auto
parameter_name
=
parameter_users_info
.
first
;
auto
target_parameter
=
parameter_users_info
.
second
.
first
;
auto
target_set
=
parameter_users_info
.
second
.
second
;
if
(
target_set
.
size
()
<=
1
)
{
continue
;
}
...
...
mindspore/ccsrc/frontend/parallel/step_parallel.cc
浏览文件 @
414184c1
...
...
@@ -2499,6 +2499,149 @@ void HandleForwardMakeTuple(const std::vector<AnfNodePtr> &all_nodes) {
}
}
RefKeyPair
CNodeWithRefKeys
(
const
AnfNodePtr
&
cnode
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
std
::
vector
<
AnfNodePtr
>
refkeys
;
if
(
cnode
->
isa
<
CNode
>
())
{
auto
cnode_ptr
=
cnode
->
cast
<
CNodePtr
>
();
auto
inputs
=
cnode_ptr
->
inputs
();
for
(
auto
&
one_input
:
inputs
)
{
if
(
IsValueNode
<
RefKey
>
(
one_input
))
{
refkeys
.
push_back
(
one_input
);
}
}
if
(
refkeys
.
size
()
>=
1
)
{
return
std
::
make_pair
(
cnode
,
refkeys
);
}
}
return
{
nullptr
,
refkeys
};
}
ParameterUsersInfo
FindParameterNodeUsers
(
const
AnfNodePtr
&
node
,
bool
(
*
IsCareNode
)(
const
CNodePtr
&
))
{
// In this case, node is a Parameter
ParameterUsersInfo
parameter_user_info
;
MS_EXCEPTION_IF_NULL
(
node
->
func_graph
());
MS_EXCEPTION_IF_NULL
(
node
->
func_graph
()
->
manager
());
auto
candidate_set
=
node
->
func_graph
()
->
manager
()
->
node_users
()[
node
];
for
(
auto
&
candidate
:
candidate_set
)
{
auto
candidate_node
=
candidate
.
first
;
auto
c
=
candidate_node
->
cast
<
CNodePtr
>
();
if
(
c
==
nullptr
||
!
IsValueNode
<
Primitive
>
(
c
->
input
(
0
))
||
!
IsCareNode
(
c
))
{
continue
;
}
(
void
)
parameter_user_info
.
second
.
second
.
insert
(
candidate
);
}
parameter_user_info
.
first
=
node
->
cast
<
ParameterPtr
>
()
->
name
();
parameter_user_info
.
second
.
first
=
node
;
return
parameter_user_info
;
}
ParameterUsersInfo
FindRefKeyNodeUsers
(
const
RefKeyPair
&
ref_key_pair
,
bool
(
*
IsCareNode
)(
const
CNodePtr
&
))
{
// Dealing with the RefKey case
ParameterUsersInfo
parameter_user_info
;
auto
refkeys
=
ref_key_pair
.
second
;
auto
cnode
=
ref_key_pair
.
first
;
auto
cnode_ptr
=
cnode
->
cast
<
CNodePtr
>
();
if
((
cnode_ptr
==
nullptr
)
||
!
IsValueNode
<
Primitive
>
(
cnode_ptr
->
input
(
0
))
||
!
IsCareNode
(
cnode_ptr
))
{
return
parameter_user_info
;
}
if
(
refkeys
.
size
()
>
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"CNode: "
<<
cnode
->
fullname_with_scope
()
<<
"'s inputs have more than 1 RefKeys"
;
}
MS_EXCEPTION_IF_NULL
(
cnode
->
func_graph
());
auto
cnode_func_graph
=
cnode
->
func_graph
();
MS_EXCEPTION_IF_NULL
(
cnode
->
func_graph
()
->
manager
());
// Find the RefKey being used
auto
candidate_set_by_refkey
=
cnode_func_graph
->
manager
()
->
node_users
()[
refkeys
[
0
]];
for
(
auto
&
candidate
:
candidate_set_by_refkey
)
{
auto
candidate_node
=
candidate
.
first
;
auto
c
=
candidate_node
->
cast
<
CNodePtr
>
();
if
((
c
==
nullptr
)
||
!
IsValueNode
<
Primitive
>
(
c
->
input
(
0
))
||
!
IsCareNode
(
c
))
{
continue
;
}
parameter_user_info
.
second
.
second
.
add
(
candidate
);
}
// Find the corresponding Parameter being used
std
::
vector
<
AnfNodePtr
>
parameters
=
FindParameterByRefKeyNode
(
refkeys
[
0
],
cnode_func_graph
);
if
(
parameters
.
size
()
!=
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"Find parameter by ref key node failed"
;
}
parameter_user_info
.
first
=
parameters
[
0
]
->
cast
<
ParameterPtr
>
()
->
name
();
parameter_user_info
.
second
.
first
=
parameters
[
0
];
auto
candidate_set_by_para
=
cnode_func_graph
->
manager
()
->
node_users
()[
parameters
[
0
]];
for
(
auto
&
candidate
:
candidate_set_by_para
)
{
auto
candidate_node
=
candidate
.
first
;
auto
c
=
candidate_node
->
cast
<
CNodePtr
>
();
if
((
c
==
nullptr
)
||
!
IsValueNode
<
Primitive
>
(
c
->
input
(
0
))
||
!
IsCareNode
(
c
))
{
continue
;
}
(
void
)
parameter_user_info
.
second
.
second
.
insert
(
candidate
);
}
return
parameter_user_info
;
}
ParameterUsersInfo
FindParameterUsers
(
const
AnfNodePtr
&
node
,
bool
(
*
IsCareNode
)(
const
CNodePtr
&
))
{
ParameterUsersInfo
parameter_users_info
;
auto
cnode_with_refkeys
=
CNodeWithRefKeys
(
node
);
if
(
cnode_with_refkeys
.
first
!=
nullptr
)
{
// the node is a ref key node
return
FindRefKeyNodeUsers
(
cnode_with_refkeys
,
IsCareNode
);
}
else
if
(
node
->
isa
<
Parameter
>
())
{
// the node is a parameter node
return
FindParameterNodeUsers
(
node
,
IsCareNode
);
}
return
parameter_users_info
;
}
Shape
ParameterSliceShape
(
const
std
::
pair
<
AnfNodePtr
,
int
>
&
param_info
)
{
auto
user_cnode
=
param_info
.
first
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
user_cnode
);
auto
user_input_index
=
param_info
.
second
;
OperatorInfoPtr
op_info
=
user_cnode
->
user_data
<
OperatorInfo
>
();
MS_EXCEPTION_IF_NULL
(
op_info
);
size_t
input_tensor_info_size
=
op_info
->
inputs_tensor_info
().
size
();
if
(
SizeToInt
(
input_tensor_info_size
)
<=
user_input_index
-
1
)
{
MS_LOG
(
EXCEPTION
)
<<
op_info
->
name
()
<<
": the size of inputs tensor info is "
<<
input_tensor_info_size
<<
", but the index is "
<<
user_input_index
-
1
;
}
TensorInfo
tensor_info
=
op_info
->
inputs_tensor_info
()[
user_input_index
-
1
];
MS_LOG
(
DEBUG
)
<<
"The op name is "
<<
op_info
->
name
()
<<
", the parameter index is "
<<
user_input_index
-
1
<<
", the slice shape is "
<<
ShapeToString
(
tensor_info
.
slice_shape
())
<<
", the origin shape is "
<<
ShapeToString
(
tensor_info
.
shape
());
return
tensor_info
.
slice_shape
();
}
void
CheckParameterSplit
(
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
)
{
for
(
auto
&
node
:
all_nodes
)
{
ParameterUsersInfo
parameter_users_info
=
FindParameterUsers
(
node
,
IsParallelCareNode
);
auto
users_set
=
parameter_users_info
.
second
.
second
;
if
(
users_set
.
size
()
<=
1
)
{
continue
;
}
auto
parameter_name
=
parameter_users_info
.
first
;
MS_LOG
(
INFO
)
<<
"The parameter: "
<<
parameter_name
<<
" has "
<<
users_set
.
size
()
<<
" users"
;
auto
first_user
=
users_set
.
pop
();
Shape
first_user_slice_shape
=
ParameterSliceShape
(
first_user
);
for
(
auto
&
user
:
users_set
)
{
Shape
user_slice_shape
=
ParameterSliceShape
(
user
);
if
(
first_user_slice_shape
!=
user_slice_shape
)
{
MS_LOG
(
EXCEPTION
)
<<
"The parameter: "
<<
parameter_name
<<
" has multiple users, but the split strategies are different"
;
}
}
}
}
bool
StepParallel
(
const
FuncGraphPtr
&
root
,
const
opt
::
OptimizerPtr
&
optimizer
)
{
MS_EXCEPTION_IF_NULL
(
root
);
MS_EXCEPTION_IF_NULL
(
optimizer
);
...
...
@@ -2556,6 +2699,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
HandleForwardMakeTuple
(
all_nodes
);
// if the input or parameter has multiple users, check whether its split strategies are consistent.
CheckParameterSplit
(
all_nodes
);
// save strategy as checkpoint for multi-train
if
(
StrategyCheckpoint
::
GetInstance
().
SaveCheckPointOn
())
{
CheckpointStrategy
(
all_nodes
);
...
...
mindspore/ccsrc/frontend/parallel/step_parallel.h
浏览文件 @
414184c1
...
...
@@ -150,6 +150,13 @@ std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node);
std
::
set
<
FuncGraphPtr
>
ForwardGraph
(
const
FuncGraphPtr
&
root
);
bool
AnfNodeIsPrimitive
(
const
AnfNodePtr
&
anf_node
,
const
std
::
string
&
prim_name
);
using
RefKeyPair
=
std
::
pair
<
AnfNodePtr
,
std
::
vector
<
AnfNodePtr
>>
;
using
ParameterUsersInfo
=
std
::
pair
<
std
::
string
,
std
::
pair
<
AnfNodePtr
,
AnfNodeIndexSet
>>
;
RefKeyPair
CNodeWithRefKeys
(
const
AnfNodePtr
&
cnode
);
ParameterUsersInfo
FindParameterUsers
(
const
AnfNodePtr
&
node
,
bool
(
*
IsCareNode
)(
const
CNodePtr
&
));
}
// namespace parallel
}
// namespace mindspore
...
...
tests/ut/python/parallel/test_auto_parallel_reshape.py
浏览文件 @
414184c1
...
...
@@ -245,51 +245,3 @@ def test_reshape_auto_5():
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
)
net
.
set_auto_parallel
()
_executor
.
compile
(
net
,
x
,
y
)
def
test_reshape_auto_6
():
class
NetWithLoss6
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
NetWithLoss6
,
self
).
__init__
()
self
.
loss
=
VirtualLoss
()
self
.
network
=
network
def
construct
(
self
,
x
,
y
):
predict
=
self
.
network
(
x
,
y
)
return
self
.
loss
(
predict
)
class
GradWrap6
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
GradWrap6
,
self
).
__init__
()
self
.
network
=
network
def
construct
(
self
,
x
,
y
):
return
grad_all
(
self
.
network
)(
x
,
y
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
relu
=
P
.
ReLU
()
self
.
mul
=
P
.
Mul
()
self
.
reshape
=
P
.
Reshape
()
self
.
reduce_mean
=
P
.
ReduceMean
()
self
.
wide_w
=
Parameter
(
Tensor
(
np
.
ones
([
4
,
1024
,
1
]),
dtype
=
ms
.
float32
),
name
=
"weight"
)
def
construct
(
self
,
x
,
y
):
out1
=
x
+
self
.
wide_w
w
=
self
.
reshape
(
self
.
wide_w
,
(
4
,
1024
))
out1
=
self
.
reduce_mean
(
out1
,
1
)
out1
=
out1
-
w
out2
=
self
.
mul
(
y
,
w
)
out
=
out1
+
out2
return
out
size
=
8
context
.
set_auto_parallel_context
(
device_num
=
size
,
global_rank
=
0
)
x
=
Tensor
(
np
.
ones
([
4
,
1024
,
1
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
4
,
1024
,]),
dtype
=
ms
.
float32
)
net
=
GradWrap6
(
NetWithLoss6
(
Net
()))
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
)
net
.
set_auto_parallel
()
_executor
.
compile
(
net
,
x
,
y
)
tests/ut/python/parallel/test_parameter_multi_users.py
0 → 100644
浏览文件 @
414184c1
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
pytest
import
mindspore
as
ms
from
mindspore
import
context
,
Tensor
,
Parameter
from
mindspore.common.api
import
_executor
from
mindspore.nn
import
Cell
,
TrainOneStepCell
,
Momentum
from
mindspore.ops
import
operations
as
P
class
Net
(
Cell
):
def
__init__
(
self
,
mul_weight
,
strategy1
=
None
,
strategy2
=
None
):
super
().
__init__
()
self
.
mul
=
P
.
Mul
().
set_strategy
(
strategy1
)
self
.
mul2
=
P
.
Mul
().
set_strategy
(
strategy2
)
self
.
mul_weight
=
Parameter
(
mul_weight
,
"w1"
)
def
construct
(
self
,
x
,
b
):
out
=
self
.
mul
(
x
,
self
.
mul_weight
)
out
=
self
.
mul2
(
out
,
self
.
mul_weight
)
return
out
class
Net2
(
Cell
):
def
__init__
(
self
,
mul_weight
,
strategy1
=
None
,
strategy2
=
None
):
super
().
__init__
()
self
.
mul
=
P
.
Mul
().
set_strategy
(
strategy1
)
self
.
mul2
=
P
.
Mul
().
set_strategy
(
strategy2
)
self
.
mul_weight
=
Parameter
(
mul_weight
,
"w1"
)
def
construct
(
self
,
x
,
b
):
out
=
self
.
mul
(
x
,
self
.
mul_weight
)
out
=
self
.
mul2
(
x
,
out
)
return
out
_x
=
Tensor
(
np
.
ones
([
128
,
64
,
32
]),
dtype
=
ms
.
float32
)
_w
=
Tensor
(
np
.
ones
([
128
,
64
,
32
]),
dtype
=
ms
.
float32
)
_b
=
Tensor
(
np
.
ones
([
128
,
64
,
32
]),
dtype
=
ms
.
float32
)
def
compile_net
(
net
):
optimizer
=
Momentum
(
net
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
train_net
=
TrainOneStepCell
(
net
,
optimizer
)
train_net
.
set_auto_parallel
()
_executor
.
compile
(
train_net
,
_x
,
_b
)
context
.
reset_auto_parallel_context
()
def
test_parameter_same_split
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
16
,
1
,
1
),
(
16
,
1
,
1
))
strategy2
=
((
16
,
1
,
1
),
(
16
,
1
,
1
))
net
=
Net
(
_w
,
strategy1
,
strategy2
)
compile_net
(
net
)
def
test_parameter_different_split
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
16
,
1
,
1
),
(
16
,
1
,
1
))
strategy2
=
((
4
,
4
,
1
),
(
4
,
4
,
1
))
net
=
Net
(
_w
,
strategy1
,
strategy2
)
with
pytest
.
raises
(
RuntimeError
):
compile_net
(
net
)
def
test_input_same_split
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
16
,
1
,
1
),
(
16
,
1
,
1
))
strategy2
=
((
16
,
1
,
1
),
(
16
,
1
,
1
))
net
=
Net
(
_w
,
strategy1
,
strategy2
)
compile_net
(
net
)
def
test_input_different_split
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
16
,
global_rank
=
0
)
strategy1
=
((
16
,
1
,
1
),
(
16
,
1
,
1
))
strategy2
=
((
4
,
4
,
1
),
(
4
,
4
,
1
))
net
=
Net2
(
_w
,
strategy1
,
strategy2
)
with
pytest
.
raises
(
RuntimeError
):
compile_net
(
net
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录