Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
840922e5
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
840922e5
编写于
5月 14, 2020
作者:
K
kingfo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add backward hook function in pynative mode
上级
d402b944
变更
23
隐藏空白更改
内联
并排
Showing
23 changed file
with
422 addition
and
21 deletion
+422
-21
mindspore/_extends/parse/parser.py
mindspore/_extends/parse/parser.py
+4
-1
mindspore/ccsrc/ir/primitive.cc
mindspore/ccsrc/ir/primitive.cc
+1
-0
mindspore/ccsrc/ir/primitive.h
mindspore/ccsrc/ir/primitive.h
+0
-3
mindspore/ccsrc/ir/primitive_base.h
mindspore/ccsrc/ir/primitive_base.h
+7
-0
mindspore/ccsrc/operator/ops.cc
mindspore/ccsrc/operator/ops.cc
+2
-0
mindspore/ccsrc/operator/ops.h
mindspore/ccsrc/operator/ops.h
+2
-0
mindspore/ccsrc/operator/prim_nn.cc
mindspore/ccsrc/operator/prim_nn.cc
+10
-0
mindspore/ccsrc/optimizer/ad/dfunctor.cc
mindspore/ccsrc/optimizer/ad/dfunctor.cc
+1
-0
mindspore/ccsrc/optimizer/ad/dfunctor.h
mindspore/ccsrc/optimizer/ad/dfunctor.h
+1
-0
mindspore/ccsrc/optimizer/ad/kprim.cc
mindspore/ccsrc/optimizer/ad/kprim.cc
+48
-4
mindspore/ccsrc/optimizer/irpass.cc
mindspore/ccsrc/optimizer/irpass.cc
+4
-3
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h
+3
-1
mindspore/ccsrc/pipeline/parse/data_converter.cc
mindspore/ccsrc/pipeline/parse/data_converter.cc
+37
-1
mindspore/ccsrc/pipeline/static_analysis/prim.cc
mindspore/ccsrc/pipeline/static_analysis/prim.cc
+1
-0
mindspore/ccsrc/pipeline/static_analysis/prim.h
mindspore/ccsrc/pipeline/static_analysis/prim.h
+2
-0
mindspore/ccsrc/vm/backend.cc
mindspore/ccsrc/vm/backend.cc
+3
-0
mindspore/ccsrc/vm/transform.cc
mindspore/ccsrc/vm/transform.cc
+10
-4
mindspore/ccsrc/vm/vm.cc
mindspore/ccsrc/vm/vm.cc
+53
-3
mindspore/ccsrc/vm/vm.h
mindspore/ccsrc/vm/vm.h
+2
-0
mindspore/nn/cell.py
mindspore/nn/cell.py
+36
-0
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+2
-1
mindspore/ops/operations/debug_ops.py
mindspore/ops/operations/debug_ops.py
+60
-0
tests/ut/python/pynative_mode/test_hook.py
tests/ut/python/pynative_mode/test_hook.py
+133
-0
未找到文件。
mindspore/_extends/parse/parser.py
浏览文件 @
840922e5
...
...
@@ -102,7 +102,10 @@ def get_parse_method_of_class(obj, parse_method=None):
method_name
=
parse_method
else
:
if
isinstance
(
obj
,
nn
.
Cell
):
method_name
=
"construct"
if
obj
.
enable_hook
:
method_name
=
"_hook_construct"
else
:
method_name
=
"construct"
if
method_name
is
not
None
:
if
hasattr
(
obj
,
method_name
):
method
=
getattr
(
obj
,
method_name
)
...
...
mindspore/ccsrc/ir/primitive.cc
浏览文件 @
840922e5
...
...
@@ -115,6 +115,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
.
def
(
"get_attr_dict"
,
&
PrimitivePy
::
GetAttrDict
,
"get primitive attr"
)
.
def
(
"set_prim_type"
,
&
PrimitivePy
::
set_prim_type
,
"Set primitive type."
)
.
def
(
"set_signatures"
,
&
PrimitivePy
::
set_signatures
,
"Set primitive inputs signature."
)
.
def
(
"register_hook"
,
&
PrimitivePy
::
set_hook
,
"Set primitive hook function."
)
.
def
(
"set_instance_name"
,
&
PrimitivePy
::
set_instance_name
,
"Set primitive instance name."
);
}));
}
// namespace mindspore
mindspore/ccsrc/ir/primitive.h
浏览文件 @
840922e5
...
...
@@ -23,7 +23,6 @@
#include <string>
#include <tuple>
#include "pybind11/pybind11.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "utils/misc.h"
#include "utils/log_adapter.h"
...
...
@@ -31,8 +30,6 @@
#include "ir/signature.h"
#include "parallel/ops_info/operator_info.h"
namespace
py
=
pybind11
;
namespace
mindspore
{
class
PrimitivePy
:
public
Primitive
{
public:
...
...
mindspore/ccsrc/ir/primitive_base.h
浏览文件 @
840922e5
...
...
@@ -24,6 +24,9 @@
#include <tuple>
#include "ir/dtype/type.h"
#include "pybind11/pybind11.h"
namespace
py
=
pybind11
;
namespace
mindspore
{
// Supported meta type
...
...
@@ -73,6 +76,9 @@ class Primitive : public Named {
return
iter
==
attrs_
.
cend
()
?
nullptr
:
iter
->
second
;
}
void
set_hook
(
const
py
::
function
&
hook
)
{
hook_
=
hook
;
}
py
::
function
hook
()
const
{
return
hook_
;
}
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
attrs
()
const
{
return
attrs_
;
}
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
...
...
@@ -103,6 +109,7 @@ class Primitive : public Named {
private:
std
::
string
instance_name_
;
py
::
function
hook_
;
bool
is_base_
;
bool
has_signature_
;
PrimType
prim_type_
;
...
...
mindspore/ccsrc/operator/ops.cc
浏览文件 @
840922e5
...
...
@@ -211,6 +211,7 @@ const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU");
const
PrimitivePtr
kPrimReluV2
=
std
::
make_shared
<
Primitive
>
(
"ReLUV2"
);
const
PrimitivePtr
kPrimZerosLikeTensor
=
std
::
make_shared
<
Primitive
>
(
"zeros_like_tensor"
);
const
PrimitivePtr
kPrimFakeBprop
=
std
::
make_shared
<
Primitive
>
(
"fake_bprop"
);
const
PrimitivePtr
kPrimBpropCut
=
std
::
make_shared
<
Primitive
>
(
"bprop_cut"
);
// Other miscellaneous
const
PrimitivePtr
kPrimIdentity
=
std
::
make_shared
<
Primitive
>
(
"identity"
);
...
...
@@ -224,6 +225,7 @@ const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key");
const
PrimitivePtr
kPrimGetRefValue
=
std
::
make_shared
<
Primitive
>
(
"get_ref_value"
);
const
PrimitivePtr
kPrimGetRefOrigin
=
std
::
make_shared
<
Primitive
>
(
"get_ref_origin"
);
const
PrimitivePtr
kPrimInsertGradientOf
=
std
::
make_shared
<
Primitive
>
(
"InsertGradientOf"
);
const
PrimitivePtr
kPrimHookBackward
=
std
::
make_shared
<
Primitive
>
(
"HookBackward"
);
const
PrimitivePtr
kPrimPrintShapeType
=
std
::
make_shared
<
Primitive
>
(
"PrintShapeType"
);
const
PrimitivePtr
kPrimSameTypeShape
=
std
::
make_shared
<
Primitive
>
(
"SameTypeShape"
);
const
PrimitivePtr
kPrimCheckBprop
=
std
::
make_shared
<
Primitive
>
(
"CheckBprop"
);
...
...
mindspore/ccsrc/operator/ops.h
浏览文件 @
840922e5
...
...
@@ -216,6 +216,7 @@ extern const PrimitivePtr kPrimReluV2;
extern
const
PrimitivePtr
kPrimActivation
;
extern
const
PrimitivePtr
kPrimZerosLikeTensor
;
extern
const
PrimitivePtr
kPrimFakeBprop
;
extern
const
PrimitivePtr
kPrimBpropCut
;
// Other Miscellaneous
extern
const
PrimitivePtr
kPrimIdentity
;
...
...
@@ -230,6 +231,7 @@ extern const PrimitivePtr kPrimGetRefKey;
extern
const
PrimitivePtr
kPrimGetRefValue
;
extern
const
PrimitivePtr
kPrimGetRefOrigin
;
extern
const
PrimitivePtr
kPrimInsertGradientOf
;
extern
const
PrimitivePtr
kPrimHookBackward
;
extern
const
PrimitivePtr
kPrimPrintShapeType
;
extern
const
PrimitivePtr
kPrimPrint
;
extern
const
PrimitivePtr
kPrimSameTypeShape
;
...
...
mindspore/ccsrc/operator/prim_nn.cc
浏览文件 @
840922e5
...
...
@@ -285,6 +285,16 @@ AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr
return
args_spec_list
[
0
]
->
Broaden
();
}
AbstractBasePtr
InferImplBpropCut
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: a tensor.
AbstractBasePtrList
args_list
;
for
(
size_t
i
=
0
;
i
<
args_spec_list
.
size
()
-
2
;
i
++
)
{
args_list
.
push_back
(
args_spec_list
[
i
]
->
Broaden
());
}
return
std
::
make_shared
<
AbstractTuple
>
(
args_list
);
}
AbstractBasePtr
InferImplLayerNorm
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: three tensors(x, gamma, beta).
...
...
mindspore/ccsrc/optimizer/ad/dfunctor.cc
浏览文件 @
840922e5
...
...
@@ -32,6 +32,7 @@
#include "operator/ops.h"
#include "operator/composite/composite.h"
#include "utils/symbolic.h"
#include "utils/context/ms_context.h"
#include "./common.h"
namespace
mindspore
{
...
...
mindspore/ccsrc/optimizer/ad/dfunctor.h
浏览文件 @
840922e5
...
...
@@ -125,6 +125,7 @@ class KPrim {
FuncGraphPtr
GetBprop
(
const
PrimitivePtr
&
prim
);
FuncGraphPtr
GetFprop
(
const
PrimitivePtr
&
prim
);
FuncGraphPtr
FakeBprop
(
const
ValueNodePtr
&
value_node
,
const
pipeline
::
ResourceBasePtr
&
resources
);
FuncGraphPtr
BpropCut
(
const
ValueNodePtr
&
value_node
,
const
pipeline
::
ResourceBasePtr
&
resources
);
// Given a bprop rule, do the K mapping.
template
<
typename
T
>
FuncGraphPtr
BpropToK
(
const
T
&
primal
,
const
FuncGraphPtr
&
bprop_g
);
...
...
mindspore/ccsrc/optimizer/ad/kprim.cc
浏览文件 @
840922e5
...
...
@@ -115,10 +115,15 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
}
bool
is_faked_bprop
=
false
;
auto
bprop_fg
=
GetBprop
(
prim
);
if
(
bprop_fg
==
nullptr
)
{
bprop_fg
=
FakeBprop
(
value_node
,
resources
);
is_faked_bprop
=
true
;
FuncGraphPtr
bprop_fg
=
nullptr
;
if
(
prim
->
Hash
()
==
prim
::
kPrimHookBackward
->
Hash
()
&&
prim
->
name
()
==
"HookBackward"
)
{
bprop_fg
=
BpropCut
(
value_node
,
resources
);
}
else
{
bprop_fg
=
GetBprop
(
prim
);
if
(
bprop_fg
==
nullptr
)
{
bprop_fg
=
FakeBprop
(
value_node
,
resources
);
is_faked_bprop
=
true
;
}
}
auto
expanded_fg
=
BpropToK
(
prim
,
bprop_fg
);
...
...
@@ -206,6 +211,45 @@ FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) {
return
expanded_fg
;
}
FuncGraphPtr
KPrim
::
BpropCut
(
const
ValueNodePtr
&
value_node
,
const
pipeline
::
ResourceBasePtr
&
resources
)
{
auto
prim
=
GetValueNode
<
PrimitivePtr
>
(
value_node
);
MS_EXCEPTION_IF_NULL
(
prim
);
auto
&
node_users
=
resources
->
manager
()
->
node_users
();
auto
&
users
=
node_users
[
value_node
];
auto
cnode
=
std
::
find_if
(
users
.
begin
(),
users
.
end
(),
[
&
prim
](
const
std
::
pair
<
AnfNodePtr
,
int
>
&
user
)
->
bool
{
return
IsPrimitiveCNode
(
user
.
first
,
prim
);
});
if
(
cnode
==
users
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"Fail to find cnode."
;
}
auto
inputs_num
=
cnode
->
first
->
cast
<
CNodePtr
>
()
->
size
()
-
1
;
auto
func_graph
=
std
::
make_shared
<
FuncGraph
>
();
std
::
vector
<
AnfNodePtr
>
outputs
;
auto
bprop_cut
=
std
::
make_shared
<
Primitive
>
(
"bprop_cut"
);
bprop_cut
->
set_hook
(
prim
->
hook
());
auto
cell_id
=
GetValue
<
std
::
string
>
(
prim
->
GetAttr
(
"cell_id"
));
if
(
cell_id
!=
""
)
{
(
void
)
bprop_cut
->
AddAttr
(
"cell_hook"
,
MakeValue
(
true
));
(
void
)
bprop_cut
->
AddAttr
(
"cell_id"
,
MakeValue
(
cell_id
));
}
outputs
.
push_back
(
NewValueNode
(
bprop_cut
));
for
(
size_t
i
=
0
;
i
<
inputs_num
;
++
i
)
{
auto
param
=
func_graph
->
add_parameter
();
outputs
.
push_back
(
param
);
}
auto
p1
=
func_graph
->
add_parameter
();
auto
p2
=
func_graph
->
add_parameter
();
outputs
.
push_back
(
p1
);
outputs
.
push_back
(
p2
);
func_graph
->
set_output
(
func_graph
->
NewCNode
(
outputs
));
return
func_graph
;
}
FuncGraphPtr
KPrim
::
FakeBprop
(
const
ValueNodePtr
&
value_node
,
const
pipeline
::
ResourceBasePtr
&
resources
)
{
auto
prim
=
value_node
->
value
()
->
cast
<
PrimitivePtr
>
();
MS_EXCEPTION_IF_NULL
(
prim
);
...
...
mindspore/ccsrc/optimizer/irpass.cc
浏览文件 @
840922e5
...
...
@@ -49,9 +49,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
arithmetic_simplify_
=
MakeSubstitution
(
ArithmeticSimplify
(),
"arithmetic_simplify"
,
{
prim
::
kPrimScalarAdd
,
prim
::
kPrimScalarMul
,
prim
::
kPrimTensorAdd
,
prim
::
kPrimIdentity
,
prim
::
kPrimMomentum
,
prim
::
kPrimMul
});
special_op_eliminate_
=
MakeSubstitution
(
SpecialOpEliminater
(),
"special_op_eliminate"
,
{
prim
::
kPrimInsertGradientOf
,
prim
::
kPrimPrintShapeType
,
prim
::
kPrimGetRefKey
,
prim
::
kPrimMirror
,
prim
::
kPrimVirtualDiv
});
special_op_eliminate_
=
MakeSubstitution
(
SpecialOpEliminater
(),
"special_op_eliminate"
,
{
prim
::
kPrimInsertGradientOf
,
prim
::
kPrimHookBackward
,
prim
::
kPrimPrintShapeType
,
prim
::
kPrimGetRefKey
,
prim
::
kPrimMirror
,
prim
::
kPrimVirtualDiv
});
zero_like_fill_zero_
=
MakeSubstitution
(
ZeroLikeFillZero
(),
"zero_like_fill_zero"
,
prim
::
kPrimZerosLikeTensor
);
adjust_all_reduce_mul_add_
=
MakeSubstitution
(
AdjustAllReduceMulAdd
(),
"adjust_all_reduce_mul_add"
,
prim
::
kPrimAddN
);
...
...
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h
浏览文件 @
840922e5
...
...
@@ -35,11 +35,13 @@ class SpecialOpEliminater {
public:
SpecialOpEliminater
()
:
insert_gradient_of_
(
prim
::
kPrimInsertGradientOf
),
hook_backward_
(
prim
::
kPrimHookBackward
),
print_shape_type_
(
prim
::
kPrimPrintShapeType
),
get_ref_value_
(
prim
::
kPrimGetRefValue
),
mirror_
(
prim
::
kPrimMirror
),
virtual_div_
(
prim
::
kPrimVirtualDiv
)
{
eliminaters_
.
emplace_back
(
insert_gradient_of_
);
eliminaters_
.
emplace_back
(
hook_backward_
);
eliminaters_
.
emplace_back
(
print_shape_type_
);
eliminaters_
.
emplace_back
(
get_ref_value_
);
eliminaters_
.
emplace_back
(
mirror_
);
...
...
@@ -59,7 +61,7 @@ class SpecialOpEliminater {
}
private:
PrimEliminater
insert_gradient_of_
,
print_shape_type_
,
get_ref_value_
,
mirror_
,
virtual_div_
;
PrimEliminater
insert_gradient_of_
,
hook_backward_
,
print_shape_type_
,
get_ref_value_
,
mirror_
,
virtual_div_
;
std
::
vector
<
TransformFuncType
>
eliminaters_
{};
};
...
...
mindspore/ccsrc/pipeline/parse/data_converter.cc
浏览文件 @
840922e5
...
...
@@ -30,6 +30,7 @@
#include "operator/composite/composite.h"
#include "ir/func_graph_cloner.h"
#include "utils/symbolic.h"
#include "utils/context/ms_context.h"
#include "debug/trace.h"
namespace
mindspore
{
...
...
@@ -207,6 +208,35 @@ bool ConvertTensor(const py::object &obj, ValuePtr *const data) {
return
true
;
}
FuncGraphPtr
ConvertToBpropCut
(
py
::
object
obj
)
{
std
::
vector
<
std
::
string
>
results
=
data_converter
::
GetObjKey
(
obj
);
std
::
string
obj_key
=
results
[
0
];
py
::
function
bprop_func
=
py
::
getattr
(
obj
,
"bprop"
);
FuncGraphPtr
bprop_graph
=
std
::
make_shared
<
FuncGraph
>
();
std
::
vector
<
AnfNodePtr
>
outputs
;
auto
fake_bprop
=
std
::
make_shared
<
Primitive
>
(
"bprop_cut"
);
fake_bprop
->
set_hook
(
bprop_func
);
(
void
)
fake_bprop
->
AddAttr
(
"bprop"
,
MakeValue
(
true
));
outputs
.
push_back
(
NewValueNode
(
fake_bprop
));
py
::
object
code_obj
=
py
::
getattr
(
bprop_func
,
"__code__"
);
size_t
inputs_num
=
py
::
cast
<
int
>
(
py
::
getattr
(
code_obj
,
"co_argcount"
))
-
3
;
for
(
size_t
i
=
0
;
i
<
inputs_num
;
++
i
)
{
auto
param
=
bprop_graph
->
add_parameter
();
outputs
.
push_back
(
param
);
}
auto
p1
=
bprop_graph
->
add_parameter
();
auto
p2
=
bprop_graph
->
add_parameter
();
outputs
.
push_back
(
p1
);
outputs
.
push_back
(
p2
);
bprop_graph
->
set_output
(
bprop_graph
->
NewCNode
(
outputs
));
data_converter
::
SetObjGraphValue
(
obj_key
,
bprop_graph
);
return
bprop_graph
;
}
bool
ConvertOtherObj
(
py
::
object
obj
,
ValuePtr
*
const
data
)
{
auto
obj_type
=
data_converter
::
GetObjType
(
obj
);
MS_LOG
(
DEBUG
)
<<
"Converting the object("
<<
((
std
::
string
)
py
::
str
(
obj
))
<<
") detail type: "
<<
obj_type
<<
" "
;
...
...
@@ -238,7 +268,13 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
}
// if the cell object has specified bprop, it has user-defined bprop function parse and record it
if
(
py
::
hasattr
(
obj
,
"bprop"
))
{
FuncGraphPtr
bprop_graph
=
ConvertToFuncGraph
(
obj
,
PYTHON_MOD_GET_BPROP_METHOD
);
FuncGraphPtr
bprop_graph
=
nullptr
;
bool
enable_bprop_debug
=
py
::
cast
<
bool
>
(
py
::
getattr
(
obj
,
"bprop_debug"
));
if
(
enable_bprop_debug
)
{
bprop_graph
=
ConvertToBpropCut
(
obj
);
}
else
{
bprop_graph
=
ConvertToFuncGraph
(
obj
,
PYTHON_MOD_GET_BPROP_METHOD
);
}
if
(
bprop_graph
!=
nullptr
)
{
(
void
)
func_graph
->
transforms
().
insert
(
std
::
make_pair
(
"bprop"
,
FuncGraphTransform
(
bprop_graph
)));
(
void
)
bprop_graph
->
transforms
().
insert
(
std
::
make_pair
(
"primal"
,
FuncGraphTransform
(
func_graph
)));
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.cc
浏览文件 @
840922e5
...
...
@@ -108,6 +108,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{
prim
::
kPrimRelu
,
{
InferImplRelu
,
true
}},
{
prim
::
kPrimZerosLikeTensor
,
{
InferImplZerosLikeTensor
,
true
}},
{
prim
::
kPrimFakeBprop
,
{
InferImplFakeBprop
,
false
}},
{
prim
::
kPrimBpropCut
,
{
InferImplBpropCut
,
true
}},
{
prim
::
kPrimLayerNorm
,
{
InferImplLayerNorm
,
true
}},
{
prim
::
kPrimLayerNormGrad
,
{
InferImplLayerNormGrad
,
true
}},
{
prim
::
kPrimDropoutGenMask
,
{
InferImplDropoutGenMask
,
true
}},
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.h
浏览文件 @
840922e5
...
...
@@ -210,6 +210,8 @@ AbstractBasePtr InferImplZerosLikeTensor(const AnalysisEnginePtr &, const Primit
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplFakeBprop
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplBpropCut
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplLayerNorm
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplLayerNormGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
...
...
mindspore/ccsrc/vm/backend.cc
浏览文件 @
840922e5
...
...
@@ -64,6 +64,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) {
result
.
outputs
=
outputs
;
result
.
graph_id
=
kInvalidGraphId
;
auto
graph_id
=
sess_
->
CompileGraph
(
lst
,
outputs
);
if
(
MsContext
::
GetInstance
()
->
execution_mode
()
==
kPynativeMode
)
{
sess_
->
BuildGraph
(
graph_id
);
}
if
(
MsContext
::
GetInstance
()
->
precompile_only
())
{
MS_LOG
(
INFO
)
<<
"PrecompileOnly, stop run graph"
;
return
result
;
...
...
mindspore/ccsrc/vm/transform.cc
浏览文件 @
840922e5
...
...
@@ -40,9 +40,10 @@ using MapPrimTypeFuncGraph = std::map<PrimTypePair, FuncGraphPtr>;
using
TypedPrimitiveAbstractClosurePtr
=
std
::
shared_ptr
<
abstract
::
TypedPrimitiveAbstractClosure
>
;
std
::
vector
<
PrimitivePtr
>
nonlinear_ops
=
{
prim
::
kPrimReturn
,
prim
::
kPrimPartial
,
prim
::
kPrimSwitch
,
prim
::
kPrimMakeTuple
};
prim
::
kPrimMakeTuple
,
prim
::
kPrimBpropCut
};
const
std
::
vector
<
PrimitivePtr
>
&
GetMsNonlinearOps
()
{
static
const
std
::
vector
<
PrimitivePtr
>
ms_nonlinear_ops
=
{
prim
::
kPrimReturn
,
prim
::
kPrimPartial
,
prim
::
kPrimSwitch
};
static
const
std
::
vector
<
PrimitivePtr
>
ms_nonlinear_ops
=
{
prim
::
kPrimReturn
,
prim
::
kPrimPartial
,
prim
::
kPrimSwitch
,
prim
::
kPrimBpropCut
};
return
ms_nonlinear_ops
;
}
...
...
@@ -646,8 +647,13 @@ BackendPtr CreateBackend() {
auto
backend
=
std
::
make_shared
<
MsBackend
>
(
name
,
target
,
device_id
);
std
::
string
device_target
=
MsContext
::
GetInstance
()
->
device_target
();
if
(
device_target
==
kAscendDevice
)
{
backend
->
set_is_multi_graph_sink
(
true
);
context_ptr
->
set_is_multi_graph_sink
(
true
);
if
(
MsContext
::
GetInstance
()
->
execution_mode
()
==
kPynativeMode
)
{
backend
->
set_is_multi_graph_sink
(
false
);
context_ptr
->
set_is_multi_graph_sink
(
false
);
}
else
{
backend
->
set_is_multi_graph_sink
(
true
);
context_ptr
->
set_is_multi_graph_sink
(
true
);
}
}
return
backend
;
}
...
...
mindspore/ccsrc/vm/vm.cc
浏览文件 @
840922e5
...
...
@@ -587,15 +587,65 @@ void FinalVM::InstPushPrim(const VectorRef &args) {
VectorRef
tuple
;
auto
prim
=
utils
::
cast
<
PrimitivePtr
>
(
args
[
0
]);
for
(
size_t
i
=
1
;
i
<
args
.
size
();
++
i
)
{
auto
index
=
utils
::
cast
<
int
>
(
args
[
1
]);
auto
index
=
utils
::
cast
<
int
>
(
args
[
i
]);
tuple
.
push_back
(
Ref
(
index
));
}
auto
outs
=
RunOperation
(
prim
,
tuple
);
Push
(
outs
);
if
(
prim
->
name
()
==
"bprop_cut"
)
{
auto
outs
=
RunHook
(
prim
,
tuple
);
Push
(
outs
);
}
else
{
auto
outs
=
RunOperation
(
prim
,
tuple
);
Push
(
outs
);
}
MS_LOG
(
DEBUG
)
<<
"End"
;
}
BaseRef
FinalVM
::
RunHook
(
const
PrimitivePtr
&
prim
,
const
VectorRef
&
args
)
{
py
::
tuple
py_args
=
py
::
tuple
(
args
.
size
());
MS_LOG
(
DEBUG
)
<<
"input for operation:"
;
size_t
i
=
0
;
for
(
auto
&
arg
:
args
)
{
py_args
[
i
]
=
BaseRefToPyData
(
arg
);
MS_LOG
(
DEBUG
)
<<
"arg: "
<<
i
<<
":"
;
i
++
;
}
py
::
object
obj
;
bool
is_bprop
=
prim
->
HasAttr
(
"bprop"
);
if
(
is_bprop
)
{
py
::
function
fn_bprop
=
prim
->
hook
();
obj
=
fn_bprop
(
*
py_args
);
return
obj
;
}
bool
is_cell
=
prim
->
HasAttr
(
"cell_hook"
);
if
(
is_cell
)
{
std
::
string
cell_id
=
GetValue
<
std
::
string
>
(
prim
->
GetAttr
(
"cell_id"
));
if
(
_hook_grad
.
find
(
cell_id
)
!=
_hook_grad
.
end
())
{
py
::
tuple
hook_args
=
py
::
tuple
(
3
);
hook_args
[
0
]
=
cell_id
;
hook_args
[
1
]
=
_hook_grad
[
cell_id
];
hook_args
[
2
]
=
py_args
[
2
];
py
::
function
fn_hook
=
prim
->
hook
();
obj
=
fn_hook
(
*
hook_args
);
if
(
py
::
isinstance
<
py
::
none
>
(
obj
))
{
obj
=
py_args
[
2
];
}
_hook_grad
.
erase
(
cell_id
);
}
else
{
_hook_grad
[
cell_id
]
=
py_args
[
2
];
obj
=
py_args
[
2
];
}
}
else
{
py
::
function
fn_hook
=
prim
->
hook
();
obj
=
fn_hook
(
py_args
[
2
]);
if
(
py
::
isinstance
<
py
::
none
>
(
obj
))
{
obj
=
py_args
[
2
];
}
}
obj
=
py
::
make_tuple
(
obj
);
return
obj
;
}
}
// namespace compile
}
// namespace mindspore
mindspore/ccsrc/vm/vm.h
浏览文件 @
840922e5
...
...
@@ -115,6 +115,7 @@ class FinalVM {
void
InstPushPrim
(
const
VectorRef
&
args
);
void
InstSwitchReturn
(
const
VectorRef
&
args
);
void
set_insts
(
const
InstSet
&
value
)
{
insts_
=
value
;
}
BaseRef
RunHook
(
const
PrimitivePtr
&
prim
,
const
VectorRef
&
args
);
protected:
BaseRef
Ref
(
int
i
);
...
...
@@ -156,6 +157,7 @@ class FinalVM {
{
Instruction
::
kPrim
,
[
this
](
const
VectorRef
&
args
)
{
InstPushPrim
(
args
);
}},
{
Instruction
::
kSwitchReturn
,
[
this
](
const
VectorRef
&
args
)
{
InstSwitchReturn
(
args
);
}},
};
std
::
map
<
std
::
string
,
py
::
object
>
_hook_grad
;
};
using
FinalVMPtr
=
std
::
shared_ptr
<
FinalVM
>
;
...
...
mindspore/nn/cell.py
浏览文件 @
840922e5
...
...
@@ -24,6 +24,7 @@ from .._checkparam import _check_str_by_regular
from
..common.parameter
import
Parameter
,
ParameterTuple
from
.._c_expression
import
init_backend
from
..ops.primitive
import
Primitive
from
..ops.operations
import
HookBackward
from
..parallel._tensor
import
_load_tensor_by_layout
from
..common.tensor
import
Tensor
...
...
@@ -75,6 +76,9 @@ class Cell:
self
.
_parallel_inputs_run
=
None
if
flags
:
self
.
add_flags
(
**
flags
)
self
.
_backward_hook
=
None
self
.
_enable_hook
=
False
self
.
_bprop_debug
=
False
@
property
def
create_time
(
self
):
...
...
@@ -91,6 +95,16 @@ class Cell:
"""
return
self
.
_param_prefix
@
property
def
bprop_debug
(
self
):
return
self
.
_bprop_debug
@
bprop_debug
.
setter
def
bprop_debug
(
self
,
value
):
if
not
isinstance
(
value
,
bool
):
raise
TypeError
(
"'bprop debug' value must be bool type."
)
self
.
_bprop_debug
=
value
def
update_cell_prefix
(
self
):
"""
Update the all child cells' self.param_prefix.
...
...
@@ -728,3 +742,25 @@ class Cell:
self
.
_auto_parallel_mode
=
True
self
.
add_flags
(
auto_parallel
=
True
)
self
.
_get_construct_inputs_number_and_name
()
def
_hook_construct
(
self
,
inputs
):
"""Hook construct method to replace original construct method when hook function enabled."""
inputs
=
self
.
_backward_hook
(
inputs
)
inputs
=
self
.
construct
(
inputs
)
outputs
=
self
.
_backward_hook
(
inputs
)
return
outputs
@
property
def
enable_hook
(
self
):
"""Whether the cell register hook function"""
return
self
.
_enable_hook
def
register_backward_hook
(
self
,
fn
):
"""
Set the cell backward hook function.
Args:
fn (function): Specifies the hook function with grad as input.
"""
self
.
_backward_hook
=
HookBackward
(
fn
,
str
(
id
(
self
)))
self
.
_enable_hook
=
True
mindspore/ops/operations/__init__.py
浏览文件 @
840922e5
...
...
@@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
from
.comm_ops
import
(
AllGather
,
AllReduce
,
_AlltoAll
,
ReduceScatter
,
Broadcast
,
_MirrorOperator
,
ReduceOp
,
_VirtualDataset
,
_VirtualDiv
,
_GetTensorSlice
)
from
.debug_ops
import
(
ImageSummary
,
InsertGradientOf
,
ScalarSummary
,
from
.debug_ops
import
(
ImageSummary
,
InsertGradientOf
,
HookBackward
,
ScalarSummary
,
TensorSummary
,
HistogramSummary
,
Print
)
from
.control_ops
import
ControlDepend
,
GeSwitch
,
Merge
from
.inner_ops
import
ScalarCast
...
...
@@ -155,6 +155,7 @@ __all__ = [
'HistogramSummary'
,
"Print"
,
'InsertGradientOf'
,
'HookBackward'
,
'InvertPermutation'
,
'Shape'
,
'DropoutDoMask'
,
...
...
mindspore/ops/operations/debug_ops.py
浏览文件 @
840922e5
...
...
@@ -14,6 +14,7 @@
# ============================================================================
"""debug_ops"""
from
types
import
FunctionType
from
..._checkparam
import
Validator
as
validator
from
...common
import
dtype
as
mstype
from
..primitive
import
Primitive
,
prim_attr_register
,
PrimitiveWithInfer
...
...
@@ -193,6 +194,65 @@ class InsertGradientOf(PrimitiveWithInfer):
return
x_type
class
HookBackward
(
PrimitiveWithInfer
):
"""
Used as tag to hook gradient in intermediate variables.
Note:
The hook function should have one input of gradient of the variable.
hook function will be executed in python environment, while callback
of InsertGradientOf will be parsed and added to the graph.
Args:
hook_fn (Function): Python function. hook function.
Inputs:
- **inputs** (Tensor) - The variable to hook.
Examples:
>>> def hook_fn(grad_out):
>>> print(grad_out)
>>>
>>> hook = P.HookBackward(hook_fn)
>>>
>>> def hook_test(x, y):
>>> z = x * y
>>> z = hook(z)
>>> z = z * y
>>> return z
>>>
>>> def backward(x, y):
>>> return C.grad_all(hook_test)(x, y)
>>>
>>> backward(1, 2)
"""
def
__init__
(
self
,
hook_fn
,
cell_id
=
""
):
super
(
HookBackward
,
self
).
__init__
(
self
.
__class__
.
__name__
)
self
.
add_prim_attr
(
"cell_id"
,
cell_id
)
self
.
init_attrs
[
"cell_id"
]
=
cell_id
if
not
isinstance
(
hook_fn
,
FunctionType
):
raise
TypeError
(
"Hook function should be python function type."
)
self
.
register_hook
(
hook_fn
)
self
.
cell_id
=
cell_id
def
__call__
(
self
,
*
inputs
):
"""run in PyNative mode."""
if
len
(
inputs
)
==
1
:
return
inputs
[
0
]
return
inputs
def
infer_shape
(
self
,
*
inputs_shape
):
if
len
(
inputs_shape
)
==
1
:
return
inputs_shape
[
0
]
return
inputs_shape
def
infer_dtype
(
self
,
*
inputs_type
):
if
len
(
inputs_type
)
==
1
:
return
inputs_type
[
0
]
return
inputs_type
class
Print
(
PrimitiveWithInfer
):
"""
Output tensor or string to stdout.
...
...
tests/ut/python/pynative_mode/test_hook.py
0 → 100644
浏览文件 @
840922e5
import
numpy
as
np
import
mindspore.nn
as
nn
import
mindspore.ops.operations
as
P
from
mindspore
import
context
from
mindspore.ops
import
composite
as
C
from
mindspore.common
import
dtype
as
mstype
from
mindspore
import
context
,
Tensor
,
ParameterTuple
from
mindspore.common.initializer
import
TruncatedNormal
from
mindspore.nn
import
Dense
,
WithLossCell
,
SoftmaxCrossEntropyWithLogits
,
Momentum
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"GPU"
)
def
conv
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
):
"""weight initial for conv layer"""
weight
=
weight_variable
()
return
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
weight_init
=
weight
,
has_bias
=
False
,
pad_mode
=
"valid"
)
def
fc_with_initialize
(
input_channels
,
out_channels
):
"""weight initial for fc layer"""
weight
=
weight_variable
()
bias
=
weight_variable
()
return
nn
.
Dense
(
input_channels
,
out_channels
,
weight
,
bias
)
def
weight_variable
():
"""weight initial"""
return
TruncatedNormal
(
0.02
)
def
cell_hook_function
(
cell_id
,
grad_input
,
grad_output
):
print
(
cell_id
)
assert
(
grad_output
.
asnumpy
().
shape
==
(
32
,
6
,
14
,
14
))
assert
(
grad_input
.
asnumpy
().
shape
==
(
32
,
16
,
10
,
10
))
def
var_hook_function
(
grad_out
):
print
(
"grad:"
,
grad_out
)
assert
(
grad_out
.
asnumpy
().
shape
==
(
32
,
120
))
class
LeNet5
(
nn
.
Cell
):
"""
Lenet network
Args:
num_class (int): Num classes. Default: 10.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
"""
def
__init__
(
self
,
num_class
=
10
):
super
(
LeNet5
,
self
).
__init__
()
self
.
num_class
=
num_class
self
.
batch_size
=
32
self
.
conv1
=
conv
(
1
,
6
,
5
)
self
.
conv2
=
conv
(
6
,
16
,
5
)
self
.
conv2
.
register_backward_hook
(
cell_hook_function
)
self
.
fc1
=
fc_with_initialize
(
16
*
5
*
5
,
120
)
self
.
fc2
=
fc_with_initialize
(
120
,
84
)
self
.
fc3
=
fc_with_initialize
(
84
,
self
.
num_class
)
self
.
relu
=
nn
.
ReLU
()
self
.
max_pool2d
=
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
)
self
.
reshape
=
P
.
Reshape
()
self
.
hook
=
P
.
HookBackward
(
var_hook_function
)
def
construct
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
reshape
(
x
,
(
self
.
batch_size
,
-
1
))
x
=
self
.
fc1
(
x
)
x
=
self
.
hook
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
fc3
(
x
)
return
x
class
GradWrap
(
nn
.
Cell
):
""" GradWrap definition """
def
__init__
(
self
,
network
):
super
(
GradWrap
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
network
=
network
self
.
weights
=
ParameterTuple
(
filter
(
lambda
x
:
x
.
requires_grad
,
network
.
get_parameters
()))
def
construct
(
self
,
x
,
label
):
weights
=
self
.
weights
return
C
.
GradOperation
(
'get_by_list'
,
get_by_list
=
True
)(
self
.
network
,
weights
)(
x
,
label
)
def
test_hook
():
net
=
LeNet5
()
optimizer
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
net
.
get_parameters
()),
0.1
,
0.9
)
criterion
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
False
)
net_with_criterion
=
WithLossCell
(
net
,
criterion
)
train_network
=
GradWrap
(
net_with_criterion
)
train_network
.
set_train
()
input_data
=
Tensor
(
np
.
ones
([
net
.
batch_size
,
1
,
32
,
32
]).
astype
(
np
.
float32
)
*
0.01
)
label
=
Tensor
(
np
.
ones
([
net
.
batch_size
,
net
.
num_class
]).
astype
(
np
.
float32
))
output
=
net
(
Tensor
(
input_data
))
loss_output
=
criterion
(
output
,
label
)
grads
=
train_network
(
input_data
,
label
)
success
=
optimizer
(
grads
)
print
(
loss_output
.
asnumpy
().
shape
)
class
MulAdd
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
MulAdd
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
):
return
2
*
x
+
y
def
bprop
(
self
,
x
,
y
,
out
,
dout
):
assert
(
x
==
1
)
assert
(
y
==
2
)
assert
(
out
==
4
)
assert
(
dout
==
1
)
return
3
*
dout
,
2
*
y
def
test_custom_bprop
():
mul_add
=
MulAdd
()
mul_add
.
bprop_debug
=
True
assert
C
.
grad_all
(
mul_add
)(
1
,
2
)
==
(
3
,
4
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录