Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
4834a6b3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4834a6b3
编写于
7月 29, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 29, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3574 Rename AnfNode::user_data related functions to follow naming rule
Merge pull request !3574 from hewei/rename_user_data_func
上级
e4a7ca7f
4eb81d79
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
52 addition
and
52 deletion
+52
-52
mindspore/ccsrc/debug/anf_ir_dump.cc
mindspore/ccsrc/debug/anf_ir_dump.cc
+1
-1
mindspore/ccsrc/debug/draw.cc
mindspore/ccsrc/debug/draw.cc
+1
-1
mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc
...rc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc
+2
-2
mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.cc
...csrc/frontend/parallel/allreduce_fusion/allreduce_node.cc
+1
-1
mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc
...e/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc
+2
-2
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
+12
-12
mindspore/ccsrc/frontend/parallel/step_parallel.cc
mindspore/ccsrc/frontend/parallel/step_parallel.cc
+24
-24
mindspore/core/ir/anf.h
mindspore/core/ir/anf.h
+6
-6
tests/ut/cpp/parallel/step_auto_parallel_test.cc
tests/ut/cpp/parallel/step_auto_parallel_test.cc
+1
-1
tests/ut/cpp/parallel/step_parallel_test.cc
tests/ut/cpp/parallel/step_parallel_test.cc
+2
-2
未找到文件。
mindspore/ccsrc/debug/anf_ir_dump.cc
浏览文件 @
4834a6b3
...
...
@@ -267,7 +267,7 @@ void DumpParallelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo
return
;
}
auto
operator_info
=
node
->
GetUserD
ata
<
parallel
::
OperatorInfo
>
();
auto
operator_info
=
node
->
user_d
ata
<
parallel
::
OperatorInfo
>
();
if
(
operator_info
==
nullptr
)
{
return
;
}
...
...
mindspore/ccsrc/debug/draw.cc
浏览文件 @
4834a6b3
...
...
@@ -437,7 +437,7 @@ static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) {
if
(
graph_obj
==
nullptr
||
node
==
nullptr
)
{
return
;
}
auto
distributed_operation_info
=
node
->
GetUserD
ata
<
parallel
::
OperatorInfo
>
();
auto
distributed_operation_info
=
node
->
user_d
ata
<
parallel
::
OperatorInfo
>
();
if
(
distributed_operation_info
!=
nullptr
)
{
auto
strategyPtr
=
distributed_operation_info
->
strategy
();
if
(
strategyPtr
!=
nullptr
)
{
...
...
mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc
浏览文件 @
4834a6b3
...
...
@@ -50,7 +50,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr ¶, uint32_t
if
(
node_prim
->
name
()
==
DEPEND
&&
node_pair
.
second
!=
1
)
{
continue
;
}
if
(
IsParallelCareNode
(
cnode
)
&&
cnode
->
HasUserD
ata
<
OperatorInfo
>
())
{
if
(
IsParallelCareNode
(
cnode
)
&&
cnode
->
has_user_d
ata
<
OperatorInfo
>
())
{
(
void
)
cnode_set
.
emplace
(
cnode
);
}
else
{
auto
cnode_set_sub
=
FindCNodesWithPara
(
node_pair
.
first
,
recursive_times
+
1
);
...
...
@@ -98,7 +98,7 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursi
return
cnode_dist
;
}
auto
operator_info
=
cnode
->
GetUserD
ata
<
OperatorInfo
>
();
auto
operator_info
=
cnode
->
user_d
ata
<
OperatorInfo
>
();
MS_LOG
(
DEBUG
)
<<
"cnode "
<<
cnode
->
ToString
()
<<
" IsParallelCareNode: "
<<
IsParallelCareNode
(
cnode
)
<<
" operator_info: "
<<
(
operator_info
!=
nullptr
);
...
...
mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.cc
浏览文件 @
4834a6b3
...
...
@@ -83,7 +83,7 @@ Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) {
}
auto
para_ptr
=
node_ptr
->
cast
<
ParameterPtr
>
();
MS_EXCEPTION_IF_NULL
(
para_ptr
);
auto
layout_ptr
=
para_ptr
->
GetUserD
ata
<
TensorLayout
>
();
auto
layout_ptr
=
para_ptr
->
user_d
ata
<
TensorLayout
>
();
if
(
layout_ptr
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"layout_ptr is nullptr!"
;
return
FAILED
;
...
...
mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc
浏览文件 @
4834a6b3
...
...
@@ -37,7 +37,7 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
for
(
auto
para
:
graph_params
)
{
std
::
string
name
=
std
::
static_pointer_cast
<
Parameter
>
(
para
)
->
name
();
auto
tensor_layout
=
para
->
GetUserD
ata
<
parallel
::
TensorLayout
>
();
auto
tensor_layout
=
para
->
user_d
ata
<
parallel
::
TensorLayout
>
();
if
(
tensor_layout
==
nullptr
)
{
MS_LOG
(
INFO
)
<<
"GetParameterLayout nullptr name = "
<<
name
;
}
else
{
...
...
@@ -70,7 +70,7 @@ py::dict GetCNodeStrategy(const FuncGraphPtr &graph) {
if
(
node
->
isa
<
CNode
>
())
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
distributed_operation_info
=
cnode
->
GetUserD
ata
<
OperatorInfo
>
();
auto
distributed_operation_info
=
cnode
->
user_d
ata
<
OperatorInfo
>
();
if
(
distributed_operation_info
!=
nullptr
)
{
auto
strategyPtr
=
distributed_operation_info
->
strategy
();
if
(
strategyPtr
!=
nullptr
)
{
...
...
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
浏览文件 @
4834a6b3
...
...
@@ -435,7 +435,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
std
::
vector
<
std
::
string
>
inputs_tensor_name
=
ExtractInputsTensorName
(
cnode
);
entire_costgraph
->
AddOperator
(
operator_info
);
cnode
->
SetUserD
ata
<
OperatorInfo
>
(
operator_info
);
cnode
->
set_user_d
ata
<
OperatorInfo
>
(
operator_info
);
MS_LOG
(
INFO
)
<<
"The CNode with UniqueId: "
<<
cnode
->
UniqueId
()
<<
" and UniqueIdThroughCopy: "
<<
cnode
->
UniqueIdThroughCopy
()
<<
" is set OperatorInfo: "
<<
operator_info
->
name
()
<<
", Primitive: "
<<
prim
->
name
();
...
...
@@ -501,7 +501,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
std
::
vector
<
std
::
string
>
inputs_tensor_name
=
ExtractInputsTensorName
(
cnode
);
entire_costgraph
->
AddOperator
(
operator_info
);
cnode
->
SetUserD
ata
<
OperatorInfo
>
(
operator_info
);
cnode
->
set_user_d
ata
<
OperatorInfo
>
(
operator_info
);
MS_LOG
(
INFO
)
<<
"The CNode with UniqueId: "
<<
cnode
->
UniqueId
()
<<
" and UniqueIdThroughCopy: "
<<
cnode
->
UniqueIdThroughCopy
()
<<
" is set OperatorInfo: "
<<
operator_info
->
name
()
<<
", Primitive: "
<<
prim
->
name
();
...
...
@@ -520,7 +520,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
MS_LOG
(
EXCEPTION
)
<<
"The OperatorInfo: "
<<
current_op_ptr
->
name
()
<<
" does not match the Prim: "
<<
prim
->
name
();
}
cnode
->
SetUserD
ata
<
OperatorInfo
>
(
current_op_ptr
);
cnode
->
set_user_d
ata
<
OperatorInfo
>
(
current_op_ptr
);
MS_LOG
(
INFO
)
<<
"The CNode with UniqueId: "
<<
cnode
->
UniqueId
()
<<
" and UniqueIdThroughCopy: "
<<
cnode
->
UniqueIdThroughCopy
()
<<
" is set OperatorInfo: "
<<
current_op_ptr
->
name
()
<<
", Primitive: "
<<
prim
->
name
();
...
...
@@ -549,7 +549,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
PrimitivePtr
prim
=
GetValueNode
<
PrimitivePtr
>
(
prim_anf_node
);
size_t
edge_count
=
0
;
auto
node_op_info
=
cnode
->
GetUserD
ata
<
OperatorInfo
>
();
auto
node_op_info
=
cnode
->
user_d
ata
<
OperatorInfo
>
();
for
(
size_t
i
=
1
;
i
<
inputs
.
size
();
++
i
)
{
auto
prev_cnode
=
inputs
[
i
]
->
cast
<
CNodePtr
>
();
...
...
@@ -565,7 +565,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
(
IsAutoParallelCareNode
(
prev_cnode
))
||
(
prev_prim
->
name
()
==
TUPLE_GETITEM
)
||
(
prev_prim
->
name
()
==
DEPEND
);
while
(
bool_result
)
{
if
(
IsAutoParallelCareNode
(
prev_cnode
))
{
auto
prev_op_info
=
prev_cnode
->
GetUserD
ata
<
OperatorInfo
>
();
auto
prev_op_info
=
prev_cnode
->
user_d
ata
<
OperatorInfo
>
();
std
::
string
edge_name
=
prev_op_info
->
name
()
+
OPERATOR_TO_OPERATOR_CONNECTOR
+
node_op_info
->
name
();
// If the edge between these two operators already has been added, then the edge will not be added again.
if
(
entire_costgraph
->
IsEdgeInCostGraph
(
edge_name
,
output_index
,
i
-
1
))
{
...
...
@@ -751,7 +751,7 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
auto
target_cnode
=
target
.
first
->
cast
<
CNodePtr
>
();
auto
input_index
=
target
.
second
;
(
void
)
target_without_duplicate
.
insert
(
std
::
to_string
(
input_index
)
+
target_cnode
->
GetUserD
ata
<
OperatorInfo
>
()
->
name
());
target_cnode
->
user_d
ata
<
OperatorInfo
>
()
->
name
());
}
if
(
target_without_duplicate
.
size
()
<=
1
)
{
continue
;
...
...
@@ -831,7 +831,7 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
auto
target_cnode
=
target
.
first
->
cast
<
CNodePtr
>
();
auto
prim
=
GetValueNode
<
PrimitivePtr
>
(
target_cnode
->
input
(
0
));
auto
input_index
=
target
.
second
;
auto
target_op_info
=
target_cnode
->
GetUserD
ata
<
OperatorInfo
>
();
auto
target_op_info
=
target_cnode
->
user_d
ata
<
OperatorInfo
>
();
std
::
string
edge_name
=
std
::
string
(
IDENTITY_INFO
)
+
OPERATOR_TO_OPERATOR_CONNECTOR
+
target_op_info
->
name
();
// If the edge between these two operators already has been added, then the edge will not be added again.
...
...
@@ -862,7 +862,7 @@ bool FindReshape(const CNodePtr &cnode) {
if
((
cnode
==
nullptr
)
||
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)))
{
return
false
;
}
if
(
!
IsParallelCareNode
(
cnode
)
||
!
cnode
->
HasUserD
ata
<
OperatorInfo
>
())
{
if
(
!
IsParallelCareNode
(
cnode
)
||
!
cnode
->
has_user_d
ata
<
OperatorInfo
>
())
{
return
false
;
}
ValueNodePtr
prim_anf_node
=
cnode
->
input
(
0
)
->
cast
<
ValueNodePtr
>
();
...
...
@@ -884,7 +884,7 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_
if
(
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)))
{
return
false
;
}
auto
node_op_info
=
cnode
->
GetUserD
ata
<
OperatorInfo
>
();
auto
node_op_info
=
cnode
->
user_d
ata
<
OperatorInfo
>
();
if
(
IsParallelCareNode
(
cnode
)
&&
(
node_op_info
!=
nullptr
))
{
*
pre_operator_info
=
node_op_info
;
*
out_index
=
0
;
...
...
@@ -900,7 +900,7 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_
MS_LOG
(
EXCEPTION
)
<<
"tuple get item's second input is not a cnode"
;
}
CNodePtr
pre_cnode
=
pre_node
->
cast
<
CNodePtr
>
();
auto
pre_op_info
=
pre_cnode
->
GetUserD
ata
<
OperatorInfo
>
();
auto
pre_op_info
=
pre_cnode
->
user_d
ata
<
OperatorInfo
>
();
if
(
IsParallelCareNode
(
pre_cnode
)
&&
(
pre_op_info
!=
nullptr
))
{
*
pre_operator_info
=
pre_op_info
;
return
true
;
...
...
@@ -941,7 +941,7 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator
if
(
node_prim
->
name
()
==
DEPEND
&&
node_pair
.
second
!=
1
)
{
continue
;
}
auto
op_info
=
use_apply
->
GetUserD
ata
<
OperatorInfo
>
();
auto
op_info
=
use_apply
->
user_d
ata
<
OperatorInfo
>
();
if
(
IsParallelCareNode
(
use_apply
)
&&
(
op_info
!=
nullptr
))
{
MS_LOG
(
INFO
)
<<
"FindNextNodeStraCosts success prim "
<<
node_prim
->
name
();
*
next_operator_info
=
op_info
;
...
...
@@ -970,7 +970,7 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
int32_t
out_index
=
0
;
OperatorInfoPtr
pre_operator_info
;
std
::
vector
<
std
::
shared_ptr
<
StrategyWithCost
>>
pre_stra_costs
;
auto
operator_info
=
cnode
->
GetUserD
ata
<
OperatorInfo
>
();
auto
operator_info
=
cnode
->
user_d
ata
<
OperatorInfo
>
();
if
(
pre_node
->
isa
<
Parameter
>
())
{
auto
reshape_info
=
std
::
dynamic_pointer_cast
<
ReshapeInfo
>
(
operator_info
);
reshape_info
->
SetCostForReshapeWithParameter
();
...
...
mindspore/ccsrc/frontend/parallel/step_parallel.cc
浏览文件 @
4834a6b3
...
...
@@ -272,7 +272,7 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) {
if
(
!
IsParallelCareNode
(
node
))
{
return
nullptr
;
}
OperatorInfoPtr
distribute_operator
=
node
->
GetUserD
ata
<
OperatorInfo
>
();
OperatorInfoPtr
distribute_operator
=
node
->
user_d
ata
<
OperatorInfo
>
();
if
(
distribute_operator
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"GetDistributeOperator:distribute_operator is nullptr"
;
}
...
...
@@ -409,7 +409,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) {
if
(
prim
->
name
()
==
GET_NEXT
)
{
return
true
;
}
if
((
prim
->
name
()
==
CAST
)
&&
!
cnode
->
HasUserD
ata
<
OperatorInfo
>
())
{
if
((
prim
->
name
()
==
CAST
)
&&
!
cnode
->
has_user_d
ata
<
OperatorInfo
>
())
{
return
false
;
}
...
...
@@ -446,7 +446,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
if
(
node_prim
->
name
()
==
DEPEND
&&
node_pair
.
second
!=
1
)
{
continue
;
}
if
(
IsParallelCareNode
(
use_cnode
)
&&
use_cnode
->
HasUserD
ata
<
OperatorInfo
>
())
{
if
(
IsParallelCareNode
(
use_cnode
)
&&
use_cnode
->
has_user_d
ata
<
OperatorInfo
>
())
{
Redistribution
(
node_pair
,
distribute_operator
,
insert_node_new
,
node_pair
.
second
,
tensor_redistribution
,
pre_node
);
}
else
{
...
...
@@ -459,7 +459,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
void
SplitTensor
(
const
AnfNodePtr
&
node
,
const
CNodePtr
&
next_node
,
int
index
)
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
next_node
);
OperatorInfoPtr
op_info
=
next_node
->
GetUserD
ata
<
OperatorInfo
>
();
OperatorInfoPtr
op_info
=
next_node
->
user_d
ata
<
OperatorInfo
>
();
MS_EXCEPTION_IF_NULL
(
op_info
);
// If the shape of tensor is [] or [1], no need to split it.
...
...
@@ -584,7 +584,7 @@ void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) {
void
StepReplaceOp
(
OperatorVector
replace_op
,
const
CNodePtr
&
node
)
{
// step1:get graph manager distribute_operator
OperatorInfoPtr
distribute_operator
=
node
->
GetUserD
ata
<
OperatorInfo
>
();
OperatorInfoPtr
distribute_operator
=
node
->
user_d
ata
<
OperatorInfo
>
();
if
(
distribute_operator
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failure:AddNode error since distribute_operator is nullptr"
;
}
...
...
@@ -622,7 +622,7 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
(
void
)
prim
->
SetAttrs
(
attrs
);
}
if
(
index
==
replace_op
.
size
()
-
1
)
{
replace_node
->
SetUserData
<
OperatorInfo
>
(
node
->
GetUserD
ata
<
OperatorInfo
>
());
replace_node
->
set_user_data
<
OperatorInfo
>
(
node
->
user_d
ata
<
OperatorInfo
>
());
}
replace_node
->
set_in_forward_flag
(
true
);
replace_input
[
0
]
->
set_scope
(
scope
);
...
...
@@ -702,7 +702,7 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) {
auto
pre_cnode
=
pre_node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
pre_cnode
);
auto
pre_prim
=
GetValueNode
<
PrimitivePtr
>
(
pre_cnode
->
input
(
0
));
if
(
pre_prim
->
name
()
==
CAST
&&
!
pre_cnode
->
HasUserD
ata
<
OperatorInfo
>
())
{
if
(
pre_prim
->
name
()
==
CAST
&&
!
pre_cnode
->
has_user_d
ata
<
OperatorInfo
>
())
{
pre_node
=
pre_cnode
->
input
(
1
);
}
...
...
@@ -1198,7 +1198,7 @@ std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node) {
if
(
node_prim
->
name
()
==
DEPEND
&&
node_pair
.
second
!=
1
)
{
continue
;
}
if
(
IsParallelCareNode
(
cnode
)
&&
cnode
->
HasUserD
ata
<
OperatorInfo
>
())
{
if
(
IsParallelCareNode
(
cnode
)
&&
cnode
->
has_user_d
ata
<
OperatorInfo
>
())
{
return
node_pair
;
}
else
if
(
FindParallelCareNode
(
node_pair
.
first
).
first
!=
nullptr
)
{
return
FindParallelCareNode
(
node_pair
.
first
);
...
...
@@ -1248,7 +1248,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i
MS_LOG
(
DEBUG
)
<<
"SetParallelShape "
<<
parameter
->
ToString
()
<<
" shape "
<<
parameter
->
Shape
()
->
ToString
();
CNodePtr
cnode
=
res
.
first
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
OperatorInfoPtr
distribute_operator
=
cnode
->
GetUserD
ata
<
OperatorInfo
>
();
OperatorInfoPtr
distribute_operator
=
cnode
->
user_d
ata
<
OperatorInfo
>
();
if
(
distribute_operator
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failure:node "
<<
cnode
->
ToString
()
<<
" 's OperatorInfoPtr is nullptr"
;
}
...
...
@@ -1271,7 +1271,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i
TensorLayout
tensor_layout
=
tensorinfo_in
.
tensor_layout
();
ParameterPtr
parameter_ptr
=
parameter
->
cast
<
ParameterPtr
>
();
MS_EXCEPTION_IF_NULL
(
parameter_ptr
);
parameter_ptr
->
SetUserD
ata
<
TensorLayout
>
(
std
::
make_shared
<
TensorLayout
>
(
tensor_layout
));
parameter_ptr
->
set_user_d
ata
<
TensorLayout
>
(
std
::
make_shared
<
TensorLayout
>
(
tensor_layout
));
}
void
CoverSliceShape
(
const
FuncGraphPtr
&
root
)
{
...
...
@@ -1359,7 +1359,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
if
(
found_be_cloned_parameter
)
{
// set the shape and tensor layout for cloned parameter
cloned_parameter
->
SetUserData
<
TensorLayout
>
(
cloned_from_parameter
->
GetUserD
ata
<
TensorLayout
>
());
cloned_parameter
->
set_user_data
<
TensorLayout
>
(
cloned_from_parameter
->
user_d
ata
<
TensorLayout
>
());
MS_EXCEPTION_IF_NULL
(
cloned_parameter_node
->
abstract
());
MS_EXCEPTION_IF_NULL
(
cloned_from_node
->
abstract
());
auto
cloned_abstract
=
cloned_parameter_node
->
abstract
()
->
Clone
();
...
...
@@ -1454,7 +1454,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
(
*
operator_
).
set_outputs_dtype
(
cnode
->
Type
());
(
*
operator_
).
set_cnode
(
cnode
);
if
(
prim
->
name
()
==
RESHAPE
)
{
cnode
->
SetUserD
ata
<
OperatorInfo
>
(
operator_
);
cnode
->
set_user_d
ata
<
OperatorInfo
>
(
operator_
);
continue
;
}
// load strategy checkpoint
...
...
@@ -1489,7 +1489,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
if
(
operator_
->
Init
(
strategyPtr
)
==
FAILED
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failure:operator "
<<
prim
->
name
()
<<
" init failed"
;
}
cnode
->
SetUserD
ata
<
OperatorInfo
>
(
operator_
);
cnode
->
set_user_d
ata
<
OperatorInfo
>
(
operator_
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"ERROR:strategy_ptr is nullptr"
;
}
...
...
@@ -1532,13 +1532,13 @@ std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode) {
if
(
node_prim
->
name
()
==
DEPEND
&&
node_pair
.
second
!=
1
)
{
continue
;
}
if
(
IsParallelCareNode
(
use_apply
)
&&
use_apply
->
HasUserD
ata
<
OperatorInfo
>
())
{
if
(
IsParallelCareNode
(
use_apply
)
&&
use_apply
->
has_user_d
ata
<
OperatorInfo
>
())
{
MS_LOG
(
INFO
)
<<
"FindNextLayout success prim "
<<
node_prim
->
name
();
auto
layout
=
GetInputLayoutFromCNode
(
node_pair
);
return
std
::
make_shared
<
TensorLayout
>
(
layout
);
}
MS_LOG
(
DEBUG
)
<<
"FindNextLayout failed prim "
<<
node_prim
->
name
()
<<
" "
<<
IsParallelCareNode
(
use_apply
)
<<
" "
<<
use_apply
->
HasUserD
ata
<
OperatorInfo
>
();
<<
" "
<<
use_apply
->
has_user_d
ata
<
OperatorInfo
>
();
auto
layout_ptr
=
FindNextLayout
(
use_apply
);
if
(
layout_ptr
)
{
...
...
@@ -1570,7 +1570,7 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n
if
(
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)))
{
return
nullptr
;
}
if
(
IsParallelCareNode
(
cnode
)
&&
cnode
->
HasUserD
ata
<
OperatorInfo
>
())
{
if
(
IsParallelCareNode
(
cnode
)
&&
cnode
->
has_user_d
ata
<
OperatorInfo
>
())
{
auto
layout_ptr
=
GetOutputLayoutFromCNode
(
cnode
,
output_index
);
if
(
!
layout_ptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failure:GetLayoutFromCNode failed"
;
...
...
@@ -1614,7 +1614,7 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) {
if
(
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)))
{
return
nullptr
;
}
if
(
IsParallelCareNode
(
cnode
)
&&
cnode
->
HasUserD
ata
<
OperatorInfo
>
())
{
if
(
IsParallelCareNode
(
cnode
)
&&
cnode
->
has_user_d
ata
<
OperatorInfo
>
())
{
auto
layout_ptr
=
GetOutputLayoutFromCNode
(
cnode
,
0
);
if
(
!
layout_ptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failure:GetLayoutFromCNode failed"
;
...
...
@@ -1654,12 +1654,12 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
continue
;
}
ValueNodePtr
prim_anf_node
=
cnode
->
input
(
0
)
->
cast
<
ValueNodePtr
>
();
if
(
!
IsParallelCareNode
(
cnode
)
||
!
cnode
->
HasUserD
ata
<
OperatorInfo
>
())
{
if
(
!
IsParallelCareNode
(
cnode
)
||
!
cnode
->
has_user_d
ata
<
OperatorInfo
>
())
{
continue
;
}
PrimitivePtr
prim
=
GetValueNode
<
PrimitivePtr
>
(
prim_anf_node
);
MS_EXCEPTION_IF_NULL
(
prim
);
OperatorInfoPtr
operator_info
=
cnode
->
GetUserD
ata
<
OperatorInfo
>
();
OperatorInfoPtr
operator_info
=
cnode
->
user_d
ata
<
OperatorInfo
>
();
if
(
operator_info
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failure:Primitive "
<<
prim
->
ToString
()
<<
" OperatorInstance is nullptr"
;
}
...
...
@@ -1704,7 +1704,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
auto
current_prim
=
GetValueNode
<
PrimitivePtr
>
(
pre_cnode
->
input
(
0
));
// return -> cast
if
(
current_prim
->
name
()
==
CAST
&&
!
pre_cnode
->
HasUserD
ata
<
OperatorInfo
>
())
{
if
(
current_prim
->
name
()
==
CAST
&&
!
pre_cnode
->
has_user_d
ata
<
OperatorInfo
>
())
{
pre_cnode
=
pre_cnode
->
input
(
1
)
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
pre_cnode
);
current_prim
=
GetValueNode
<
PrimitivePtr
>
(
pre_cnode
->
input
(
0
));
...
...
@@ -1761,7 +1761,7 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) {
return
ret
;
}
OperatorInfoPtr
operator_info
=
loss_cnode
->
GetUserD
ata
<
OperatorInfo
>
();
OperatorInfoPtr
operator_info
=
loss_cnode
->
user_d
ata
<
OperatorInfo
>
();
MS_EXCEPTION_IF_NULL
(
operator_info
);
TensorInfo
loss_grad_tensor_info
;
size_t
op_output_size
=
operator_info
->
outputs_tensor_info
().
size
();
...
...
@@ -1799,7 +1799,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay
if
(
sens_tensor_node
->
isa
<
Parameter
>
())
{
auto
sens_tensor_param
=
sens_tensor_node
->
cast
<
ParameterPtr
>
();
MS_LOG
(
DEBUG
)
<<
"loss layout "
<<
loss_grad_layout
.
ToString
();
sens_tensor_param
->
SetUserD
ata
<
TensorLayout
>
(
std
::
make_shared
<
TensorLayout
>
(
loss_grad_layout
));
sens_tensor_param
->
set_user_d
ata
<
TensorLayout
>
(
std
::
make_shared
<
TensorLayout
>
(
loss_grad_layout
));
}
MS_LOG
(
INFO
)
<<
"The shape of sens is "
<<
ShapeToString
(
sens_shape
)
<<
", no need to split sens"
;
return
;
...
...
@@ -1824,7 +1824,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay
cloned_abstract
->
set_shape
(
parallel_shape
);
sens_tensor_node
->
set_abstract
(
cloned_abstract
);
auto
sens_tensor_param
=
sens_tensor_node
->
cast
<
ParameterPtr
>
();
sens_tensor_param
->
SetUserD
ata
<
TensorLayout
>
(
std
::
make_shared
<
TensorLayout
>
(
loss_grad_layout
));
sens_tensor_param
->
set_user_d
ata
<
TensorLayout
>
(
std
::
make_shared
<
TensorLayout
>
(
loss_grad_layout
));
return
;
}
MS_LOG
(
EXCEPTION
)
<<
"The type of sens node is not Tensor or Parameter, it is unsupported now."
;
...
...
@@ -2131,7 +2131,7 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
}
PrimitivePtr
prim
=
GetValueNode
<
PrimitivePtr
>
(
cnode
->
input
(
0
));
MS_EXCEPTION_IF_NULL
(
prim
);
OperatorInfoPtr
operator_info
=
cnode
->
GetUserD
ata
<
OperatorInfo
>
();
OperatorInfoPtr
operator_info
=
cnode
->
user_d
ata
<
OperatorInfo
>
();
if
(
operator_info
)
{
if
(
operator_info
->
name
().
find
(
RESHAPEINFO
)
!=
std
::
string
::
npos
)
{
continue
;
...
...
mindspore/core/ir/anf.h
浏览文件 @
4834a6b3
...
...
@@ -158,29 +158,29 @@ class AnfNode : public Base {
size_t
seen_
{
0
};
template
<
typename
T
>
void
SetUserD
ata
(
const
std
::
string
&
key
,
const
std
::
shared_ptr
<
T
>
&
value
)
{
void
set_user_d
ata
(
const
std
::
string
&
key
,
const
std
::
shared_ptr
<
T
>
&
value
)
{
user_data_
.
set
<
T
>
(
key
,
value
);
}
template
<
typename
T
>
void
SetUserD
ata
(
const
std
::
shared_ptr
<
T
>
&
value
)
{
void
set_user_d
ata
(
const
std
::
shared_ptr
<
T
>
&
value
)
{
user_data_
.
set
<
T
>
(
T
::
key
,
value
);
}
template
<
typename
T
>
std
::
shared_ptr
<
T
>
GetUserD
ata
(
const
std
::
string
&
key
)
const
{
std
::
shared_ptr
<
T
>
user_d
ata
(
const
std
::
string
&
key
)
const
{
return
user_data_
.
get
<
T
>
(
key
);
}
template
<
typename
T
>
std
::
shared_ptr
<
T
>
GetUserD
ata
()
const
{
std
::
shared_ptr
<
T
>
user_d
ata
()
const
{
return
user_data_
.
get
<
T
>
(
T
::
key
);
}
bool
HasUserD
ata
(
const
std
::
string
&
key
)
const
{
return
user_data_
.
has
(
key
);
}
bool
has_user_d
ata
(
const
std
::
string
&
key
)
const
{
return
user_data_
.
has
(
key
);
}
template
<
typename
T
>
bool
HasUserD
ata
()
const
{
bool
has_user_d
ata
()
const
{
return
user_data_
.
has
(
T
::
key
);
}
...
...
tests/ut/cpp/parallel/step_auto_parallel_test.cc
浏览文件 @
4834a6b3
...
...
@@ -153,7 +153,7 @@ TEST_F(TestStepAutoParallel, test_create_op_instance) {
StrategyPtr
strategyPtr
;
std
::
shared_ptr
<
OperatorInfo
>
matmul_info
=
NewOperatorInstance
(
prim
,
attrs
,
shape
);
node
->
SetUserD
ata
<
OperatorInfo
>
(
matmul_info
);
node
->
set_user_d
ata
<
OperatorInfo
>
(
matmul_info
);
std
::
string
name_expect
=
"MatMulInfo00"
;
std
::
string
name_test
=
matmul_info
->
name
();
ASSERT_EQ
(
name_expect
,
name_test
);
...
...
tests/ut/cpp/parallel/step_parallel_test.cc
浏览文件 @
4834a6b3
...
...
@@ -522,8 +522,8 @@ TEST_F(TestStepParallel, GetTensorInLayout) {
std
::
vector
<
Shapes
>
shape
=
{
inputs_shape
,
outputs_shape
};
OperatorInfoPtr
matmul_info
=
OperatorInstance
(
prim
,
attrs
,
shape
);
matmul_info
->
Init
(
strategyPtr
);
node
->
SetUserD
ata
<
OperatorInfo
>
(
matmul_info
);
OperatorInfoPtr
distribute_operator_pre
=
node
->
GetUserD
ata
<
OperatorInfo
>
();
node
->
set_user_d
ata
<
OperatorInfo
>
(
matmul_info
);
OperatorInfoPtr
distribute_operator_pre
=
node
->
user_d
ata
<
OperatorInfo
>
();
TensorLayout
tensorlayout_e
;
std
::
vector
<
int32_t
>
array
=
{
64
,
64
};
TensorLayout
tensorlayout
=
GetTensorInLayout
(
node1
,
prim
,
distribute_operator_pre
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录