Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ef71ae94
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ef71ae94
编写于
4月 27, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 27, 2020
浏览文件
操作
浏览文件
下载
差异文件
!698 [Auto parallel] Support multi-subgraphs in auto-parallel
Merge pull request !698 from Xiaoda/support-wide-deep-in-auto-parallel
上级
e537a708
e2274156
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
298 addition
and
26 deletion
+298
-26
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h
+1
-0
mindspore/ccsrc/parallel/costmodel_context.cc
mindspore/ccsrc/parallel/costmodel_context.cc
+2
-0
mindspore/ccsrc/parallel/costmodel_context.h
mindspore/ccsrc/parallel/costmodel_context.h
+5
-0
mindspore/ccsrc/parallel/step_auto_parallel.cc
mindspore/ccsrc/parallel/step_auto_parallel.cc
+80
-24
mindspore/ccsrc/parallel/step_auto_parallel.h
mindspore/ccsrc/parallel/step_auto_parallel.h
+3
-1
mindspore/ccsrc/pipeline/action.cc
mindspore/ccsrc/pipeline/action.cc
+5
-1
mindspore/ccsrc/pipeline/init.cc
mindspore/ccsrc/pipeline/init.cc
+2
-0
mindspore/parallel/_cost_model_context.py
mindspore/parallel/_cost_model_context.py
+29
-0
tests/ut/python/parallel/test_auto_parallel_double_subgraphs.py
...ut/python/parallel/test_auto_parallel_double_subgraphs.py
+101
-0
tests/ut/python/parallel/test_auto_parallel_two_bn.py
tests/ut/python/parallel/test_auto_parallel_two_bn.py
+70
-0
未找到文件。
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h
浏览文件 @
ef71ae94
...
...
@@ -44,6 +44,7 @@ namespace parallel {
#define DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE 16
#define DEFAULT_FULLY_USE_DEVICES true
#define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false
#define DEFAULT_IS_MULTI_SUBGRAPHS false
class
CostGraph
;
using
CostGraphPtr
=
std
::
shared_ptr
<
CostGraph
>
;
...
...
mindspore/ccsrc/parallel/costmodel_context.cc
浏览文件 @
ef71ae94
...
...
@@ -46,6 +46,7 @@ void CostModelContext::ResetCostModel() {
costmodel_communi_threshold_
=
DEFAULT_COST_MODEL_COMMUNI_THRESHOLD
;
costmodel_communi_const_
=
DEFAULT_COST_MODEL_COMMUNI_CONST
;
costmodel_communi_bias_
=
DEFAULT_COST_MODEL_COMMUNI_BIAS
;
is_multi_subgraphs_
=
DEFAULT_IS_MULTI_SUBGRAPHS
;
costmodel_allreduce_fusion_algorithm_
=
DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM
;
costmodel_allreduce_fusion_times_
=
DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES
;
costmodel_allreduce_fusion_tail_percent_
=
DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT
;
...
...
@@ -84,6 +85,7 @@ void CostModelContext::set_costmodel_communi_const(double cm_communi_const) {
void
CostModelContext
::
set_costmodel_communi_bias
(
double
cm_communi_bias
)
{
costmodel_communi_bias_
=
cm_communi_bias
;
}
void
CostModelContext
::
set_multi_subgraphs
(
bool
multi_graphs
)
{
is_multi_subgraphs_
=
multi_graphs
;
}
void
CostModelContext
::
set_costmodel_allreduce_fusion_algorithm
(
int32_t
algorithm
)
{
costmodel_allreduce_fusion_algorithm_
=
algorithm
;
}
...
...
mindspore/ccsrc/parallel/costmodel_context.h
浏览文件 @
ef71ae94
...
...
@@ -67,6 +67,9 @@ class CostModelContext {
void
set_costmodel_communi_bias
(
double
);
double
costmodel_communi_bias
()
const
{
return
costmodel_communi_bias_
;
}
void
set_multi_subgraphs
(
bool
);
bool
is_multi_subgraphs
()
const
{
return
is_multi_subgraphs_
;
}
void
set_costmodel_allreduce_fusion_algorithm
(
int32_t
);
int32_t
costmodel_allreduce_fusion_algorithm
()
const
{
return
costmodel_allreduce_fusion_algorithm_
;
}
...
...
@@ -138,6 +141,8 @@ class CostModelContext {
// COST_MODEL_COMMUNI_BIAS
double
costmodel_communi_bias_
;
bool
is_multi_subgraphs_
;
int32_t
costmodel_allreduce_fusion_algorithm_
;
int32_t
costmodel_allreduce_fusion_times_
;
...
...
mindspore/ccsrc/parallel/step_auto_parallel.cc
浏览文件 @
ef71ae94
...
...
@@ -426,13 +426,13 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
return
operator_info
;
}
Status
ConstructCostGraphNodes
(
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
,
const
FuncGraphPtr
&
)
{
// Using CNode's UniqueIds to construct nodes
Status
ConstructCostGraphNodesByUniqueId
(
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
,
const
FuncGraphPtr
&
)
{
MS_LOG
(
INFO
)
<<
"Constructing nodes for cost graph begins."
;
entire_costgraph
=
std
::
make_shared
<
CostGraph
>
();
entire_costgraph
->
SetDeviceMemoryAndCostParameter
();
bool
new_operator
=
true
,
first_operator
=
true
;
std
::
string
first_operator_cnode
;
size_t
current_op_index
=
0
;
// The map from CNode's UniqueId to its operatorInfo
std
::
map
<
std
::
string
,
OperatorInfoPtr
>
from_cnode_to_info
;
// Step 1
for
(
auto
&
node
:
all_nodes
)
{
...
...
@@ -449,12 +449,8 @@ Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const F
PrimitivePtr
prim
=
GetValueNode
<
PrimitivePtr
>
(
prim_anf_node
);
MS_EXCEPTION_IF_NULL
(
prim
);
// When visiting the second subgraph, use the corresponding operatorInfo which already created
bool
modify_new_operator
=
(
new_operator
)
&&
(
!
first_operator
)
&&
(
cnode
->
UniqueId
()
==
first_operator_cnode
);
if
(
modify_new_operator
)
{
new_operator
=
false
;
}
if
(
new_operator
)
{
auto
search_cnode
=
from_cnode_to_info
.
find
(
cnode
->
UniqueId
());
if
(
search_cnode
==
from_cnode_to_info
.
end
())
{
auto
operator_info
=
CreateTheOperatorInfo
(
prim
,
cnode
);
if
(
operator_info
==
nullptr
)
{
return
FAILED
;
...
...
@@ -465,14 +461,67 @@ Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const F
entire_costgraph
->
AddOperator
(
operator_info
);
(
void
)
cnode
->
set_operator_info
(
operator_info
);
if
(
first_operator
)
{
first_operator_cnode
=
cnode
->
UniqueId
();
first_operator
=
false
;
MS_LOG
(
INFO
)
<<
"The CNode with UniqueId: "
<<
cnode
->
UniqueId
()
<<
" and UniqueIdThroughCopy: "
<<
cnode
->
UniqueIdThroughCopy
()
<<
" is set OperatorInfo: "
<<
operator_info
->
name
()
<<
", Primitive: "
<<
prim
->
name
();
(
void
)
from_cnode_to_info
.
emplace
(
std
::
make_pair
(
cnode
->
UniqueIdThroughCopy
(),
operator_info
));
// Needed by rec_parser
entire_costgraph
->
add_inputs_tensor_name
(
inputs_tensor_name
);
}
else
{
// Two CNODEs' UniqueIds should not be equal
MS_LOG
(
EXCEPTION
)
<<
"The CNode with UniqueId: "
<<
cnode
->
UniqueId
()
<<
" and UniqueIdThroughCopy: "
<<
cnode
->
UniqueIdThroughCopy
()
<<
" is set OperatorInfo: "
<<
search_cnode
->
second
->
name
()
<<
", Primitive: "
<<
prim
->
name
();
}
}
MS_LOG
(
INFO
)
<<
"Constructing nodes for cost graph ends."
;
return
SUCCESS
;
}
// Using CNode's UniqueIdThroughCopys to construct nodes
Status
ConstructCostGraphNodesByUniqueIdTC
(
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
,
const
FuncGraphPtr
&
)
{
MS_LOG
(
INFO
)
<<
"Constructing nodes for cost graph begins."
;
entire_costgraph
=
std
::
make_shared
<
CostGraph
>
();
entire_costgraph
->
SetDeviceMemoryAndCostParameter
();
// The map from CNode's UniqueIdThroughCopy to its operatorInfo
std
::
map
<
std
::
string
,
OperatorInfoPtr
>
from_cnode_to_info
;
for
(
auto
&
node
:
all_nodes
)
{
// NOTE: we only care about splittable Primitive operators
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
bool
bool_result
=
(
cnode
==
nullptr
)
||
(
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)));
if
(
bool_result
)
{
continue
;
}
ValueNodePtr
prim_anf_node
=
cnode
->
input
(
0
)
->
cast
<
ValueNodePtr
>
();
if
(
!
IsAutoParallelCareNode
(
cnode
))
{
continue
;
}
PrimitivePtr
prim
=
GetValueNode
<
PrimitivePtr
>
(
prim_anf_node
);
// Find the operatorInfo if it exists
auto
search_cnode
=
from_cnode_to_info
.
find
(
cnode
->
UniqueIdThroughCopy
());
if
(
search_cnode
==
from_cnode_to_info
.
end
())
{
// In this case, the corresponding OperatorInfo is not created, create the new one.
auto
operator_info
=
CreateTheOperatorInfo
(
prim
,
cnode
);
if
(
operator_info
==
nullptr
)
{
return
FAILED
;
}
// Needed by rec_parser
operator_info
->
set_type
(
prim
->
name
());
std
::
vector
<
std
::
string
>
inputs_tensor_name
=
ExtractInputsTensorName
(
cnode
);
entire_costgraph
->
AddOperator
(
operator_info
);
(
void
)
cnode
->
set_operator_info
(
operator_info
);
MS_LOG
(
INFO
)
<<
"The CNode with UniqueId: "
<<
cnode
->
UniqueId
()
<<
" and UniqueIdThroughCopy: "
<<
cnode
->
UniqueIdThroughCopy
()
<<
" is set OperatorInfo: "
<<
operator_info
->
name
()
<<
", Primitive: "
<<
prim
->
name
();
(
void
)
from_cnode_to_info
.
emplace
(
std
::
make_pair
(
cnode
->
UniqueIdThroughCopy
(),
operator_info
));
// Needed by rec_parser
entire_costgraph
->
add_inputs_tensor_name
(
inputs_tensor_name
);
}
else
{
auto
current_op_ptr
=
entire_costgraph
->
FindOperatorByIndex
(
current_op_index
)
;
auto
current_op_ptr
=
search_cnode
->
second
;
if
(
current_op_ptr
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Find "
<<
prim
->
name
()
<<
" from CostGraph failed."
;
}
else
{
...
...
@@ -484,14 +533,12 @@ Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const F
<<
" does not match the Prim: "
<<
prim
->
name
();
}
(
void
)
cnode
->
set_operator_info
(
current_op_ptr
);
current_op_index
++
;
MS_LOG
(
INFO
)
<<
"The CNode with UniqueId: "
<<
cnode
->
UniqueId
()
<<
" and UniqueIdThroughCopy: "
<<
cnode
->
UniqueIdThroughCopy
()
<<
" is set OperatorInfo: "
<<
current_op_ptr
->
name
()
<<
", Primitive: "
<<
prim
->
name
();
}
}
}
if
((
!
new_operator
)
&&
(
current_op_index
!=
entire_costgraph
->
GetOperators
().
size
()))
{
MS_LOG
(
EXCEPTION
)
<<
"The second subgraph's operator number: "
<<
current_op_index
<<
" does not match the first ones: "
<<
entire_costgraph
->
GetOperators
().
size
();
}
MS_LOG
(
INFO
)
<<
"Constructing nodes for cost graph ends."
;
return
SUCCESS
;
...
...
@@ -844,11 +891,20 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
// OUTPUT: the determined strategy for each operator.
// Step 1
if
(
ConstructCostGraphNodes
(
all_nodes
,
root
)
==
SUCCESS
)
{
MS_LOG
(
INFO
)
<<
"Constructing nodes for cost graph succeeded. There are "
<<
entire_costgraph
->
GetOperators
().
size
()
<<
" operators."
;
if
(
CostModelContext
::
GetInstance
()
->
is_multi_subgraphs
())
{
if
(
ConstructCostGraphNodesByUniqueIdTC
(
all_nodes
,
root
)
==
SUCCESS
)
{
MS_LOG
(
INFO
)
<<
"Constructing nodes for cost graph succeeded. There are "
<<
entire_costgraph
->
GetOperators
().
size
()
<<
" operators."
;
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Constructing nodes for cost graph failed."
;
}
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Constructing nodes for cost graph failed."
;
if
(
ConstructCostGraphNodesByUniqueId
(
all_nodes
,
root
)
==
SUCCESS
)
{
MS_LOG
(
INFO
)
<<
"Constructing nodes for cost graph succeeded. There are "
<<
entire_costgraph
->
GetOperators
().
size
()
<<
" operators."
;
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Constructing nodes for cost graph failed."
;
}
}
// Step 2
...
...
@@ -916,7 +972,7 @@ std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::st
}
Status
ParallelStrategyRecSearch
(
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
,
const
FuncGraphPtr
&
root
)
{
if
(
ConstructCostGraphNodes
(
all_nodes
,
root
)
==
SUCCESS
)
{
if
(
ConstructCostGraphNodes
ByUniqueId
(
all_nodes
,
root
)
==
SUCCESS
)
{
MS_LOG
(
INFO
)
<<
"Constructing nodes for cost graph succeeded. There are "
<<
entire_costgraph
->
GetOperators
().
size
()
<<
" operators."
;
}
else
{
...
...
mindspore/ccsrc/parallel/step_auto_parallel.h
浏览文件 @
ef71ae94
...
...
@@ -43,7 +43,9 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node);
std
::
vector
<
TypePtr
>
ExtractOutputTypeByNode
(
const
CNodePtr
&
node
);
Status
ConstructCostGraphNodes
(
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
,
const
FuncGraphPtr
&
root
);
Status
ConstructCostGraphNodesByUniqueId
(
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
,
const
FuncGraphPtr
&
root
);
Status
ConstructCostGraphNodesByUniqueIdTC
(
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
,
const
FuncGraphPtr
&
root
);
void
ConstructCostGraphEdges
(
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
);
...
...
mindspore/ccsrc/pipeline/action.cc
浏览文件 @
ef71ae94
...
...
@@ -24,6 +24,7 @@
#include <functional>
#include "ir/func_graph_cloner.h"
#include "parallel/costmodel_context.h"
#include "pipeline/pass.h"
#include "pipeline/parse/parse_base.h"
#include "pipeline/parse/data_converter.h"
...
...
@@ -341,7 +342,10 @@ static std::vector<ActionItem> CommonPipeline() {
// Resolve the python func
actions
.
emplace_back
(
std
::
make_pair
(
"symbol_resolve"
,
SymbolResolveAction
));
actions
.
emplace_back
(
std
::
make_pair
(
"combine_like_graphs"
,
CombineLikeGraphs
));
auto
multi_graphs
=
parallel
::
CostModelContext
::
GetInstance
()
->
is_multi_subgraphs
();
if
(
!
multi_graphs
)
{
actions
.
emplace_back
(
std
::
make_pair
(
"combine_like_graphs"
,
CombineLikeGraphs
));
}
actions
.
emplace_back
(
std
::
make_pair
(
"inference_opt_prepare"
,
InferenceOptPrepareAction
));
// Evaluate type and shape, and specialize
actions
.
emplace_back
(
std
::
make_pair
(
"abstract_specialize"
,
AbstractSpecializeAction
));
...
...
mindspore/ccsrc/pipeline/init.cc
浏览文件 @
ef71ae94
...
...
@@ -222,6 +222,8 @@ PYBIND11_MODULE(_c_expression, m) {
"Set the parameter cost_model_communi_bias of the DP algorithm."
)
.
def
(
"get_costmodel_communi_bias"
,
&
CostModelContext
::
costmodel_communi_bias
,
"Get the parameter cost_model_communi_bias of the DP algorithm."
)
.
def
(
"set_multi_subgraphs"
,
&
CostModelContext
::
set_multi_subgraphs
,
"Set the parameter is_multi_subgraphs."
)
.
def
(
"get_multi_subgraphs"
,
&
CostModelContext
::
is_multi_subgraphs
,
"Get the parameter is_multi_subgraphs."
)
.
def
(
"set_costmodel_allreduce_fusion_algorithm"
,
&
CostModelContext
::
set_costmodel_allreduce_fusion_algorithm
,
"Set the parameter gradient AllReduce fusion algorithm."
)
.
def
(
"get_costmodel_allreduce_fusion_algorithm"
,
&
CostModelContext
::
costmodel_allreduce_fusion_algorithm
,
...
...
mindspore/parallel/_cost_model_context.py
浏览文件 @
ef71ae94
...
...
@@ -214,6 +214,31 @@ class _CostModelContext:
raise
ValueError
(
"Context handle is none in context!!!"
)
return
self
.
_context_handle
.
get_costmodel_communi_bias
()
def
set_multi_subgraphs
(
self
,
multi_subgraph
):
"""
Set the flag of ANF graph containing multiple subgraphs.
Args:
multi_subgraph (bool): A parameter used in marking the multi-subgraphs flag.
Raises:
ValueError: If context handle is none.
"""
if
self
.
_context_handle
is
None
:
raise
ValueError
(
"Context handle is none in context!!!"
)
self
.
_context_handle
.
set_multi_subgraphs
(
multi_subgraph
)
def
get_multi_subgraphs
(
self
):
"""
Get the flag of ANF graph containing multiple subgraphs.
Raises:
ValueError: If context handle is none.
"""
if
self
.
_context_handle
is
None
:
raise
ValueError
(
"Context handle is none in context!!!"
)
return
self
.
_context_handle
.
get_multi_subgraphs
()
def
set_costmodel_allreduce_fusion_algorithm
(
self
,
algorithm
):
"""
Set costmodel allreduce fusion algorithm.
...
...
@@ -427,6 +452,7 @@ set_cost_model_context_func_map = {
"costmodel_communi_threshold"
:
cost_model_context
().
set_costmodel_communi_threshold
,
"costmodel_communi_const"
:
cost_model_context
().
set_costmodel_communi_const
,
"costmodel_communi_bias"
:
cost_model_context
().
set_costmodel_communi_bias
,
"multi_subgraphs"
:
cost_model_context
().
set_multi_subgraphs
,
"costmodel_allreduce_fusion_algorithm"
:
cost_model_context
().
set_costmodel_allreduce_fusion_algorithm
,
"costmodel_allreduce_fusion_times"
:
cost_model_context
().
set_costmodel_allreduce_fusion_times
,
"costmodel_allreduce_fusion_tail_percent"
:
cost_model_context
().
set_costmodel_allreduce_fusion_tail_percent
,
...
...
@@ -447,6 +473,7 @@ get_cost_model_context_func_map = {
"costmodel_communi_threshold"
:
cost_model_context
().
get_costmodel_communi_threshold
,
"costmodel_communi_const"
:
cost_model_context
().
get_costmodel_communi_const
,
"costmodel_communi_bias"
:
cost_model_context
().
get_costmodel_communi_bias
,
"multi_subgraphs"
:
cost_model_context
().
get_multi_subgraphs
(),
"costmodel_allreduce_fusion_algorithm"
:
cost_model_context
().
get_costmodel_allreduce_fusion_algorithm
,
"costmodel_allreduce_fusion_times"
:
cost_model_context
().
get_costmodel_allreduce_fusion_times
,
"costmodel_allreduce_fusion_tail_percent"
:
cost_model_context
().
get_costmodel_allreduce_fusion_tail_percent
,
...
...
@@ -461,6 +488,7 @@ get_cost_model_context_func_map = {
@
args_type_check
(
device_memory_capacity
=
float
,
costmodel_alpha
=
float
,
costmodel_beta
=
float
,
costmodel_gamma
=
float
,
costmodel_communi_threshold
=
float
,
costmodel_communi_const
=
float
,
costmodel_communi_bias
=
float
,
multi_subgraphs
=
bool
,
costmodel_allreduce_fusion_algorithm
=
int
,
costmodel_allreduce_fusion_times
=
int
,
costmodel_allreduce_fusion_tail_percent
=
float
,
costmodel_allreduce_fusion_tail_time
=
float
,
costmodel_allreduce_fusion_allreduce_inherent_time
=
float
,
...
...
@@ -481,6 +509,7 @@ def set_cost_model_context(**kwargs):
costmodel_communi_threshold (float): A parameter used in adjusting communication calculation for practice.
costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice.
costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice.
multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs.
costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm.
0: bypass allreduce fusion;
1: only use backward computation time to group allreduce;
...
...
tests/ut/python/parallel/test_auto_parallel_double_subgraphs.py
0 → 100644
浏览文件 @
ef71ae94
import
numpy
as
np
from
mindspore
import
context
import
mindspore
as
ms
import
mindspore.nn
as
nn
from
mindspore.nn.optim
import
Adam
,
FTRL
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
from
mindspore
import
Tensor
,
Parameter
,
ParameterTuple
from
mindspore.ops
import
composite
as
C
from
mindspore.parallel
import
_cost_model_context
as
cost_model_context
from
mindspore.common.api
import
_executor
from
mindspore.parallel
import
set_algo_parameters
,
get_algo_parameters
,
reset_algo_parameters
from
mindspore.parallel._utils
import
_reset_op_id
as
reset_op_id
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
mul
=
P
.
Mul
()
self
.
relu
=
P
.
ReLU
()
self
.
wd
=
Parameter
(
Tensor
(
np
.
ones
([
8
,
8
,
8
,
8
]).
astype
(
np
.
float32
)),
name
=
"wide"
)
self
.
wt
=
Parameter
(
Tensor
(
np
.
ones
([
8
,
8
,
8
,
8
]).
astype
(
np
.
float32
)),
name
=
"l"
)
def
construct
(
self
,
x
):
out
=
self
.
mul
(
x
,
self
.
wd
)
out
=
self
.
mul
(
out
,
self
.
wt
)
out
=
self
.
relu
(
out
)
return
out
class
NetWithLoss
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
NetWithLoss
,
self
).
__init__
()
self
.
sum
=
P
.
ReduceSum
()
self
.
mean
=
P
.
ReduceMean
()
self
.
net
=
network
def
construct
(
self
,
x
):
predict
=
self
.
net
(
x
)
loss1
=
self
.
sum
(
predict
,
-
1
)
loss2
=
self
.
mean
(
predict
,
-
1
)
return
loss1
,
loss2
class
IthOutputCell
(
nn
.
Cell
):
def
__init__
(
self
,
network
,
output_index
):
super
(
IthOutputCell
,
self
).
__init__
()
self
.
network
=
network
self
.
output_index
=
output_index
def
construct
(
self
,
x
):
predict
=
self
.
network
(
x
)[
self
.
output_index
]
return
predict
class
TrainStepWarp
(
nn
.
Cell
):
def
__init__
(
self
,
network
,
sens
=
1000.0
):
super
(
TrainStepWarp
,
self
).
__init__
()
self
.
network
=
network
self
.
network
.
set_train
()
self
.
trainable_params
=
network
.
trainable_params
()
weights_w
=
[]
weights_d
=
[]
for
params
in
self
.
trainable_params
:
weights_w
.
append
(
params
)
weights_d
.
append
(
params
)
self
.
weights_w
=
ParameterTuple
(
weights_w
)
self
.
weights_d
=
ParameterTuple
(
weights_d
)
self
.
optimizer_w
=
FTRL
(
learning_rate
=
1e-2
,
params
=
self
.
weights_w
,
l1
=
1e-8
,
l2
=
1e-8
,
initial_accum
=
1.0
)
self
.
optimizer_d
=
Adam
(
self
.
weights_d
,
learning_rate
=
3.5e-4
,
eps
=
1e-8
,
loss_scale
=
sens
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
grad_w
=
C
.
GradOperation
(
'grad_w'
,
get_by_list
=
True
,
sens_param
=
True
)
self
.
grad_d
=
C
.
GradOperation
(
'grad_d'
,
get_by_list
=
True
,
sens_param
=
True
)
self
.
sens
=
sens
self
.
loss_net_w
=
IthOutputCell
(
network
,
output_index
=
0
)
self
.
loss_net_d
=
IthOutputCell
(
network
,
output_index
=
1
)
def
construct
(
self
,
x
):
weights_w
=
self
.
weights_w
weights_d
=
self
.
weights_d
loss_w
,
loss_d
=
self
.
network
(
x
)
sens_w
=
P
.
Fill
()(
P
.
DType
()(
loss_w
),
P
.
Shape
()(
loss_w
),
self
.
sens
)
sens_d
=
P
.
Fill
()(
P
.
DType
()(
loss_d
),
P
.
Shape
()(
loss_d
),
self
.
sens
)
grads_w
=
self
.
grad_w
(
self
.
loss_net_w
,
weights_w
)(
x
,
sens_w
)
grads_d
=
self
.
grad_d
(
self
.
loss_net_d
,
weights_d
)(
x
,
sens_d
)
return
F
.
depend
(
loss_w
,
self
.
optimizer_w
(
grads_w
)),
F
.
depend
(
loss_d
,
self
.
optimizer_d
(
grads_d
))
def
test_double_subgraphs
():
cost_model_context
.
set_cost_model_context
(
multi_subgraphs
=
True
)
context
.
set_context
(
save_graphs
=
True
)
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
)
net
=
TrainStepWarp
(
NetWithLoss
(
Net
()))
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
)
x
=
Tensor
(
np
.
ones
([
8
,
8
,
8
,
8
]),
dtype
=
ms
.
float32
)
reset_op_id
()
_executor
.
compile
(
net
,
x
,
phase
=
'train'
)
strategies
=
_executor
.
_get_strategy
(
net
)
expected_strategies
=
{
'Default/network-NetWithLoss/ReduceMean-op0'
:
[[
8
,
1
,
1
,
1
]],
'Default/network-NetWithLoss/net-Net/ReLU-op1'
:
[[
8
,
1
,
1
,
1
]],
'Default/network-NetWithLoss/net-Net/Mul-op2'
:
[[
8
,
1
,
1
,
1
],
[
8
,
1
,
1
,
1
]],
'Default/network-NetWithLoss/net-Net/Mul-op3'
:
[[
8
,
1
,
1
,
1
],
[
8
,
1
,
1
,
1
]],
'Default/network-NetWithLoss/ReduceSum-op4'
:
[[
8
,
1
,
1
,
1
]]}
assert
strategies
==
expected_strategies
tests/ut/python/parallel/test_auto_parallel_two_bn.py
0 → 100644
浏览文件 @
ef71ae94
import
numpy
as
np
from
mindspore
import
context
import
mindspore
as
ms
import
mindspore.nn
as
nn
from
mindspore.ops
import
operations
as
P
from
mindspore
import
Tensor
from
mindspore.common.api
import
_executor
from
tests.ut.python.ops.test_math_ops
import
VirtualLoss
from
mindspore.parallel
import
set_algo_parameters
from
mindspore.parallel._utils
import
_reset_op_id
as
reset_op_id
import
re
class
NetWithLoss
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
NetWithLoss
,
self
).
__init__
()
self
.
loss
=
VirtualLoss
()
self
.
network
=
network
def
construct
(
self
,
x
):
predict
=
self
.
network
(
x
)
return
self
.
loss
(
predict
)
class
Blockcell
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Blockcell
,
self
).
__init__
()
self
.
bn
=
nn
.
BatchNorm2d
(
64
,
momentum
=
0.9
)
def
construct
(
self
,
x
):
out
=
self
.
bn
(
x
)
return
out
def
getBlock
():
return
Blockcell
()
def
test_two_bn
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
block1
=
getBlock
()
self
.
block2
=
getBlock
()
self
.
relu
=
P
.
ReLU
()
self
.
add
=
P
.
TensorAdd
()
self
.
bias
=
Tensor
(
np
.
ones
([
64
,
64
]),
dtype
=
ms
.
float32
)
def
construct
(
self
,
x
):
out
=
self
.
block1
(
x
)
out
=
self
.
relu
(
out
)
out
=
self
.
add
(
out
,
self
.
bias
)
out
=
self
.
block2
(
out
)
return
out
net
=
NetWithLoss
(
Net
())
x
=
Tensor
(
np
.
ones
([
64
,
64
]),
dtype
=
ms
.
float32
)
context
.
set_context
(
save_graphs
=
True
)
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
)
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
)
set_algo_parameters
(
elementwise_op_strategy_follow
=
True
)
reset_op_id
()
_executor
.
compile
(
net
,
x
,
phase
=
'train'
)
strategies
=
_executor
.
_get_strategy
(
net
)
assert
len
(
strategies
)
==
4
for
(
k
,
v
)
in
strategies
.
items
():
if
re
.
search
(
'BatchNorm-op'
,
k
)
is
not
None
:
assert
v
==
[[
8
,
1
],
[
1
],
[
1
],
[
1
],
[
1
]]
elif
re
.
search
(
'TensorAdd-op'
,
k
)
is
not
None
:
assert
v
==
[[
8
,
1
],
[
8
,
1
]]
elif
re
.
search
(
'ReLU-op'
,
k
)
is
not
None
:
assert
v
==
[[
8
,
1
]]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录