Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
f0bf438a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f0bf438a
编写于
5月 05, 2020
作者:
Y
yao_yf
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
reshape strategy search
上级
08d86c48
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
520 addition
and
61 deletion
+520
-61
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc
+59
-3
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h
+5
-5
mindspore/ccsrc/parallel/ops_info/operator_info.h
mindspore/ccsrc/parallel/ops_info/operator_info.h
+1
-0
mindspore/ccsrc/parallel/ops_info/reshape_info.cc
mindspore/ccsrc/parallel/ops_info/reshape_info.cc
+40
-32
mindspore/ccsrc/parallel/ops_info/reshape_info.h
mindspore/ccsrc/parallel/ops_info/reshape_info.h
+15
-0
mindspore/ccsrc/parallel/step_auto_parallel.cc
mindspore/ccsrc/parallel/step_auto_parallel.cc
+209
-2
mindspore/ccsrc/parallel/step_auto_parallel.h
mindspore/ccsrc/parallel/step_auto_parallel.h
+2
-0
tests/ut/cpp/parallel/ops_info/reshape_test.cc
tests/ut/cpp/parallel/ops_info/reshape_test.cc
+0
-17
tests/ut/python/parallel/test_auto_parallel_reshape.py
tests/ut/python/parallel/test_auto_parallel_reshape.py
+189
-2
未找到文件。
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc
浏览文件 @
f0bf438a
...
...
@@ -13,9 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "parallel/auto_parallel/graph_costmodel.h"
#include <algorithm>
#include <cstdlib>
#include <iterator>
...
...
@@ -24,6 +21,10 @@
#include <utility>
#include <vector>
#include "parallel/auto_parallel/graph_costmodel.h"
#include "parallel/ops_info/reshape_info.h"
#include "parallel/step_auto_parallel.h"
namespace
mindspore
{
namespace
parallel
{
CostGraphPtr
entire_costgraph
=
nullptr
;
...
...
@@ -40,6 +41,7 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES;
bool
ELEMENTWISE_OP_STRA_FOLLOW
=
DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW
;
bool
MULTI_SUBGRAPHS
=
DEFAULT_IS_MULTI_SUBGRAPHS
;
int32_t
RUN_PHASE
=
DEFAULT_RUN_PHASE
;
constexpr
char
RESHAPEINFO
[]
=
"ReshapeInfo"
;
void
CostGraph
::
SetDeviceMemoryAndCostParameter
()
{
MS_EXCEPTION_IF_NULL
(
CostModelContext
::
GetInstance
());
...
...
@@ -182,6 +184,20 @@ bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr &op_test) {
return
std
::
any_of
(
ops_
.
begin
(),
ops_
.
end
(),
IsInGraph
(
op_test
));
}
void
CostGraph
::
AddEdge
(
OperatorInfoPtr
u_node
,
OperatorInfoPtr
v_node
,
const
EdgePtr
&
edge
)
{
std
::
vector
<
EdgePtr
>
curr_edges
(
edges_
[{
u_node
,
v_node
}]);
curr_edges
.
push_back
(
edge
);
edges_
[{
u_node
,
v_node
}]
=
curr_edges
;
std
::
vector
<
EdgePtr
>
curr_out_edges
(
out_edges_
[
u_node
]);
curr_out_edges
.
push_back
(
edge
);
out_edges_
[
u_node
]
=
curr_out_edges
;
std
::
vector
<
EdgePtr
>
curr_in_edges
(
in_edges_
[
v_node
]);
curr_in_edges
.
push_back
(
edge
);
in_edges_
[
v_node
]
=
curr_in_edges
;
}
bool
CostGraph
::
IsEdgeInCostGraph
(
const
std
::
string
&
test_edge_name
,
size_t
output_index
,
size_t
input_index
)
{
for
(
auto
&
edge_pair
:
edges_
)
{
auto
edges
=
edge_pair
.
second
;
...
...
@@ -1338,11 +1354,51 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo
Status
CostGraph
::
InitSelectedStrategy
()
{
for
(
auto
&
op
:
ops_
)
{
MS_EXCEPTION_IF_NULL
(
op
);
if
(
op
->
name
().
find
(
RESHAPEINFO
)
!=
std
::
string
::
npos
)
{
continue
;
}
auto
result
=
op
->
InitSelectedStrategy
(
op
->
selected_strategy
());
if
(
result
!=
SUCCESS
)
{
return
result
;
}
}
// reshape init should be apply after the init of it's previous node and next node.
for
(
size_t
i
=
0
;
i
<
ops_
.
size
();
++
i
)
{
if
(
ops_
[
i
]
->
name
().
find
(
RESHAPEINFO
)
!=
std
::
string
::
npos
)
{
auto
reshape_info
=
std
::
dynamic_pointer_cast
<
ReshapeInfo
>
(
ops_
[
i
]);
auto
in_edges
=
GetOriginalPrevEdges
(
ops_
[
i
]);
auto
pre_iter
=
std
::
find_if
(
in_edges
.
begin
(),
in_edges
.
end
(),
[
&
](
std
::
shared_ptr
<
Edge
>
edge
)
{
return
edge
->
prev_operator
()
->
name
()
==
reshape_info
->
pre_operator_name
();
});
auto
out_edges
=
GetOriginalNextEdges
(
ops_
[
i
]);
auto
next_iter
=
std
::
find_if
(
out_edges
.
begin
(),
out_edges
.
end
(),
[
&
](
std
::
shared_ptr
<
Edge
>
edge
)
{
return
edge
->
next_operator
()
->
name
()
==
reshape_info
->
next_operator_name
();
});
if
(
pre_iter
!=
in_edges
.
end
())
{
MS_LOG
(
DEBUG
)
<<
"Set reshape input layout by "
<<
reshape_info
->
pre_operator_name
();
int32_t
pre_index
=
reshape_info
->
pre_operator_index
();
Dimensions
stra
;
TensorInfo
pre_info
;
if
(
ops_
[
i
]
->
name
()
==
(
*
pre_iter
)
->
prev_operator
()
->
name
())
{
pre_info
=
(
*
pre_iter
)
->
prev_operator
()
->
inputs_tensor_info
()[
pre_index
];
}
else
{
pre_info
=
(
*
pre_iter
)
->
prev_operator
()
->
outputs_tensor_info
()[
pre_index
];
}
reshape_info
->
SetInputLayout
(
pre_info
.
tensor_layout
());
InferStraByTensorInfo
(
pre_info
,
&
stra
);
std
::
vector
<
Dimensions
>
stra_inputs
=
{
stra
};
StrategyPtr
reshape_stra
=
std
::
make_shared
<
Strategy
>
((
*
pre_iter
)
->
prev_operator
()
->
strategy
()
->
GetInputStage
(),
stra_inputs
);
reshape_info
->
set_strategy
(
reshape_stra
);
}
if
(
next_iter
!=
out_edges
.
end
())
{
MS_LOG
(
DEBUG
)
<<
"Set reshape output layout by "
<<
reshape_info
->
next_operator_name
();
int32_t
next_index
=
reshape_info
->
next_operator_index
();
reshape_info
->
SetOutputLayout
((
*
next_iter
)
->
next_operator
()
->
inputs_tensor_info
()[
next_index
].
tensor_layout
());
}
return
reshape_info
->
Init
(
nullptr
);
}
}
return
SUCCESS
;
}
...
...
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h
浏览文件 @
f0bf438a
...
...
@@ -87,11 +87,9 @@ class CostGraph {
void
RemoveOperator
(
const
OperatorInfoPtr
&
op
);
bool
IsOperatorInCostGraph
(
const
OperatorInfoPtr
&
op
);
// the edge is in the form: u --> v
void
AddEdge
(
OperatorInfoPtr
u_node
,
OperatorInfoPtr
v_node
,
const
EdgePtr
&
edge
)
{
std
::
vector
<
EdgePtr
>
curr_edges
(
edges_
[{
u_node
,
v_node
}]);
curr_edges
.
push_back
(
edge
);
edges_
[{
u_node
,
v_node
}]
=
curr_edges
;
}
void
AddEdge
(
OperatorInfoPtr
u_node
,
OperatorInfoPtr
v_node
,
const
EdgePtr
&
edge
);
std
::
vector
<
std
::
shared_ptr
<
Edge
>>
GetOriginalPrevEdges
(
OperatorInfoPtr
v_node
)
{
return
in_edges_
[
v_node
];
}
std
::
vector
<
std
::
shared_ptr
<
Edge
>>
GetOriginalNextEdges
(
OperatorInfoPtr
u_node
)
{
return
out_edges_
[
u_node
];
}
// An edge is uniquely identified by its name, and its output index and input index.
bool
IsEdgeInCostGraph
(
const
std
::
string
&
,
size_t
,
size_t
);
...
...
@@ -219,6 +217,8 @@ class CostGraph {
std
::
vector
<
OperatorInfoPtr
>
ops_
;
std
::
map
<
std
::
pair
<
OperatorInfoPtr
,
OperatorInfoPtr
>
,
std
::
vector
<
EdgePtr
>>
edges_
;
std
::
vector
<
std
::
shared_ptr
<
CostGraph
>>
connected_compoents_
;
std
::
map
<
OperatorInfoPtr
,
std
::
vector
<
EdgePtr
>>
out_edges_
;
std
::
map
<
OperatorInfoPtr
,
std
::
vector
<
EdgePtr
>>
in_edges_
;
};
}
// namespace parallel
}
// namespace mindspore
...
...
mindspore/ccsrc/parallel/ops_info/operator_info.h
浏览文件 @
f0bf438a
...
...
@@ -111,6 +111,7 @@ class OperatorInfo {
Shape
dev_matrix_shape
()
const
{
return
dev_matrix_shape_
;
}
std
::
vector
<
TensorInfo
>
inputs_tensor_info
()
const
{
return
inputs_tensor_info_
;
}
std
::
vector
<
TensorInfo
>
outputs_tensor_info
()
const
{
return
outputs_tensor_info_
;
}
std
::
vector
<
std
::
shared_ptr
<
StrategyWithCost
>>
strategy_cost
()
const
{
return
strategy_cost_
;
}
const
std
::
string
&
name
()
const
{
return
name_
;
}
void
set_name
(
const
std
::
string
&
name
)
{
name_
=
name
;
}
RankList
global_device_list
()
const
{
return
global_device_list_
;
}
...
...
mindspore/ccsrc/parallel/ops_info/reshape_info.cc
浏览文件 @
f0bf438a
...
...
@@ -22,6 +22,7 @@
#include "parallel/device_manager.h"
#include "parallel/device_matrix.h"
#include "parallel/step_parallel.h"
#include "parallel/auto_parallel/graph_costmodel.h"
#include "utils/convert_utils.h"
#include "utils/log_adapter.h"
...
...
@@ -46,26 +47,6 @@ Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) {
}
return
FAILED
;
}
std
::
vector
<
Dimensions
>
stra
=
strategy
->
GetInputDim
();
for
(
size_t
i
=
0
;
i
<
strategy_size
;
++
i
)
{
Shape
sub_strategy
=
stra
.
at
(
i
);
size_t
strategy_len
=
sub_strategy
.
size
();
bool
flag
=
false
;
for
(
size_t
j
=
0
;
j
<
strategy_len
;
++
j
)
{
int32_t
strategy_value
=
sub_strategy
.
at
(
j
);
if
(
strategy_value
>
1
)
{
if
(
flag
)
{
if
(
is_auto_parallel_
)
{
MS_LOG
(
DEBUG
)
<<
name_
<<
": Only support batch parallel strategy."
;
}
else
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Only support batch parallel strategy."
;
}
return
FAILED
;
}
flag
=
true
;
}
}
}
return
SUCCESS
;
}
...
...
@@ -402,6 +383,41 @@ Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr
return
SUCCESS
;
}
void
ReshapeInfo
::
SetCostForReshapeWithParameter
()
{
size_t
success
=
0
;
for
(
auto
&
sp
:
sp_vector_
)
{
if
(
SetCostUnderStrategy
(
sp
)
==
SUCCESS
)
{
success
++
;
MS_LOG
(
INFO
)
<<
name_
<<
": Successfully generated "
<<
success
<<
" strategy."
;
PrintStrategy
(
sp
);
}
}
}
void
ReshapeInfo
::
SetCostForReshape
(
const
mindspore
::
parallel
::
StrategyPtr
&
strategy
)
{
MS_EXCEPTION_IF_NULL
(
strategy
);
int32_t
stage_id
=
strategy
->
GetInputStage
();
double
computation_cost
=
operator_cost
()
->
GetForwardComputationCost
(
inputs_tensor_info_
,
outputs_tensor_info_
,
stage_id
);
double
communication_cost
=
operator_cost
()
->
GetCommCost
(
inputs_tensor_info_
,
outputs_tensor_info_
,
stage_id
);
std
::
shared_ptr
<
Cost
>
result
=
std
::
make_shared
<
Cost
>
(
computation_cost
,
communication_cost
);
result
->
communication_without_parameter_
=
operator_cost
()
->
GetForwardCommCost
(
inputs_tensor_info_
,
outputs_tensor_info_
,
stage_id
);
result
->
communication_with_partial_para_
=
result
->
communication_without_parameter_
+
COST_MODEL_GAMMA
*
(
communication_cost
-
result
->
communication_without_parameter_
);
// Breaking ties for preferring data parallelization
BreakingTiesForPerferringDataParallel
(
strategy
,
result
);
// refine communication cost calculation for practice
RefineForPracticalCost
(
result
,
false
);
std
::
shared_ptr
<
StrategyWithCost
>
swc
=
std
::
make_shared
<
StrategyWithCost
>
(
strategy
,
inputs_tensor_info_
,
outputs_tensor_info_
);
swc
->
cost_list
.
push_back
(
result
);
strategy_cost_
.
emplace_back
(
swc
);
}
Status
ReshapeInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
if
(
GetAttrs
()
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": GetAttrs failed."
;
...
...
@@ -414,22 +430,14 @@ Status ReshapeInfo::GenerateStrategies(int32_t stage_id) {
}
is_auto_parallel_
=
true
;
Shape
input0_split
;
input0_split
.
emplace_back
(
1
);
(
void
)
input0_split
.
insert
(
input0_split
.
end
(),
inputs_shape_
[
0
].
size
()
-
1
,
0
);
(
void
)
input0_split
.
insert
(
input0_split
.
end
(),
inputs_shape_
[
0
].
size
(),
1
);
Shapes
splittable_inputs
=
{
input0_split
};
std
::
vector
<
StrategyPtr
>
sp_vector
;
if
(
GenerateStrategiesForIndependentInputs
(
stage_id
,
inputs_shape_
,
splittable_inputs
,
&
sp_vector
)
!=
SUCCESS
)
{
// strategy used only in the input node is parameter,
// in other case, use the input node's output_layout as input_layout.
if
(
GenerateStrategiesForIndependentInputs
(
stage_id
,
inputs_shape_
,
splittable_inputs
,
&
sp_vector_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": GenerateStrategiesForIndependentInputs failed."
;
return
FAILED
;
}
size_t
success
=
0
;
for
(
auto
&
sp
:
sp_vector
)
{
if
(
SetCostUnderStrategy
(
sp
)
==
SUCCESS
)
{
success
++
;
MS_LOG
(
INFO
)
<<
name_
<<
": Successfully generated "
<<
success
<<
" strategy."
;
PrintStrategy
(
sp
);
}
}
return
SUCCESS
;
}
}
// namespace parallel
...
...
mindspore/ccsrc/parallel/ops_info/reshape_info.h
浏览文件 @
f0bf438a
...
...
@@ -50,9 +50,19 @@ class ReshapeInfo : public OperatorInfo {
output_layout_
=
output_layout
;
output_layout_set_flag_
=
true
;
}
void
SetCostForReshape
(
const
mindspore
::
parallel
::
StrategyPtr
&
strategy
);
void
SetCostForReshapeWithParameter
();
void
set_pre_operator_name
(
const
std
::
string
&
pre_name
)
{
pre_operator_name_
=
pre_name
;
}
void
set_next_operator_name
(
const
std
::
string
&
next_name
)
{
next_operator_name_
=
next_name
;
}
void
set_pre_operator_index
(
int32_t
pre_index
)
{
pre_operator_index_
=
pre_index
;
}
void
set_next_operator_index
(
int32_t
next_index
)
{
next_operator_index_
=
next_index
;
}
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
std
::
string
pre_operator_name
()
const
{
return
pre_operator_name_
;
}
std
::
string
next_operator_name
()
const
{
return
next_operator_name_
;
}
int32_t
pre_operator_index
()
const
{
return
pre_operator_index_
;
}
int32_t
next_operator_index
()
const
{
return
next_operator_index_
;
}
protected:
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
...
...
@@ -73,12 +83,17 @@ class ReshapeInfo : public OperatorInfo {
Status
InferDefaultLayout
(
const
Shape
&
shape
,
TensorLayout
*
const
layout
);
int32_t
dev_num_
;
int32_t
pre_operator_index_
;
int32_t
next_operator_index_
;
std
::
vector
<
int32_t
>
parameter_input_v_
;
std
::
vector
<
StrategyPtr
>
sp_vector_
;
Dimensions
input_strategy_
;
TensorLayout
input_layout_
;
TensorLayout
output_layout_
;
bool
input_layout_set_flag_
;
bool
output_layout_set_flag_
;
std
::
string
pre_operator_name_
;
std
::
string
next_operator_name_
;
};
}
// namespace parallel
}
// namespace mindspore
...
...
mindspore/ccsrc/parallel/step_auto_parallel.cc
浏览文件 @
f0bf438a
...
...
@@ -39,6 +39,7 @@
#include "parallel/auto_parallel/rec_core/rec_partition.h"
#include "parallel/context.h"
#include "parallel/ops_info/tmp_identity_info.h"
#include "parallel/ops_info/reshape_info.h"
#include "parallel/step_parallel.h"
#include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
#include "pipeline/parse/python_adapter.h"
...
...
@@ -608,7 +609,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
EdgePtr
edge_ptr
;
MS_LOG
(
INFO
)
<<
"Creating edge: "
<<
edge_name
;
bool
follow_strategy
=
ELEMENTWISE_OP_STRA_FOLLOW
&&
IsElementWiseOperator
(
prev_prim
->
name
());
bool
follow_strategy
=
(
prim
->
name
()
==
RESHAPE
)
||
(
prev_prim
->
name
()
==
RESHAPE
)
||
(
ELEMENTWISE_OP_STRA_FOLLOW
&&
IsElementWiseOperator
(
prev_prim
->
name
()));
if
(
follow_strategy
)
{
// Redistribution in not allowed on the edge.
// Elementwise operators have the same strategy as their previous operators.
...
...
@@ -893,6 +895,209 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
}
}
bool
FindReshape
(
const
CNodePtr
&
cnode
)
{
if
((
cnode
==
nullptr
)
||
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)))
{
return
false
;
}
ValueNodePtr
prim_anf_node
=
cnode
->
input
(
0
)
->
cast
<
ValueNodePtr
>
();
if
(
!
IsParallelCareNode
(
cnode
)
||
(
cnode
->
operator_info
()
==
nullptr
))
{
return
false
;
}
PrimitivePtr
prim
=
GetValueNode
<
PrimitivePtr
>
(
prim_anf_node
);
MS_EXCEPTION_IF_NULL
(
prim
);
OperatorInfoPtr
operator_info
=
cnode
->
operator_info
();
if
(
operator_info
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failure:Primitive "
<<
prim
->
ToString
()
<<
" OperatorInstance is nullptr"
;
}
if
(
prim
->
name
()
!=
RESHAPE
)
{
return
false
;
}
return
true
;
}
// find previous node, then obtain its strategy_cost_ vector to get its layout vector.
bool
FindPreNodeStraCosts
(
const
AnfNodePtr
&
node
,
OperatorInfoPtr
*
pre_operator_info
,
int32_t
*
out_index
)
{
// if previous node is a parameter, handle it in the outsize.
if
(
node
->
isa
<
Parameter
>
())
{
return
false
;
}
if
(
!
node
->
isa
<
CNode
>
())
{
return
false
;
}
CNodePtr
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)))
{
return
false
;
}
if
(
IsParallelCareNode
(
cnode
)
&&
(
cnode
->
operator_info
()
!=
nullptr
))
{
*
pre_operator_info
=
cnode
->
operator_info
();
*
out_index
=
0
;
return
true
;
}
ValueNodePtr
prim_anf_node
=
cnode
->
input
(
0
)
->
cast
<
ValueNodePtr
>
();
PrimitivePtr
prim
=
prim_anf_node
->
value
()
->
cast
<
PrimitivePtr
>
();
if
(
prim
->
name
()
==
TUPLE_GETITEM
)
{
*
out_index
=
GetTupleGetItemIndex
(
cnode
);
// find tuple_get_item's previous node
auto
pre_node
=
cnode
->
input
(
1
);
if
(
!
pre_node
->
isa
<
CNode
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"tuple get item's second input is not a cnode"
;
}
CNodePtr
pre_cnode
=
pre_node
->
cast
<
CNodePtr
>
();
if
(
IsParallelCareNode
(
pre_cnode
)
&&
(
pre_cnode
->
operator_info
()
!=
nullptr
))
{
*
pre_operator_info
=
pre_cnode
->
operator_info
();
return
true
;
}
return
false
;
}
for
(
size_t
index
=
0
;
index
<
cnode
->
inputs
().
size
();
++
index
)
{
if
(
prim
->
name
()
==
DEPEND
&&
index
!=
1
)
{
continue
;
}
if
(
!
FindPreNodeStraCosts
(
cnode
->
inputs
()[
index
],
pre_operator_info
,
out_index
))
{
continue
;
}
return
true
;
}
MS_LOG
(
WARNING
)
<<
"FindPreNodeStraCosts failed, if reshape is not the first primitive, there must be some error"
;
return
false
;
}
// find next node, then obtain its strategy_cost_ vector to get its layout vector.
// if reshape's output connect to several primitive, return the first layout found
bool
FindNextNodeStraCosts
(
const
CNodePtr
&
cnode
,
OperatorInfoPtr
*
next_operator_info
,
int32_t
*
in_index
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
cnode
->
func_graph
());
FuncGraphManagerPtr
manager
=
cnode
->
func_graph
()
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
AnfNodeIndexSet
node_set
=
manager
->
node_users
()[
cnode
];
for
(
auto
&
node_pair
:
node_set
)
{
CNodePtr
use_apply
=
node_pair
.
first
->
cast
<
CNodePtr
>
();
if
(
use_apply
==
nullptr
||
!
IsValueNode
<
Primitive
>
(
use_apply
->
input
(
0
)))
{
continue
;
}
ValueNodePtr
prim_anf_node
=
use_apply
->
input
(
0
)
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
prim_anf_node
);
PrimitivePtr
node_prim
=
prim_anf_node
->
value
()
->
cast
<
PrimitivePtr
>
();
MS_EXCEPTION_IF_NULL
(
node_prim
);
MS_LOG
(
INFO
)
<<
"FindNextLayout prim "
<<
node_prim
->
name
();
if
(
node_prim
->
name
()
==
DEPEND
&&
node_pair
.
second
!=
1
)
{
continue
;
}
if
(
IsParallelCareNode
(
use_apply
)
&&
(
use_apply
->
operator_info
()
!=
nullptr
))
{
MS_LOG
(
INFO
)
<<
"FindNextNodeStraCosts success prim "
<<
node_prim
->
name
();
*
next_operator_info
=
use_apply
->
operator_info
();
*
in_index
=
node_pair
.
second
-
1
;
return
true
;
}
MS_LOG
(
DEBUG
)
<<
"FindNextNodeStraCosts failed prim "
<<
node_prim
->
name
()
<<
" "
<<
IsParallelCareNode
(
use_apply
)
<<
" "
<<
(
use_apply
->
operator_info
()
!=
nullptr
);
if
(
FindNextNodeStraCosts
(
use_apply
,
next_operator_info
,
in_index
))
{
return
true
;
}
}
return
false
;
}
void
InferStraByTensorInfo
(
const
TensorInfo
&
pre_out_tensor_info
,
Dimensions
*
stra
)
{
Shape
shape
=
pre_out_tensor_info
.
shape
();
Shape
slice_shape
=
pre_out_tensor_info
.
slice_shape
();
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
if
((
slice_shape
[
i
]
==
0
)
||
(
shape
[
i
]
%
slice_shape
[
i
]
!=
0
))
{
MS_LOG
(
EXCEPTION
)
<<
"slice_shape is wrong in reshape operator"
;
}
int32_t
dim
=
(
int32_t
)(
shape
[
i
]
/
slice_shape
[
i
]);
(
*
stra
).
push_back
(
dim
);
}
}
void
ReshapeCostCompute
(
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
)
{
for
(
auto
node
:
all_nodes
)
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
!
FindReshape
(
cnode
))
{
continue
;
}
MS_ASSERT
(
cnode
->
inputs
().
size
()
==
3
);
// get previous node's strategy_cost_
auto
pre_node
=
cnode
->
input
(
1
);
int32_t
out_index
=
0
;
OperatorInfoPtr
pre_operator_info
;
std
::
vector
<
std
::
shared_ptr
<
StrategyWithCost
>>
pre_stra_costs
;
if
(
pre_node
->
isa
<
Parameter
>
())
{
OperatorInfoPtr
operator_info
=
cnode
->
operator_info
();
auto
reshape_info
=
std
::
dynamic_pointer_cast
<
ReshapeInfo
>
(
operator_info
);
reshape_info
->
SetCostForReshapeWithParameter
();
pre_operator_info
=
reshape_info
;
pre_stra_costs
=
reshape_info
->
strategy_cost
();
}
else
{
if
(
!
FindPreNodeStraCosts
(
pre_node
,
&
pre_operator_info
,
&
out_index
))
{
MS_LOG
(
EXCEPTION
)
<<
"FindPreNodeStraCosts for reshape failed"
;
}
pre_stra_costs
=
pre_operator_info
->
strategy_cost
();
}
// get next node's strategy_cost_
int32_t
in_index
=
0
;
OperatorInfoPtr
next_operator_info
;
std
::
vector
<
std
::
shared_ptr
<
StrategyWithCost
>>
next_stra_costs
;
bool
find_next_node
=
FindNextNodeStraCosts
(
cnode
,
&
next_operator_info
,
&
in_index
);
if
(
!
find_next_node
)
{
MS_LOG
(
INFO
)
<<
"FindNextNodeStraCosts for reshape failed"
;
}
// set input_layout and output_layout for reshape.
// init reshape and set cost for each input_layout and output_layout.
OperatorInfoPtr
operator_info
=
cnode
->
operator_info
();
auto
reshape_info
=
std
::
dynamic_pointer_cast
<
ReshapeInfo
>
(
operator_info
);
reshape_info
->
set_pre_operator_name
(
pre_operator_info
->
name
());
reshape_info
->
set_pre_operator_index
(
out_index
);
if
(
find_next_node
)
{
next_stra_costs
=
next_operator_info
->
strategy_cost
();
reshape_info
->
set_next_operator_name
(
next_operator_info
->
name
());
reshape_info
->
set_next_operator_index
(
in_index
);
}
for
(
auto
pre_stra_cost
:
pre_stra_costs
)
{
std
::
vector
<
TensorInfo
>
pre_out_tensor_infos
;
if
(
pre_node
->
isa
<
Parameter
>
())
{
pre_out_tensor_infos
=
pre_stra_cost
->
inputs_ptr
;
}
else
{
pre_out_tensor_infos
=
pre_stra_cost
->
outputs_ptr
;
}
if
(
pre_out_tensor_infos
.
size
()
<=
IntToSize
(
out_index
))
{
MS_LOG
(
EXCEPTION
)
<<
"out_index is out of range of the tensor_infos in setting reshape's input_layout"
;
}
TensorInfo
pre_out_tensor_info
=
pre_out_tensor_infos
[
out_index
];
TensorLayout
pre_out_tensor_layout
=
pre_out_tensor_info
.
tensor_layout
();
reshape_info
->
SetInputLayout
(
pre_out_tensor_layout
);
// infer pre_node output strategy from output_layout.
Dimensions
stra
;
InferStraByTensorInfo
(
pre_out_tensor_info
,
&
stra
);
std
::
vector
<
Dimensions
>
stra_inputs
=
{
stra
};
StrategyPtr
reshape_stra
=
std
::
make_shared
<
Strategy
>
(
pre_stra_cost
->
strategy_ptr
->
GetInputStage
(),
stra_inputs
);
if
(
next_stra_costs
.
empty
())
{
if
(
reshape_info
->
Init
(
nullptr
)
==
FAILED
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failure:operator reshape init failed"
;
}
// set cost for each input_layout and output_layout pairs.
reshape_info
->
SetCostForReshape
(
reshape_stra
);
continue
;
}
for
(
auto
next_stra_cost
:
next_stra_costs
)
{
std
::
vector
<
TensorInfo
>
next_in_tensor_infos
=
next_stra_cost
->
inputs_ptr
;
if
(
next_in_tensor_infos
.
size
()
<=
IntToSize
(
in_index
))
{
MS_LOG
(
EXCEPTION
)
<<
"in_index is out of range of the tensor_infos in setting reshape's output_layout"
;
}
TensorInfo
next_in_tensor_info
=
next_in_tensor_infos
[
in_index
];
TensorLayout
next_in_tensor_layout
=
next_in_tensor_info
.
tensor_layout
();
reshape_info
->
SetOutputLayout
(
next_in_tensor_layout
);
if
(
reshape_info
->
Init
(
nullptr
)
==
FAILED
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failure:operator reshape init failed"
;
}
// set cost for each input_layout and output_layout pairs.
reshape_info
->
SetCostForReshape
(
reshape_stra
);
}
}
}
}
Status
ParallelStrategySearch
(
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
,
const
FuncGraphPtr
&
root
)
{
// There are 4 meta-steps to determine the parallelization strategy for the ANF graph.
// Step 1: Traverse the ANF graph, and create NODEs for costgraph:
...
...
@@ -930,7 +1135,9 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
MS_LOG
(
EXCEPTION
)
<<
"Constructing nodes for cost graph failed."
;
}
}
// reshape operator needs the next node's input_layout as its output_layout.
// and needs the previous node's output_layout as its input_layout.
ReshapeCostCompute
(
all_nodes
);
// Step 2
ConstructCostGraphEdges
(
all_nodes
);
MS_LOG
(
INFO
)
<<
"Constructing edges for cost graph succeeded. There are "
<<
entire_costgraph
->
GetOperators
().
size
()
...
...
mindspore/ccsrc/parallel/step_auto_parallel.h
浏览文件 @
f0bf438a
...
...
@@ -51,6 +51,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes);
void
AugmentCostGraph
(
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
);
void
InferStraByTensorInfo
(
const
TensorInfo
&
pre_out_tensor_info
,
Dimensions
*
stra
);
Status
ParallelStrategySearch
(
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
,
const
FuncGraphPtr
&
root
);
Status
ParallelStrategyRecSearch
(
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
,
const
FuncGraphPtr
&
root
);
...
...
tests/ut/cpp/parallel/ops_info/reshape_test.cc
浏览文件 @
f0bf438a
...
...
@@ -219,22 +219,5 @@ TEST_F(TestReshapeInfo, CheckStrategy3) {
Status
ret
=
reshape
->
Init
(
strategy
);
ASSERT_EQ
(
ret
,
SUCCESS
);
}
TEST_F
(
TestReshapeInfo
,
AutoStrategy1
)
{
ASSERT_EQ
(
reshape
->
GenerateStrategies
(
0
),
Status
::
SUCCESS
);
std
::
vector
<
std
::
shared_ptr
<
StrategyWithCost
>>
sc
=
reshape
->
GetStrategyCost
();
Shapes
splittable_inputs
=
{{
1
,
0
,
0
,
0
}};
std
::
vector
<
StrategyPtr
>
sp_vector
;
Shapes
inputs_shape
=
{{
32
,
512
,
7
,
7
}};
GenerateStrategiesForIndependentInputs
(
0
,
inputs_shape
,
splittable_inputs
,
&
sp_vector
);
ASSERT_EQ
(
sc
.
size
(),
sp_vector
.
size
());
for
(
auto
stra
:
sp_vector
)
{
auto
stra0
=
stra
->
GetInputDim
()[
0
];
ASSERT_EQ
(
stra0
[
1
],
1
);
ASSERT_EQ
(
stra0
[
2
],
1
);
ASSERT_EQ
(
stra0
[
3
],
1
);
}
}
}
// namespace parallel
}
// namespace mindspore
tests/ut/python/parallel/test_auto_parallel_reshape.py
浏览文件 @
f0bf438a
...
...
@@ -65,6 +65,193 @@ def test_reshape_matmul():
net
.
set_auto_parallel
()
_executor
.
compile
(
net
,
x
)
def
test_reshape_auto_1
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
relu
=
P
.
ReLU
()
self
.
reshape
=
P
.
Reshape
()
self
.
matmul
=
P
.
MatMul
()
self
.
matmul_weight
=
Parameter
(
Tensor
(
np
.
ones
([
28
,
64
]),
dtype
=
ms
.
float32
),
name
=
"weight"
)
def
construct
(
self
,
x
):
out
=
self
.
relu
(
x
)
out
=
self
.
reshape
(
out
,
(
64
,
28
))
out
=
self
.
matmul
(
out
,
self
.
matmul_weight
)
return
out
size
=
8
context
.
set_auto_parallel_context
(
device_num
=
size
,
global_rank
=
0
)
x
=
Tensor
(
np
.
ones
([
8
*
size
,
28
,
1
,
1
]),
dtype
=
ms
.
float32
)
net
=
GradWrap
(
NetWithLoss
(
Net
()))
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
)
_executor
.
compile
(
net
,
x
)
def
test_reshape_auto_2
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
relu
=
P
.
ReLU
()
self
.
reshape
=
P
.
Reshape
()
self
.
matmul
=
P
.
MatMul
()
self
.
add_weight
=
Parameter
(
Tensor
(
np
.
ones
([
128
,
32
]),
dtype
=
ms
.
float32
),
name
=
"weight1"
)
self
.
matmul_weight
=
Parameter
(
Tensor
(
np
.
ones
([
28
,
64
]),
dtype
=
ms
.
float32
),
name
=
"weight"
)
def
construct
(
self
,
x
):
out
=
self
.
relu
(
x
)
out
=
self
.
reshape
(
out
,
(
64
,
28
))
out
=
self
.
matmul
(
out
,
self
.
matmul_weight
)
out
=
self
.
reshape
(
out
,
(
128
,
32
))
out
=
out
+
self
.
add_weight
return
out
size
=
8
context
.
set_auto_parallel_context
(
device_num
=
size
,
global_rank
=
0
)
x
=
Tensor
(
np
.
ones
([
8
*
size
,
28
,
1
,
1
]),
dtype
=
ms
.
float32
)
net
=
GradWrap
(
NetWithLoss
(
Net
()))
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
)
_executor
.
compile
(
net
,
x
)
def
test_reshape_auto_3
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
relu
=
P
.
ReLU
()
self
.
reshape
=
P
.
Reshape
()
self
.
matmul
=
P
.
MatMul
()
self
.
matmul_weight
=
Parameter
(
Tensor
(
np
.
ones
([
28
,
64
]),
dtype
=
ms
.
float32
),
name
=
"weight"
)
def
construct
(
self
,
x
):
out
=
self
.
relu
(
x
)
out
=
self
.
matmul
(
out
,
self
.
matmul_weight
)
out
=
self
.
reshape
(
out
,
(
8
,
8
,
8
,
8
))
return
out
size
=
8
context
.
set_auto_parallel_context
(
device_num
=
size
,
global_rank
=
0
)
x
=
Tensor
(
np
.
ones
([
8
*
size
,
28
]),
dtype
=
ms
.
float32
)
net
=
GradWrap
(
NetWithLoss
(
Net
()))
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
)
_executor
.
compile
(
net
,
x
)
def
test_reshape_auto_4
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
relu
=
P
.
ReLU
()
self
.
reshape
=
P
.
Reshape
()
self
.
matmul
=
P
.
MatMul
()
self
.
matmul_weight
=
Parameter
(
Tensor
(
np
.
ones
([
28
*
64
]),
dtype
=
ms
.
float32
),
name
=
"weight"
)
def
construct
(
self
,
x
):
out
=
self
.
relu
(
x
)
out
=
self
.
reshape
(
out
,
(
64
,
28
))
w
=
self
.
reshape
(
self
.
matmul_weight
,
(
28
,
64
))
out
=
self
.
matmul
(
out
,
w
)
return
out
if
__name__
==
'__main__'
:
test_reshape_matmul
()
\ No newline at end of file
size
=
8
context
.
set_auto_parallel_context
(
device_num
=
size
,
global_rank
=
0
)
x
=
Tensor
(
np
.
ones
([
8
*
size
,
28
,
1
,
1
]),
dtype
=
ms
.
float32
)
net
=
GradWrap
(
NetWithLoss
(
Net
()))
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
)
_executor
.
compile
(
net
,
x
)
def
test_reshape_auto_5
():
class
NetWithLoss
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
NetWithLoss
,
self
).
__init__
()
self
.
loss
=
VirtualLoss
()
self
.
network
=
network
def
construct
(
self
,
x
,
y
):
predict
=
self
.
network
(
x
,
y
)
return
self
.
loss
(
predict
)
class
GradWrap
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
GradWrap
,
self
).
__init__
()
self
.
network
=
network
def
construct
(
self
,
x
,
y
):
return
C
.
grad_all
(
self
.
network
)(
x
,
y
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
relu
=
P
.
ReLU
()
self
.
mul
=
P
.
Mul
()
self
.
reshape
=
P
.
Reshape
()
self
.
reduce_sum
=
P
.
ReduceSum
()
self
.
wide_w
=
Parameter
(
Tensor
(
np
.
ones
([
4
,
1024
*
8
,
64
]),
dtype
=
ms
.
float32
),
name
=
"weight"
)
def
construct
(
self
,
x
,
y
):
mask
=
self
.
reshape
(
y
,
(
4
,
1024
*
8
,
1
))
w_id
=
self
.
relu
(
x
)
wx
=
self
.
mul
(
w_id
,
mask
)
wide_out
=
self
.
reshape
(
self
.
reduce_sum
(
wx
,
1
),
(
-
1
,
1
))
deep_id
=
x
+
self
.
wide_w
vx
=
self
.
mul
(
deep_id
,
mask
)
deep_in
=
self
.
reshape
(
vx
,
(
-
1
,
1024
*
8
*
64
))
out
=
wide_out
+
deep_in
return
out
size
=
8
context
.
set_auto_parallel_context
(
device_num
=
size
,
global_rank
=
0
)
x
=
Tensor
(
np
.
ones
([
4
,
1024
*
size
,
1
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
4
,
1024
*
size
,]),
dtype
=
ms
.
float32
)
net
=
GradWrap
(
NetWithLoss
(
Net
()))
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
)
_executor
.
compile
(
net
,
x
,
y
)
def
test_reshape_auto_6
():
class
NetWithLoss
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
NetWithLoss
,
self
).
__init__
()
self
.
loss
=
VirtualLoss
()
self
.
network
=
network
def
construct
(
self
,
x
,
y
):
predict
=
self
.
network
(
x
,
y
)
return
self
.
loss
(
predict
)
class
GradWrap
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
GradWrap
,
self
).
__init__
()
self
.
network
=
network
def
construct
(
self
,
x
,
y
):
return
C
.
grad_all
(
self
.
network
)(
x
,
y
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
relu
=
P
.
ReLU
()
self
.
mul
=
P
.
Mul
()
self
.
reshape
=
P
.
Reshape
()
self
.
reduce_mean
=
P
.
ReduceMean
()
self
.
wide_w
=
Parameter
(
Tensor
(
np
.
ones
([
4
,
1024
,
1
]),
dtype
=
ms
.
float32
),
name
=
"weight"
)
def
construct
(
self
,
x
,
y
):
out1
=
x
+
self
.
wide_w
w
=
self
.
reshape
(
self
.
wide_w
,
(
4
,
1024
))
out1
=
self
.
reduce_mean
(
out1
,
1
)
out1
=
out1
-
w
out2
=
self
.
mul
(
y
,
w
)
out
=
out1
+
out2
return
out
size
=
8
context
.
set_auto_parallel_context
(
device_num
=
size
,
global_rank
=
0
)
x
=
Tensor
(
np
.
ones
([
4
,
1024
,
1
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
4
,
1024
,]),
dtype
=
ms
.
float32
)
net
=
GradWrap
(
NetWithLoss
(
Net
()))
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
)
_executor
.
compile
(
net
,
x
,
y
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录