Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
20afadb4
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
20afadb4
编写于
6月 03, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 03, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1822 fix codex split big functions
Merge pull request !1822 from fary86/codex_big_functions
上级
a193d097
abfaf159
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
234 addition
and
193 deletion
+234
-193
mindspore/ccsrc/debug/trace.cc
mindspore/ccsrc/debug/trace.cc
+49
-42
mindspore/ccsrc/operator/composite/do_signature.cc
mindspore/ccsrc/operator/composite/do_signature.cc
+61
-66
mindspore/ccsrc/pipeline/parse/data_converter.cc
mindspore/ccsrc/pipeline/parse/data_converter.cc
+30
-25
mindspore/ccsrc/pipeline/pipeline.cc
mindspore/ccsrc/pipeline/pipeline.cc
+1
-1
mindspore/ccsrc/pipeline/pipeline.h
mindspore/ccsrc/pipeline/pipeline.h
+1
-1
mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc
mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc
+84
-58
mindspore/ccsrc/pipeline/static_analysis/static_analysis.h
mindspore/ccsrc/pipeline/static_analysis/static_analysis.h
+8
-0
未找到文件。
mindspore/ccsrc/debug/trace.cc
浏览文件 @
20afadb4
...
...
@@ -124,6 +124,8 @@ class AnalyzedFuncGraphExporter : public AnfExporter {
void
ExportOneFuncGraph
(
std
::
ofstream
&
ofs
,
const
FuncGraphPtr
&
func_graph
);
void
OutputCNodes
(
std
::
ofstream
&
ofs
,
const
std
::
vector
<
AnfNodePtr
>
&
nodes
,
const
FuncGraphPtr
&
func_graph
);
void
OutputCNode
(
std
::
ofstream
&
ofs
,
const
CNodePtr
&
cnode
,
const
FuncGraphPtr
&
func_graph
,
int
*
idx
,
std
::
map
<
AnfNodePtr
,
int
>
*
const
apply_map
);
private:
std
::
string
GetNodeType
(
const
AnfNodePtr
&
nd
)
override
;
...
...
@@ -169,7 +171,7 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
}
auto
abs
=
ret
->
abstract
();
if
(
abs
==
nullptr
)
{
return
nullptr
;
return
"Undefined"
;
}
auto
dtype
=
abs
->
BuildType
();
auto
shape
=
abs
->
BuildShape
();
...
...
@@ -247,6 +249,51 @@ AnalysisContextPtr AnalyzedFuncGraphExporter::ProcessFuncGraphCall(const CNodePt
return
ctx
;
}
void
AnalyzedFuncGraphExporter
::
OutputCNode
(
std
::
ofstream
&
ofs
,
const
CNodePtr
&
cnode
,
const
FuncGraphPtr
&
func_graph
,
int
*
idx
,
std
::
map
<
AnfNodePtr
,
int
>
*
const
apply_map
)
{
auto
&
inputs
=
cnode
->
inputs
();
std
::
string
op_text
=
GetAnfNodeText
(
func_graph
,
inputs
[
0
],
*
apply_map
);
// non-return node
if
(
cnode
!=
func_graph
->
get_return
())
{
int
apply_idx
=
(
*
idx
)
++
;
(
*
apply_map
)[
cnode
]
=
apply_idx
;
std
::
string
type_info
=
GetNodeType
(
cnode
);
if
(
type_info
==
"Undefined"
)
{
ofs
<<
" %"
<<
apply_idx
<<
" = "
<<
op_text
<<
"("
;
}
else
{
ofs
<<
" %"
<<
apply_idx
<<
" : "
<<
type_info
<<
" = "
<<
op_text
<<
"("
;
}
}
else
{
ofs
<<
" "
<<
op_text
<<
"("
;
}
for
(
size_t
i
=
1
;
i
<
inputs
.
size
();
++
i
)
{
if
(
i
!=
1
)
{
ofs
<<
", "
;
}
AnfNodePtr
arg
=
inputs
[
i
];
ofs
<<
GetAnfNodeText
(
func_graph
,
arg
,
*
apply_map
);
}
ofs
<<
")"
;
// process function graph call
auto
ctx
=
ProcessFuncGraphCall
(
cnode
);
// output comment
OutputStatementComment
(
ofs
,
cnode
);
if
(
ctx
!=
nullptr
)
{
ofs
<<
" @ctx.addr="
<<
ctx
.
get
();
}
ofs
<<
"
\n
"
;
if
(
label_manage
::
GetGlobalTraceLabelType
()
==
label_manage
::
TraceLabelType
::
kWithUniqueId
)
{
ofs
<<
trace
::
GetDebugInfo
(
cnode
->
debug_info
(),
" # "
,
kSourceLineTipDiscard
)
<<
"#"
<<
label_manage
::
Label
(
cnode
->
debug_info
())
<<
"
\n
"
;
}
else
{
ofs
<<
trace
::
GetDebugInfo
(
cnode
->
debug_info
(),
" # "
,
kSourceLineTipDiscard
)
<<
"
\n
"
;
}
}
void
AnalyzedFuncGraphExporter
::
OutputCNodes
(
std
::
ofstream
&
ofs
,
const
std
::
vector
<
AnfNodePtr
>
&
nodes
,
const
FuncGraphPtr
&
func_graph
)
{
if
(
func_graph
==
nullptr
)
{
...
...
@@ -267,47 +314,7 @@ void AnalyzedFuncGraphExporter::OutputCNodes(std::ofstream &ofs, const std::vect
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
auto
&
inputs
=
cnode
->
inputs
();
std
::
string
op_text
=
GetAnfNodeText
(
func_graph
,
inputs
[
0
],
apply_map
);
// non-return node
if
(
node
!=
func_graph
->
get_return
())
{
int
apply_idx
=
idx
++
;
apply_map
[
node
]
=
apply_idx
;
std
::
string
type_info
=
GetNodeType
(
node
);
if
(
type_info
==
"Undefined"
)
{
ofs
<<
" %"
<<
apply_idx
<<
" = "
<<
op_text
<<
"("
;
}
else
{
ofs
<<
" %"
<<
apply_idx
<<
" : "
<<
type_info
<<
" = "
<<
op_text
<<
"("
;
}
}
else
{
ofs
<<
" "
<<
op_text
<<
"("
;
}
for
(
size_t
i
=
1
;
i
<
inputs
.
size
();
++
i
)
{
if
(
i
!=
1
)
{
ofs
<<
", "
;
}
AnfNodePtr
arg
=
inputs
[
i
];
ofs
<<
GetAnfNodeText
(
func_graph
,
arg
,
apply_map
);
}
ofs
<<
")"
;
// process function graph call
auto
ctx
=
ProcessFuncGraphCall
(
cnode
);
// output comment
OutputStatementComment
(
ofs
,
cnode
);
if
(
ctx
!=
nullptr
)
{
ofs
<<
" @ctx.addr="
<<
ctx
.
get
();
}
ofs
<<
"
\n
"
;
if
(
label_manage
::
GetGlobalTraceLabelType
()
==
label_manage
::
TraceLabelType
::
kWithUniqueId
)
{
ofs
<<
trace
::
GetDebugInfo
(
cnode
->
debug_info
(),
" # "
,
kSourceLineTipDiscard
)
<<
"#"
<<
label_manage
::
Label
(
cnode
->
debug_info
())
<<
"
\n
"
;
}
else
{
ofs
<<
trace
::
GetDebugInfo
(
cnode
->
debug_info
(),
" # "
,
kSourceLineTipDiscard
)
<<
"
\n
"
;
}
OutputCNode
(
ofs
,
cnode
,
func_graph
,
&
idx
,
&
apply_map
);
}
}
...
...
mindspore/ccsrc/operator/composite/do_signature.cc
浏览文件 @
20afadb4
...
...
@@ -76,44 +76,56 @@ bool CompareTensorScalarType(const TypeId &tensor_type, const size_t &t_type_num
return
true
;
}
void
s
etMaxType
(
TypeId
*
max_type_id
,
TypeId
*
max_type
,
size_t
*
max_type_number
,
const
TypeId
type_id
,
const
TypeId
type
,
void
S
etMaxType
(
TypeId
*
max_type_id
,
TypeId
*
max_type
,
size_t
*
max_type_number
,
const
TypeId
type_id
,
const
TypeId
type
,
const
size_t
type_number
)
{
*
max_type_id
=
type_id
;
*
max_type
=
type
;
*
max_type_number
=
type_number
;
}
TypeId
GetMaxTypeId
(
const
abstract
::
AbstractBasePtrList
&
args_spec_list
,
std
::
vector
<
size_t
>
indexs
,
const
std
::
set
<
size_t
>
&
write_indexs
)
{
bool
GetTensorOrScalarTypeInfo
(
AbstractBasePtr
arg_value
,
bool
is_write
,
TypeId
*
arg_type_id
,
TypeId
*
arg_type
=
nullptr
)
{
if
(
arg_value
->
isa
<
abstract
::
AbstractRef
>
())
{
if
(
is_write
)
{
arg_value
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
()
->
ref_origin
();
}
else
{
arg_value
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
()
->
ref
();
}
}
if
(
arg_value
->
isa
<
abstract
::
AbstractTensor
>
())
{
auto
tensor
=
arg_value
->
cast
<
abstract
::
AbstractTensorPtr
>
();
auto
tensor_type
=
tensor
->
element
()
->
BuildType
();
MS_EXCEPTION_IF_NULL
(
tensor_type
);
*
arg_type_id
=
tensor_type
->
type_id
();
if
(
arg_type
!=
nullptr
)
{
*
arg_type
=
kObjectTypeTensorType
;
}
return
true
;
}
if
(
arg_value
->
isa
<
abstract
::
AbstractScalar
>
())
{
auto
scalar
=
arg_value
->
cast
<
abstract
::
AbstractScalarPtr
>
();
auto
scalar_type
=
scalar
->
BuildType
();
MS_EXCEPTION_IF_NULL
(
scalar_type
);
*
arg_type_id
=
scalar_type
->
type_id
();
if
(
arg_type
!=
nullptr
)
{
*
arg_type
=
kObjectTypeNumber
;
}
return
true
;
}
return
false
;
}
TypeId
GetMaxTypeId
(
const
abstract
::
AbstractBasePtrList
&
args_spec_list
,
std
::
vector
<
size_t
>
indices
,
const
std
::
set
<
size_t
>
&
write_indices
)
{
TypeId
max_type_id
=
kTypeUnknown
;
TypeId
max_type
=
kTypeUnknown
;
size_t
max_type_number
=
0
;
bool
has_int8
=
false
;
for
(
const
auto
&
index
:
ind
ex
s
)
{
for
(
const
auto
&
index
:
ind
ice
s
)
{
TypeId
arg_type_id
=
kTypeUnknown
;
TypeId
arg_type
=
kTypeUnknown
;
AbstractBasePtr
arg_value
=
args_spec_list
[
index
];
if
(
arg_value
->
isa
<
abstract
::
AbstractRef
>
())
{
auto
is_write
=
(
write_indexs
.
find
(
index
)
!=
write_indexs
.
end
());
if
(
is_write
)
{
arg_value
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
()
->
ref_origin
();
}
else
{
arg_value
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
()
->
ref
();
}
}
if
(
arg_value
->
isa
<
abstract
::
AbstractTensor
>
())
{
auto
tensor
=
arg_value
->
cast
<
abstract
::
AbstractTensorPtr
>
();
auto
tensor_type
=
tensor
->
element
()
->
BuildType
();
MS_EXCEPTION_IF_NULL
(
tensor_type
);
arg_type_id
=
tensor_type
->
type_id
();
arg_type
=
kObjectTypeTensorType
;
}
else
if
(
arg_value
->
isa
<
abstract
::
AbstractScalar
>
())
{
auto
scalar
=
arg_value
->
cast
<
abstract
::
AbstractScalarPtr
>
();
auto
scalar_type
=
scalar
->
BuildType
();
MS_EXCEPTION_IF_NULL
(
scalar_type
);
arg_type_id
=
scalar_type
->
type_id
();
arg_type
=
kObjectTypeNumber
;
}
else
{
auto
is_write
=
(
write_indices
.
find
(
index
)
!=
write_indices
.
end
());
if
(
!
GetTensorOrScalarTypeInfo
(
args_spec_list
[
index
],
is_write
,
&
arg_type_id
,
&
arg_type
))
{
continue
;
}
auto
it
=
type_map
.
find
(
arg_type_id
);
...
...
@@ -124,22 +136,22 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
has_int8
=
true
;
}
if
(
max_type_id
==
kTypeUnknown
)
{
s
etMaxType
(
&
max_type_id
,
&
max_type
,
&
max_type_number
,
arg_type_id
,
arg_type
,
it
->
second
);
S
etMaxType
(
&
max_type_id
,
&
max_type
,
&
max_type_number
,
arg_type_id
,
arg_type
,
it
->
second
);
continue
;
}
if
(
max_type
==
arg_type
)
{
if
(
it
->
second
>
max_type_number
)
{
s
etMaxType
(
&
max_type_id
,
&
max_type
,
&
max_type_number
,
arg_type_id
,
arg_type
,
it
->
second
);
S
etMaxType
(
&
max_type_id
,
&
max_type
,
&
max_type_number
,
arg_type_id
,
arg_type
,
it
->
second
);
}
}
else
{
if
(
arg_type
==
kObjectTypeTensorType
)
{
if
(
CompareTensorScalarType
(
arg_type_id
,
it
->
second
,
max_type_id
,
max_type_number
))
{
s
etMaxType
(
&
max_type_id
,
&
max_type
,
&
max_type_number
,
arg_type_id
,
arg_type
,
it
->
second
);
S
etMaxType
(
&
max_type_id
,
&
max_type
,
&
max_type_number
,
arg_type_id
,
arg_type
,
it
->
second
);
}
}
else
{
if
(
!
CompareTensorScalarType
(
max_type_id
,
max_type_number
,
arg_type_id
,
it
->
second
))
{
s
etMaxType
(
&
max_type_id
,
&
max_type
,
&
max_type_number
,
arg_type_id
,
arg_type
,
it
->
second
);
S
etMaxType
(
&
max_type_id
,
&
max_type
,
&
max_type_number
,
arg_type_id
,
arg_type
,
it
->
second
);
}
}
}
...
...
@@ -154,28 +166,28 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
// Get the largest type of index in the same SignatureEnumDType of arguments.
std
::
map
<
SignatureEnumDType
,
TypeId
>
GetMaxDtype
(
const
std
::
vector
<
SignatureEnumDType
>
&
dtypes
,
const
abstract
::
AbstractBasePtrList
&
args_spec_list
,
const
std
::
set
<
size_t
>
&
write_ind
ex
s
)
{
const
std
::
set
<
size_t
>
&
write_ind
ice
s
)
{
// record index for signature.dtypes of the same type
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
std
::
map
<
SignatureEnumDType
,
std
::
vector
<
size_t
>>
type_ind
ex
s
;
std
::
map
<
SignatureEnumDType
,
std
::
vector
<
size_t
>>
type_ind
ice
s
;
for
(
size_t
i
=
0
;
i
<
dtypes
.
size
();
++
i
)
{
auto
it
=
type_ind
ex
s
.
find
(
dtypes
[
i
]);
if
(
it
==
type_ind
ex
s
.
end
())
{
(
void
)
type_ind
ex
s
.
insert
(
std
::
make_pair
(
dtypes
[
i
],
std
::
vector
<
size_t
>
{
i
}));
auto
it
=
type_ind
ice
s
.
find
(
dtypes
[
i
]);
if
(
it
==
type_ind
ice
s
.
end
())
{
(
void
)
type_ind
ice
s
.
insert
(
std
::
make_pair
(
dtypes
[
i
],
std
::
vector
<
size_t
>
{
i
}));
}
else
{
it
->
second
.
push_back
(
i
);
}
}
std
::
map
<
SignatureEnumDType
,
TypeId
>
dst_type
;
for
(
auto
it
=
type_ind
exs
.
begin
();
it
!=
type_index
s
.
end
();
(
void
)
++
it
)
{
for
(
auto
it
=
type_ind
ices
.
begin
();
it
!=
type_indice
s
.
end
();
(
void
)
++
it
)
{
auto
type
=
it
->
first
;
auto
ind
ex
s
=
it
->
second
;
auto
ind
ice
s
=
it
->
second
;
// If the number of arguments belonging to the same SignatureEnumDType is less than 2, skip it.
if
(
ind
ex
s
.
size
()
<
2
)
{
if
(
ind
ice
s
.
size
()
<
2
)
{
continue
;
}
bool
has_tensor
=
false
;
for
(
const
auto
&
index
:
ind
ex
s
)
{
for
(
const
auto
&
index
:
ind
ice
s
)
{
AbstractBasePtr
arg_value
=
args_spec_list
[
index
];
if
(
arg_value
->
isa
<
abstract
::
AbstractRef
>
())
{
arg_value
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
()
->
ref
();
...
...
@@ -189,7 +201,7 @@ std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnum
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
kTypeUnknown
));
continue
;
}
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
GetMaxTypeId
(
args_spec_list
,
ind
exs
,
write_index
s
)));
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
GetMaxTypeId
(
args_spec_list
,
ind
ices
,
write_indice
s
)));
}
return
dst_type
;
}
...
...
@@ -204,7 +216,7 @@ AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGrap
void
DoAutoCast
(
const
std
::
string
&
func_name
,
const
std
::
vector
<
Signature
>
&
signature
,
const
abstract
::
AbstractBasePtrList
&
args_spec_list
,
const
FuncGraphPtr
&
graph
,
std
::
vector
<
AnfNodePtr
>
*
const
op_inputs
,
const
std
::
set
<
size_t
>
&
write_ind
ex
s
)
{
std
::
vector
<
AnfNodePtr
>
*
const
op_inputs
,
const
std
::
set
<
size_t
>
&
write_ind
ice
s
)
{
std
::
vector
<
SignatureEnumDType
>
dtypes
;
(
void
)
std
::
transform
(
signature
.
begin
(),
signature
.
end
(),
std
::
back_inserter
(
dtypes
),
[](
const
Signature
&
sig
)
{
return
sig
.
dtype
;
});
...
...
@@ -213,36 +225,19 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
return
;
}
// Stat the index of the arguments with the largest type in the same SignatureEnumDType.
std
::
map
<
SignatureEnumDType
,
TypeId
>
dst_type
=
GetMaxDtype
(
dtypes
,
args_spec_list
,
write_ind
ex
s
);
std
::
map
<
SignatureEnumDType
,
TypeId
>
dst_type
=
GetMaxDtype
(
dtypes
,
args_spec_list
,
write_ind
ice
s
);
// Identify which arg requires auto cast
for
(
size_t
i
=
0
;
i
<
args_spec_list
.
size
();
++
i
)
{
auto
it
=
dst_type
.
find
(
dtypes
[
i
]);
if
(
it
==
dst_type
.
end
()
||
it
->
second
==
kTypeUnknown
)
{
continue
;
}
auto
rw_it
=
write_ind
ex
s
.
find
(
i
);
auto
is_write
=
(
rw_it
!=
write_ind
ex
s
.
end
());
auto
rw_it
=
write_ind
ice
s
.
find
(
i
);
auto
is_write
=
(
rw_it
!=
write_ind
ice
s
.
end
());
AbstractBasePtr
arg_value
=
args_spec_list
[
i
];
if
(
arg_value
->
isa
<
abstract
::
AbstractRef
>
())
{
if
(
is_write
)
{
arg_value
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
()
->
ref_origin
();
}
else
{
arg_value
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
()
->
ref
();
}
}
TypeId
arg_type_id
=
kTypeUnknown
;
if
(
arg_value
->
isa
<
abstract
::
AbstractTensor
>
())
{
auto
tensor
=
arg_value
->
cast
<
abstract
::
AbstractTensorPtr
>
();
auto
tensor_type
=
tensor
->
element
()
->
BuildType
();
MS_EXCEPTION_IF_NULL
(
tensor_type
);
arg_type_id
=
tensor_type
->
type_id
();
}
else
if
(
arg_value
->
isa
<
abstract
::
AbstractScalar
>
())
{
auto
scalar
=
arg_value
->
cast
<
abstract
::
AbstractScalarPtr
>
();
auto
scalar_type
=
scalar
->
BuildType
();
MS_EXCEPTION_IF_NULL
(
scalar_type
);
arg_type_id
=
scalar_type
->
type_id
();
}
AbstractBasePtr
arg_value
=
args_spec_list
[
i
];
(
void
)
GetTensorOrScalarTypeInfo
(
arg_value
,
is_write
,
&
arg_type_id
);
auto
it_map
=
type_map
.
find
(
arg_type_id
);
if
(
it_map
==
type_map
.
end
())
{
continue
;
...
...
@@ -279,7 +274,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
}
}
std
::
vector
<
AnfNodePtr
>
op_inputs
;
std
::
set
<
size_t
>
write_ind
ex
s
;
std
::
set
<
size_t
>
write_ind
ice
s
;
op_inputs
.
push_back
(
NewValueNode
(
function
));
// Assume, the write input of op is always the first input. We check if any write op,
// and add cast op on other inputs to keep the same type with assigned parameter.
...
...
@@ -303,7 +298,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
param
=
func_graph
->
NewCNode
({
NewValueNode
(
prim
::
kPrimGetRefValue
),
param
});
}
else
if
(
sig
==
SignatureEnumRW
::
kRWWrite
)
{
param
=
func_graph
->
NewCNode
({
NewValueNode
(
prim
::
kPrimGetRefOrigin
),
param
});
write_ind
ex
s
.
insert
(
i
);
write_ind
ice
s
.
insert
(
i
);
}
// If sig is SignatureEnumRW::kRWRef, not do anything.
}
else
if
(
sig
==
SignatureEnumRW
::
kRWWrite
&&
type
->
type_id
()
!=
kObjectTypeRefKey
)
{
...
...
@@ -313,7 +308,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
}
// process default
ProcessDefault
(
func_name
,
args_spec_list
,
signature
,
has_var
,
&
op_inputs
);
DoAutoCast
(
func_name
,
signature
,
args_spec_list
,
func_graph
,
&
op_inputs
,
write_ind
ex
s
);
DoAutoCast
(
func_name
,
signature
,
args_spec_list
,
func_graph
,
&
op_inputs
,
write_ind
ice
s
);
return
func_graph
->
NewCNode
(
op_inputs
);
}
}
// namespace
...
...
mindspore/ccsrc/pipeline/parse/data_converter.cc
浏览文件 @
20afadb4
...
...
@@ -238,6 +238,31 @@ FuncGraphPtr ConvertToBpropCut(py::object obj) {
return
bprop_graph
;
}
bool
ConvertCellObjToFuncGraph
(
py
::
object
obj
,
ValuePtr
*
const
data
)
{
FuncGraphPtr
func_graph
=
ConvertToFuncGraph
(
obj
);
if
(
func_graph
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Parse resolve function error."
;
return
false
;
}
// if the cell object has specified bprop, it has user-defined bprop function parse and record it
if
(
py
::
hasattr
(
obj
,
"bprop"
))
{
FuncGraphPtr
bprop_graph
=
nullptr
;
bool
enable_bprop_debug
=
py
::
cast
<
bool
>
(
py
::
getattr
(
obj
,
"bprop_debug"
));
if
(
enable_bprop_debug
)
{
bprop_graph
=
ConvertToBpropCut
(
obj
);
}
else
{
bprop_graph
=
ConvertToFuncGraph
(
obj
,
PYTHON_MOD_GET_BPROP_METHOD
);
}
if
(
bprop_graph
!=
nullptr
)
{
(
void
)
func_graph
->
transforms
().
insert
(
std
::
make_pair
(
"bprop"
,
FuncGraphTransform
(
bprop_graph
)));
(
void
)
bprop_graph
->
transforms
().
insert
(
std
::
make_pair
(
"primal"
,
FuncGraphTransform
(
func_graph
)));
func_graph
->
set_flags
(
FUNC_GRAPH_FLAG_DEFER_INLINE
,
true
);
}
}
*
data
=
func_graph
;
return
true
;
}
bool
ConvertOtherObj
(
py
::
object
obj
,
ValuePtr
*
const
data
)
{
auto
obj_type
=
data_converter
::
GetObjType
(
obj
);
MS_LOG
(
DEBUG
)
<<
"Converting the object("
<<
((
std
::
string
)
py
::
str
(
obj
))
<<
") detail type: "
<<
obj_type
<<
" "
;
...
...
@@ -262,32 +287,12 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
// Create the namespace for common class instance
// When the obj is Cell, default parse the 'construct'
if
(
data_converter
::
IsCellInstance
(
obj
))
{
FuncGraphPtr
func_graph
=
ConvertToFuncGraph
(
obj
);
if
(
func_graph
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Parse resolve function error."
;
return
false
;
}
// if the cell object has specified bprop, it has user-defined bprop function parse and record it
if
(
py
::
hasattr
(
obj
,
"bprop"
))
{
FuncGraphPtr
bprop_graph
=
nullptr
;
bool
enable_bprop_debug
=
py
::
cast
<
bool
>
(
py
::
getattr
(
obj
,
"bprop_debug"
));
if
(
enable_bprop_debug
)
{
bprop_graph
=
ConvertToBpropCut
(
obj
);
}
else
{
bprop_graph
=
ConvertToFuncGraph
(
obj
,
PYTHON_MOD_GET_BPROP_METHOD
);
}
if
(
bprop_graph
!=
nullptr
)
{
(
void
)
func_graph
->
transforms
().
insert
(
std
::
make_pair
(
"bprop"
,
FuncGraphTransform
(
bprop_graph
)));
(
void
)
bprop_graph
->
transforms
().
insert
(
std
::
make_pair
(
"primal"
,
FuncGraphTransform
(
func_graph
)));
func_graph
->
set_flags
(
FUNC_GRAPH_FLAG_DEFER_INLINE
,
true
);
}
}
*
data
=
func_graph
;
}
else
{
py
::
module
mod
=
python_adapter
::
GetPyModule
(
PYTHON_MOD_PARSE_MODULE
);
py
::
object
namespace_var
=
python_adapter
::
CallPyModFn
(
mod
,
PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL
,
obj
);
*
data
=
std
::
make_shared
<
NameSpace
>
(
RESOLVE_NAMESPACE_NAME_CLASS_MEMBER
,
namespace_var
);
return
ConvertCellObjToFuncGraph
(
obj
,
data
);
}
py
::
module
mod
=
python_adapter
::
GetPyModule
(
PYTHON_MOD_PARSE_MODULE
);
py
::
object
namespace_var
=
python_adapter
::
CallPyModFn
(
mod
,
PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL
,
obj
);
*
data
=
std
::
make_shared
<
NameSpace
>
(
RESOLVE_NAMESPACE_NAME_CLASS_MEMBER
,
namespace_var
);
return
true
;
}
MS_LOG
(
ERROR
)
<<
"Resolve type is invalid "
<<
((
std
::
string
)
py
::
str
(
obj
));
...
...
mindspore/ccsrc/pipeline/pipeline.cc
浏览文件 @
20afadb4
...
...
@@ -608,7 +608,7 @@ void Pipeline::Run() {
MS_LOG
(
INFO
)
<<
"End"
;
}
void
ProcessVmArgInner
(
const
py
::
tuple
&
args
,
const
ResourcePtr
&
res
,
VectorRef
*
arg_list
)
{
void
ProcessVmArgInner
(
const
py
::
tuple
&
args
,
const
ResourcePtr
&
res
,
VectorRef
*
const
arg_list
)
{
std
::
size_t
size
=
args
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
size
;
i
++
)
{
...
...
mindspore/ccsrc/pipeline/pipeline.h
浏览文件 @
20afadb4
...
...
@@ -139,7 +139,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
const
std
::
vector
<
TypePtr
>
&
types
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
shapes
,
const
std
::
vector
<
int64_t
>
&
input_indexes
,
bool
need_run
);
void
ProcessVmArgInner
(
const
py
::
tuple
&
args
,
const
ResourcePtr
&
res
,
VectorRef
*
arg_list
);
void
ProcessVmArgInner
(
const
py
::
tuple
&
args
,
const
ResourcePtr
&
res
,
VectorRef
*
const
arg_list
);
}
// namespace pipeline
}
// namespace mindspore
...
...
mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc
浏览文件 @
20afadb4
...
...
@@ -464,6 +464,85 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
return
ExecuteMultipleEvaluators
(
evaluators
,
out_conf
,
args_conf_list
);
}
void
AnalysisEngine
::
SetUndeterminedFlag
(
const
EvaluatorPtr
&
evaluator
)
{
auto
fg_eval
=
evaluator
->
cast
<
FuncGraphEvaluatorPtr
>
();
if
(
fg_eval
==
nullptr
)
{
return
;
}
auto
fg
=
fg_eval
->
func_graph
();
MS_EXCEPTION_IF_NULL
(
fg
);
auto
undetermined_fgs
=
fg
->
recursive_graphs
();
if
(
undetermined_fgs
)
{
auto
fg_parent
=
fg
->
parent
();
MS_EXCEPTION_IF_NULL
(
fg_parent
);
fg_parent
->
set_flags
(
kFuncGraphFlagUndetermined
,
true
);
MS_LOG
(
DEBUG
)
<<
"Set graph undetermined: "
<<
fg_parent
->
ToString
();
}
}
EvaluatorPtr
AnalysisEngine
::
HandleNestedRecursion
(
const
std
::
vector
<
EvaluatorPtr
>
&
evaluators
,
const
EvaluatorPtr
&
eval
,
const
AbstractBasePtrList
&
args_spec_list
,
const
EvalTraceRevIter
&
it
,
bool
*
continue_flag
)
{
*
continue_flag
=
false
;
// Find latest entry function to handle nested recursion.
EvaluatorPtr
latest_entry
=
eval
;
auto
latest_entry_iter
=
eval_trace_
.
rbegin
();
for
(
auto
r_it
=
eval_trace_
.
rbegin
();
*
r_it
!=
*
it
;)
{
auto
it_temp
=
std
::
find
(
evaluators
.
begin
(),
evaluators
.
end
(),
r_it
->
first
);
if
(
it_temp
!=
evaluators
.
end
())
{
latest_entry
=
*
it_temp
;
latest_entry_iter
=
r_it
;
break
;
}
latest_entry_iter
=
++
r_it
;
}
if
(
latest_entry
!=
eval
)
{
MS_LOG
(
DEBUG
)
<<
"Continue Evaluator "
<<
eval
->
ToString
();
*
continue_flag
=
true
;
return
latest_entry
;
}
bool
has_undetermined
=
false
;
// Check whether sub loop has untraced undetermined evaluator.
std
::
set
<
std
::
pair
<
EvaluatorPtr
,
AbstractBasePtrList
>>
undetermined_evals
;
for
(
auto
r_it
=
eval_trace_
.
rbegin
();
r_it
!=
latest_entry_iter
;
r_it
++
)
{
undetermined_evals
.
insert
(
*
r_it
);
}
MS_LOG
(
DEBUG
)
<<
"undetermined_evals size(): "
<<
undetermined_evals
.
size
();
for
(
auto
u_eval
:
undetermined_evals
)
{
MS_LOG
(
DEBUG
)
<<
u_eval
.
first
->
ToString
()
<<
" check undetermined."
;
if
(
!
undetermined_evals
.
count
(
std
::
make_pair
(
multi_poss_
[
u_eval
.
first
],
args_spec_list
)))
{
MS_LOG
(
DEBUG
)
<<
u_eval
.
first
->
ToString
()
<<
" has undetermined."
;
has_undetermined
=
true
;
break
;
}
}
if
(
has_undetermined
==
false
)
{
MS_LOG
(
DEBUG
)
<<
eval
->
ToString
()
<<
" has no undetermined."
;
*
continue_flag
=
true
;
return
latest_entry
;
}
return
latest_entry
;
}
EvalResultPtr
AnalysisEngine
::
ProcessEvalResults
(
const
AbstractBasePtrList
&
out_specs
)
{
if
(
out_specs
.
size
()
==
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"There is an endless loop for evaluator."
;
}
if
(
out_specs
.
size
()
==
1
)
{
MS_EXCEPTION_IF_NULL
(
out_specs
[
0
]);
// If only one result derived, then broaden it to avoid wrong constant propagation.
return
std
::
make_shared
<
EvalResult
>
(
out_specs
[
0
]
->
Broaden
(),
std
::
make_shared
<
AttrValueMap
>
());
}
auto
joined_spec
=
AbstractJoin
(
out_specs
);
MS_EXCEPTION_IF_NULL
(
joined_spec
);
MS_LOG
(
DEBUG
)
<<
"Multiple evaluators joined: "
<<
joined_spec
->
ToString
();
return
std
::
make_shared
<
EvalResult
>
(
joined_spec
,
std
::
make_shared
<
AttrValueMap
>
());
}
EvalResultPtr
AnalysisEngine
::
ExecuteMultipleEvaluators
(
const
std
::
vector
<
EvaluatorPtr
>
&
evaluators
,
const
AnfNodeConfigPtr
&
out_conf
,
const
ConfigPtrList
&
args_conf_list
)
{
...
...
@@ -479,18 +558,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
return
conf
->
GetEvaluatedValue
()
->
abstract
();
});
for
(
auto
eval
:
evaluators
)
{
auto
fg_eval
=
eval
->
cast
<
FuncGraphEvaluatorPtr
>
();
if
(
fg_eval
)
{
auto
fg
=
fg_eval
->
func_graph
();
MS_EXCEPTION_IF_NULL
(
fg
);
auto
undetermined_fgs
=
fg
->
recursive_graphs
();
if
(
undetermined_fgs
)
{
auto
fg_parent
=
fg
->
parent
();
MS_EXCEPTION_IF_NULL
(
fg_parent
);
fg_parent
->
set_flags
(
kFuncGraphFlagUndetermined
,
true
);
MS_LOG
(
DEBUG
)
<<
"Set graph undetermined: "
<<
fg_parent
->
ToString
();
}
}
SetUndeterminedFlag
(
eval
);
auto
current_inf
=
std
::
make_pair
(
eval
,
args_spec_list
);
MS_LOG
(
DEBUG
)
<<
"Check Evaluator "
<<
eval
->
ToString
();
...
...
@@ -510,40 +578,9 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
multi_poss_
.
clear
();
}
}
else
if
(
it
!=
eval_trace_
.
rbegin
())
{
// Find latest entry function to handle nested recursion.
EvaluatorPtr
latest_entry
=
eval
;
auto
latest_entry_iter
=
eval_trace_
.
rbegin
();
for
(
auto
r_it
=
eval_trace_
.
rbegin
();
*
r_it
!=
*
it
;)
{
auto
it_temp
=
std
::
find
(
evaluators
.
begin
(),
evaluators
.
end
(),
r_it
->
first
);
if
(
it_temp
!=
evaluators
.
end
())
{
latest_entry
=
*
it_temp
;
latest_entry_iter
=
r_it
;
break
;
}
latest_entry_iter
=
++
r_it
;
}
if
(
latest_entry
!=
eval
)
{
MS_LOG
(
DEBUG
)
<<
"Continue Evaluator "
<<
eval
->
ToString
();
continue
;
}
bool
has_undetermined
=
false
;
// Check whether sub loop has untraced undetermined evaluator.
std
::
set
<
std
::
pair
<
EvaluatorPtr
,
AbstractBasePtrList
>>
undetermined_evals
;
for
(
auto
r_it
=
eval_trace_
.
rbegin
();
r_it
!=
latest_entry_iter
;
r_it
++
)
{
undetermined_evals
.
insert
(
*
r_it
);
}
MS_LOG
(
DEBUG
)
<<
"undetermined_evals size(): "
<<
undetermined_evals
.
size
();
for
(
auto
u_eval
:
undetermined_evals
)
{
MS_LOG
(
DEBUG
)
<<
u_eval
.
first
->
ToString
()
<<
" check undetermined."
;
if
(
!
undetermined_evals
.
count
(
std
::
make_pair
(
multi_poss_
[
u_eval
.
first
],
args_spec_list
)))
{
MS_LOG
(
DEBUG
)
<<
u_eval
.
first
->
ToString
()
<<
" has undetermined."
;
has_undetermined
=
true
;
break
;
}
}
if
(
has_undetermined
==
false
)
{
MS_LOG
(
DEBUG
)
<<
eval
->
ToString
()
<<
" has no undetermined."
;
bool
continue_flag
=
false
;
auto
latest_entry
=
HandleNestedRecursion
(
evaluators
,
eval
,
args_spec_list
,
it
,
&
continue_flag
);
if
(
continue_flag
)
{
continue
;
}
...
...
@@ -558,19 +595,8 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
}
}
}
if
(
out_specs
.
size
()
==
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"There is an endless loop for evaluator."
;
}
if
(
out_specs
.
size
()
==
1
)
{
MS_EXCEPTION_IF_NULL
(
out_specs
[
0
]);
// If only one result derived, then broaden it to avoid wrong constant propagation.
return
std
::
make_shared
<
EvalResult
>
(
out_specs
[
0
]
->
Broaden
(),
std
::
make_shared
<
AttrValueMap
>
());
}
auto
joined_spec
=
AbstractJoin
(
out_specs
);
MS_EXCEPTION_IF_NULL
(
joined_spec
);
MS_LOG
(
DEBUG
)
<<
"Multiple evaluators joined: "
<<
joined_spec
->
ToString
();
return
std
::
make_shared
<
EvalResult
>
(
joined_spec
,
std
::
make_shared
<
AttrValueMap
>
());
return
ProcessEvalResults
(
out_specs
);
}
EvalResultPtr
AnfNodeConfig
::
GetEvaluatedValue
()
{
...
...
mindspore/ccsrc/pipeline/static_analysis/static_analysis.h
浏览文件 @
20afadb4
...
...
@@ -172,6 +172,8 @@ struct AnalysisResult {
AnalysisContextPtr
context
;
};
using
EvalTraceRevIter
=
std
::
list
<
std
::
pair
<
EvaluatorPtr
,
AbstractBasePtrList
>>::
reverse_iterator
;
class
AnalysisEngine
:
public
std
::
enable_shared_from_this
<
AnalysisEngine
>
{
public:
AnalysisEngine
(
const
PrimEvaluatorMap
&
prim_evaluator_map
,
const
FuncGraphManagerPtr
&
func_graph_manager
)
...
...
@@ -222,6 +224,12 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
std
::
unordered_map
<
PrimitivePyPtr
,
EvaluatorPtr
>
prim_py_evaluators_
;
private:
void
SetUndeterminedFlag
(
const
EvaluatorPtr
&
evaluator
);
EvaluatorPtr
HandleNestedRecursion
(
const
std
::
vector
<
EvaluatorPtr
>
&
evaluators
,
const
EvaluatorPtr
&
eval
,
const
AbstractBasePtrList
&
args_spec_list
,
const
EvalTraceRevIter
&
it
,
bool
*
continue_flag
);
EvalResultPtr
ProcessEvalResults
(
const
AbstractBasePtrList
&
out_specs
);
const
PrimEvaluatorMap
&
prim_constructors_
;
FuncGraphManagerPtr
func_graph_manager_
;
std
::
unordered_map
<
AbstractFunctionPtr
,
EvaluatorPtr
>
constructors_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录