Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ea3ddea3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ea3ddea3
编写于
8月 08, 2020
作者:
W
Wei Luning
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove ref origin
上级
e7df5416
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
182 addition
and
98 deletion
+182
-98
mindspore/ccsrc/frontend/operator/composite/composite.cc
mindspore/ccsrc/frontend/operator/composite/composite.cc
+34
-34
mindspore/ccsrc/frontend/operator/composite/do_signature.cc
mindspore/ccsrc/frontend/operator/composite/do_signature.cc
+20
-6
mindspore/ccsrc/frontend/operator/composite/unpack_call.cc
mindspore/ccsrc/frontend/operator/composite/unpack_call.cc
+4
-3
mindspore/ccsrc/frontend/operator/prim_others.cc
mindspore/ccsrc/frontend/operator/prim_others.cc
+9
-17
mindspore/ccsrc/frontend/optimizer/irpass.cc
mindspore/ccsrc/frontend/optimizer/irpass.cc
+3
-3
mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h
mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h
+0
-4
mindspore/ccsrc/pipeline/jit/parse/parse.cc
mindspore/ccsrc/pipeline/jit/parse/parse.cc
+11
-0
mindspore/ccsrc/pipeline/jit/parse/parse.h
mindspore/ccsrc/pipeline/jit/parse/parse.h
+1
-0
mindspore/ccsrc/pipeline/jit/parse/resolve.cc
mindspore/ccsrc/pipeline/jit/parse/resolve.cc
+4
-2
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
+0
-1
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+3
-2
mindspore/core/abstract/abstract_value.cc
mindspore/core/abstract/abstract_value.cc
+59
-9
mindspore/core/abstract/abstract_value.h
mindspore/core/abstract/abstract_value.h
+33
-13
mindspore/core/abstract/analysis_context.cc
mindspore/core/abstract/analysis_context.cc
+1
-3
mindspore/core/base/core_ops.h
mindspore/core/base/core_ops.h
+0
-1
未找到文件。
mindspore/ccsrc/frontend/operator/composite/composite.cc
浏览文件 @
ea3ddea3
...
...
@@ -333,28 +333,28 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL
}
FuncGraphPtr
HyperMap
::
GenerateFromTypes
(
const
TypePtrList
&
args_spec_list
)
{
FuncGraphPtr
ptr
G
raph
=
std
::
make_shared
<
FuncGraph
>
();
ptr
G
raph
->
set_flag
(
FUNC_GRAPH_FLAG_CORE
,
true
);
ptr
G
raph
->
set_flag
(
FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER
,
true
);
ptr
G
raph
->
debug_info
()
->
set_name
(
"hyper_map"
);
FuncGraphPtr
ptr
_g
raph
=
std
::
make_shared
<
FuncGraph
>
();
ptr
_g
raph
->
set_flag
(
FUNC_GRAPH_FLAG_CORE
,
true
);
ptr
_g
raph
->
set_flag
(
FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER
,
true
);
ptr
_g
raph
->
debug_info
()
->
set_name
(
"hyper_map"
);
AnfNodePtr
ptrFnArg
=
nullptr
;
std
::
size_t
i
=
0
;
ArgsPairList
argmap
;
ArgsPairList
argmap2
;
if
(
fn_leaf_
==
nullptr
)
{
ptrFnArg
=
ptr
G
raph
->
add_parameter
();
ptrFnArg
=
ptr
_g
raph
->
add_parameter
();
i
=
1
;
}
std
::
size_t
size
=
args_spec_list
.
size
();
for
(;
i
<
size
;
++
i
)
{
argmap
.
push_back
(
std
::
make_pair
(
ptr
G
raph
->
add_parameter
(),
args_spec_list
[
i
]));
argmap
.
push_back
(
std
::
make_pair
(
ptr
_g
raph
->
add_parameter
(),
args_spec_list
[
i
]));
}
argmap2
=
Harmonize
(
ptr
G
raph
,
argmap
);
ptr
Graph
->
set_output
(
Make
(
ptrG
raph
,
ptrFnArg
,
argmap2
));
return
ptr
G
raph
;
argmap2
=
Harmonize
(
ptr
_g
raph
,
argmap
);
ptr
_graph
->
set_output
(
Make
(
ptr_g
raph
,
ptrFnArg
,
argmap2
));
return
ptr
_g
raph
;
}
abstract
::
AbstractBasePtrList
HyperMap
::
NormalizeArgs
(
const
AbstractBasePtrList
&
args_spec_list
)
const
{
...
...
@@ -582,30 +582,30 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
inputs
.
push_back
(
opsTupleItem
);
inputs
.
push_back
(
cnode
);
inputs
.
push_back
(
NewValueNode
(
1
));
AnfNodePtr
ptr
B
prop
=
ret
->
NewCNode
(
inputs
);
AnfNodePtr
ptr
_b
prop
=
ret
->
NewCNode
(
inputs
);
doGetGrad
(
ret
,
out
,
ptr
B
prop
,
weights_node
,
opsTupleItem
);
doGetGrad
(
ret
,
out
,
ptr
_b
prop
,
weights_node
,
opsTupleItem
);
return
ret
;
}
void
GradOperation
::
doGetGrad
(
const
FuncGraphPtr
&
func_graph
,
AnfNodePtr
out
,
AnfNodePtr
ptr
B
prop
,
AnfNodePtr
weights
,
void
GradOperation
::
doGetGrad
(
const
FuncGraphPtr
&
func_graph
,
AnfNodePtr
out
,
AnfNodePtr
ptr
_b
prop
,
AnfNodePtr
weights
,
ValueNodePtr
opsTupleItem
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
AnfNodePtr
ptr
BPropA
rg
=
nullptr
;
AnfNodePtr
ptr
_bprop_a
rg
=
nullptr
;
if
(
sens_param_
)
{
ptr
BPropA
rg
=
func_graph
->
add_parameter
();
ptr
_bprop_a
rg
=
func_graph
->
add_parameter
();
}
else
{
auto
ones_like
=
prim
::
GetPythonOps
(
"ones_like"
);
ptr
BPropA
rg
=
func_graph
->
NewCNode
({
NewValueNode
(
ones_like
),
out
});
ptr
_bprop_a
rg
=
func_graph
->
NewCNode
({
NewValueNode
(
ones_like
),
out
});
}
AnfNodePtr
ptr
BApp
=
func_graph
->
NewCNode
({
ptrBprop
,
ptrBPropA
rg
});
AnfNodePtr
ptr
_bapp
=
func_graph
->
NewCNode
({
ptr_bprop
,
ptr_bprop_a
rg
});
CNodePtr
fv_bprop
=
nullptr
;
if
(
get_by_list_
)
{
// python code: grads = hyper_map(F.partial(env_get, env), weights)
AnfNodePtr
env
=
func_graph
->
NewCNode
({
NewValueNode
(
prim
::
kPrimTupleGetItem
),
ptr
BA
pp
,
NewValueNode
(
0
)});
AnfNodePtr
env
=
func_graph
->
NewCNode
({
NewValueNode
(
prim
::
kPrimTupleGetItem
),
ptr
_ba
pp
,
NewValueNode
(
0
)});
AnfNodePtr
partial_env_get
=
func_graph
->
NewCNode
({
NewValueNode
(
prim
::
kPrimPartial
),
NewValueNode
(
prim
::
GetPythonOps
(
"env_get"
)),
env
});
MetaFuncGraphPtr
hyper_map
=
std
::
make_shared
<
HyperMap
>
();
...
...
@@ -614,7 +614,7 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An
CNodePtr
inputs_bprop
=
nullptr
;
if
(
get_all_
)
{
inputs_bprop
=
func_graph
->
NewCNode
({
NewValueNode
(
kTail
),
ptr
BA
pp
});
inputs_bprop
=
func_graph
->
NewCNode
({
NewValueNode
(
kTail
),
ptr
_ba
pp
});
}
// Gradients wrt inputs and parameters
...
...
@@ -636,8 +636,8 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An
}
// Gradients wrt first input.
// ptr
BA
pp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
func_graph
->
set_output
(
func_graph
->
NewCNode
({
opsTupleItem
,
ptr
BA
pp
,
NewValueNode
(
1
)}));
// ptr
_ba
pp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
func_graph
->
set_output
(
func_graph
->
NewCNode
({
opsTupleItem
,
ptr
_ba
pp
,
NewValueNode
(
1
)}));
}
// Generate the graph.
...
...
@@ -657,35 +657,35 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
auto
real_fn
=
dyn_cast
<
FuncGraphAbstractClosure
>
(
fn
);
MS_EXCEPTION_IF_NULL
(
real_fn
);
FuncGraphPtr
ptr
G
raph
=
real_fn
->
func_graph
();
MS_EXCEPTION_IF_NULL
(
ptr
G
raph
);
TraceManager
::
DebugTrace
(
std
::
make_shared
<
TraceGradOperation
>
(
ptr
G
raph
->
debug_info
()));
FuncGraphPtr
df
B
uilder
=
std
::
make_shared
<
FuncGraph
>
();
FuncGraphPtr
ptr
_g
raph
=
real_fn
->
func_graph
();
MS_EXCEPTION_IF_NULL
(
ptr
_g
raph
);
TraceManager
::
DebugTrace
(
std
::
make_shared
<
TraceGradOperation
>
(
ptr
_g
raph
->
debug_info
()));
FuncGraphPtr
df
_b
uilder
=
std
::
make_shared
<
FuncGraph
>
();
TraceManager
::
EndTrace
();
auto
nparam
=
ptr
G
raph
->
parameters
().
size
();
auto
nparam
=
ptr
_g
raph
->
parameters
().
size
();
std
::
ostringstream
ss
;
ss
<<
"grad{"
<<
nparam
<<
"}"
;
df
B
uilder
->
set_flag
(
FUNC_GRAPH_FLAG_CORE
,
true
);
df
B
uilder
->
debug_info
()
->
set_name
(
ss
.
str
());
ParameterPtr
param_graph
=
df
B
uilder
->
add_parameter
();
df
_b
uilder
->
set_flag
(
FUNC_GRAPH_FLAG_CORE
,
true
);
df
_b
uilder
->
debug_info
()
->
set_name
(
ss
.
str
());
ParameterPtr
param_graph
=
df
_b
uilder
->
add_parameter
();
AnfNodePtr
weights
=
nullptr
;
if
(
get_by_list_
)
{
weights
=
df
B
uilder
->
add_parameter
();
weights
=
df
_b
uilder
->
add_parameter
();
}
std
::
vector
<
AnfNodePtr
>
inputs
;
inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimJ
));
inputs
.
push_back
(
param_graph
);
auto
jf
=
df
B
uilder
->
NewCNode
(
inputs
);
auto
jf
=
df
_b
uilder
->
NewCNode
(
inputs
);
// df is checked in GetGrad
TraceManager
::
DebugTrace
(
std
::
make_shared
<
TraceGradOperation
>
(
ptr
G
raph
->
debug_info
()));
auto
df
=
GetGrad
(
jf
,
weights
,
ptr
G
raph
->
parameters
());
TraceManager
::
DebugTrace
(
std
::
make_shared
<
TraceGradOperation
>
(
ptr
_g
raph
->
debug_info
()));
auto
df
=
GetGrad
(
jf
,
weights
,
ptr
_g
raph
->
parameters
());
TraceManager
::
EndTrace
();
df
B
uilder
->
set_output
(
NewValueNode
(
df
));
df
_b
uilder
->
set_output
(
NewValueNode
(
df
));
return
df
B
uilder
;
return
df
_b
uilder
;
}
REGISTER_PYBIND_DEFINE
(
GradOperation_
,
([](
const
py
::
module
*
m
)
{
...
...
mindspore/ccsrc/frontend/operator/composite/do_signature.cc
浏览文件 @
ea3ddea3
...
...
@@ -72,10 +72,15 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_
bool
GetTensorOrScalarTypeInfo
(
AbstractBasePtr
arg_value
,
bool
is_write
,
TypeId
*
arg_type_id
,
TypeId
*
arg_type
=
nullptr
)
{
if
(
arg_value
->
isa
<
abstract
::
AbstractRef
>
())
{
if
(
is_write
)
{
arg_value
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
()
->
ref_origin
();
}
else
{
arg_value
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
()
->
ref
();
auto
ref
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
();
arg_value
=
ref
->
ref
();
if
(
!
is_write
&&
ref
->
need_cast
())
{
auto
tensor_type
=
ref
->
target_type
();
*
arg_type_id
=
tensor_type
->
type_id
();
if
(
arg_type
!=
nullptr
)
{
*
arg_type
=
kObjectTypeTensorType
;
}
return
true
;
}
}
if
(
arg_value
->
isa
<
abstract
::
AbstractTensor
>
())
{
...
...
@@ -248,6 +253,8 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
if
(
arg_value
->
isa
<
abstract
::
AbstractTensor
>
()
&&
arg_type_id
==
it
->
second
)
{
continue
;
}
MS_LOG
(
DEBUG
)
<<
"do cast for inputs "
<<
i
<<
" "
<<
(
*
op_inputs
)[
i
+
1
]
->
ToString
()
<<
" "
<<
arg_type_id
<<
" to "
<<
it
->
second
;
(
*
op_inputs
)[
i
+
1
]
=
DoCast
((
*
op_inputs
)[
i
+
1
],
it
->
second
,
graph
);
}
}
...
...
@@ -289,16 +296,23 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
TypePtr
type
=
args_spec_list
[
i
]
->
GetTypeTrack
();
if
(
type
&&
type
->
type_id
()
==
kObjectTypeRef
)
{
auto
ref_abs
=
args_spec_list
[
i
]
->
cast
<
abstract
::
AbstractRefPtr
>
();
if
(
sig
==
SignatureEnumRW
::
kRWRead
)
{
param
=
func_graph
->
NewCNode
({
NewValueNode
(
prim
::
kPrimGetRefValue
),
param
});
param
=
NewCNode
({
NewValueNode
(
prim
::
kPrimGetRefValue
),
param
},
func_graph
);
if
(
ref_abs
&&
ref_abs
->
need_cast
())
{
auto
cast
=
prim
::
GetPythonOps
(
"cast"
,
"mindspore.ops.functional"
);
param
=
NewCNode
({
NewValueNode
(
cast
),
param
,
NewValueNode
(
ref_abs
->
target_type
())},
func_graph
);
}
}
else
if
(
sig
==
SignatureEnumRW
::
kRWWrite
)
{
param
=
func_graph
->
NewCNode
({
NewValueNode
(
prim
::
kPrimGetRefOrigin
),
param
}
);
param
=
NewCNode
({
NewValueNode
(
prim
::
kPrimGetRefValue
),
param
},
func_graph
);
write_indices
.
insert
(
i
);
}
// If sig is SignatureEnumRW::kRWRef, not do anything.
}
else
if
(
sig
==
SignatureEnumRW
::
kRWWrite
&&
type
->
type_id
()
!=
kObjectTypeRefKey
)
{
MS_EXCEPTION
(
TypeError
)
<<
"Function "
<<
func_name
<<
"'s input "
<<
i
<<
" should be a Parameter."
;
}
MS_LOG
(
DEBUG
)
<<
"Function "
<<
func_name
<<
"'s input "
<<
i
<<
" "
<<
param
->
DebugString
(
2
)
<<
" type "
<<
args_spec_list
[
i
]
->
ToString
();
op_inputs
.
push_back
(
param
);
}
// process default
...
...
mindspore/ccsrc/frontend/operator/composite/unpack_call.cc
浏览文件 @
ea3ddea3
...
...
@@ -49,13 +49,14 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" requires at least two args, but got "
<<
arg_length
<<
"."
;
}
(
void
)
abstract
::
CheckArg
<
AbstractFunction
>
(
op_name
,
args_spec_list
,
0
);
// No need to check, check will be done in infer.
auto
ret_graph
=
std
::
make_shared
<
FuncGraph
>
();
ret_graph
->
set_flag
(
FUNC_GRAPH_FLAG_CORE
,
true
);
ret_graph
->
debug_info
()
->
set_name
(
"UnpackCall"
);
AnfNodePtr
fn
N
ode
=
ret_graph
->
add_parameter
();
AnfNodePtr
fn
_n
ode
=
ret_graph
->
add_parameter
();
std
::
vector
<
AnfNodePtr
>
elems
;
elems
.
push_back
(
fn
N
ode
);
elems
.
push_back
(
fn
_n
ode
);
for
(
size_t
index
=
1
;
index
<
arg_length
;
index
++
)
{
MS_EXCEPTION_IF_NULL
(
args_spec_list
[
index
]);
if
(
args_spec_list
[
index
]
->
isa
<
AbstractTuple
>
())
{
...
...
mindspore/ccsrc/frontend/operator/prim_others.cc
浏览文件 @
ea3ddea3
...
...
@@ -129,16 +129,22 @@ AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePt
AbstractBasePtr
InferImplMakeRef
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// arguments: key, value,
original value
// arguments: key, value,
target type(None if no target type)
if
(
args_spec_list
.
size
()
!=
3
)
{
MS_LOG
(
EXCEPTION
)
<<
"make_ref evaluator requires 3 parameters, while the input size is "
<<
args_spec_list
.
size
()
<<
"."
;
}
TypePtr
type
=
args_spec_list
[
0
]
->
GetTypeTrack
();
ValuePtr
tensor_target_v
=
args_spec_list
[
2
]
->
BuildValue
();
if
(
type
->
type_id
()
!=
kObjectTypeRefKey
)
{
MS_LOG
(
EXCEPTION
)
<<
"First input of make_ref should be a RefKey but a "
<<
type
->
ToString
();
}
return
std
::
make_shared
<
AbstractRef
>
(
args_spec_list
[
0
],
args_spec_list
[
1
],
args_spec_list
[
2
]);
auto
need_cast
=
!
tensor_target_v
->
isa
<
None
>
();
if
(
need_cast
&&
!
tensor_target_v
->
isa
<
Type
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Third input of make_ref should be a Type but a "
<<
tensor_target_v
->
ToString
();
}
TypePtr
cast_target
=
tensor_target_v
->
cast
<
TypePtr
>
();
return
std
::
make_shared
<
AbstractRef
>
(
args_spec_list
[
0
],
args_spec_list
[
1
],
need_cast
,
cast_target
);
}
AbstractBasePtr
InferImplGetRefKey
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
...
...
@@ -163,25 +169,11 @@ AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitiveP
}
TypePtr
type
=
args_spec_list
[
0
]
->
GetTypeTrack
();
if
(
type
->
type_id
()
!=
kObjectTypeRef
)
{
MS_LOG
(
EXCEPTION
)
<<
"First input of get_ref_value should be a Ref but a "
<<
type
->
ToString
()
;
return
args_spec_list
[
0
]
;
}
return
args_spec_list
[
0
]
->
cast
<
AbstractRefPtr
>
()
->
ref
();
}
AbstractBasePtr
InferImplGetRefOrigin
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// arguments: value
if
(
args_spec_list
.
size
()
!=
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"get_ref_origin requires 1 parameters, while the input size is "
<<
args_spec_list
.
size
()
<<
"."
;
}
TypePtr
type
=
args_spec_list
[
0
]
->
GetTypeTrack
();
if
(
type
->
type_id
()
!=
kObjectTypeRef
)
{
MS_LOG
(
EXCEPTION
)
<<
"First input of get_ref_value should be a Ref but a "
<<
type
->
ToString
();
}
return
args_spec_list
[
0
]
->
cast
<
AbstractRefPtr
>
()
->
ref_origin
();
}
AbstractBasePtr
InferImplStateSetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// args: Two objects of a subclass of AbstractBase, key and value.
...
...
mindspore/ccsrc/frontend/optimizer/irpass.cc
浏览文件 @
ea3ddea3
...
...
@@ -95,10 +95,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Ref eliminate
make_ref_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
MakeRefEliminater
>
(),
"make_ref_eliminate"
,
prim
::
kPrimMakeRef
);
get_ref_param_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
GetRefParamEliminater
>
(),
"get_ref_param_eliminate"
,
{
prim
::
kPrimGetRefValue
,
prim
::
kPrimGetRefOrigin
});
get_ref_param_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
GetRefParamEliminater
>
(),
"get_ref_param_eliminate"
,
{
prim
::
kPrimGetRefValue
});
get_make_ref_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
GetMakeRefEliminater
>
(),
"get_make_ref_eliminate"
,
{
prim
::
kPrimGetRefKey
,
prim
::
kPrimGetRefValue
,
prim
::
kPrimGetRefOrigin
});
{
prim
::
kPrimGetRefKey
,
prim
::
kPrimGetRefValue
});
replace_refkey_by_param_
=
MakeSubstitution
(
std
::
make_shared
<
ReplaceRefkeyByParam
>
(),
"replace_refkey_by_param"
,
IsValueNode
<
RefKey
>
,
opt
::
FORCE_RENORM
);
...
...
mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h
浏览文件 @
ea3ddea3
...
...
@@ -37,27 +37,23 @@ class MakeRefEliminater : public OptimizerCaller {
};
// {prim::kPrimGetRefValue, Parameter} -> Parameter
// {prim::kPrimGetRefOrigin, Parameter} -> Parameter
class
GetRefParamEliminater
:
public
OptimizerCaller
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
PatternNode
<
AnfNodePtr
>
x
;
MATCH_REPLACE
(
node
,
PPrimitive
(
prim
::
kPrimGetRefValue
,
x
),
x
);
MATCH_REPLACE
(
node
,
PPrimitive
(
prim
::
kPrimGetRefOrigin
,
x
),
x
);
return
nullptr
;
}
};
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z
class
GetMakeRefEliminater
:
public
OptimizerCaller
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
PatternNode
<
AnfNodePtr
>
x
,
y
,
z
;
MATCH_REPLACE
(
node
,
PPrimitive
(
prim
::
kPrimGetRefKey
,
PPrimitive
(
prim
::
kPrimMakeRef
,
x
,
y
,
z
)),
x
);
MATCH_REPLACE
(
node
,
PPrimitive
(
prim
::
kPrimGetRefValue
,
PPrimitive
(
prim
::
kPrimMakeRef
,
x
,
y
,
z
)),
y
);
MATCH_REPLACE
(
node
,
PPrimitive
(
prim
::
kPrimGetRefOrigin
,
PPrimitive
(
prim
::
kPrimMakeRef
,
x
,
y
,
z
)),
z
);
return
nullptr
;
}
...
...
mindspore/ccsrc/pipeline/jit/parse/parse.cc
浏览文件 @
ea3ddea3
...
...
@@ -60,6 +60,17 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo
return
func_graph
;
}
ValuePtr
GetMixedPrecisionTargetType
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
param
)
{
TypePtr
dst_type
;
if
(
func_graph
->
has_flag
(
GRAPH_FLAG_MIX_PRECISION_FP32
))
{
return
kFloat32
;
}
else
if
(
func_graph
->
has_flag
(
GRAPH_FLAG_MIX_PRECISION_FP16
))
{
return
kFloat16
;
}
else
{
return
kNone
;
}
}
// if any mixed precision flag add a cast node after the parameter node.
AnfNodePtr
GetMixedPrecisionCastHelp
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
param
)
{
TypePtr
dst_type
;
...
...
mindspore/ccsrc/pipeline/jit/parse/parse.h
浏览文件 @
ea3ddea3
...
...
@@ -359,6 +359,7 @@ class ParseAst {
bool
UpdateFuncGraphFlags
(
py
::
object
obj
,
const
FuncGraphPtr
&
func_graph
);
AnfNodePtr
GetMixedPrecisionCastHelp
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
param
);
ValuePtr
GetMixedPrecisionTargetType
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
param
);
}
// namespace parse
}
// namespace mindspore
...
...
mindspore/ccsrc/pipeline/jit/parse/resolve.cc
浏览文件 @
ea3ddea3
...
...
@@ -70,6 +70,7 @@ bool SymbolResolver::Resolve() {
}
namespace
{
// if any mixed precision flag add a cast node after the parameter node.
// argument obj should be python Parameter object
// it will be converted to Parameter node here
AnfNodePtr
ResolveParameterObj
(
const
FuncGraphPtr
&
func_graph
,
const
py
::
object
&
obj
)
{
...
...
@@ -112,11 +113,12 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
}
auto
iter
=
func_graph
->
make_ref_params
().
find
(
para_node
);
if
(
iter
==
func_graph
->
make_ref_params
().
end
())
{
AnfNodePtr
value
=
GetMixedPrecisionCastHelp
(
func_graph
,
para_node
);
ValuePtr
target_type
=
GetMixedPrecisionTargetType
(
func_graph
,
para_node
);
AnfNodePtr
make_ref
=
NewValueNode
(
prim
::
kPrimMakeRef
);
AnfNodePtr
ref_key
=
NewValueNode
(
std
::
make_shared
<
RefKey
>
(
param_name
));
AnfNodePtr
ref_node
=
func_graph
->
NewCNode
({
make_ref
,
ref_key
,
value
,
para_node
});
AnfNodePtr
target_type_node
=
NewValueNode
(
target_type
);
AnfNodePtr
ref_node
=
func_graph
->
NewCNode
({
make_ref
,
ref_key
,
para_node
,
target_type_node
});
func_graph
->
make_ref_params
()[
para_node
]
=
ref_node
;
func_graph
->
add_parameter_obj_node
(
ref_node
);
return
ref_node
;
...
...
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
浏览文件 @
ea3ddea3
...
...
@@ -125,7 +125,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{
prim
::
kPrimMakeRef
,
{
InferImplMakeRef
,
true
}},
{
prim
::
kPrimGetRefKey
,
{
InferImplGetRefKey
,
true
}},
{
prim
::
kPrimGetRefValue
,
{
InferImplGetRefValue
,
true
}},
{
prim
::
kPrimGetRefOrigin
,
{
InferImplGetRefOrigin
,
true
}},
{
prim
::
kPrimStateSetItem
,
{
InferImplStateSetItem
,
true
}},
{
prim
::
kPrimDepend
,
{
InferImplDepend
,
true
}},
{
prim
::
kPrimBroadcastGradientArgs
,
{
InferImplBroadcastGradientArgs
,
false
}},
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
ea3ddea3
...
...
@@ -1117,11 +1117,12 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
free_param
->
debug_info
()
->
set_name
(
param_name
);
para_node
=
free_param
;
}
AnfNodePtr
value
=
parse
::
GetMixedPrecisionCastHelp
(
df_builder_
,
para_node
);
ValuePtr
target_type
=
parse
::
GetMixedPrecisionTargetType
(
df_builder_
,
para_node
);
AnfNodePtr
make_ref
=
NewValueNode
(
prim
::
kPrimMakeRef
);
auto
refkey
=
std
::
make_shared
<
RefKey
>
(
para_node
->
cast
<
ParameterPtr
>
()
->
name
());
AnfNodePtr
ref_key_node
=
NewValueNode
(
refkey
);
AnfNodePtr
ref_node
=
df_builder_
->
NewCNode
({
make_ref
,
ref_key_node
,
value
,
para_node
});
AnfNodePtr
target_type_node
=
NewValueNode
(
target_type
);
AnfNodePtr
ref_node
=
df_builder_
->
NewCNode
({
make_ref
,
ref_key_node
,
para_node
,
target_type_node
});
w_args
.
push_back
(
ref_node
);
}
}
else
{
...
...
mindspore/core/abstract/abstract_value.cc
浏览文件 @
ea3ddea3
...
...
@@ -808,14 +808,40 @@ std::string AbstractJTagged::ToString() const {
return
buffer
.
str
();
}
AbstractRef
::
AbstractRef
(
const
AbstractBasePtr
&
ref_key
,
const
AbstractBasePtr
&
ref_value
,
bool
need_cast
,
TypePtr
cast_target
)
:
ref_key_
(
ref_key
),
ref_
(
ref_value
),
need_cast_
(
false
),
target_type_
(
nullptr
),
ref_key_value_
(
nullptr
)
{
set_type
(
std
::
make_shared
<
RefType
>
());
auto
origin_type
=
ref_value
->
BuildType
();
if
(
need_cast
&&
cast_target
&&
origin_type
&&
origin_type
->
isa
<
TensorType
>
())
{
auto
tensor_dtype
=
origin_type
->
cast
<
TensorTypePtr
>
()
->
element
();
if
(
tensor_dtype
&&
IsSubType
(
tensor_dtype
,
kFloat
))
{
if
(
cast_target
!=
tensor_dtype
)
{
need_cast_
=
true
;
target_type_
=
cast_target
;
}
}
}
if
(
ref_key
&&
ref_key
->
isa
<
AbstractRefKey
>
())
{
ref_key_value_
=
ref_key
->
cast
<
AbstractRefKeyPtr
>
()
->
ref_key_value
();
}
}
BaseShapePtr
AbstractRef
::
BuildShape
()
const
{
return
ref_
->
BuildShape
();
}
TypePtr
AbstractRef
::
BuildType
()
const
{
TypePtr
subtype
=
ref_
->
BuildType
();
TypePtr
subtype_origin
=
ref_origin_
->
BuildType
();
TypePtr
subtype_origin
=
subtype
;
if
(
need_cast_
)
{
subtype_origin
=
std
::
make_shared
<
TensorType
>
(
target_type_
);
}
return
std
::
make_shared
<
RefType
>
(
subtype
,
subtype_origin
);
}
bool
AbstractRef
::
operator
==
(
const
AbstractRef
&
other
)
const
{
return
(
*
ref_
==
*
other
.
ref_
)
&&
(
*
ref_key_
==
*
other
.
ref_key_
)
&&
(
*
ref_origin_
==
*
other
.
ref_origin_
);
return
(
*
ref_
==
*
other
.
ref_
)
&&
(
need_cast_
==
other
.
need_cast_
)
&&
(
!
need_cast_
||
(
*
target_type_
==
*
other
.
target_type_
));
// not compare the key for reuse the graph (*ref_key_ == *other.ref_key_);
}
bool
AbstractRef
::
operator
==
(
const
AbstractBase
&
other
)
const
{
...
...
@@ -826,27 +852,45 @@ bool AbstractRef::operator==(const AbstractBase &other) const {
return
false
;
}
AbstractBasePtr
AbstractRefKey
::
Join
(
const
AbstractBasePtr
&
other
)
{
MS_EXCEPTION_IF_NULL
(
other
);
if
(
*
this
==
*
other
)
{
auto
ret
=
shared_from_base
<
AbstractBase
>
();
return
ret
;
}
auto
value_self
=
GetValueTrack
();
MS_EXCEPTION_IF_NULL
(
value_self
);
ValuePtr
res_value
=
ValueJoin
(
value_self
,
other
->
GetValueTrack
());
if
(
res_value
==
value_self
)
{
auto
ret
=
shared_from_base
<
AbstractBase
>
();
return
ret
;
}
auto
ret
=
std
::
make_shared
<
AbstractRefKey
>
();
ret
->
set_value
(
res_value
);
return
ret
;
}
AbstractBasePtr
AbstractRef
::
Join
(
const
AbstractBasePtr
&
other
)
{
auto
other_ref
=
other
->
cast
<
AbstractRefPtr
>
();
if
(
other_ref
==
nullptr
)
{
auto
new_ref
=
ref_
->
Join
(
other
);
return
std
::
make_shared
<
AbstractRef
>
(
ref_key_
,
new_ref
,
ref_origin_
);
return
std
::
make_shared
<
AbstractRef
>
(
ref_key_
,
new_ref
);
}
if
(
*
this
==
*
other
)
{
if
(
(
*
this
==
*
other
)
&&
(
*
ref_key_
==
*
other_ref
->
ref_key_
)
)
{
return
shared_from_base
<
AbstractBase
>
();
}
auto
ref_key
=
ref_key_
->
Join
(
other_ref
->
ref_key_
);
auto
ref
=
ref_
->
Join
(
other_ref
->
ref
());
auto
ref_origin
=
ref_origin_
->
Join
(
other_ref
->
ref_origin_
);
return
std
::
make_shared
<
AbstractRef
>
(
ref_key
,
ref
,
ref_origin
);
return
std
::
make_shared
<
AbstractRef
>
(
ref_key
,
ref
);
}
std
::
string
AbstractRef
::
ToString
()
const
{
std
::
ostringstream
buffer
;
buffer
<<
type_name
()
<<
"("
<<
"key: "
<<
ref_key_
->
ToString
()
<<
" ref_value: "
<<
ref_
->
ToString
()
<<
" origin_value: "
<<
ref_origin_
->
ToString
();
<<
"key: "
<<
ref_key_
->
ToString
()
<<
" ref_value: "
<<
ref_
->
ToString
();
if
(
need_cast_
)
{
buffer
<<
" cast to: "
<<
target_type_
->
ToString
();
}
auto
value
=
GetValueTrack
();
if
(
value
)
{
buffer
<<
", value: "
<<
value
->
ToString
();
...
...
@@ -873,6 +917,12 @@ std::string AbstractNone::ToString() const {
ValuePtr
AbstractNone
::
RealBuildValue
()
const
{
return
kNone
;
}
AbstractBasePtr
AbstractRefKey
::
Broaden
()
const
{
auto
refkey
=
std
::
make_shared
<
AbstractRefKey
>
();
refkey
->
set_value
(
kAnyValue
);
return
refkey
;
}
bool
AbstractRefKey
::
operator
==
(
const
AbstractRefKey
&
other
)
const
{
ValuePtr
value_self
=
GetValueTrack
();
ValuePtr
value_other
=
other
.
GetValueTrack
();
...
...
mindspore/core/abstract/abstract_value.h
浏览文件 @
ea3ddea3
...
...
@@ -535,50 +535,70 @@ using AbstractEllipsisPtr = std::shared_ptr<AbstractEllipsis>;
class
AbstractRefKey
:
public
AbstractBase
{
public:
AbstractRefKey
()
:
AbstractBase
()
{
set_type
(
std
::
make_shared
<
RefKeyType
>
());
}
AbstractRefKey
()
:
AbstractBase
()
,
ref_key_value_
(
nullptr
)
{
set_type
(
std
::
make_shared
<
RefKeyType
>
());
}
~
AbstractRefKey
()
override
=
default
;
MS_DECLARE_PARENT
(
AbstractRefKey
,
AbstractBase
)
TypePtr
BuildType
()
const
override
{
return
std
::
make_shared
<
RefKeyType
>
();
}
bool
operator
==
(
const
AbstractRefKey
&
other
)
const
;
bool
operator
==
(
const
AbstractBase
&
other
)
const
override
;
AbstractBasePtr
Clone
()
const
override
{
return
std
::
make_shared
<
AbstractRefKey
>
();
}
AbstractBasePtr
Clone
()
const
override
{
auto
cloned
=
std
::
make_shared
<
AbstractRefKey
>
();
cloned
->
set_value
(
GetValueTrack
());
return
cloned
;
}
inline
void
set_value
(
const
ValuePtr
&
value
)
{
AbstractBase
::
set_value
(
value
);
ref_key_value_
=
value
->
cast
<
RefKeyPtr
>
();
}
RefKeyPtr
ref_key_value
()
const
{
return
ref_key_value_
;
}
AbstractBasePtr
Join
(
const
AbstractBasePtr
&
other
)
override
;
AbstractBasePtr
Broaden
()
const
override
;
std
::
string
ToString
()
const
override
;
private:
// cache for ref_key after build value, when value is null, return nullptr.
RefKeyPtr
ref_key_value_
{
nullptr
};
};
using
AbstractRefKeyPtr
=
std
::
shared_ptr
<
AbstractRefKey
>
;
class
AbstractRef
:
public
AbstractBase
{
public:
AbstractRef
(
const
AbstractBasePtr
&
ref_key
,
const
AbstractBasePtr
&
ref_value
,
const
AbstractBasePtr
&
ref_origin
)
:
ref_key_
(
ref_key
),
ref_
(
ref_value
),
ref_origin_
(
ref_origin
)
{
set_type
(
std
::
make_shared
<
RefType
>
());
}
AbstractRef
(
const
AbstractBasePtr
&
ref_key
,
const
AbstractBasePtr
&
ref_value
,
bool
need_cast
=
false
,
TypePtr
cast_target
=
nullptr
);
~
AbstractRef
()
override
=
default
;
MS_DECLARE_PARENT
(
AbstractRef
,
AbstractBase
)
TypePtr
BuildType
()
const
override
;
BaseShapePtr
BuildShape
()
const
override
;
bool
operator
==
(
const
AbstractRef
&
other
)
const
;
bool
operator
==
(
const
AbstractBase
&
other
)
const
override
;
AbstractBasePtr
Clone
()
const
override
{
return
std
::
make_shared
<
AbstractRef
>
(
ref_key_
->
Clone
(),
ref_
->
Clone
(),
ref_origin_
->
Clone
()
);
return
std
::
make_shared
<
AbstractRef
>
(
ref_key_
->
Clone
(),
ref_
->
Clone
(),
need_cast_
,
target_type_
);
}
std
::
string
ToString
()
const
override
;
AbstractBasePtr
ref
()
{
return
ref_
;
}
AbstractBasePtr
ref_origin
()
{
return
ref_origin_
;
}
AbstractBasePtr
ref_key
()
{
return
ref_key_
;
}
inline
AbstractBasePtr
ref
()
const
{
return
ref_
;
}
inline
AbstractBasePtr
ref_key
()
const
{
return
ref_key_
;
}
inline
RefKeyPtr
ref_key_value
()
const
{
return
ref_key_value_
;
}
inline
TypePtr
target_type
()
const
{
return
target_type_
;
}
inline
bool
need_cast
()
const
{
return
need_cast_
;
}
AbstractBasePtr
Broaden
()
const
override
{
return
std
::
make_shared
<
AbstractRef
>
(
ref_key_
->
Broaden
(),
ref_
->
Broaden
(),
ref_origin_
->
Broaden
()
);
return
std
::
make_shared
<
AbstractRef
>
(
ref_key_
->
Broaden
(),
ref_
->
Broaden
(),
need_cast_
,
target_type_
);
}
AbstractBasePtr
Join
(
const
AbstractBasePtr
&
other
)
override
;
std
::
size_t
hash
()
const
override
{
return
ref_
key_
->
hash
()
^
ref_
->
hash
()
^
ref_origin_
->
hash
()
^
(
std
::
hash
<
uint32_t
>
{}(
this
->
tid
())
<<
1
);
return
ref_
->
hash
()
^
(
std
::
hash
<
uint32_t
>
{}(
this
->
tid
())
<<
1
);
// ref_key_->hash() ^
}
private:
AbstractBasePtr
ref_key_
;
AbstractBasePtr
ref_
;
AbstractBasePtr
ref_origin_
;
// For mix presicion, only float type need to cast to float16 of float32
bool
need_cast_
;
TypePtr
target_type_
;
// cache for ref_key after build value, when value is null, return nullptr.
RefKeyPtr
ref_key_value_
;
};
using
AbstractRefPtr
=
std
::
shared_ptr
<
AbstractRef
>
;
...
...
mindspore/core/abstract/analysis_context.cc
浏览文件 @
ea3ddea3
...
...
@@ -171,9 +171,7 @@ AnalysisContextPtr AnalysisContext::SpecializeKey() const {
}
if
(
arg
->
isa
<
AbstractRef
>
())
{
MS_LOG
(
DEBUG
)
<<
"refkey broaden"
;
auto
arg_spec
=
dyn_cast
<
AbstractRef
>
(
arg
);
auto
ret_spec
=
arg_spec
->
Broaden
();
return
ret_spec
;
return
arg
->
Broaden
();
}
return
arg
;
});
...
...
mindspore/core/base/core_ops.h
浏览文件 @
ea3ddea3
...
...
@@ -121,7 +121,6 @@ inline const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add");
inline
const
PrimitivePtr
kPrimMakeRefKey
=
std
::
make_shared
<
Primitive
>
(
"MakeRefKey"
);
inline
const
PrimitivePtr
kPrimGetRefKey
=
std
::
make_shared
<
Primitive
>
(
"get_ref_key"
);
inline
const
PrimitivePtr
kPrimGetRefValue
=
std
::
make_shared
<
Primitive
>
(
"get_ref_value"
);
inline
const
PrimitivePtr
kPrimGetRefOrigin
=
std
::
make_shared
<
Primitive
>
(
"get_ref_origin"
);
inline
const
PrimitivePtr
kPrimInsertGradientOf
=
std
::
make_shared
<
Primitive
>
(
"InsertGradientOf"
);
inline
const
PrimitivePtr
kPrimHookBackward
=
std
::
make_shared
<
Primitive
>
(
"HookBackward"
);
inline
const
PrimitivePtr
kPrimPrintShapeType
=
std
::
make_shared
<
Primitive
>
(
"PrintShapeType"
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录