Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
5b9c145f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5b9c145f
编写于
5月 27, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 27, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1383 keep different attributes for cnode evaluation
Merge pull request !1383 from amongo/KeepPrimAttrInCNode
上级
2eedd321
1159f2b1
变更
24
展开全部
隐藏空白更改
内联
并排
Showing
24 changed file
with
800 addition
and
324 deletion
+800
-324
mindspore/ccsrc/debug/trace.cc
mindspore/ccsrc/debug/trace.cc
+3
-3
mindspore/ccsrc/ir/primitive_base.h
mindspore/ccsrc/ir/primitive_base.h
+18
-2
mindspore/ccsrc/operator/prim_structures.cc
mindspore/ccsrc/operator/prim_structures.cc
+5
-4
mindspore/ccsrc/pipeline/static_analysis/evaluator.cc
mindspore/ccsrc/pipeline/static_analysis/evaluator.cc
+42
-40
mindspore/ccsrc/pipeline/static_analysis/evaluator.h
mindspore/ccsrc/pipeline/static_analysis/evaluator.h
+32
-25
mindspore/ccsrc/pipeline/static_analysis/prim.cc
mindspore/ccsrc/pipeline/static_analysis/prim.cc
+69
-58
mindspore/ccsrc/pipeline/static_analysis/prim.h
mindspore/ccsrc/pipeline/static_analysis/prim.h
+9
-9
mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc
...pore/ccsrc/pipeline/static_analysis/program_specialize.cc
+60
-15
mindspore/ccsrc/pipeline/static_analysis/program_specialize.h
...spore/ccsrc/pipeline/static_analysis/program_specialize.h
+2
-1
mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc
mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc
+45
-46
mindspore/ccsrc/pipeline/static_analysis/static_analysis.h
mindspore/ccsrc/pipeline/static_analysis/static_analysis.h
+39
-17
mindspore/ccsrc/pynative/pynative_execute.cc
mindspore/ccsrc/pynative/pynative_execute.cc
+1
-1
mindspore/ccsrc/utils/graph_utils.cc
mindspore/ccsrc/utils/graph_utils.cc
+27
-0
mindspore/ccsrc/utils/graph_utils.h
mindspore/ccsrc/utils/graph_utils.h
+1
-0
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+0
-9
mindspore/ops/operations/math_ops.py
mindspore/ops/operations/math_ops.py
+0
-3
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+2
-3
tests/st/ops/ascend/test_ops_infer.py
tests/st/ops/ascend/test_ops_infer.py
+65
-0
tests/ut/cpp/operator/composite_test.cc
tests/ut/cpp/operator/composite_test.cc
+13
-13
tests/ut/cpp/pipeline/static_analysis/evaluator_test.cc
tests/ut/cpp/pipeline/static_analysis/evaluator_test.cc
+10
-10
tests/ut/cpp/pipeline/static_analysis/prim_test.cc
tests/ut/cpp/pipeline/static_analysis/prim_test.cc
+51
-51
tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc
...s/ut/cpp/pipeline/static_analysis/static_analysis_test.cc
+13
-13
tests/ut/python/ops/test_nn_ops.py
tests/ut/python/ops/test_nn_ops.py
+17
-1
tests/ut/python/ops/test_ops_attr_infer.py
tests/ut/python/ops/test_ops_attr_infer.py
+276
-0
未找到文件。
mindspore/ccsrc/debug/trace.cc
浏览文件 @
5b9c145f
...
...
@@ -230,11 +230,11 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
auto
ctx
=
node_cfg_
->
context
();
auto
engine
=
node_cfg_
->
engine
();
auto
cfg
=
engine
->
MakeConfig
(
node
,
ctx
);
auto
abs
=
engine
->
cache
().
GetValue
(
cfg
);
if
(
abs
==
nullptr
)
{
auto
eval_result
=
engine
->
cache
().
GetValue
(
cfg
);
if
(
eval_result
==
nullptr
||
eval_result
->
abstract
()
==
nullptr
)
{
return
"Undefined"
;
}
auto
abs
=
eval_result
->
abstract
();
auto
dtype
=
abs
->
BuildType
();
auto
shape
=
abs
->
BuildShape
();
std
::
ostringstream
oss
;
...
...
mindspore/ccsrc/ir/primitive_base.h
浏览文件 @
5b9c145f
...
...
@@ -42,7 +42,11 @@ enum PrimType {
class
Primitive
:
public
Named
{
public:
explicit
Primitive
(
const
std
::
string
&
name
,
const
bool
is_base
=
true
,
const
PrimType
prim_type
=
kPrimTypeBuiltIn
)
:
Named
(
name
),
is_base_
(
is_base
),
has_signature_
(
false
),
prim_type_
(
prim_type
)
{}
:
Named
(
name
),
is_base_
(
is_base
),
has_signature_
(
false
),
prim_type_
(
prim_type
),
record_evaluate_add_attr_
(
false
)
{}
Primitive
(
const
Primitive
&
prim
)
:
Named
(
prim
),
...
...
@@ -50,14 +54,23 @@ class Primitive : public Named {
instance_name_
(
prim
.
instance_name_
),
is_base_
(
prim
.
is_base_
),
has_signature_
(
prim
.
has_signature_
),
prim_type_
(
prim
.
prim_type_
)
{}
prim_type_
(
prim
.
prim_type_
),
record_evaluate_add_attr_
(
false
)
{}
MS_DECLARE_PARENT
(
Primitive
,
Named
);
abstract
::
AbstractBasePtr
ToPrimAbstract
(
const
AnfNodePtr
&
anf_node
);
std
::
string
ToString
()
const
override
{
return
name
();
}
void
BeginRecordAddAttr
()
{
evaluate_added_attrs_
.
clear
();
record_evaluate_add_attr_
=
true
;
}
void
EndRecordAddAttr
()
{
record_evaluate_add_attr_
=
false
;
}
Primitive
&
AddAttr
(
const
std
::
string
&
name
,
const
ValuePtr
&
attr
)
{
attrs_
[
name
]
=
attr
;
if
(
record_evaluate_add_attr_
)
{
evaluate_added_attrs_
[
name
]
=
attr
;
}
return
*
this
;
}
...
...
@@ -80,6 +93,7 @@ class Primitive : public Named {
py
::
function
hook
()
const
{
return
hook_
;
}
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
attrs
()
const
{
return
attrs_
;
}
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
evaluate_added_attrs
()
{
return
evaluate_added_attrs_
;
}
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
bool
HasAttr
()
const
{
return
!
attrs_
.
empty
();
}
...
...
@@ -106,6 +120,7 @@ class Primitive : public Named {
protected:
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
attrs_
;
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
evaluate_added_attrs_
;
private:
std
::
string
instance_name_
;
...
...
@@ -113,6 +128,7 @@ class Primitive : public Named {
bool
is_base_
;
bool
has_signature_
;
PrimType
prim_type_
;
bool
record_evaluate_add_attr_
;
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
PrimitivePtr
&
p
)
{
...
...
mindspore/ccsrc/operator/prim_structures.cc
浏览文件 @
5b9c145f
...
...
@@ -377,10 +377,10 @@ AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const Primitiv
}
subargs
.
push_back
(
AbstractJoin
(
l_ptr
->
elements
()));
}
AbstractBase
Ptr
engin_exc
=
engine
->
Execute
(
fn
,
subargs
);
EvalResult
Ptr
engin_exc
=
engine
->
Execute
(
fn
,
subargs
);
AbstractBasePtrList
result
;
for
(
std
::
size_t
i
=
1
;
i
<
args_spec_list
.
size
();
i
++
)
{
result
.
push_back
(
engin_exc
);
result
.
push_back
(
engin_exc
->
abstract
()
);
}
return
std
::
make_shared
<
AbstractList
>
(
result
);
}
...
...
@@ -398,8 +398,9 @@ AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const Primi
AbstractBasePtr
list_type
=
AbstractJoin
(
lst
->
elements
());
auto
result1
=
engine
->
Execute
(
fn
,
lst
->
elements
());
auto
result2
=
engine
->
Execute
(
fn
,
{
dflt
,
list_type
});
MS_EXCEPTION_IF_NULL
(
result1
);
return
result1
->
Join
(
result2
);
MS_EXCEPTION_IF_NULL
(
result1
->
abstract
());
MS_EXCEPTION_IF_NULL
(
result2
->
abstract
());
return
result1
->
abstract
()
->
Join
(
result2
->
abstract
());
}
AbstractBasePtr
InferImplTupleReversed
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
...
...
mindspore/ccsrc/pipeline/static_analysis/evaluator.cc
浏览文件 @
5b9c145f
...
...
@@ -89,7 +89,7 @@ static std::vector<AnfNodePtr> FastShadowSort(const AnfNodePtr &ret_node) {
return
sorted_nodes
;
}
AbstractBase
Ptr
BaseFuncGraphEvaluator
::
Eval
(
AnalysisEnginePtr
engine
,
const
AbstractBasePtrList
&
args_spec_list
)
{
EvalResult
Ptr
BaseFuncGraphEvaluator
::
Eval
(
AnalysisEnginePtr
engine
,
const
AbstractBasePtrList
&
args_spec_list
)
{
FuncGraphPtr
fg
=
GetFuncGraph
(
engine
,
args_spec_list
);
MS_EXCEPTION_IF_NULL
(
fg
);
std
::
size_t
nargs
=
fg
->
parameters
().
size
();
...
...
@@ -106,7 +106,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abs
const
auto
&
arg
=
args_spec_list
[
i
];
const
auto
&
node
=
parameters
[
i
];
AnfNodeConfigPtr
conf
=
engine
->
MakeConfig
(
node
,
graph_context_
);
engine
->
cache
().
set_value
(
conf
,
arg
);
engine
->
cache
().
set_value
(
conf
,
std
::
make_shared
<
EvalResult
>
(
arg
,
nullptr
)
);
}
const
AnfNodePtr
&
func_node
=
fg
->
get_return
();
...
...
@@ -118,14 +118,14 @@ AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abs
const
auto
&
node
=
*
it
;
AnfNodeConfigPtr
node_conf
=
engine
->
MakeConfig
(
node
,
graph_context_
);
MS_LOG
(
DEBUG
)
<<
"Analysis node begin, func graph: "
<<
fg
->
ToString
()
<<
", node_conf: "
<<
node_conf
->
ToString
();
ret_base
=
engine
->
GetEvaluatedValue
(
node_conf
);
ret_base
=
engine
->
GetEvaluatedValue
(
node_conf
)
->
abstract
()
;
MS_LOG
(
DEBUG
)
<<
"Analysis node end, func graph: "
<<
fg
->
ToString
()
<<
", node_conf: "
<<
node_conf
->
ToString
()
<<
", abstract: "
<<
ret_base
->
ToString
();
}
MS_EXCEPTION_IF_NULL
(
ret_base
);
MS_LOG
(
DEBUG
)
<<
"BaseFuncGraph "
<<
fg
->
ToString
()
<<
"
E
val end, evaluated abstract: "
<<
ret_base
->
ToString
();
return
ret_base
;
MS_LOG
(
DEBUG
)
<<
"BaseFuncGraph "
<<
fg
->
ToString
()
<<
"
e
val end, evaluated abstract: "
<<
ret_base
->
ToString
();
return
std
::
make_shared
<
EvalResult
>
(
ret_base
,
nullptr
)
;
}
AbstractBasePtrList
FuncGraphEvaluator
::
NormalizeArgs
(
const
AbstractBasePtrList
&
args_spec_list
)
const
{
...
...
@@ -236,15 +236,14 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons
return
cloned_func_graph
;
}
AbstractBasePtr
Evaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
{
EvalResultPtr
Evaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
{
const
std
::
string
&
evaluator_name
=
ToString
();
AbstractBasePtrList
args_spec_list
;
(
void
)
std
::
transform
(
args_conf_list
.
begin
(),
args_conf_list
.
end
(),
std
::
back_inserter
(
args_spec_list
),
[](
const
ConfigPtr
&
conf
)
->
AbstractBasePtr
{
MS_EXCEPTION_IF_NULL
(
conf
);
return
conf
->
GetEvaluatedValue
();
return
conf
->
GetEvaluatedValue
()
->
abstract
()
;
});
args_spec_list
=
NormalizeArgs
(
args_spec_list
);
args_spec_list
=
BroadenUndeterminedArgs
(
args_spec_list
);
...
...
@@ -254,79 +253,79 @@ AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &ar
auto
iter
=
cache_
->
find
(
args_spec_list
);
if
(
iter
==
cache_
->
end
())
{
MS_LOG
(
DEBUG
)
<<
evaluator_name
<<
" cache miss, call Eval()."
;
AbstractBase
Ptr
ret
=
Eval
(
engine
,
args_spec_list
);
if
(
ret
==
nullptr
)
{
EvalResult
Ptr
ret
=
Eval
(
engine
,
args_spec_list
);
if
(
ret
->
abstract
()
==
nullptr
)
{
EvalFailLogging
(
shared_from_base
<
Evaluator
>
(),
args_spec_list
,
out_conf
);
MS_LOG
(
EXCEPTION
)
<<
"Evaluator "
<<
evaluator_name
<<
" result is nullptr."
;
}
MS_EXCEPTION_IF_NULL
(
ret
);
MS_LOG
(
DEBUG
)
<<
evaluator_name
<<
" set cache. return: "
<<
ret
->
ToString
()
<<
"."
;
MS_LOG
(
DEBUG
)
<<
evaluator_name
<<
" set cache. return: "
<<
ret
->
abstract
()
->
ToString
()
<<
"."
;
(
*
cache_
)[
args_spec_list
]
=
ret
;
trace
::
TraceGraphEvalLeave
(
shared_from_base
<
Evaluator
>
());
return
ret
;
}
else
{
MS_EXCEPTION_IF_NULL
(
iter
->
second
);
MS_LOG
(
DEBUG
)
<<
evaluator_name
<<
" cache hit. return: "
<<
iter
->
second
->
ToString
()
<<
"."
;
MS_EXCEPTION_IF_NULL
(
iter
->
second
->
abstract
());
MS_LOG
(
DEBUG
)
<<
evaluator_name
<<
" cache hit. return: "
<<
iter
->
second
->
abstract
()
->
ToString
()
<<
"."
;
trace
::
TraceGraphEvalLeave
(
shared_from_base
<
Evaluator
>
());
return
iter
->
second
;
}
}
AbstractBase
Ptr
TrivialPrimEvaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
)
{
EvalResult
Ptr
TrivialPrimEvaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
)
{
AbstractBasePtrList
args_spec_list
;
(
void
)
std
::
transform
(
args_conf_list
.
begin
(),
args_conf_list
.
end
(),
std
::
back_inserter
(
args_spec_list
),
[](
const
ConfigPtr
&
conf
)
->
AbstractBasePtr
{
MS_EXCEPTION_IF_NULL
(
conf
);
return
conf
->
GetEvaluatedValue
();
return
conf
->
GetEvaluatedValue
()
->
abstract
()
;
});
AbstractBase
Ptr
ret
=
EvalPrim
(
engine
,
args_spec_list
);
EvalResult
Ptr
ret
=
EvalPrim
(
engine
,
args_spec_list
);
return
ret
;
}
AbstractBase
Ptr
TransitionPrimEvaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
{
EvalResult
Ptr
TransitionPrimEvaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
{
AbstractBasePtrList
args_spec_list
;
(
void
)
std
::
transform
(
args_conf_list
.
begin
(),
args_conf_list
.
end
(),
std
::
back_inserter
(
args_spec_list
),
[](
const
ConfigPtr
&
conf
)
->
AbstractBasePtr
{
MS_EXCEPTION_IF_NULL
(
conf
);
return
conf
->
GetEvaluatedValue
();
return
conf
->
GetEvaluatedValue
()
->
abstract
()
;
});
if
(
args_conf_list
.
size
()
==
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Size should greater than 0"
;
}
AbstractBase
Ptr
ret
=
EvalPrim
(
engine
,
args_spec_list
,
args_conf_list
[
0
],
out_conf
);
EvalResult
Ptr
ret
=
EvalPrim
(
engine
,
args_spec_list
,
args_conf_list
[
0
],
out_conf
);
// No need to cache.
return
ret
;
}
AbstractBase
Ptr
SymbolicPrimEvaluator
::
Run
(
AnalysisEnginePtr
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
)
{
AbstractBase
Ptr
ret
=
EvalPrim
(
args_conf_list
);
EvalResult
Ptr
SymbolicPrimEvaluator
::
Run
(
AnalysisEnginePtr
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
)
{
EvalResult
Ptr
ret
=
EvalPrim
(
args_conf_list
);
return
ret
;
}
AbstractBase
Ptr
TrackedEvaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
{
EvalResult
Ptr
TrackedEvaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
{
AbstractBasePtrList
args_spec_list
;
(
void
)
std
::
transform
(
args_conf_list
.
begin
(),
args_conf_list
.
end
(),
std
::
back_inserter
(
args_spec_list
),
[](
const
ConfigPtr
&
conf
)
->
AbstractBasePtr
{
MS_EXCEPTION_IF_NULL
(
conf
);
return
conf
->
GetEvaluatedValue
();
return
conf
->
GetEvaluatedValue
()
->
abstract
()
;
});
AbstractBase
Ptr
ret
=
sub_evaluator_
->
Run
(
engine
,
args_conf_list
,
out_conf
);
EvalResult
Ptr
ret
=
sub_evaluator_
->
Run
(
engine
,
args_conf_list
,
out_conf
);
// Don't lookup from cache, as different out_conf with same node but different context
// may add different entry to anfnode_config_map_, like getattr primitive.
(
*
cache_
)[
args_spec_list
]
=
ret
;
return
ret
;
}
AbstractBase
Ptr
PartialAppEvaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
{
EvalResult
Ptr
PartialAppEvaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
{
AbstractBasePtrList
args_spec_list
;
(
void
)
std
::
transform
(
args_conf_list
.
begin
(),
args_conf_list
.
end
(),
std
::
back_inserter
(
args_spec_list
),
[](
const
ConfigPtr
&
conf
)
->
AbstractBasePtr
{
MS_EXCEPTION_IF_NULL
(
conf
);
return
conf
->
GetEvaluatedValue
();
return
conf
->
GetEvaluatedValue
()
->
abstract
()
;
});
MS_EXCEPTION_IF_NULL
(
cache_
);
auto
iter
=
cache_
->
find
(
args_spec_list
);
...
...
@@ -341,17 +340,18 @@ AbstractBasePtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigP
(
void
)
std
::
transform
(
args_spec_list
.
begin
(),
args_spec_list
.
end
(),
std
::
back_inserter
(
partial_args_conf_list
),
[](
const
AbstractBasePtr
&
arg
)
->
ConfigPtr
{
return
std
::
make_shared
<
VirtualConfig
>
(
arg
);
});
AbstractBasePtr
ret
=
evaluator_
->
Run
(
engine
,
partial_args_conf_list
,
out_conf
);
EvalResultPtr
ret
=
evaluator_
->
Run
(
engine
,
partial_args_conf_list
,
out_conf
);
(
*
cache_
)[
args_spec_list
]
=
ret
;
return
ret
;
}
AbstractBase
Ptr
JEvaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
)
{
EvalResult
Ptr
JEvaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
)
{
AbstractBasePtrList
args_spec_list
;
(
void
)
std
::
transform
(
args_conf_list
.
begin
(),
args_conf_list
.
end
(),
std
::
back_inserter
(
args_spec_list
),
[](
const
ConfigPtr
&
conf
)
->
AbstractBasePtr
{
MS_EXCEPTION_IF_NULL
(
conf
);
return
conf
->
GetEvaluatedValue
();
return
conf
->
GetEvaluatedValue
()
->
abstract
()
;
});
MS_EXCEPTION_IF_NULL
(
cache_
);
auto
iter
=
cache_
->
find
(
args_spec_list
);
...
...
@@ -360,7 +360,7 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a
}
// Call the original evaluator, get the result: y = f(x)
AbstractBase
Ptr
result
=
evaluator_
->
Run
(
engine
,
args_conf_list
,
nullptr
);
EvalResult
Ptr
result
=
evaluator_
->
Run
(
engine
,
args_conf_list
,
nullptr
);
// Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input
// parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
AbstractBasePtrList
bparams
;
...
...
@@ -369,16 +369,18 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a
args_spec_list
.
begin
(),
args_spec_list
.
end
(),
std
::
back_inserter
(
bparams
),
[](
const
AbstractBasePtr
&
arg_spec
)
->
AbstractBasePtr
{
return
SensitivityTransform
(
arg_spec
);
});
AbstractBasePtr
bparams_final
=
std
::
make_shared
<
AbstractTuple
>
(
bparams
);
AbstractFunctionPtr
bprop
=
std
::
make_shared
<
VirtualAbstractClosure
>
(
SensitivityTransform
(
result
),
bparams_final
);
AbstractFunctionPtr
bprop
=
std
::
make_shared
<
VirtualAbstractClosure
>
(
SensitivityTransform
(
result
->
abstract
()),
bparams_final
);
// J(f)(J(x)) return a tuple (y, bprop_f)
AbstractBasePtrList
jargs
=
{
result
,
bprop
};
AbstractBasePtrList
jargs
=
{
result
->
abstract
()
,
bprop
};
AbstractBasePtr
jtuple
=
std
::
make_shared
<
AbstractTuple
>
(
jargs
);
(
*
cache_
)[
args_spec_list
]
=
jtuple
;
return
jtuple
;
auto
infer_reuslt
=
std
::
make_shared
<
EvalResult
>
(
jtuple
,
std
::
make_shared
<
AttrValueMap
>
());
(
*
cache_
)[
args_spec_list
]
=
infer_reuslt
;
return
infer_reuslt
;
}
AbstractBase
Ptr
VirtualEvaluator
::
Eval
(
AnalysisEnginePtr
,
const
AbstractBasePtrList
&
args_spec_list
)
{
EvalResult
Ptr
VirtualEvaluator
::
Eval
(
AnalysisEnginePtr
,
const
AbstractBasePtrList
&
args_spec_list
)
{
if
(
args_spec_list
.
size
()
!=
args_spec_list_
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Arguments mismatch, parameters no: "
<<
args_spec_list_
.
size
()
<<
", arguments no: "
<<
args_spec_list
.
size
();
...
...
@@ -388,7 +390,7 @@ AbstractBasePtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrL
MS_EXCEPTION_IF_NULL
(
args_spec_list
[
i
]);
(
void
)
args_spec_list
[
i
]
->
Join
(
args_spec_list_
[
i
]);
}
return
output_
;
return
std
::
make_shared
<
EvalResult
>
(
output_
,
std
::
make_shared
<
AttrValueMap
>
())
;
}
}
// namespace abstract
}
// namespace mindspore
mindspore/ccsrc/pipeline/static_analysis/evaluator.h
浏览文件 @
5b9c145f
...
...
@@ -29,21 +29,28 @@
namespace
mindspore
{
namespace
abstract
{
using
EvaluatorCacheMap
=
std
::
unordered_map
<
AbstractBasePtrList
,
AbstractBase
Ptr
,
AbstractBasePtrListHasher
,
AbstractBasePtrListEqual
>
;
std
::
unordered_map
<
AbstractBasePtrList
,
EvalResult
Ptr
,
AbstractBasePtrListHasher
,
AbstractBasePtrListEqual
>
;
using
EvaluatorCacheMapPtr
=
std
::
shared_ptr
<
EvaluatorCacheMap
>
;
using
EvaluatorAttrMap
=
std
::
unordered_map
<
AbstractBasePtrList
,
AttrValueMapPtr
,
AbstractBasePtrListHasher
,
AbstractBasePtrListEqual
>
;
using
EvaluatorAttrMapPtr
=
std
::
shared_ptr
<
EvaluatorAttrMap
>
;
class
Evaluator
:
public
Base
{
public:
explicit
Evaluator
(
const
std
::
string
&
id
)
:
cache_
(
std
::
make_shared
<
EvaluatorCacheMap
>
()),
identifier_
(
id
)
{}
explicit
Evaluator
(
const
std
::
string
&
id
)
:
cache_
(
std
::
make_shared
<
EvaluatorCacheMap
>
()),
attr_cache_
(
std
::
make_shared
<
EvaluatorAttrMap
>
()),
identifier_
(
id
)
{}
~
Evaluator
()
override
=
default
;
MS_DECLARE_PARENT
(
Evaluator
,
Base
);
// difference between Run() and Eval():
// Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr.
// Run() will modify cache_ member, so it cannot marked as const;
virtual
AbstractBase
Ptr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
);
virtual
EvalResult
Ptr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
);
virtual
AbstractBase
Ptr
Eval
(
AnalysisEnginePtr
engine
,
const
AbstractBasePtrList
&
args_spec_list
)
=
0
;
virtual
EvalResult
Ptr
Eval
(
AnalysisEnginePtr
engine
,
const
AbstractBasePtrList
&
args_spec_list
)
=
0
;
virtual
AbstractBasePtrList
NormalizeArgs
(
const
AbstractBasePtrList
&
args_spec_list
)
const
{
return
args_spec_list
;
}
...
...
@@ -58,9 +65,10 @@ class Evaluator : public Base {
virtual
void
set_bound_node
(
const
AnfNodePtr
&
node
)
{
bound_node_
=
AnfNodeWeakPtr
(
node
);
}
EvaluatorCacheMapPtr
&
cache
()
{
return
cache_
;
}
EvaluatorAttrMapPtr
&
attr_cache
()
{
return
attr_cache_
;
}
EvaluatorCacheMapPtr
cache_
;
EvaluatorAttrMapPtr
attr_cache_
;
std
::
string
identifier_
;
AnfNodeWeakPtr
bound_node_
;
...
...
@@ -71,7 +79,7 @@ class PrimEvaluator : public Evaluator {
explicit
PrimEvaluator
(
const
std
::
string
&
id
)
:
Evaluator
(
id
)
{}
~
PrimEvaluator
()
override
=
default
;
MS_DECLARE_PARENT
(
PrimEvaluator
,
Evaluator
);
AbstractBase
Ptr
Eval
(
AnalysisEnginePtr
,
const
AbstractBasePtrList
&
)
final
{
EvalResult
Ptr
Eval
(
AnalysisEnginePtr
,
const
AbstractBasePtrList
&
)
final
{
MS_LOG
(
EXCEPTION
)
<<
"Eval() should not be called, Run() method should be called"
;
}
};
...
...
@@ -81,8 +89,8 @@ class TrivialPrimEvaluator : public PrimEvaluator {
explicit
TrivialPrimEvaluator
(
const
std
::
string
&
id
)
:
PrimEvaluator
(
id
)
{}
~
TrivialPrimEvaluator
()
override
=
default
;
MS_DECLARE_PARENT
(
TrivialPrimEvaluator
,
PrimEvaluator
);
AbstractBase
Ptr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
final
;
virtual
AbstractBase
Ptr
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
)
=
0
;
EvalResult
Ptr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
final
;
virtual
EvalResult
Ptr
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
)
=
0
;
};
class
TransitionPrimEvaluator
:
public
PrimEvaluator
{
...
...
@@ -90,10 +98,10 @@ class TransitionPrimEvaluator : public PrimEvaluator {
explicit
TransitionPrimEvaluator
(
const
std
::
string
&
id
)
:
PrimEvaluator
(
id
)
{}
~
TransitionPrimEvaluator
()
override
=
default
;
MS_DECLARE_PARENT
(
TransitionPrimEvaluator
,
PrimEvaluator
);
AbstractBase
Ptr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
final
;
EvalResult
Ptr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
final
;
// Parameter in_conf0 : the first element in args_conf_list;
virtual
AbstractBase
Ptr
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
,
const
ConfigPtr
&
in_conf0
,
const
AnfNodeConfigPtr
&
out_conf
)
=
0
;
virtual
EvalResult
Ptr
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
,
const
ConfigPtr
&
in_conf0
,
const
AnfNodeConfigPtr
&
out_conf
)
=
0
;
};
class
SymbolicPrimEvaluator
:
public
PrimEvaluator
{
...
...
@@ -101,8 +109,8 @@ class SymbolicPrimEvaluator : public PrimEvaluator {
explicit
SymbolicPrimEvaluator
(
const
std
::
string
&
id
)
:
PrimEvaluator
(
id
)
{}
~
SymbolicPrimEvaluator
()
override
=
default
;
MS_DECLARE_PARENT
(
SymbolicPrimEvaluator
,
PrimEvaluator
);
AbstractBase
Ptr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
final
;
virtual
AbstractBase
Ptr
EvalPrim
(
const
ConfigPtrList
&
args_conf_list
)
=
0
;
EvalResult
Ptr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
final
;
virtual
EvalResult
Ptr
EvalPrim
(
const
ConfigPtrList
&
args_conf_list
)
=
0
;
};
// Evaluator will be stored in AnalysisEngine.constructors_
...
...
@@ -113,7 +121,7 @@ class DummyEvaluator : public Evaluator {
DummyEvaluator
()
:
Evaluator
(
"dummy"
)
{}
~
DummyEvaluator
()
override
=
default
;
MS_DECLARE_PARENT
(
DummyEvaluator
,
Evaluator
);
AbstractBase
Ptr
Eval
(
AnalysisEnginePtr
,
const
AbstractBasePtrList
&
)
override
{
return
nullptr
;
}
EvalResult
Ptr
Eval
(
AnalysisEnginePtr
,
const
AbstractBasePtrList
&
)
override
{
return
nullptr
;
}
};
// Wrap another evaluator to track a subset of uses.
...
...
@@ -139,11 +147,10 @@ class TrackedEvaluator : public Evaluator {
bound_node_
=
AnfNodeWeakPtr
(
node
);
}
AbstractBase
Ptr
Eval
(
AnalysisEnginePtr
,
const
AbstractBasePtrList
&
)
override
{
EvalResult
Ptr
Eval
(
AnalysisEnginePtr
,
const
AbstractBasePtrList
&
)
override
{
MS_LOG
(
EXCEPTION
)
<<
"Eval() should not be called, Run() method should be called"
;
}
AbstractBasePtr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
override
;
EvalResultPtr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
override
;
std
::
string
ToString
()
const
override
{
return
identifier_
+
"_"
+
sub_evaluator_
->
ToString
();
}
private:
...
...
@@ -158,7 +165,7 @@ class BaseFuncGraphEvaluator : public Evaluator {
~
BaseFuncGraphEvaluator
()
override
=
default
;
MS_DECLARE_PARENT
(
BaseFuncGraphEvaluator
,
Evaluator
);
AbstractBase
Ptr
Eval
(
AnalysisEnginePtr
engine
,
const
AbstractBasePtrList
&
args_spec_list
)
override
;
EvalResult
Ptr
Eval
(
AnalysisEnginePtr
engine
,
const
AbstractBasePtrList
&
args_spec_list
)
override
;
virtual
FuncGraphPtr
GetFuncGraph
(
AnalysisEnginePtr
engine
,
const
AbstractBasePtrList
&
args_spec_list
)
=
0
;
...
...
@@ -238,12 +245,12 @@ class PartialAppEvaluator : public Evaluator {
}
bound_node_
=
AnfNodeWeakPtr
(
node
);
}
AbstractBasePtr
Eval
(
AnalysisEnginePtr
,
const
AbstractBasePtrList
&
)
override
{
EvalResultPtr
Eval
(
AnalysisEnginePtr
,
const
AbstractBasePtrList
&
)
override
{
MS_LOG
(
EXCEPTION
)
<<
"Should not be called, Run() method should be called"
;
}
AbstractBasePtr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
override
;
EvalResultPtr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
override
;
std
::
string
ToString
()
const
override
{
return
identifier_
+
"_"
+
evaluator_
->
ToString
();
}
private:
...
...
@@ -258,7 +265,7 @@ class VirtualEvaluator : public Evaluator {
~
VirtualEvaluator
()
override
=
default
;
MS_DECLARE_PARENT
(
VirtualEvaluator
,
Evaluator
);
AbstractBase
Ptr
Eval
(
AnalysisEnginePtr
engine
,
const
AbstractBasePtrList
&
args_spec_list
)
override
;
EvalResult
Ptr
Eval
(
AnalysisEnginePtr
engine
,
const
AbstractBasePtrList
&
args_spec_list
)
override
;
std
::
string
ToString
()
const
override
{
return
identifier_
;
}
private:
...
...
@@ -285,11 +292,11 @@ class JEvaluator : public Evaluator {
}
bound_node_
=
AnfNodeWeakPtr
(
node
);
}
AbstractBasePtr
Eval
(
AnalysisEnginePtr
,
const
AbstractBasePtrList
&
)
override
{
EvalResultPtr
Eval
(
AnalysisEnginePtr
,
const
AbstractBasePtrList
&
)
override
{
MS_LOG
(
EXCEPTION
)
<<
"Should not be called, Run() method should be called"
;
}
AbstractBasePtr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
override
;
EvalResultPtr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
override
;
std
::
string
ToString
()
const
override
{
return
identifier_
+
"_"
+
evaluator_
->
ToString
();
}
private:
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.cc
浏览文件 @
5b9c145f
...
...
@@ -135,13 +135,17 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
using
mindspore
::
parse
::
PyObjectWrapper
;
AbstractBasePtr
StandardPrimEvaluator
::
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args
)
{
EvalResultPtr
StandardPrimEvaluator
::
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args
)
{
prim_
->
BeginRecordAddAttr
();
AbstractBasePtr
abs_base
=
eval_impl_
(
engine
,
prim_
,
args
);
return
abs_base
;
prim_
->
EndRecordAddAttr
();
auto
added_attrs
=
prim_
->
evaluate_added_attrs
();
auto
infer_result
=
std
::
make_shared
<
EvalResult
>
(
abs_base
,
std
::
make_shared
<
AttrValueMap
>
(
added_attrs
));
return
infer_result
;
}
AbstractBase
Ptr
DoSignatureEvaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
{
EvalResult
Ptr
DoSignatureEvaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
{
AbstractBasePtrList
args_spec_list
;
if
(
!
prim_
->
isa
<
prim
::
DoSignaturePrimitive
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Primitive should be DoSignature, but "
<<
prim_
->
ToString
();
...
...
@@ -161,7 +165,7 @@ AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const Config
AnfNodePtrList
args_inputs
{
out_node_inputs
.
begin
()
+
1
,
out_node_inputs
.
end
()};
(
void
)
std
::
transform
(
args_conf_list
.
begin
(),
args_conf_list
.
end
(),
std
::
back_inserter
(
args_spec_list
),
[](
const
ConfigPtr
&
ref
)
->
AbstractBasePtr
{
return
ref
->
GetEvaluatedValue
();
});
[](
const
ConfigPtr
&
ref
)
->
AbstractBasePtr
{
return
ref
->
GetEvaluatedValue
()
->
abstract
()
;
});
ScopePtr
scope
=
kDefaultScope
;
if
(
out_conf
!=
nullptr
)
{
...
...
@@ -212,8 +216,8 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s
return
graph_specialize_args
;
}
AbstractBase
Ptr
UnpackGraphEvaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
{
EvalResult
Ptr
UnpackGraphEvaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
{
if
(
out_conf
->
node
()
==
nullptr
||
!
out_conf
->
node
()
->
isa
<
CNode
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Node of out_conf should be CNode"
;
}
...
...
@@ -232,7 +236,7 @@ AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const Config
AnfNodePtrList
args_inputs
{
out_node_inputs
.
begin
()
+
1
,
out_node_inputs
.
end
()};
AbstractBasePtrList
args_spec_list
;
(
void
)
std
::
transform
(
args_conf_list
.
begin
(),
args_conf_list
.
end
(),
std
::
back_inserter
(
args_spec_list
),
[](
const
ConfigPtr
&
ref
)
->
AbstractBasePtr
{
return
ref
->
GetEvaluatedValue
();
});
[](
const
ConfigPtr
&
ref
)
->
AbstractBasePtr
{
return
ref
->
GetEvaluatedValue
()
->
abstract
()
;
});
// get the forward graph
MS_EXCEPTION_IF_NULL
(
args_spec_list
[
0
]);
AbstractFunctionPtr
fn
=
args_spec_list
[
0
]
->
cast
<
AbstractFunctionPtr
>
();
...
...
@@ -411,7 +415,7 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
}
}
// end anonymous namespace
AbstractBase
Ptr
PythonPrimEvaluator
::
EvalPrim
(
const
AnalysisEnginePtr
&
,
const
AbstractBasePtrList
&
args
)
{
EvalResult
Ptr
PythonPrimEvaluator
::
EvalPrim
(
const
AnalysisEnginePtr
&
,
const
AbstractBasePtrList
&
args
)
{
MS_LOG
(
DEBUG
)
<<
"Eval for:"
<<
prim_py_
->
ToString
();
const
auto
&
iter
=
cache_
->
find
(
args
);
...
...
@@ -425,17 +429,20 @@ AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const A
MS_LOG
(
EXCEPTION
)
<<
"["
<<
prim_py_
->
ToString
()
<<
"]: pyobj is empty"
;
}
auto
infer_fuc
=
pyobj
.
attr
(
"__infer__"
);
prim_py_
->
BeginRecordAddAttr
();
py
::
dict
output
=
infer_fuc
(
*
py_args
);
prim_py_
->
EndRecordAddAttr
();
auto
added_attrs
=
prim_py_
->
evaluate_added_attrs
();
MS_LOG
(
DEBUG
)
<<
"Output type is "
<<
(
std
::
string
)
py
::
str
(
output
);
auto
res_spec
=
PyInferRes2Abstract
(
prim_py_
,
output
);
MS_LOG
(
DEBUG
)
<<
"Python InferTensor result spec: "
<<
res_spec
->
ToString
()
<<
"."
;
(
*
cache_
)[
args
]
=
res_spec
;
return
res_spec
;
auto
infer_result
=
std
::
make_shared
<
EvalResult
>
(
res_spec
,
std
::
make_shared
<
AttrValueMap
>
(
added_attrs
));
(
*
cache_
)[
args
]
=
infer_result
;
return
infer_result
;
}
AbstractBase
Ptr
UniformPrimEvaluator
::
EvalPrim
(
const
AnalysisEnginePtr
&
,
const
AbstractBasePtrList
&
args
)
{
EvalResult
Ptr
UniformPrimEvaluator
::
EvalPrim
(
const
AnalysisEnginePtr
&
,
const
AbstractBasePtrList
&
args
)
{
// if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
if
(
nargs_
!=
args
.
size
())
{
MS_LOG
(
ERROR
)
<<
"UniformPrimEvaluator expect "
<<
nargs_
<<
" args, but got "
<<
args
.
size
()
<<
" inputs"
;
...
...
@@ -476,7 +483,7 @@ AbstractBasePtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const
}
AbstractScalarPtr
abs_base
=
std
::
make_shared
<
AbstractScalar
>
(
evaluated_value
,
ret_value_type
);
return
abs_base
;
return
std
::
make_shared
<
EvalResult
>
(
abs_base
,
std
::
make_shared
<
AttrValueMap
>
())
;
}
ValuePtr
UniformPrimEvaluator
::
RunImpl
(
const
ValuePtrList
&
args
)
const
{
...
...
@@ -553,8 +560,8 @@ inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr fun
manager
->
AddFuncGraph
(
func_graph
);
}
AbstractBase
Ptr
StaticGetterInferred
(
const
ValuePtr
&
value
,
const
ConfigPtr
&
data_conf
,
const
AnfNodeConfigPtr
&
old_conf
)
{
EvalResult
Ptr
StaticGetterInferred
(
const
ValuePtr
&
value
,
const
ConfigPtr
&
data_conf
,
const
AnfNodeConfigPtr
&
old_conf
)
{
MS_EXCEPTION_IF_NULL
(
old_conf
);
AbstractBasePtr
abs_ptr
=
ToAbstract
(
value
,
AnalysisContext
::
DummyContext
(),
old_conf
);
...
...
@@ -585,9 +592,9 @@ AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &dat
return
eng
->
ForwardConfig
(
old_conf
,
fn_conf
);
}
AbstractBase
Ptr
GetEvaluatedValueForNameSpaceString
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
,
const
AnfNodeConfigPtr
&
out_conf
)
{
EvalResult
Ptr
GetEvaluatedValueForNameSpaceString
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
,
const
AnfNodeConfigPtr
&
out_conf
)
{
// args_spec_list: same as StaticGetter
if
(
args_spec_list
.
size
()
<
2
)
{
MS_LOG
(
EXCEPTION
)
<<
"Size of args_spec_list is less than 2"
;
...
...
@@ -627,9 +634,9 @@ AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &eng
return
eng
->
ForwardConfig
(
out_conf
,
fn_conf
);
}
AbstractBase
Ptr
GetEvaluatedValueForClassAttrOrMethod
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
,
const
ValuePtr
&
item_v
,
const
ConfigPtr
&
data_conf
,
const
AnfNodeConfigPtr
&
out_conf
)
{
EvalResult
Ptr
GetEvaluatedValueForClassAttrOrMethod
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
,
const
ValuePtr
&
item_v
,
const
ConfigPtr
&
data_conf
,
const
AnfNodeConfigPtr
&
out_conf
)
{
if
(
args_spec_list
.
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"args_spec_list is empty"
;
}
...
...
@@ -646,7 +653,7 @@ AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &e
AbstractBasePtr
attr
=
cls
->
GetAttribute
(
item_name
);
if
(
attr
!=
nullptr
)
{
return
attr
;
return
std
::
make_shared
<
EvalResult
>
(
attr
,
nullptr
)
;
}
ValuePtr
method
=
cls
->
GetMethod
(
item_name
);
...
...
@@ -660,9 +667,9 @@ AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &e
return
StaticGetterInferred
(
converted_v
,
data_conf
,
out_conf
);
}
AbstractBase
Ptr
GetEvaluatedValueForBuiltinTypeMethod
(
const
AnalysisEnginePtr
&
engine
,
const
ValuePtr
&
item_v
,
const
TypePtr
&
data_type
,
const
ConfigPtr
&
data_conf
,
const
AnfNodeConfigPtr
&
out_conf
)
{
EvalResult
Ptr
GetEvaluatedValueForBuiltinTypeMethod
(
const
AnalysisEnginePtr
&
engine
,
const
ValuePtr
&
item_v
,
const
TypePtr
&
data_type
,
const
ConfigPtr
&
data_conf
,
const
AnfNodeConfigPtr
&
out_conf
)
{
MS_EXCEPTION_IF_NULL
(
item_v
);
MS_EXCEPTION_IF_NULL
(
data_type
);
// The method maybe a Primitive or Composite
...
...
@@ -689,8 +696,8 @@ AbstractBasePtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &e
return
StaticGetterInferred
(
converted_v
,
data_conf
,
out_conf
);
}
AbstractBase
Ptr
StaticGetter
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
,
const
ConfigPtr
&
data_conf
,
const
AnfNodeConfigPtr
&
out_conf
)
{
EvalResult
Ptr
StaticGetter
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
,
const
ConfigPtr
&
data_conf
,
const
AnfNodeConfigPtr
&
out_conf
)
{
// Inputs: namespace and its static function; or class and its member function
CheckArgsSize
(
"StaticGetter"
,
args_spec_list
,
2
);
...
...
@@ -725,7 +732,7 @@ class EmbedEvaluator : public SymbolicPrimEvaluator {
EmbedEvaluator
()
:
SymbolicPrimEvaluator
(
"EmbedEvaluator"
)
{}
~
EmbedEvaluator
()
override
=
default
;
MS_DECLARE_PARENT
(
EmbedEvaluator
,
SymbolicPrimEvaluator
);
AbstractBase
Ptr
EvalPrim
(
const
ConfigPtrList
&
args_conf_list
)
override
{
EvalResult
Ptr
EvalPrim
(
const
ConfigPtrList
&
args_conf_list
)
override
{
// arg: free variable to be embedded
if
(
args_conf_list
.
size
()
!=
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"EmbedEvaluator requires 1 parameter, but got "
<<
args_conf_list
.
size
();
...
...
@@ -733,11 +740,11 @@ class EmbedEvaluator : public SymbolicPrimEvaluator {
AnfNodeConfigPtr
node_conf
=
dyn_cast
<
AnfNodeConfig
>
(
args_conf_list
[
0
]);
MS_EXCEPTION_IF_NULL
(
node_conf
);
AbstractBasePtr
x
=
node_conf
->
GetEvaluatedValue
();
AbstractBasePtr
x
=
node_conf
->
GetEvaluatedValue
()
->
abstract
()
;
x
=
SensitivityTransform
(
x
);
SymbolicKeyInstancePtr
key
=
std
::
make_shared
<
SymbolicKeyInstance
>
(
node_conf
->
node
(),
x
);
AbstractScalarPtr
abs_scalar
=
std
::
make_shared
<
AbstractScalar
>
(
key
,
std
::
make_shared
<
SymbolicKeyType
>
());
return
abs_scalar
;
return
std
::
make_shared
<
EvalResult
>
(
abs_scalar
,
std
::
make_shared
<
AttrValueMap
>
())
;
}
};
...
...
@@ -762,7 +769,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
RefToEmbedEvaluator
()
:
SymbolicPrimEvaluator
(
"RefToEmbedEvaluator"
)
{}
~
RefToEmbedEvaluator
()
override
=
default
;
MS_DECLARE_PARENT
(
RefToEmbedEvaluator
,
SymbolicPrimEvaluator
);
AbstractBase
Ptr
EvalPrim
(
const
ConfigPtrList
&
args_conf_list
)
override
{
EvalResult
Ptr
EvalPrim
(
const
ConfigPtrList
&
args_conf_list
)
override
{
if
(
args_conf_list
.
size
()
!=
1
)
{
MS_LOG
(
ERROR
)
<<
"Requires 1 parameter, but has: "
<<
args_conf_list
.
size
();
return
nullptr
;
...
...
@@ -773,7 +780,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
MS_LOG
(
ERROR
)
<<
"Conf should be AnfNodeConfig"
;
return
nullptr
;
}
AbstractBasePtr
abs
=
node_conf
->
GetEvaluatedValue
();
AbstractBasePtr
abs
=
node_conf
->
GetEvaluatedValue
()
->
abstract
()
;
AbstractRefPtr
ref_abs
=
abs
->
cast
<
AbstractRefPtr
>
();
if
(
ref_abs
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"The first parameter of RefToEmbed should be Ref."
;
...
...
@@ -791,7 +798,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
}
auto
refkey
=
key_value
->
cast
<
RefKeyPtr
>
();
if
(
refkey
==
nullptr
)
{
return
std
::
make_shared
<
AbstractScalar
>
(
type
);
return
std
::
make_shared
<
EvalResult
>
(
std
::
make_shared
<
AbstractScalar
>
(
type
),
std
::
make_shared
<
AttrValueMap
>
()
);
}
std
::
string
name
=
refkey
->
tag
();
...
...
@@ -805,7 +812,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
x
=
SensitivityTransform
(
x
);
std
::
shared_ptr
<
SymbolicKeyInstance
>
key
=
std
::
make_shared
<
SymbolicKeyInstance
>
(
node
,
x
);
std
::
shared_ptr
<
AbstractScalar
>
abs_scalar
=
std
::
make_shared
<
AbstractScalar
>
(
key
,
type
);
return
abs_scalar
;
return
std
::
make_shared
<
EvalResult
>
(
abs_scalar
,
std
::
make_shared
<
AttrValueMap
>
())
;
}
};
...
...
@@ -814,13 +821,13 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
GetAttrEvaluator
()
:
TransitionPrimEvaluator
(
"GetAttrEvaluator"
)
{}
~
GetAttrEvaluator
()
override
=
default
;
MS_DECLARE_PARENT
(
GetAttrEvaluator
,
TransitionPrimEvaluator
);
AbstractBase
Ptr
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
,
const
ConfigPtr
&
in_conf0
,
const
AnfNodeConfigPtr
&
out_conf
)
override
{
EvalResult
Ptr
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
,
const
ConfigPtr
&
in_conf0
,
const
AnfNodeConfigPtr
&
out_conf
)
override
{
// Inputs: data, item
if
(
args_spec_list
.
size
()
!=
2
)
{
MS_LOG
(
EXCEPTION
)
<<
"Expected args_spec_list size = 2, but has size:"
<<
args_spec_list
.
size
();
}
AbstractBase
Ptr
ret
=
nullptr
;
EvalResult
Ptr
ret
=
nullptr
;
if
(
bound_node
()
!=
nullptr
)
{
TraceManager
::
DebugTrace
(
std
::
make_shared
<
TraceResolve
>
(
bound_node
()
->
debug_info
()));
ret
=
StaticGetter
(
engine
,
args_spec_list
,
in_conf0
,
out_conf
);
...
...
@@ -840,13 +847,13 @@ class ResolveEvaluator : public TransitionPrimEvaluator {
ResolveEvaluator
()
:
TransitionPrimEvaluator
(
"ResolveEvaluator"
)
{}
~
ResolveEvaluator
()
override
=
default
;
MS_DECLARE_PARENT
(
ResolveEvaluator
,
TransitionPrimEvaluator
);
AbstractBase
Ptr
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
,
const
ConfigPtr
&
in_conf0
,
const
AnfNodeConfigPtr
&
out_conf
)
override
{
EvalResult
Ptr
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
,
const
ConfigPtr
&
in_conf0
,
const
AnfNodeConfigPtr
&
out_conf
)
override
{
// Inputs: namespace, symbol
if
(
args_spec_list
.
size
()
!=
2
)
{
MS_LOG
(
EXCEPTION
)
<<
"Expected args_spec_list size = 2, but has size:"
<<
args_spec_list
.
size
();
}
AbstractBase
Ptr
ret
=
nullptr
;
EvalResult
Ptr
ret
=
nullptr
;
if
(
bound_node
()
!=
nullptr
)
{
TraceManager
::
DebugTrace
(
std
::
make_shared
<
TraceResolve
>
(
bound_node
()
->
debug_info
()));
ret
=
StaticGetter
(
engine
,
args_spec_list
,
in_conf0
,
out_conf
);
...
...
@@ -863,8 +870,8 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
CreateInstanceEvaluator
()
:
TransitionPrimEvaluator
(
"CreateInstanceEvaluator"
)
{}
~
CreateInstanceEvaluator
()
override
=
default
;
MS_DECLARE_PARENT
(
CreateInstanceEvaluator
,
TransitionPrimEvaluator
);
AbstractBasePtr
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
,
const
ConfigPtr
&
,
const
AnfNodeConfigPtr
&
out_conf
)
override
{
EvalResultPtr
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
,
const
ConfigPtr
&
,
const
AnfNodeConfigPtr
&
out_conf
)
override
{
if
(
args_spec_list
.
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"'args_spec_list' should not be empty"
;
}
...
...
@@ -915,8 +922,9 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
}
AbstractBasePtr
ret
=
ToAbstract
(
converted_ret
,
AnalysisContext
::
DummyContext
(),
out_conf
);
(
*
cache_
)[
args_spec_list
]
=
ret
;
return
ret
;
auto
infer_result
=
std
::
make_shared
<
EvalResult
>
(
ret
,
nullptr
);
(
*
cache_
)[
args_spec_list
]
=
infer_result
;
return
infer_result
;
}
pybind11
::
tuple
GetParameters
(
const
AbstractBasePtrList
&
args_spec_list
)
const
{
...
...
@@ -942,23 +950,24 @@ class PartialEvaluator : public Evaluator {
public:
PartialEvaluator
()
:
Evaluator
(
"PartialEvaluator"
)
{}
~
PartialEvaluator
()
override
=
default
;
AbstractBase
Ptr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
=
nullptr
)
override
{
EvalResult
Ptr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
=
nullptr
)
override
{
if
(
args_conf_list
.
size
()
==
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Args size should be greater than 0"
;
}
MS_EXCEPTION_IF_NULL
(
out_conf
);
MS_EXCEPTION_IF_NULL
(
out_conf
->
node
());
auto
arg0_value
=
args_conf_list
[
0
]
->
GetEvaluatedValue
();
auto
arg0_value
=
args_conf_list
[
0
]
->
GetEvaluatedValue
()
->
abstract
();
AbstractBasePtrList
args_spec_list
{
arg0_value
};
// Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
if
(
arg0_value
->
isa
<
AbstractError
>
())
{
auto
ret
=
std
::
make_shared
<
AbstractError
>
(
arg0_value
->
GetValueTrack
()
->
cast
<
StringImmPtr
>
(),
out_conf
->
node
());
MS_LOG
(
DEBUG
)
<<
"AbstractError for node: "
<<
out_conf
->
node
()
->
DebugString
()
<<
" as func is: "
<<
arg0_value
->
ToString
();
(
*
cache_
)[
args_spec_list
]
=
ret
;
return
ret
;
auto
eval_result
=
std
::
make_shared
<
EvalResult
>
(
ret
,
std
::
make_shared
<
AttrValueMap
>
());
(
*
cache_
)[
args_spec_list
]
=
eval_result
;
return
eval_result
;
}
auto
func
=
CheckArg
<
AbstractFunction
>
(
"partial"
,
args_spec_list
,
0
);
// Sometimes, node[0] in out_conf becomes phi0;
...
...
@@ -970,8 +979,9 @@ class PartialEvaluator : public Evaluator {
}
}
(
void
)
std
::
transform
(
args_conf_list
.
begin
()
+
1
,
args_conf_list
.
end
(),
std
::
back_inserter
(
args_spec_list
),
[](
const
ConfigPtr
&
config
)
->
AbstractBasePtr
{
return
config
->
GetEvaluatedValue
();
});
(
void
)
std
::
transform
(
args_conf_list
.
begin
()
+
1
,
args_conf_list
.
end
(),
std
::
back_inserter
(
args_spec_list
),
[](
const
ConfigPtr
&
config
)
->
AbstractBasePtr
{
return
config
->
GetEvaluatedValue
()
->
abstract
();
});
AbstractBasePtrList
args
(
args_spec_list
.
begin
()
+
1
,
args_spec_list
.
end
());
auto
cnode
=
out_conf
->
node
()
->
cast
<
CNodePtr
>
();
...
...
@@ -989,16 +999,17 @@ class PartialEvaluator : public Evaluator {
func
->
Visit
(
build_partial
);
auto
ret
=
AbstractFunction
::
MakeAbstractFunction
(
partial_funcs_list
);
(
*
cache_
)[
args_spec_list
]
=
ret
;
return
ret
;
auto
infer_result
=
std
::
make_shared
<
EvalResult
>
(
ret
,
std
::
make_shared
<
AttrValueMap
>
());
(
*
cache_
)[
args_spec_list
]
=
infer_result
;
return
infer_result
;
}
AbstractBase
Ptr
Eval
(
AnalysisEnginePtr
,
const
AbstractBasePtrList
&
)
override
{
EvalResult
Ptr
Eval
(
AnalysisEnginePtr
,
const
AbstractBasePtrList
&
)
override
{
MS_LOG
(
EXCEPTION
)
<<
"Eval() should not be called, Run() method should be called"
;
}
AbstractBase
Ptr
HandleDoSignature
(
const
AnalysisEnginePtr
&
engine
,
const
ValuePtr
&
signature_value
,
const
AnfNodeConfigPtr
&
out_conf
=
nullptr
)
const
{
EvalResult
Ptr
HandleDoSignature
(
const
AnalysisEnginePtr
&
engine
,
const
ValuePtr
&
signature_value
,
const
AnfNodeConfigPtr
&
out_conf
=
nullptr
)
const
{
MS_EXCEPTION_IF_NULL
(
out_conf
);
MS_EXCEPTION_IF_NULL
(
out_conf
->
node
());
auto
cnode
=
out_conf
->
node
()
->
cast
<
CNodePtr
>
();
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.h
浏览文件 @
5b9c145f
...
...
@@ -45,7 +45,7 @@ class StandardPrimEvaluator : public TrivialPrimEvaluator {
:
TrivialPrimEvaluator
(
"StandardPrimEvaluator"
),
prim_
(
primitive
),
eval_impl_
(
eval_impl
)
{}
~
StandardPrimEvaluator
()
override
=
default
;
MS_DECLARE_PARENT
(
StandardPrimEvaluator
,
TrivialPrimEvaluator
);
AbstractBase
Ptr
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args
)
override
;
EvalResult
Ptr
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args
)
override
;
PrimitivePtr
prim
()
{
return
prim_
;
}
std
::
string
ToString
()
const
override
{
return
identifier_
+
prim_
->
name
();
}
...
...
@@ -63,7 +63,7 @@ class PythonPrimEvaluator : public TrivialPrimEvaluator {
:
TrivialPrimEvaluator
(
"PythonPrimEvaluator"
),
prim_py_
(
primitive
)
{}
~
PythonPrimEvaluator
()
override
=
default
;
MS_DECLARE_PARENT
(
PythonPrimEvaluator
,
TrivialPrimEvaluator
);
AbstractBase
Ptr
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args
)
override
;
EvalResult
Ptr
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args
)
override
;
PrimitivePtr
prim
()
{
return
dyn_cast
<
Primitive
>
(
prim_py_
);
}
std
::
string
ToString
()
const
override
{
return
identifier_
+
prim_py_
->
name
();
}
...
...
@@ -76,10 +76,10 @@ class DoSignatureEvaluator : public Evaluator {
public:
explicit
DoSignatureEvaluator
(
const
PrimitivePtr
primitive
)
:
Evaluator
(
"DoSignatureEvaluator"
),
prim_
(
primitive
)
{}
~
DoSignatureEvaluator
()
override
=
default
;
AbstractBase
Ptr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
argrefs
,
AnfNodeConfigPtr
out_config
=
nullptr
)
override
;
EvalResult
Ptr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
argrefs
,
AnfNodeConfigPtr
out_config
=
nullptr
)
override
;
AbstractBase
Ptr
Eval
(
AnalysisEnginePtr
,
const
AbstractBasePtrList
&
)
override
{
EvalResult
Ptr
Eval
(
AnalysisEnginePtr
,
const
AbstractBasePtrList
&
)
override
{
MS_LOG
(
EXCEPTION
)
<<
"Eval() should not be called, Run() method should be called"
;
}
...
...
@@ -91,10 +91,10 @@ class UnpackGraphEvaluator : public Evaluator {
public:
explicit
UnpackGraphEvaluator
(
const
PrimitivePtr
primitive
)
:
Evaluator
(
"UnpackGraphEvaluator"
),
prim_
(
primitive
)
{}
~
UnpackGraphEvaluator
()
override
=
default
;
AbstractBase
Ptr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
argrefs
,
AnfNodeConfigPtr
out_config
=
nullptr
)
override
;
EvalResult
Ptr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
argrefs
,
AnfNodeConfigPtr
out_config
=
nullptr
)
override
;
AbstractBase
Ptr
Eval
(
AnalysisEnginePtr
,
const
AbstractBasePtrList
&
)
override
{
EvalResult
Ptr
Eval
(
AnalysisEnginePtr
,
const
AbstractBasePtrList
&
)
override
{
MS_LOG
(
EXCEPTION
)
<<
"Eval() should not be called, Run() method should be called"
;
}
...
...
@@ -131,7 +131,7 @@ class UniformPrimEvaluator : public TrivialPrimEvaluator {
~
UniformPrimEvaluator
()
override
=
default
;
MS_DECLARE_PARENT
(
UniformPrimEvaluator
,
TrivialPrimEvaluator
);
AbstractBase
Ptr
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args
)
override
;
EvalResult
Ptr
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args
)
override
;
ValuePtr
RunImpl
(
const
ValuePtrList
&
args
)
const
;
// If eval_value_ is False, return broadened arguments.
...
...
mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc
浏览文件 @
5b9c145f
...
...
@@ -36,7 +36,7 @@ inline AbstractBasePtr GetEvaluatedValueWrap(const AnfNodeConfigPtr &conf) {
if
(
conf
->
node
()
->
intermediate_abstract
())
{
return
conf
->
node
()
->
intermediate_abstract
();
}
return
conf
->
GetEvaluatedValue
();
return
conf
->
GetEvaluatedValue
()
->
abstract
()
;
}
AnfNodePtr
BuildValueNode
(
const
ValuePtr
&
v
,
const
AbstractBasePtr
&
abs_base
)
{
...
...
@@ -212,7 +212,7 @@ void FuncGraphSpecializer::FirstPass() {
// Specialize CNode in func graphs
void
FuncGraphSpecializer
::
SecondPass
()
{
for
(
auto
&
node
:
DeepLinkedGraphSearch
(
specialized_func_graph_
->
get_return
()))
{
for
(
auto
&
node
:
BroadFirstSearchGraphCNodes
(
specialized_func_graph_
->
get_return
()))
{
if
(
node
->
isa
<
CNode
>
())
{
ProcessCNode
(
node
->
cast
<
CNodePtr
>
());
}
...
...
@@ -225,7 +225,6 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
AnfNodeConfigPtr
conf
=
MakeConfig
(
node
);
AnfNodePtr
new_node
=
GetReplicatedNode
(
node
);
MS_EXCEPTION_IF_NULL
(
new_node
);
if
(
new_node
->
func_graph
()
!=
specialized_func_graph_
)
{
MS_LOG
(
EXCEPTION
)
<<
"Error in specializer [A] node: "
<<
node
->
DebugString
()
<<
", new_node: "
<<
new_node
->
DebugString
()
...
...
@@ -244,6 +243,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
MS_LOG
(
DEBUG
)
<<
"Set new_node: "
<<
new_node
->
ToString
()
<<
", abstract as: "
<<
new_node
->
abstract
()
->
ToString
();
if
(
node
->
isa
<
CNode
>
())
{
auto
attrs
=
conf
->
GetEvaluatedValue
()
->
attribute
();
auto
c_old
=
node
->
cast
<
CNodePtr
>
();
auto
c_new
=
new_node
->
cast
<
CNodePtr
>
();
auto
new_inputs
=
c_new
->
inputs
();
...
...
@@ -254,7 +254,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
AbstractBasePtr
ival
=
GetEvaluatedValueWrap
(
iconf
);
// First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
// can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
AnfNodePtr
replace_node
=
BuildPossibleValueNode
(
iconf
->
node
(),
ival
);
AnfNodePtr
replace_node
=
BuildPossibleValueNode
(
iconf
->
node
(),
ival
,
attrs
);
if
(
replace_node
==
nullptr
)
{
replace_node
=
BuildReplacedNode
(
iconf
);
MS_EXCEPTION_IF_NULL
(
replace_node
);
...
...
@@ -424,9 +424,10 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &n
MS_LOG
(
EXCEPTION
)
<<
"Size of cnode: "
<<
cnode
->
DebugString
()
<<
" is not equal to 2 added to size of args: "
<<
mindspore
::
ToString
(
partial_closure
->
args
());
}
auto
attrs
=
std
::
make_shared
<
AttrValueMap
>
();
for
(
size_t
i
=
0
;
i
<
partial_closure
->
args
().
size
();
i
++
)
{
auto
old_node
=
cnode
->
input
(
i
+
2
);
auto
possibile_value_node
=
BuildPossibleValueNode
(
old_node
,
partial_closure
->
args
()[
i
]);
auto
possibile_value_node
=
BuildPossibleValueNode
(
old_node
,
partial_closure
->
args
()[
i
]
,
attrs
);
if
(
possibile_value_node
!=
nullptr
)
{
partial_node_list
.
push_back
(
possibile_value_node
);
}
else
{
...
...
@@ -455,7 +456,7 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
const
EvaluatorPtr
&
eval
)
{
MS_EXCEPTION_IF_NULL
(
eval
);
std
::
unordered_set
<
AbstractBasePtrList
,
AbstractBasePtrListHasher
,
AbstractBasePtrListEqual
>
choices
;
AbstractBase
Ptr
ret
=
nullptr
;
EvalResult
Ptr
ret
=
nullptr
;
AbstractBasePtrList
broaded_argvals
;
for
(
auto
&
argvals_map
:
*
evalcaches_
[
eval
])
{
auto
argvals
=
argvals_map
.
first
;
...
...
@@ -478,7 +479,7 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
(
*
real
)[
broaded_argvals
]
=
ret
;
evalcaches_
[
eval
]
=
real
;
return
std
::
make_pair
(
broaded_argvals
,
ret
);
return
std
::
make_pair
(
broaded_argvals
,
ret
->
abstract
()
);
}
else
{
MS_LOG
(
DEBUG
)
<<
"Choices.size: "
<<
choices
.
size
();
return
std
::
make_pair
(
AbstractBasePtrList
(),
nullptr
);
...
...
@@ -491,7 +492,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
return
;
}
specializer_
->
AddSeen
(
new_node
);
auto
new_inputs
=
new_node
->
inputs
();
if
(
new_inputs
.
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"Inputs of CNode is empty"
;
...
...
@@ -530,7 +530,13 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
}
if
(
CanSpecializeNode
(
func
))
{
new_inputs
[
0
]
=
BuildSpecializedNode
(
func
,
fnval
,
argvals
);
// for primitive node , we build the primitive node with infered attributes in the first pass
// so we do not build replaced node again here in second pass
if
(
IsValueNode
<
Primitive
>
(
func
))
{
new_inputs
[
0
]
=
func
;
}
else
{
new_inputs
[
0
]
=
BuildSpecializedNode
(
func
,
fnval
,
argvals
);
}
}
for
(
size_t
i
=
0
;
i
<
argvals
.
size
();)
{
...
...
@@ -540,7 +546,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
}
i
=
next
;
}
new_node
->
set_inputs
(
new_inputs
);
}
...
...
@@ -582,7 +587,7 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
EvaluatorCacheMap
evaluator_cache_map
=
*
eval
->
cache
();
if
(
evaluator_cache_map
.
find
(
argvals
)
!=
evaluator_cache_map
.
end
())
{
*
result
=
std
::
make_pair
(
argvals
,
evaluator_cache_map
[
argvals
]);
*
result
=
std
::
make_pair
(
argvals
,
evaluator_cache_map
[
argvals
]
->
abstract
()
);
return
kSpecializeSuccess
;
}
DumpEvaluatorCache
(
evaluator_cache_map
,
argvals
);
...
...
@@ -591,11 +596,11 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
MS_EXCEPTION_IF_NULL
(
choices
);
if
(
choices
->
count
(
argvals
))
{
*
result
=
std
::
make_pair
(
argvals
,
(
*
choices
)[
argvals
]);
*
result
=
std
::
make_pair
(
argvals
,
(
*
choices
)[
argvals
]
->
abstract
()
);
return
kSpecializeSuccess
;
}
else
if
(
choices
->
size
()
==
1
)
{
MS_LOG
(
DEBUG
)
<<
"Evaluator cache has a single item, just use it."
;
*
result
=
std
::
make_pair
(
choices
->
begin
()
->
first
,
choices
->
begin
()
->
second
);
*
result
=
std
::
make_pair
(
choices
->
begin
()
->
first
,
choices
->
begin
()
->
second
->
abstract
()
);
return
kSpecializeSuccess
;
}
else
if
(
choices
->
empty
())
{
MS_LOG
(
DEBUG
)
<<
"Find DEAD code, it may be optimized in later phase."
;
...
...
@@ -614,8 +619,43 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
return
kSpecializeFindUniqueArgvalPoly
;
}
}
static
PrimitivePtr
BuildPrimtiveValueWithAttributes
(
const
PrimitivePtr
&
prim
,
const
AttrValueMapPtr
&
attrs
)
{
auto
&
prim_attrs
=
prim
->
attrs
();
bool
is_attr_same
=
true
;
for
(
auto
&
item
:
*
attrs
)
{
auto
itr
=
prim_attrs
.
find
(
item
.
first
);
if
(
itr
!=
prim_attrs
.
end
())
{
if
(
!
(
*
(
itr
->
second
)
==
*
(
item
.
second
)))
{
is_attr_same
=
false
;
break
;
}
}
else
{
is_attr_same
=
false
;
break
;
}
}
if
(
!
is_attr_same
)
{
if
(
prim
->
isa
<
PrimitivePy
>
())
{
PrimitivePyPtr
prim_py
=
prim
->
cast
<
PrimitivePyPtr
>
();
auto
clone_fn
=
prim_py
->
GetPyObj
().
attr
(
"_clone"
);
py
::
object
new_obj
=
clone_fn
();
auto
cloned_prim
=
new_obj
.
cast
<
PrimitivePyPtr
>
();
for
(
auto
&
item
:
*
attrs
)
{
cloned_prim
->
AddAttr
(
item
.
first
,
item
.
second
);
}
return
cloned_prim
;
}
auto
cloned_prim
=
std
::
make_shared
<
Primitive
>
(
*
prim
);
for
(
auto
&
item
:
*
attrs
)
{
cloned_prim
->
AddAttr
(
item
.
first
,
item
.
second
);
}
return
cloned_prim
;
}
return
prim
;
}
AnfNodePtr
FuncGraphSpecializer
::
BuildPossibleValueNode
(
const
AnfNodePtr
&
origin_node
,
const
AbstractBasePtr
&
ival
)
{
AnfNodePtr
FuncGraphSpecializer
::
BuildPossibleValueNode
(
const
AnfNodePtr
&
origin_node
,
const
AbstractBasePtr
&
ival
,
const
AttrValueMapPtr
&
attrs
)
{
MS_EXCEPTION_IF_NULL
(
origin_node
);
MS_EXCEPTION_IF_NULL
(
ival
);
...
...
@@ -628,7 +668,12 @@ AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin
ValuePtr
value
=
nullptr
;
if
(
abs
->
isa
<
PrimitiveAbstractClosure
>
())
{
auto
real_fn
=
dyn_cast
<
PrimitiveAbstractClosure
>
(
abs
);
value
=
real_fn
->
prim
();
// for primitive, check if the attribute is the same with cnode infererd attribute ,if not, clone a new one
if
(
attrs
!=
nullptr
)
{
value
=
BuildPrimtiveValueWithAttributes
(
real_fn
->
prim
(),
attrs
);
}
else
{
value
=
real_fn
->
prim
();
}
}
else
if
(
abs
->
isa
<
MetaFuncGraphAbstractClosure
>
())
{
auto
real_fn
=
dyn_cast
<
MetaFuncGraphAbstractClosure
>
(
abs
);
value
=
real_fn
->
meta_func_graph
();
...
...
mindspore/ccsrc/pipeline/static_analysis/program_specialize.h
浏览文件 @
5b9c145f
...
...
@@ -110,7 +110,8 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
AnfNodePtr
BuildSpecializedParameterNode
(
const
CNodePtr
&
new_node
);
// Build a value node if ival is constant and not any-value
AnfNodePtr
BuildPossibleValueNode
(
const
AnfNodePtr
&
origin_node
,
const
AbstractBasePtr
&
ival
);
AnfNodePtr
BuildPossibleValueNode
(
const
AnfNodePtr
&
origin_node
,
const
AbstractBasePtr
&
ival
,
const
AttrValueMapPtr
&
attrs
);
// Build a replacable node for iconf->node; it may be a replicated forwared CNode in static analysis or just a
// replicated node.
AnfNodePtr
BuildReplacedNode
(
const
AnfNodeConfigPtr
&
conf
);
...
...
mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc
浏览文件 @
5b9c145f
...
...
@@ -55,29 +55,29 @@ AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBase
return
nullptr
;
}
void
AnalysisCache
::
set_value
(
const
AnfNodeConfigPtr
&
conf
,
const
AbstractBasePtr
&
arg
)
{
void
AnalysisCache
::
set_value
(
const
AnfNodeConfigPtr
&
conf
,
const
EvalResultPtr
&
result
)
{
MS_LOG
(
DEBUG
)
<<
"AnalysisCache set for NodeConfig: "
<<
conf
->
node
()
->
DebugString
()
<<
", Context: "
<<
conf
->
context
()
->
ToString
()
<<
", Value: "
<<
arg
->
ToString
()
<<
", Pointer: "
<<
arg
.
get
();
cache_
[
conf
]
=
arg
;
<<
", Context: "
<<
conf
->
context
()
->
ToString
()
<<
", Value: "
<<
result
->
abstract
()
->
ToString
()
<<
", Pointer: "
<<
result
->
abstract
()
.
get
();
cache_
[
conf
]
=
result
;
// Set intermediate abstract value.
if
(
IsIntermediateAbstract
(
arg
))
{
if
(
IsIntermediateAbstract
(
result
->
abstract
()
))
{
if
(
conf
->
node
()
->
intermediate_abstract
()
==
nullptr
)
{
conf
->
node
()
->
set_intermediate_abstract
(
arg
);
MS_LOG
(
DEBUG
)
<<
"Set intermediate abstract: "
<<
arg
->
ToString
();
conf
->
node
()
->
set_intermediate_abstract
(
result
->
abstract
()
);
MS_LOG
(
DEBUG
)
<<
"Set intermediate abstract: "
<<
result
->
abstract
()
->
ToString
();
}
else
{
auto
old_spec
=
conf
->
node
()
->
intermediate_abstract
();
auto
joined_spec
=
IntermediateJoin
(
arg
,
old_spec
);
auto
joined_spec
=
IntermediateJoin
(
result
->
abstract
()
,
old_spec
);
conf
->
node
()
->
set_intermediate_abstract
(
joined_spec
);
MS_LOG
(
DEBUG
)
<<
"Set joined intermediate abstract:
\n
old_spec:
\t\t
"
<<
old_spec
->
ToString
()
<<
"
\n
new_spec:
\t\t
"
<<
arg
->
ToString
()
<<
"
\n
joined_spec:
\t
"
<<
result
->
abstract
()
->
ToString
()
<<
"
\n
joined_spec:
\t
"
<<
(
joined_spec
!=
nullptr
?
joined_spec
->
ToString
()
:
"nullptr"
);
}
}
}
AbstractBase
Ptr
AnalysisCache
::
GetValue
(
const
AnfNodeConfigPtr
&
conf
)
{
EvalResult
Ptr
AnalysisCache
::
GetValue
(
const
AnfNodeConfigPtr
&
conf
)
{
auto
value
=
cache_
.
find
(
conf
);
if
(
value
==
cache_
.
end
())
{
return
nullptr
;
...
...
@@ -142,12 +142,12 @@ AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Ana
return
eval
->
graph_context
();
}
AbstractBase
Ptr
AnalysisEngine
::
GetEvaluatedValue
(
const
AnfNodeConfigPtr
&
conf
)
{
EvalResult
Ptr
AnalysisEngine
::
GetEvaluatedValue
(
const
AnfNodeConfigPtr
&
conf
)
{
MS_EXCEPTION_IF_NULL
(
conf
);
auto
value
=
cache_
.
GetValue
(
conf
);
if
(
value
!=
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"Evaluate cache hit for NodeConfig: "
<<
conf
->
ToString
()
<<
", Value: "
<<
value
.
get
()
<<
", "
<<
value
->
ToString
();
MS_LOG
(
DEBUG
)
<<
"Evaluate cache hit for NodeConfig: "
<<
conf
->
ToString
()
<<
", Value: "
<<
value
->
abstract
().
get
()
<<
", "
<<
value
->
abstract
()
->
ToString
();
return
value
;
}
...
...
@@ -160,10 +160,10 @@ AbstractBasePtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf)
return
value
;
}
AbstractBase
Ptr
AnalysisEngine
::
Eval
(
const
AnfNodeConfigPtr
&
conf
)
{
EvalResult
Ptr
AnalysisEngine
::
Eval
(
const
AnfNodeConfigPtr
&
conf
)
{
MS_EXCEPTION_IF_NULL
(
conf
);
AnfNodePtr
node
=
conf
->
node
();
AbstractBasePtr
ret_abstrac
t
=
nullptr
;
EvalResultPtr
eval_resul
t
=
nullptr
;
#ifdef DEBUG
compute_conf_stack_
.
push_back
(
node
);
std
::
ostringstream
buffer
;
...
...
@@ -177,14 +177,14 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL
(
node
);
if
(
node
->
abstract
()
!=
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"Return old abstract: "
<<
node
->
DebugString
();
ret_abstract
=
node
->
abstract
(
);
eval_result
=
std
::
make_shared
<
EvalResult
>
(
node
->
abstract
(),
std
::
make_shared
<
AttrValueMap
>
()
);
}
else
if
(
node
->
isa
<
ValueNode
>
())
{
auto
value_node
=
node
->
cast
<
ValueNodePtr
>
();
ret_abstract
=
EvalValueNode
(
value_node
,
conf
);
eval_result
=
std
::
make_shared
<
EvalResult
>
(
EvalValueNode
(
value_node
,
conf
),
nullptr
);
}
else
if
(
node
->
isa
<
CNode
>
())
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
trace
::
TraceEvalCNodeEnter
(
conf
);
ret_abstrac
t
=
EvalCNode
(
cnode
,
conf
);
eval_resul
t
=
EvalCNode
(
cnode
,
conf
);
trace
::
TraceEvalCNodeLeave
();
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Illegal AnfNode for evaluating, "
<<
node
->
DebugString
()
...
...
@@ -193,13 +193,13 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
#ifdef DEBUG
compute_conf_stack_
.
pop_back
();
if
(
ret_abstrac
t
==
nullptr
)
{
if
(
eval_resul
t
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Compute Config failed, node: "
<<
node
->
DebugString
()
<<
" NodeInfo: "
<<
trace
::
GetDebugInfo
(
node
->
debug_info
());
}
#endif
MS_LOG
(
DEBUG
)
<<
"End Eval NodeConfig "
<<
conf
->
ToString
()
<<
", res: "
<<
ret_abstract
->
ToString
();
return
ret_abstrac
t
;
MS_LOG
(
DEBUG
)
<<
"End Eval NodeConfig "
<<
conf
->
ToString
()
<<
", res: "
<<
eval_result
->
abstract
()
->
ToString
();
return
eval_resul
t
;
}
AbstractBasePtr
AnalysisEngine
::
EvalValueNode
(
const
ValueNodePtr
&
value_node
,
const
AnfNodeConfigPtr
&
conf
)
{
...
...
@@ -208,7 +208,7 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co
return
ToAbstract
(
value_node
->
value
(),
conf
->
context
(),
conf
);
}
AbstractBase
Ptr
AnalysisEngine
::
EvalCNode
(
const
CNodePtr
&
cnode
,
const
AnfNodeConfigPtr
&
conf
)
{
EvalResult
Ptr
AnalysisEngine
::
EvalCNode
(
const
CNodePtr
&
cnode
,
const
AnfNodeConfigPtr
&
conf
)
{
MS_EXCEPTION_IF_NULL
(
conf
);
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
&
inputs
=
cnode
->
inputs
();
...
...
@@ -223,7 +223,7 @@ AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeCo
AnfNodeConfigPtr
func_conf
=
MakeConfig
(
func_node
,
context
);
MS_EXCEPTION_IF_NULL
(
func_conf
);
// Keep it in a local variable, otherwise smart pointer will free it.
AbstractBasePtr
maybe_func
=
func_conf
->
GetEvaluatedValue
();
AbstractBasePtr
maybe_func
=
func_conf
->
GetEvaluatedValue
()
->
abstract
()
;
if
(
maybe_func
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"func_conf.GetEvaluatedValue() return null, func_conf: "
<<
func_conf
->
ToString
()
<<
" NodeInfo: "
<<
trace
::
GetDebugInfo
(
cnode
->
debug_info
());
...
...
@@ -253,7 +253,7 @@ AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeCo
return
ExecuteEvaluators
(
infs
,
conf
,
args_conf_list
);
}
AbstractBase
Ptr
AnalysisEngine
::
Execute
(
const
AbstractFunctionPtr
&
func
,
const
AbstractBasePtrList
&
args_spec_list
)
{
EvalResult
Ptr
AnalysisEngine
::
Execute
(
const
AbstractFunctionPtr
&
func
,
const
AbstractBasePtrList
&
args_spec_list
)
{
ConfigPtrList
args_conf_list
;
(
void
)
std
::
transform
(
args_spec_list
.
begin
(),
args_spec_list
.
end
(),
std
::
back_inserter
(
args_conf_list
),
[](
const
AbstractBasePtr
&
arg
)
->
ConfigPtr
{
return
std
::
make_shared
<
VirtualConfig
>
(
arg
);
});
...
...
@@ -454,9 +454,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
return
tracked_eval
;
}
AbstractBasePtr
AnalysisEngine
::
ExecuteEvaluators
(
const
std
::
vector
<
EvaluatorPtr
>
&
evaluators
,
const
AnfNodeConfigPtr
&
out_conf
,
const
ConfigPtrList
&
args_conf_list
)
{
EvalResultPtr
AnalysisEngine
::
ExecuteEvaluators
(
const
std
::
vector
<
EvaluatorPtr
>
&
evaluators
,
const
AnfNodeConfigPtr
&
out_conf
,
const
ConfigPtrList
&
args_conf_list
)
{
if
(
evaluators
.
size
()
==
1
)
{
EvaluatorPtr
eval
=
evaluators
[
0
];
MS_EXCEPTION_IF_NULL
(
eval
);
...
...
@@ -465,9 +464,9 @@ AbstractBasePtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr
return
ExecuteMultipleEvaluators
(
evaluators
,
out_conf
,
args_conf_list
);
}
AbstractBase
Ptr
AnalysisEngine
::
ExecuteMultipleEvaluators
(
const
std
::
vector
<
EvaluatorPtr
>
&
evaluators
,
const
AnfNodeConfigPtr
&
out_conf
,
const
ConfigPtrList
&
args_conf_list
)
{
EvalResult
Ptr
AnalysisEngine
::
ExecuteMultipleEvaluators
(
const
std
::
vector
<
EvaluatorPtr
>
&
evaluators
,
const
AnfNodeConfigPtr
&
out_conf
,
const
ConfigPtrList
&
args_conf_list
)
{
AbstractBasePtrList
out_specs
;
if
(
!
multi_poss_
.
count
(
evaluators
[
0
]))
{
multi_poss_
[
evaluators
[
0
]]
=
evaluators
[
1
];
...
...
@@ -477,7 +476,7 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
(
void
)
std
::
transform
(
args_conf_list
.
begin
(),
args_conf_list
.
end
(),
std
::
back_inserter
(
args_spec_list
),
[](
const
ConfigPtr
&
conf
)
->
AbstractBasePtr
{
MS_EXCEPTION_IF_NULL
(
conf
);
return
conf
->
GetEvaluatedValue
();
return
conf
->
GetEvaluatedValue
()
->
abstract
()
;
});
for
(
auto
eval
:
evaluators
)
{
auto
fg_eval
=
eval
->
cast
<
FuncGraphEvaluatorPtr
>
();
...
...
@@ -502,11 +501,10 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
eval_trace_
.
push_back
(
current_inf
);
MS_LOG
(
DEBUG
)
<<
"Trace Evaluator "
<<
eval
->
ToString
()
<<
" ptr: "
<<
eval
.
get
();
MS_EXCEPTION_IF_NULL
(
eval
);
auto
out_spec
=
eval
->
Run
(
shared_from_this
(),
args_conf_list
,
out_conf
);
MS_EXCEPTION_IF_NULL
(
out_spec
);
MS_LOG
(
DEBUG
)
<<
"Evaluator "
<<
eval
->
ToString
()
<<
" return out_spec: "
<<
out_spec
->
ToString
();
out_specs
.
push_back
(
out_spec
);
MS_LOG
(
DEBUG
)
<<
"Pop Evaluator "
<<
eval
->
ToString
();
auto
eval_result
=
eval
->
Run
(
shared_from_this
(),
args_conf_list
,
out_conf
);
MS_EXCEPTION_IF_NULL
(
eval_result
->
abstract
());
MS_LOG
(
DEBUG
)
<<
"Evaluator "
<<
eval
->
ToString
()
<<
" return out_spec: "
<<
eval_result
->
abstract
()
->
ToString
();
out_specs
.
push_back
(
eval_result
->
abstract
());
eval_trace_
.
pop_back
();
if
(
eval_trace_
.
empty
())
{
multi_poss_
.
clear
();
...
...
@@ -552,10 +550,11 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
// Try to travel the latest undetermined.
if
(
latest_entry
!=
eval_trace_
.
rbegin
()
->
first
)
{
MS_LOG
(
DEBUG
)
<<
"Direct Run Evaluator "
<<
eval
->
ToString
();
auto
out_spec
=
latest_entry
->
Run
(
shared_from_this
(),
args_conf_list
,
out_conf
);
MS_EXCEPTION_IF_NULL
(
out_spec
);
MS_LOG
(
DEBUG
)
<<
"Evaluator "
<<
latest_entry
->
ToString
()
<<
" return out_spec: "
<<
out_spec
->
ToString
();
return
out_spec
;
auto
eval_result
=
latest_entry
->
Run
(
shared_from_this
(),
args_conf_list
,
out_conf
);
MS_EXCEPTION_IF_NULL
(
eval_result
->
abstract
());
MS_LOG
(
DEBUG
)
<<
"Evaluator "
<<
latest_entry
->
ToString
()
<<
" return out_spec: "
<<
eval_result
->
abstract
()
->
ToString
();
return
eval_result
;
}
}
}
...
...
@@ -566,15 +565,15 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
if
(
out_specs
.
size
()
==
1
)
{
MS_EXCEPTION_IF_NULL
(
out_specs
[
0
]);
// If only one result derived, then broaden it to avoid wrong constant propagation.
return
out_specs
[
0
]
->
Broaden
(
);
return
std
::
make_shared
<
EvalResult
>
(
out_specs
[
0
]
->
Broaden
(),
std
::
make_shared
<
AttrValueMap
>
()
);
}
auto
joined_spec
=
AbstractJoin
(
out_specs
);
MS_EXCEPTION_IF_NULL
(
joined_spec
);
MS_LOG
(
DEBUG
)
<<
"Multiple evaluators joined: "
<<
joined_spec
->
ToString
();
return
joined_spec
;
return
std
::
make_shared
<
EvalResult
>
(
joined_spec
,
std
::
make_shared
<
AttrValueMap
>
())
;
}
AbstractBase
Ptr
AnfNodeConfig
::
GetEvaluatedValue
()
{
EvalResult
Ptr
AnfNodeConfig
::
GetEvaluatedValue
()
{
AnfNodeConfigPtr
self
=
shared_from_base
<
AnfNodeConfig
>
();
return
engine_
.
lock
()
->
GetEvaluatedValue
(
self
);
}
...
...
@@ -607,7 +606,7 @@ AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {
return
a
;
}
AbstractBase
Ptr
EvalOnePrim
(
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
arg_specs
)
{
EvalResult
Ptr
EvalOnePrim
(
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
arg_specs
)
{
auto
evaluator
=
GetPrimEvaluator
(
primitive
,
nullptr
);
MS_EXCEPTION_IF_NULL
(
evaluator
);
if
(
!
evaluator
->
isa
<
TrivialPrimEvaluator
>
())
{
...
...
@@ -615,8 +614,8 @@ AbstractBasePtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtr
<<
evaluator
->
ToString
();
}
auto
trivial_evaluator
=
dyn_cast
<
TrivialPrimEvaluator
>
(
evaluator
);
auto
res_spec
=
trivial_evaluator
->
EvalPrim
(
nullptr
,
arg_specs
);
return
res_spec
;
auto
eval_result
=
trivial_evaluator
->
EvalPrim
(
nullptr
,
arg_specs
);
return
eval_result
;
}
}
// namespace abstract
}
// namespace mindspore
mindspore/ccsrc/pipeline/static_analysis/static_analysis.h
浏览文件 @
5b9c145f
...
...
@@ -40,13 +40,33 @@
namespace
mindspore
{
namespace
abstract
{
// define attribute value map
using
AttrValueMap
=
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
;
using
AttrValueMapPtr
=
std
::
shared_ptr
<
AttrValueMap
>
;
// the class to save evaluated result: abstract value and modified attribute
class
EvalResult
:
public
Base
{
public:
EvalResult
(
AbstractBasePtr
abs
,
AttrValueMapPtr
attr
)
:
abstract_
(
abs
),
attribute_
(
attr
)
{}
~
EvalResult
()
override
=
default
;
MS_DECLARE_PARENT
(
EvalResult
,
Base
);
AbstractBasePtr
abstract
()
{
return
abstract_
;
}
AttrValueMapPtr
attribute
()
{
return
attribute_
;
}
private:
AbstractBasePtr
abstract_
;
AttrValueMapPtr
attribute_
;
};
using
EvalResultPtr
=
std
::
shared_ptr
<
EvalResult
>
;
// Superclass for AnfNodeConfig and VirtualConfig.
class
Config
:
public
Base
{
public:
Config
()
=
default
;
~
Config
()
override
=
default
;
MS_DECLARE_PARENT
(
Config
,
Base
);
virtual
AbstractBase
Ptr
GetEvaluatedValue
()
=
0
;
virtual
EvalResult
Ptr
GetEvaluatedValue
()
=
0
;
};
// Config will be stored in AnalysisCache
...
...
@@ -74,7 +94,7 @@ class AnfNodeConfig : public Config {
~
AnfNodeConfig
()
override
=
default
;
MS_DECLARE_PARENT
(
AnfNodeConfig
,
Config
);
AbstractBase
Ptr
GetEvaluatedValue
()
override
;
EvalResult
Ptr
GetEvaluatedValue
()
override
;
AnalysisContextPtr
context
()
const
{
return
context_
;
}
...
...
@@ -123,7 +143,9 @@ class VirtualConfig : public Config {
~
VirtualConfig
()
override
=
default
;
MS_DECLARE_PARENT
(
VirtualConfig
,
Config
);
AbstractBasePtr
GetEvaluatedValue
()
override
{
return
abstract_
;
}
EvalResultPtr
GetEvaluatedValue
()
override
{
return
std
::
make_shared
<
EvalResult
>
(
abstract_
,
std
::
make_shared
<
AttrValueMap
>
());
}
private:
AbstractBasePtr
abstract_
;
...
...
@@ -135,11 +157,11 @@ class AnalysisCache {
AnalysisCache
()
=
default
;
~
AnalysisCache
()
=
default
;
void
Clear
()
{
cache_
.
clear
();
}
void
set_value
(
const
AnfNodeConfigPtr
&
conf
,
const
AbstractBase
Ptr
&
arg
);
AbstractBase
Ptr
GetValue
(
const
AnfNodeConfigPtr
&
conf
);
void
set_value
(
const
AnfNodeConfigPtr
&
conf
,
const
EvalResult
Ptr
&
arg
);
EvalResult
Ptr
GetValue
(
const
AnfNodeConfigPtr
&
conf
);
private:
std
::
unordered_map
<
AnfNodeConfigPtr
,
AbstractBase
Ptr
,
AnfNodeConfigHasher
,
AnfNodeConfigEqual
>
cache_
;
std
::
unordered_map
<
AnfNodeConfigPtr
,
EvalResult
Ptr
,
AnfNodeConfigHasher
,
AnfNodeConfigEqual
>
cache_
;
};
using
PrimEvaluatorMap
=
std
::
unordered_map
<
PrimitivePtr
,
EvaluatorPtr
,
PrimitiveHasher
,
PrimitiveEqual
>
;
...
...
@@ -147,7 +169,7 @@ using AnfNodeConfigMap =
std
::
unordered_map
<
AnfNodeConfigPtr
,
AnfNodeConfigPtr
,
AnfNodeConfigHasher
,
AnfNodeConfigEqual
>
;
struct
AnalysisResult
{
AbstractBase
Ptr
inferred
;
EvalResult
Ptr
inferred
;
AnalysisContextPtr
context
;
};
...
...
@@ -160,14 +182,14 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
// func_graph: The func_graph to analyze.
// args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase.
AnalysisResult
Run
(
const
FuncGraphPtr
&
func_graph
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBase
Ptr
GetEvaluatedValue
(
const
AnfNodeConfigPtr
&
conf
);
EvalResult
Ptr
GetEvaluatedValue
(
const
AnfNodeConfigPtr
&
conf
);
// Return the Evaluator for the given function.
EvaluatorPtr
GetEvaluatorFor
(
const
AbstractFunctionPtr
&
fn
);
AbstractBasePtr
EvalValueNode
(
const
ValueNodePtr
&
value_node
,
const
AnfNodeConfigPtr
&
conf
);
AbstractBase
Ptr
EvalCNode
(
const
CNodePtr
&
cnode
,
const
AnfNodeConfigPtr
&
conf
);
EvalResult
Ptr
EvalCNode
(
const
CNodePtr
&
cnode
,
const
AnfNodeConfigPtr
&
conf
);
// Infer the result of fn(args).
AbstractBase
Ptr
Execute
(
const
AbstractFunctionPtr
&
fn
,
const
AbstractBasePtrList
&
args_spec_list
);
EvalResult
Ptr
Execute
(
const
AbstractFunctionPtr
&
fn
,
const
AbstractBasePtrList
&
args_spec_list
);
void
Clear
();
void
ClearEvaluatorCache
();
AnalysisCache
&
cache
()
{
return
cache_
;
}
...
...
@@ -188,7 +210,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
// Set the analysis result for orig to the result for new.
// This sets an entry in anfnode_config_map from orig to new.
AbstractBase
Ptr
ForwardConfig
(
const
AnfNodeConfigPtr
&
orig_conf
,
const
AnfNodeConfigPtr
new_conf
)
{
EvalResult
Ptr
ForwardConfig
(
const
AnfNodeConfigPtr
&
orig_conf
,
const
AnfNodeConfigPtr
new_conf
)
{
// Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor.
(
void
)
anfnode_config_map_
.
emplace
(
orig_conf
,
new_conf
);
MS_LOG
(
DEBUG
)
<<
"Forward orig_conf: "
<<
orig_conf
->
node
()
->
DebugString
()
...
...
@@ -211,12 +233,12 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
AnalysisContextPtr
Run
(
const
FuncGraphPtr
&
func_graph
,
const
AnalysisContextPtr
&
context
,
const
ConfigPtrList
&
args_conf_list
);
AbstractBase
Ptr
Eval
(
const
AnfNodeConfigPtr
&
conf
);
EvalResult
Ptr
Eval
(
const
AnfNodeConfigPtr
&
conf
);
EvaluatorPtr
_GetEvaluatorFor
(
const
AbstractFunctionPtr
&
fn
);
AbstractBase
Ptr
ExecuteEvaluators
(
const
std
::
vector
<
EvaluatorPtr
>
&
evaluators
,
const
AnfNodeConfigPtr
&
out_conf
,
const
ConfigPtrList
&
args_conf_list
);
AbstractBasePtr
ExecuteMultipleEvaluators
(
const
std
::
vector
<
EvaluatorPtr
>
&
evaluators
,
const
AnfNodeConfigPtr
&
out_conf
,
const
ConfigPtrList
&
args_conf_list
);
EvalResult
Ptr
ExecuteEvaluators
(
const
std
::
vector
<
EvaluatorPtr
>
&
evaluators
,
const
AnfNodeConfigPtr
&
out_conf
,
const
ConfigPtrList
&
args_conf_list
);
EvalResultPtr
ExecuteMultipleEvaluators
(
const
std
::
vector
<
EvaluatorPtr
>
&
evaluators
,
const
AnfNodeConfigPtr
&
out_conf
,
const
ConfigPtrList
&
args_conf_list
);
#ifdef DEBUG
std
::
vector
<
AnfNodePtr
>
compute_conf_stack_
;
...
...
@@ -244,7 +266,7 @@ AbstractBasePtr FromValue(const T &value, bool broaden = false) {
return
FromValueInside
(
MakeValue
(
value
),
broaden
);
}
AbstractBase
Ptr
EvalOnePrim
(
const
PrimitivePtr
&
p
,
const
AbstractBasePtrList
&
arg_specs
);
EvalResult
Ptr
EvalOnePrim
(
const
PrimitivePtr
&
p
,
const
AbstractBasePtrList
&
arg_specs
);
}
// namespace abstract
}
// namespace mindspore
...
...
mindspore/ccsrc/pynative/pynative_execute.cc
浏览文件 @
5b9c145f
...
...
@@ -116,7 +116,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::tuple &py_args, OpExecI
args_spec_list
.
emplace_back
(
abstract
::
FromValueInside
(
input_value
,
false
));
}
}
AbstractBasePtr
infer_res
=
EvalOnePrim
(
prim
,
args_spec_list
);
AbstractBasePtr
infer_res
=
EvalOnePrim
(
prim
,
args_spec_list
)
->
abstract
()
;
op_exec_info
->
abstract
=
infer_res
;
}
...
...
mindspore/ccsrc/utils/graph_utils.cc
浏览文件 @
5b9c145f
...
...
@@ -26,6 +26,8 @@
#include <list>
#include <string>
#include <fstream>
#include <queue>
#include <set>
#include "ir/visitor.h"
#include "utils/log_adapter.h"
...
...
@@ -223,6 +225,31 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
return
res
;
}
// search the cnodes inside this graph only
std
::
vector
<
CNodePtr
>
BroadFirstSearchGraphCNodes
(
CNodePtr
ret
)
{
std
::
queue
<
CNodePtr
>
todo
;
todo
.
push
(
ret
);
std
::
vector
<
CNodePtr
>
sorted_nodes
;
auto
seen
=
NewSeenGeneration
();
while
(
!
todo
.
empty
())
{
CNodePtr
top
=
todo
.
front
();
todo
.
pop
();
sorted_nodes
.
push_back
(
top
);
auto
inputs
=
top
->
inputs
();
for
(
auto
&
item
:
inputs
)
{
if
(
item
->
seen_
==
seen
)
{
continue
;
}
if
(
item
->
isa
<
CNode
>
())
{
todo
.
push
(
item
->
cast
<
CNodePtr
>
());
}
item
->
seen_
=
seen
;
}
}
return
sorted_nodes
;
}
std
::
vector
<
AnfNodePtr
>
SuccDeeper
(
const
AnfNodePtr
&
node
)
{
std
::
vector
<
AnfNodePtr
>
vecs
;
if
(
node
==
nullptr
)
{
...
...
mindspore/ccsrc/utils/graph_utils.h
浏览文件 @
5b9c145f
...
...
@@ -57,6 +57,7 @@ std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const Incl
std
::
vector
<
AnfNodePtr
>
TopoSort
(
const
AnfNodePtr
&
root
,
const
SuccFunc
&
succ
=
SuccIncoming
,
const
IncludeFunc
&
include
=
AlwaysInclude
);
std
::
vector
<
CNodePtr
>
BroadFirstSearchGraphCNodes
(
CNodePtr
ret
);
class
FuncGraphIndex
{
public:
explicit
FuncGraphIndex
(
const
FuncGraphPtr
&
fg
,
const
SearchFunc
&
search
=
DeepScopedGraphSearch
,
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
5b9c145f
...
...
@@ -71,7 +71,6 @@ class ExpandDims(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
):
"""init ExpandDims"""
self
.
__setattr_flag__
=
True
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'axis'
],
outputs
=
[
'output'
])
def
__infer__
(
self
,
x
,
axis
):
...
...
@@ -182,7 +181,6 @@ class Cast(PrimitiveWithInfer):
# if primitive need setattr in __infer__ need add this flag
"""init Cast"""
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'dst_type'
],
outputs
=
[
'output'
])
self
.
__setattr_flag__
=
True
def
__infer__
(
self
,
x
,
t
):
src_type
=
x
[
'dtype'
]
...
...
@@ -308,7 +306,6 @@ class Reshape(PrimitiveWithInfer):
def
__init__
(
self
):
"""init Reshape"""
self
.
init_prim_io_names
(
inputs
=
[
'tensor'
,
'shape'
],
outputs
=
[
'output'
])
self
.
__setattr_flag__
=
True
def
__infer__
(
self
,
x
,
shape
):
shape_v
=
shape
[
'value'
]
...
...
@@ -453,7 +450,6 @@ class Transpose(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
):
"""init Transpose"""
self
.
__setattr_flag__
=
True
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'perm'
],
outputs
=
[
'output'
])
def
__infer__
(
self
,
x
,
perm
):
...
...
@@ -508,7 +504,6 @@ class GatherV2(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
):
"""init index_select"""
self
.
__setattr_flag__
=
True
self
.
init_prim_io_names
(
inputs
=
[
'params'
,
'indices'
,
'axis'
],
outputs
=
[
'output'
])
def
__infer__
(
self
,
params
,
indices
,
axis
):
...
...
@@ -1402,7 +1397,6 @@ class Concat(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
,
axis
=
0
):
"""init Tile"""
self
.
__setattr_flag__
=
True
validator
.
check_value_type
(
"axis"
,
axis
,
[
int
],
self
.
name
)
def
__infer__
(
self
,
input_x
):
...
...
@@ -1476,7 +1470,6 @@ class Pack(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
,
axis
=
0
):
"""init Pack"""
self
.
__setattr_flag__
=
True
validator
.
check_value_type
(
"axis"
,
axis
,
[
int
],
self
.
name
)
self
.
axis
=
axis
...
...
@@ -1526,7 +1519,6 @@ class Unpack(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
,
axis
=
0
):
"""init Unpack"""
self
.
__setattr_flag__
=
True
validator
.
check_value_type
(
"axis"
,
axis
,
[
int
],
self
.
name
)
self
.
axis
=
axis
...
...
@@ -1656,7 +1648,6 @@ class Select(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
):
"""init"""
self
.
__setattr_flag__
=
True
def
infer_shape
(
self
,
cond_shape
,
x_shape
,
y_shape
):
if
cond_shape
!=
x_shape
or
x_shape
!=
y_shape
:
...
...
mindspore/ops/operations/math_ops.py
浏览文件 @
5b9c145f
...
...
@@ -516,7 +516,6 @@ class MatMul(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
,
transpose_a
=
False
,
transpose_b
=
False
):
self
.
init_prim_io_names
(
inputs
=
[
'x1'
,
'x2'
],
outputs
=
[
'output'
])
self
.
__setattr_flag__
=
True
cls_name
=
self
.
name
validator
.
check_value_type
(
"transpose_a"
,
transpose_a
,
[
bool
],
cls_name
)
validator
.
check_value_type
(
"transpose_b"
,
transpose_b
,
[
bool
],
cls_name
)
...
...
@@ -596,7 +595,6 @@ class BatchMatMul(MatMul):
@
prim_attr_register
def
__init__
(
self
,
transpose_a
=
False
,
transpose_b
=
False
):
self
.
init_prim_io_names
(
inputs
=
[
'x1'
,
'x2'
],
outputs
=
[
'output'
])
self
.
__setattr_flag__
=
True
cls_name
=
self
.
name
validator
.
check_value_type
(
"transpose_a"
,
transpose_a
,
[
bool
],
cls_name
)
validator
.
check_value_type
(
"transpose_b"
,
transpose_b
,
[
bool
],
cls_name
)
...
...
@@ -682,7 +680,6 @@ class AddN(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
):
self
.
__setattr_flag__
=
True
self
.
init_prim_io_names
(
inputs
=
[
"inputs"
],
outputs
=
[
"sum"
])
def
infer_shape
(
self
,
inputs
):
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
5b9c145f
...
...
@@ -730,8 +730,8 @@ class Conv2D(PrimitiveWithInfer):
"""init Conv2D"""
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'w'
],
outputs
=
[
'output'
])
self
.
kernel_size
=
_check_positive_int_or_tuple
(
'kernel_size'
,
kernel_size
,
self
.
name
)
self
.
stride
=
_check_positive_int_or_tuple
(
'stride'
,
stride
,
self
.
name
)
self
.
add_prim_attr
(
'stride'
,
(
1
,
1
,
self
.
stride
[
0
],
self
.
stride
[
1
])
)
self
.
stride
=
_check_positive_int_or_tuple
(
'stride'
,
stride
,
self
.
name
,
allow_four
=
True
,
ret_four
=
True
)
self
.
add_prim_attr
(
'stride'
,
self
.
stride
)
self
.
dilation
=
_check_positive_int_or_tuple
(
'dilation'
,
dilation
,
self
.
name
,
allow_four
=
True
,
ret_four
=
True
)
self
.
add_prim_attr
(
'dilation'
,
self
.
dilation
)
validator
.
check_value_type
(
'pad'
,
pad
,
(
int
,),
self
.
name
)
...
...
@@ -787,7 +787,6 @@ class Conv2D(PrimitiveWithInfer):
self
.
pad_list
=
[
pad_top
,
pad_bottom
,
pad_left
,
pad_right
]
self
.
add_prim_attr
(
'pad_list'
,
(
pad_top
,
pad_bottom
,
pad_left
,
pad_right
))
out_channel
=
self
.
out_channel
out_shape
=
[
x_shape
[
0
],
out_channel
,
h_out
,
w_out
]
return
out_shape
...
...
tests/st/ops/ascend/test_ops_infer.py
0 → 100644
浏览文件 @
5b9c145f
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test nn ops """
import
functools
import
numpy
as
np
import
mindspore.nn
as
nn
import
mindspore.context
as
context
import
mindspore.common.dtype
as
mstype
from
mindspore
import
Tensor
,
Parameter
from
mindspore.common.initializer
import
initializer
from
mindspore.ops
import
Primitive
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
from
mindspore.ops
import
prim_attr_register
,
PrimitiveWithInfer
from
mindspore.ops.primitive
import
constexpr
from
mindspore
import
context
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
def
test_cast_op_attr
():
class
CastNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
CastNet
,
self
).
__init__
()
self
.
cast
=
P
.
Cast
()
def
construct
(
self
,
x
,
t
):
return
self
.
cast
(
x
,
t
)
class
CastTypeTest
(
nn
.
Cell
):
def
__init__
(
self
,
net
):
super
(
CastTypeTest
,
self
).
__init__
()
self
.
net
=
net
self
.
cast
=
P
.
Cast
()
def
construct
(
self
,
x
,
y
,
z
):
cast_op
=
self
.
cast
t1
=
cast_op
(
x
,
mstype
.
float32
)
t2
=
cast_op
(
y
,
mstype
.
int32
)
cast_net
=
self
.
net
t3
=
cast_net
(
x
,
mstype
.
float16
)
t4
=
cast_net
(
y
,
mstype
.
int32
)
t5
=
cast_net
(
z
,
mstype
.
float16
)
return
(
t1
,
t2
,
t3
,
t4
,
t5
)
net
=
CastTypeTest
(
CastNet
())
t1
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
1918
]).
astype
(
np
.
int32
))
t2
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
3840
]).
astype
(
np
.
float32
))
t3
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
1918
]).
astype
(
np
.
int32
))
out
=
net
(
t1
,
t2
,
t3
)
assert
out
[
0
].
asnumpy
().
dtype
==
np
.
float32
assert
out
[
1
].
asnumpy
().
dtype
==
np
.
int32
assert
out
[
2
].
asnumpy
().
dtype
==
np
.
float16
assert
out
[
3
].
asnumpy
().
dtype
==
np
.
int32
assert
out
[
4
].
asnumpy
().
dtype
==
np
.
float16
tests/ut/cpp/operator/composite_test.cc
浏览文件 @
5b9c145f
...
...
@@ -153,7 +153,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice) {
auto
slice
=
std
::
make_shared
<
AbstractSlice
>
(
start_index
,
stop_index
,
step
);
AbstractBasePtrList
args_spec_list
=
{
tuple_tensor
,
slice
};
AbstractTuplePtr
ret
=
dyn_cast
<
AbstractTuple
>
(
engine_
->
Run
(
tupleSliceGraphPtr
,
args_spec_list
).
inferred
);
AbstractTuplePtr
ret
=
dyn_cast
<
AbstractTuple
>
(
engine_
->
Run
(
tupleSliceGraphPtr
,
args_spec_list
).
inferred
->
abstract
()
);
if
(
ret
==
nullptr
)
{
FAIL
()
<<
"Cast ret to abstract tuple failed."
;
}
...
...
@@ -179,7 +179,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) {
auto
slice
=
std
::
make_shared
<
AbstractSlice
>
(
start_index
,
stop_index
,
step
);
AbstractBasePtrList
args_spec_list
=
{
tuple_tensor
,
slice
};
AbstractTuplePtr
ret
=
dyn_cast
<
AbstractTuple
>
(
engine_
->
Run
(
tupleSliceGraphPtr
,
args_spec_list
).
inferred
);
AbstractTuplePtr
ret
=
dyn_cast
<
AbstractTuple
>
(
engine_
->
Run
(
tupleSliceGraphPtr
,
args_spec_list
).
inferred
->
abstract
()
);
if
(
ret
==
nullptr
)
{
FAIL
()
<<
"Cast ret to abstract tuple failed."
;
}
...
...
@@ -205,7 +205,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) {
auto
slice
=
std
::
make_shared
<
AbstractSlice
>
(
start_index
,
stop_index
,
step
);
AbstractBasePtrList
args_spec_list
=
{
tuple_tensor
,
slice
};
AbstractTuplePtr
ret
=
dyn_cast
<
AbstractTuple
>
(
engine_
->
Run
(
tupleSliceGraphPtr
,
args_spec_list
).
inferred
);
AbstractTuplePtr
ret
=
dyn_cast
<
AbstractTuple
>
(
engine_
->
Run
(
tupleSliceGraphPtr
,
args_spec_list
).
inferred
->
abstract
()
);
if
(
ret
==
nullptr
)
{
FAIL
()
<<
"Cast ret to abstract tuple failed."
;
}
...
...
@@ -231,7 +231,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
auto
slice
=
std
::
make_shared
<
AbstractSlice
>
(
start_index
,
stop_index
,
step
);
AbstractBasePtrList
args_spec_list
=
{
tuple_tensor
,
slice
};
AbstractTuplePtr
ret
=
dyn_cast
<
AbstractTuple
>
(
engine_
->
Run
(
tupleSliceGraphPtr
,
args_spec_list
).
inferred
);
AbstractTuplePtr
ret
=
dyn_cast
<
AbstractTuple
>
(
engine_
->
Run
(
tupleSliceGraphPtr
,
args_spec_list
).
inferred
->
abstract
()
);
if
(
ret
==
nullptr
)
{
FAIL
()
<<
"Cast ret to abstract tuple failed."
;
}
...
...
@@ -253,7 +253,7 @@ TEST_F(TestComposite, test_TensorSliceBySlice) {
AbstractSlicePtr
slice
=
std
::
make_shared
<
AbstractSlice
>
(
start_index
,
stop_index
,
step
);
AbstractBasePtrList
args_spec_list
=
{
tensor
,
slice
};
AbstractTensorPtr
ret
=
dyn_cast
<
AbstractTensor
>
(
engine_
->
Run
(
tensorSlicePtrGraphPtr
,
args_spec_list
).
inferred
);
AbstractTensorPtr
ret
=
dyn_cast
<
AbstractTensor
>
(
engine_
->
Run
(
tensorSlicePtrGraphPtr
,
args_spec_list
).
inferred
->
abstract
()
);
if
(
ret
==
nullptr
)
{
FAIL
()
<<
"Cast ret to abstract array failed."
;
}
...
...
@@ -288,7 +288,7 @@ TEST_F(TestComposite, test_TensorSliceBySliceTuple) {
AbstractTuplePtr
slice_tuple
=
std
::
make_shared
<
AbstractTuple
>
(
eles
);
AbstractBasePtrList
args_spec_list
=
{
tensor
,
slice_tuple
};
AbstractTensorPtr
ret
=
dyn_cast
<
AbstractTensor
>
(
engine_
->
Run
(
tensorSliceGraphPtr
,
args_spec_list
).
inferred
);
AbstractTensorPtr
ret
=
dyn_cast
<
AbstractTensor
>
(
engine_
->
Run
(
tensorSliceGraphPtr
,
args_spec_list
).
inferred
->
abstract
()
);
if
(
ret
==
nullptr
)
{
FAIL
()
<<
"Cast ret to abstract array failed."
;
}
...
...
@@ -320,7 +320,7 @@ TEST_F(TestComposite, test_TensorSliceBySliceTupleToReduceDimension) {
AbstractTuplePtr
slice_tuple
=
std
::
make_shared
<
AbstractTuple
>
(
eles
);
AbstractBasePtrList
args_spec_list
=
{
tensor
,
slice_tuple
};
AbstractTensorPtr
ret
=
dyn_cast
<
AbstractTensor
>
(
engine_
->
Run
(
tensorSliceGraphPtr
,
args_spec_list
).
inferred
);
AbstractTensorPtr
ret
=
dyn_cast
<
AbstractTensor
>
(
engine_
->
Run
(
tensorSliceGraphPtr
,
args_spec_list
).
inferred
->
abstract
()
);
if
(
ret
==
nullptr
)
{
FAIL
()
<<
"Cast ret to abstract array failed."
;
}
...
...
@@ -336,7 +336,7 @@ TEST_F(TestComposite, test_TensorSliceByScalar) {
AbstractScalarPtr
start_index
=
std
::
make_shared
<
AbstractScalar
>
(
2
);
AbstractBasePtrList
args_spec_list
=
{
tensor
,
start_index
};
AbstractTensorPtr
ret
=
dyn_cast
<
AbstractTensor
>
(
engine_
->
Run
(
tensorSliceGraphPtr
,
args_spec_list
).
inferred
);
AbstractTensorPtr
ret
=
dyn_cast
<
AbstractTensor
>
(
engine_
->
Run
(
tensorSliceGraphPtr
,
args_spec_list
).
inferred
->
abstract
()
);
if
(
ret
==
nullptr
)
{
FAIL
()
<<
"Cast ret to abstract array failed."
;
}
...
...
@@ -358,7 +358,7 @@ TEST_F(TestComposite, test_TensorSliceByScalarTuple) {
AbstractTuplePtr
slice_tuple
=
std
::
make_shared
<
AbstractTuple
>
(
eles
);
AbstractBasePtrList
args_spec_list
=
{
tensor
,
slice_tuple
};
AbstractTensorPtr
ret
=
dyn_cast
<
AbstractTensor
>
(
engine_
->
Run
(
tensorSliceGraphPtr
,
args_spec_list
).
inferred
);
AbstractTensorPtr
ret
=
dyn_cast
<
AbstractTensor
>
(
engine_
->
Run
(
tensorSliceGraphPtr
,
args_spec_list
).
inferred
->
abstract
()
);
if
(
ret
==
nullptr
)
{
FAIL
()
<<
"Cast ret to abstract array failed."
;
}
...
...
@@ -382,7 +382,7 @@ TEST_F(TestComposite, test_TensorSliceByScalarTupleToScalar) {
AbstractTuplePtr
slice_tuple
=
std
::
make_shared
<
AbstractTuple
>
(
eles
);
AbstractBasePtrList
args_spec_list
=
{
tensor
,
slice_tuple
};
AbstractTensorPtr
ret
=
dyn_cast
<
AbstractTensor
>
(
engine_
->
Run
(
tensorSliceGraphPtr
,
args_spec_list
).
inferred
);
AbstractTensorPtr
ret
=
dyn_cast
<
AbstractTensor
>
(
engine_
->
Run
(
tensorSliceGraphPtr
,
args_spec_list
).
inferred
->
abstract
()
);
if
(
ret
==
nullptr
)
{
FAIL
()
<<
"Cast ret to abstract array failed."
;
}
...
...
@@ -408,7 +408,7 @@ TEST_F(TestComposite, test_UnpackCall_3args) {
abstract
::
AbstractDictionaryPtr
tensor_dict
=
std
::
make_shared
<
abstract
::
AbstractDictionary
>
(
tensor_map
);
AbstractBasePtrList
args_spec_list
=
{
fn_arg
,
tensor_tuple
,
tensor_dict
};
AbstractTuplePtr
ret
=
dyn_cast
<
AbstractTuple
>
(
engine_
->
Run
(
unPackCallGraphPtr
,
args_spec_list
).
inferred
);
AbstractTuplePtr
ret
=
dyn_cast
<
AbstractTuple
>
(
engine_
->
Run
(
unPackCallGraphPtr
,
args_spec_list
).
inferred
->
abstract
()
);
if
(
ret
==
nullptr
)
{
FAIL
()
<<
"Cast ret to abstract tuple failed."
;
}
...
...
@@ -435,7 +435,7 @@ TEST_F(TestComposite, test_UnpackCall_5args) {
abstract
::
AbstractDictionaryPtr
tensor_dict
=
std
::
make_shared
<
abstract
::
AbstractDictionary
>
(
tensor_map
);
AbstractBasePtrList
args_spec_list
=
{
fn_arg
,
tensor_dict
,
tensor_tuple
,
tensor_dict
,
tensor_tuple
};
AbstractTuplePtr
ret
=
dyn_cast
<
AbstractTuple
>
(
engine_
->
Run
(
unPackCallGraphPtr
,
args_spec_list
).
inferred
);
AbstractTuplePtr
ret
=
dyn_cast
<
AbstractTuple
>
(
engine_
->
Run
(
unPackCallGraphPtr
,
args_spec_list
).
inferred
->
abstract
()
);
if
(
ret
==
nullptr
)
{
FAIL
()
<<
"Cast ret to abstract tuple failed."
;
}
...
...
@@ -457,7 +457,7 @@ TEST_F(TestComposite, test_ZipOperation) {
auto
tuple
=
std
::
make_shared
<
AbstractTuple
>
(
eles
);
AbstractBasePtrList
args_spec_list
=
{
tuple
};
AbstractTuplePtr
ret
=
dyn_cast
<
AbstractTuple
>
(
engine_
->
Run
(
zip_op_graph
,
args_spec_list
).
inferred
);
AbstractTuplePtr
ret
=
dyn_cast
<
AbstractTuple
>
(
engine_
->
Run
(
zip_op_graph
,
args_spec_list
).
inferred
->
abstract
()
);
if
(
ret
==
nullptr
)
{
FAIL
()
<<
"Cast ret to abstract tuple failed."
;
}
...
...
tests/ut/cpp/pipeline/static_analysis/evaluator_test.cc
浏览文件 @
5b9c145f
...
...
@@ -41,11 +41,11 @@ TEST_F(TestEvaluatorCacheMap, test_evaluator_cache_map) {
AbstractBasePtr
abstract_v2
=
FromValue
(
2
,
false
);
AbstractBasePtrList
args_spec_list
=
{
abstract_v1
,
abstract_v2
};
AbstractBasePtr
abstract_val
=
FromValue
(
10
,
false
);
cache
[
args_spec_list
]
=
abstract_val
;
cache
[
args_spec_list
]
=
std
::
make_shared
<
EvalResult
>
(
abstract_val
,
std
::
make_shared
<
AttrValueMap
>
())
;
auto
iter
=
cache
.
find
(
args_spec_list
);
ASSERT_TRUE
(
iter
!=
cache
.
end
());
ASSERT_TRUE
(
iter
->
second
==
abstract_val
);
ASSERT_TRUE
(
iter
->
second
->
abstract
()
==
abstract_val
);
AbstractBasePtr
abstract_v1_variant1
=
FromValue
(
1
,
false
);
AbstractBasePtr
abstract_v2_variant1
=
FromValue
(
2
,
false
);
...
...
@@ -53,7 +53,7 @@ TEST_F(TestEvaluatorCacheMap, test_evaluator_cache_map) {
iter
=
cache
.
find
(
args_spec_list_variant1
);
ASSERT_TRUE
(
iter
!=
cache
.
end
());
ASSERT_TRUE
(
iter
->
second
==
abstract_val
);
ASSERT_TRUE
(
iter
->
second
->
abstract
()
==
abstract_val
);
AbstractBasePtr
abstract_v1_variant2
=
FromValue
(
1
,
false
);
AbstractBasePtr
abstract_v2_variant2
=
FromValue
(
3
,
false
);
...
...
@@ -111,7 +111,7 @@ TEST_F(TestStandardEvaluator, test_multiple_conv2d) {
std::vector<int> shape = {2, 2, 6, 6};
expected->set_shape(std::make_shared<Shape>(shape));
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred
->abstract()
;
MS_LOG(INFO) << "result: " << res->ToString();
MS_LOG(INFO) << "expected: " << expected->ToString();
...
...
@@ -144,7 +144,7 @@ TEST_F(TestPartialEvaluator, test_infer_dataclass_resolved) {
AbstractBasePtr abstract_x = FromValue(x, false);
args_spec_list.push_back(abstract_x);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred
->abstract()
;
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32);
}
...
...
@@ -160,7 +160,7 @@ TEST_F(TestPartialEvaluator, test_infer_dataclass_unresolved) {
AbstractBasePtr abstract_x = FromValue(x, false);
args_spec_list.push_back(abstract_x);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred
->abstract()
;
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32);
}
...
...
@@ -179,7 +179,7 @@ TEST_F(TestPartialEvaluator, test_infer_add_resolved) {
args_spec_list.push_back(abstract_x);
args_spec_list.push_back(abstract_y);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred
->abstract()
;
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
}
...
...
@@ -198,7 +198,7 @@ TEST_F(TestPartialEvaluator, test_infer_sub_unresolved) {
args_spec_list.push_back(abstract_x);
args_spec_list.push_back(abstract_y);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred
->abstract()
;
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
}
...
...
@@ -217,7 +217,7 @@ TEST_F(TestPartialEvaluator, test_infer_net_construct_add_resolved) {
args_spec_list.push_back(abstract_x);
args_spec_list.push_back(abstract_y);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred
->abstract()
;
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
}
...
...
@@ -237,7 +237,7 @@ TEST_F(TestPartialEvaluator, test_infer_construct_sub_unresolved) {
args_spec_list.push_back(abstract_x);
args_spec_list.push_back(abstract_y);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred
->abstract()
;
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
}
...
...
tests/ut/cpp/pipeline/static_analysis/prim_test.cc
浏览文件 @
5b9c145f
此差异已折叠。
点击以展开。
tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc
浏览文件 @
5b9c145f
...
...
@@ -163,7 +163,7 @@ TEST_F(TestInfer, test_inferred_scalar_add) {
auto
prim_scalar_add
=
std
::
make_shared
<
Primitive
>
(
"scalar_add"
);
FuncGraphPtr
func_graph
=
MakeFuncGraph
(
prim_scalar_add
);
AbstractBasePtr
abs_base_got
=
engine_
->
Run
(
func_graph
,
args_spec_list
).
inferred
;
AbstractBasePtr
abs_base_got
=
engine_
->
Run
(
func_graph
,
args_spec_list
).
inferred
->
abstract
()
;
ASSERT_TRUE
(
abs_base_got
.
get
()
==
abstract_v1
.
get
());
}
...
...
@@ -261,7 +261,7 @@ TEST_F(TestInferGraph, test_inferred) {
MS_LOG
(
INFO
)
<<
""
<<
graph_f_
->
get_return
()
->
ToString
();
AbstractBasePtr
abstract_v1
=
FromValue
(
1
,
false
);
args_spec_list
.
push_back
(
abstract_v1
);
AbstractBasePtr
abs_base_got
=
engine_
->
Run
(
graph_f_
,
args_spec_list
).
inferred
;
AbstractBasePtr
abs_base_got
=
engine_
->
Run
(
graph_f_
,
args_spec_list
).
inferred
->
abstract
()
;
ASSERT_TRUE
(
abs_base_got
.
get
()
==
abstract_v1
.
get
());
// now this test case failed randomly, have to debug.
...
...
@@ -272,7 +272,7 @@ TEST_F(TestInferGraph, test_inferred) {
args_spec_list
.
clear
();
args_spec_list
.
push_back
(
abstract_v1
);
args_spec_list
.
push_back
(
abstract_v2
);
abs_base_got
=
engine_
->
Run
(
graph_alpha_
,
args_spec_list
).
inferred
;
abs_base_got
=
engine_
->
Run
(
graph_alpha_
,
args_spec_list
).
inferred
->
abstract
()
;
ASSERT_TRUE
(
abs_base_got
.
get
()
==
abstract_v1
.
get
());
}
...
...
@@ -358,7 +358,7 @@ TEST_F(TestInferMetaGraph, test_inferred) {
AbstractBasePtr
abstract_v2
=
FromValue
(
v1
,
false
);
args_spec_list
.
push_back
(
abstract_v1
);
args_spec_list
.
push_back
(
abstract_v2
);
AbstractBasePtr
abs_base_got
=
engine_
->
Run
(
func_graph_
,
args_spec_list
).
inferred
;
AbstractBasePtr
abs_base_got
=
engine_
->
Run
(
func_graph_
,
args_spec_list
).
inferred
->
abstract
()
;
ASSERT_TRUE
(
abs_base_got
.
get
()
==
abstract_v1
.
get
());
}
...
...
@@ -390,7 +390,7 @@ TEST_F(TestInferUniform, test_inferred_scalar_add) {
auto
prim_scalar_add
=
std
::
make_shared
<
Primitive
>
(
"scalar_add"
);
FuncGraphPtr
func_graph
=
MakeFuncGraph
(
prim_scalar_add
);
AbstractBasePtr
abs_base_got
=
engine_
->
Run
(
func_graph
,
args_spec
).
inferred
;
AbstractBasePtr
abs_base_got
=
engine_
->
Run
(
func_graph
,
args_spec
).
inferred
->
abstract
()
;
ASSERT_TRUE
(
*
(
abs_base_got
->
GetTypeTrack
())
==
*
(
abstract_v1
->
GetTypeTrack
()));
ASSERT_TRUE
(
abs_base_got
->
GetTypeTrack
()
->
type_id
()
==
kNumberTypeInt32
);
}
...
...
@@ -418,7 +418,7 @@ TEST_F(TestEvalOnePrim, test_scalar_add) {
AbstractBasePtr
base1
=
FromValue
(
x1
,
false
);
AbstractBasePtr
base2
=
FromValue
(
x2
,
false
);
AbstractBasePtrList
base_list
=
{
base1
,
base2
};
auto
res
=
EvalOnePrim
(
std
::
make_shared
<
Primitive
>
(
"scalar_add"
),
base_list
);
auto
res
=
EvalOnePrim
(
std
::
make_shared
<
Primitive
>
(
"scalar_add"
),
base_list
)
->
abstract
()
;
MS_LOG
(
INFO
)
<<
"result spec: "
<<
res
->
ToString
();
AbstractBasePtr
exp
=
FromValue
(
x3
,
false
);
MS_LOG
(
INFO
)
<<
"result exp: "
<<
exp
->
ToString
();
...
...
@@ -446,7 +446,7 @@ void TestGraphEval::TearDown() {
TEST_F(TestGraphInfer, test_graph_infer_defaults) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_defaults");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred
->abstract()
;
AbstractBasePtr expect = FromValue(MakeValue(50), false);
ASSERT_EQ(*res, *expect);
}
...
...
@@ -454,7 +454,7 @@ TEST_F(TestGraphInfer, test_graph_infer_defaults) {
TEST_F(TestGraphInfer, test_graph_infer_vararg_0) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_0");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred
->abstract()
;
AbstractBasePtr expect = FromValue(MakeValue(1), false);
ASSERT_EQ(*res, *expect);
}
...
...
@@ -462,7 +462,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_0) {
TEST_F(TestGraphInfer, test_graph_infer_vararg) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred
->abstract()
;
AbstractBasePtr expect = FromValue(MakeValue(9), false);
ASSERT_EQ(*res, *expect);
}
...
...
@@ -470,7 +470,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg) {
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred
->abstract()
;
AbstractBasePtr expect = FromValue(MakeValue(48), false);
ASSERT_EQ(*res, *expect);
}
...
...
@@ -478,7 +478,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) {
TEST_F(TestGraphInfer, test_graph_infer_kwarg) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_kwarg");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred
->abstract()
;
AbstractBasePtr expect = FromValue(MakeValue(7), false);
ASSERT_EQ(*res, *expect);
}
...
...
@@ -486,7 +486,7 @@ TEST_F(TestGraphInfer, test_graph_infer_kwarg) {
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred
->abstract()
;
AbstractBasePtr expect = FromValue(MakeValue(46), false);
ASSERT_EQ(*res, *expect);
}
...
...
@@ -494,7 +494,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) {
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg_defaults) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg_defaults");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred
->abstract()
;
AbstractBasePtr expect = FromValue(MakeValue(57), false);
ASSERT_EQ(*res, *expect);
}
...
...
tests/ut/python/ops/test_nn_ops.py
浏览文件 @
5b9c145f
...
...
@@ -31,7 +31,8 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
import
pipeline_for_compile_forward_ge_graph_for_case_by_case_config
from
....mindspore_test_framework.pipeline.forward.verify_exception
\
import
pipeline_for_verify_exception_for_case_by_case_config
from
mindspore
import
context
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
def
conv3x3
(
in_channels
,
out_channels
,
stride
=
1
,
padding
=
1
):
"""3x3 convolution """
...
...
@@ -377,6 +378,21 @@ class StateNet(nn.Cell):
return
x
def
test_conv2d_same_primitive
():
class
Conv2DSameNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Conv2DSameNet
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
16
,
64
,
(
1
,
41
),
(
1
,
4
),
"same"
,
0
,
1
,
has_bias
=
True
)
self
.
conv2
=
nn
.
Conv2d
(
16
,
64
,
(
1
,
41
),
(
1
,
4
),
"same"
,
0
,
1
,
has_bias
=
True
)
def
construct
(
self
,
x
,
y
):
r1
=
self
.
conv1
(
x
)
r2
=
self
.
conv2
(
y
)
return
(
r1
,
r2
)
t1
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
1918
]).
astype
(
np
.
float32
))
t2
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
3840
]).
astype
(
np
.
float32
))
net
=
Conv2DSameNet
()
out
=
net
(
t1
,
t2
)
class
ComparisonNet
(
nn
.
Cell
):
def
__init__
(
self
):
""" ComparisonNet definition """
...
...
tests/ut/python/ops/test_ops_attr_infer.py
0 → 100644
浏览文件 @
5b9c145f
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test nn ops """
import
functools
import
numpy
as
np
import
mindspore
import
mindspore.nn
as
nn
import
mindspore.context
as
context
import
mindspore.common.dtype
as
mstype
from
mindspore
import
Tensor
,
Parameter
from
mindspore.common.initializer
import
initializer
from
mindspore.ops
import
Primitive
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
from
mindspore.ops
import
prim_attr_register
,
PrimitiveWithInfer
from
mindspore.ops.primitive
import
constexpr
from
..ut_filter
import
non_graph_engine
from
....mindspore_test_framework.mindspore_test
import
mindspore_test
from
....mindspore_test_framework.pipeline.forward.compile_forward
\
import
pipeline_for_compile_forward_ge_graph_for_case_by_case_config
from
....mindspore_test_framework.pipeline.forward.verify_exception
\
import
pipeline_for_verify_exception_for_case_by_case_config
from
mindspore
import
context
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
class
FakeOp
(
PrimitiveWithInfer
):
@
prim_attr_register
def
__init__
(
self
):
""""""
def
infer_shape
(
self
,
x
,
y
):
self
.
second_shape
=
y
self
.
add_prim_attr
(
"second_shape"
,
y
)
return
x
def
infer_dtype
(
self
,
x
,
y
):
return
x
# test the normal case that should generate independent primitive because of different
# generated attributes after inference
def
test_conv2d_same_primitive
():
class
Conv2DSameNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Conv2DSameNet
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
16
,
64
,
(
1
,
41
),
(
1
,
4
),
"same"
,
0
,
1
,
has_bias
=
True
)
self
.
conv2
=
nn
.
Conv2d
(
16
,
64
,
(
1
,
41
),
(
1
,
4
),
"same"
,
0
,
1
,
has_bias
=
True
)
def
construct
(
self
,
x
,
y
):
r1
=
self
.
conv1
(
x
)
r2
=
self
.
conv2
(
y
)
return
(
r1
,
r2
)
t1
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
1918
]).
astype
(
np
.
float32
))
t2
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
3840
]).
astype
(
np
.
float32
))
net
=
Conv2DSameNet
()
out
=
net
(
t1
,
t2
)
# test cell as high order argument
# The graph with free variables used as argument is not supported yet
# because of the limit of inference specialize system
def
Xtest_conv2d_op_with_arg
():
class
Conv2dNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Conv2dNet
,
self
).
__init__
()
def
construct
(
self
,
op
,
x
):
return
op
(
x
)
class
OpsNet
(
nn
.
Cell
):
def
__init__
(
self
,
net
):
super
(
OpsNet
,
self
).
__init__
()
self
.
opnet
=
net
self
.
conv2
=
nn
.
Conv2d
(
16
,
64
,
(
1
,
41
),
(
1
,
4
),
"same"
,
0
,
1
,
has_bias
=
True
)
def
construct
(
self
,
x
,
y
):
conv_op
=
self
.
conv2
a
=
self
.
opnet
(
conv_op
,
x
)
b
=
self
.
opnet
(
conv_op
,
y
)
return
(
a
,
b
)
t1
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
1918
]).
astype
(
np
.
float32
))
t2
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
3840
]).
astype
(
np
.
float32
))
net
=
OpsNet
(
Conv2dNet
())
out
=
net
(
t1
,
t2
)
def
test_conv2d_op_with_arg
():
class
FackOpNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
FackOpNet
,
self
).
__init__
()
self
.
op
=
FakeOp
()
def
construct
(
self
,
x
,
y
):
return
self
.
op
(
x
,
y
)
class
OpNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
OpNet
,
self
).
__init__
()
def
construct
(
self
,
op
,
x
,
y
):
return
op
(
x
,
y
)
class
OpsNet
(
nn
.
Cell
):
def
__init__
(
self
,
net
):
super
(
OpsNet
,
self
).
__init__
()
self
.
opnet
=
net
self
.
op
=
FackOpNet
()
def
construct
(
self
,
x
,
y
):
op
=
self
.
op
a
=
self
.
opnet
(
op
,
x
,
y
)
b
=
self
.
opnet
(
op
,
y
,
x
)
return
(
a
,
b
)
t1
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
1918
]).
astype
(
np
.
float32
))
t2
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
3840
]).
astype
(
np
.
float32
))
net
=
OpsNet
(
OpNet
())
out
=
net
(
t1
,
t2
)
def
test_conv2d_op_with_arg_same_input
():
class
FackOpNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
FackOpNet
,
self
).
__init__
()
self
.
op
=
FakeOp
()
def
construct
(
self
,
x
,
y
):
return
self
.
op
(
x
,
y
)
class
OpNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
OpNet
,
self
).
__init__
()
def
construct
(
self
,
op
,
x
,
y
):
return
op
(
x
,
y
)
class
OpsNet
(
nn
.
Cell
):
def
__init__
(
self
,
net
):
super
(
OpsNet
,
self
).
__init__
()
self
.
opnet
=
net
self
.
op
=
FackOpNet
()
def
construct
(
self
,
x
,
y
):
op
=
self
.
op
a
=
self
.
opnet
(
op
,
x
,
x
)
b
=
self
.
opnet
(
op
,
y
,
x
)
return
(
a
,
b
)
t1
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
1918
]).
astype
(
np
.
float32
))
t2
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
3840
]).
astype
(
np
.
float32
))
net
=
OpsNet
(
OpNet
())
out
=
net
(
t1
,
t2
)
# test op with partial
def
test_op_as_partial
():
class
OpAsPartial
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
OpAsPartial
,
self
).
__init__
()
self
.
op
=
FakeOp
()
def
construct
(
self
,
x
,
y
,
z
):
partial_op
=
F
.
partial
(
self
.
op
,
x
)
a
=
partial_op
(
y
)
b
=
partial_op
(
z
)
return
a
,
b
t1
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
1918
]).
astype
(
np
.
float32
))
t2
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
3840
]).
astype
(
np
.
float32
))
t3
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
1234
]).
astype
(
np
.
float32
))
net
=
OpAsPartial
()
out
=
net
(
t1
,
t2
,
t3
)
# test op with partial
def
test_op_as_partial_inside
():
class
OpAsPartial
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
OpAsPartial
,
self
).
__init__
()
self
.
op
=
FakeOp
()
def
construct
(
self
,
x
,
y
,
z
):
partial_op
=
F
.
partial
(
self
.
op
,
x
)
a
=
partial_op
(
y
)
b
=
partial_op
(
z
)
return
a
,
b
class
OuterNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
OuterNet
,
self
).
__init__
()
self
.
net
=
OpAsPartial
()
def
construct
(
self
,
x
,
y
,
z
):
a
,
b
=
self
.
net
(
x
,
y
,
z
)
return
a
,
b
t1
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
1918
]).
astype
(
np
.
float32
))
t2
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
3840
]).
astype
(
np
.
float32
))
t3
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
1234
]).
astype
(
np
.
float32
))
net
=
OuterNet
()
out
=
net
(
t1
,
t2
,
t3
)
# test op with partial case 2
def
test_op_as_partial_independent
():
class
OpAsPartial
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
OpAsPartial
,
self
).
__init__
()
self
.
op
=
FakeOp
()
def
construct
(
self
,
x
,
y
,
z
):
partial_op1
=
F
.
partial
(
self
.
op
,
x
)
a
=
partial_op1
(
y
)
partial_op2
=
F
.
partial
(
self
.
op
,
x
)
b
=
partial_op2
(
z
)
return
a
,
b
t1
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
1918
]).
astype
(
np
.
float32
))
t2
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
3840
]).
astype
(
np
.
float32
))
t3
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
1234
]).
astype
(
np
.
float32
))
net
=
OpAsPartial
()
out
=
net
(
t1
,
t2
,
t3
)
def
test_nest_partial
():
class
NestPartial
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NestPartial
,
self
).
__init__
()
self
.
op
=
FakeOp
()
def
construct
(
self
,
x
,
y
,
z
):
partial_op1
=
F
.
partial
(
self
.
op
)
partial_op2
=
F
.
partial
(
partial_op1
,
x
)
a
=
partial_op2
(
y
)
partial_op3
=
F
.
partial
(
self
.
op
)
partial_op4
=
F
.
partial
(
partial_op3
,
x
)
b
=
partial_op4
(
z
)
return
a
,
b
t1
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
1918
]).
astype
(
np
.
float32
))
t2
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
3840
]).
astype
(
np
.
float32
))
t3
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
1234
]).
astype
(
np
.
float32
))
net
=
NestPartial
()
out
=
net
(
t1
,
t2
,
t3
)
# high order argument
# op and op args as network arguments
def
test_op_with_arg_as_input
():
class
WithOpArgNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
WithOpArgNet
,
self
).
__init__
()
def
construct
(
self
,
op
,
x
,
y
):
return
op
(
x
,
y
)
class
OpsNet
(
nn
.
Cell
):
def
__init__
(
self
,
net
):
super
(
OpsNet
,
self
).
__init__
()
self
.
opnet
=
net
self
.
op
=
FakeOp
()
def
construct
(
self
,
x
,
y
,
z
):
op
=
self
.
op
a
=
self
.
opnet
(
op
,
x
,
z
)
b
=
self
.
opnet
(
op
,
x
,
y
)
return
(
a
,
b
)
t1
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
1918
]).
astype
(
np
.
float32
))
t2
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
3840
]).
astype
(
np
.
float32
))
t3
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
1234
]).
astype
(
np
.
float32
))
net
=
OpsNet
(
WithOpArgNet
())
out
=
net
(
t1
,
t2
,
t3
)
# The partial application used as argument is not supported yet
# because of the limit of inference specialize system
def
Xtest_partial_as_arg
():
class
PartialArgNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
PartialArgNet
,
self
).
__init__
()
def
construct
(
self
,
partial_op
,
y
):
return
partial_op
(
y
)
class
OpsNet
(
nn
.
Cell
):
def
__init__
(
self
,
net
):
super
(
OpsNet
,
self
).
__init__
()
self
.
partial_net
=
net
self
.
op
=
FakeOp
()
def
construct
(
self
,
x
,
y
,
z
):
partial_op
=
F
.
partial
(
self
.
op
,
x
)
a
=
self
.
partial_net
(
partial_op
,
z
)
b
=
self
.
partial_net
(
partial_op
,
y
)
return
(
a
,
b
)
t1
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
1918
]).
astype
(
np
.
float32
))
t2
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
3840
]).
astype
(
np
.
float32
))
t3
=
Tensor
(
np
.
ones
([
1
,
16
,
1
,
1234
]).
astype
(
np
.
float32
))
net
=
OpsNet
(
PartialArgNet
())
out
=
net
(
t1
,
t2
,
t3
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录