Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
358982a9
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
358982a9
编写于
6月 16, 2020
作者:
K
kingfo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix hook and bprop debug issue
上级
fe797aaf
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
158 addition
and
74 deletion
+158
-74
mindspore/_extends/builtin_operations.py
mindspore/_extends/builtin_operations.py
+18
-0
mindspore/ccsrc/pipeline/parse/data_converter.cc
mindspore/ccsrc/pipeline/parse/data_converter.cc
+31
-31
mindspore/ccsrc/pipeline/parse/data_converter.h
mindspore/ccsrc/pipeline/parse/data_converter.h
+1
-0
mindspore/ccsrc/pipeline/parse/parse_base.h
mindspore/ccsrc/pipeline/parse/parse_base.h
+1
-0
mindspore/ccsrc/pynative/base.h
mindspore/ccsrc/pynative/base.h
+1
-1
mindspore/ccsrc/pynative/pynative_execute.cc
mindspore/ccsrc/pynative/pynative_execute.cc
+32
-12
mindspore/ccsrc/pynative/pynative_execute.h
mindspore/ccsrc/pynative/pynative_execute.h
+2
-2
mindspore/common/_register_for_tensor.py
mindspore/common/_register_for_tensor.py
+10
-2
mindspore/common/tensor.py
mindspore/common/tensor.py
+11
-10
mindspore/nn/cell.py
mindspore/nn/cell.py
+1
-1
mindspore/nn/layer/container.py
mindspore/nn/layer/container.py
+10
-0
mindspore/ops/composite/base.py
mindspore/ops/composite/base.py
+1
-1
mindspore/ops/functional.py
mindspore/ops/functional.py
+5
-2
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+2
-0
mindspore/ops/primitive.py
mindspore/ops/primitive.py
+1
-7
tests/ut/cpp/pynative/pynative_execute_test.cc
tests/ut/cpp/pynative/pynative_execute_test.cc
+2
-3
tests/ut/python/ir/test_tensor.py
tests/ut/python/ir/test_tensor.py
+1
-1
tests/ut/python/pynative_mode/test_hook.py
tests/ut/python/pynative_mode/test_hook.py
+28
-1
未找到文件。
mindspore/_extends/builtin_operations.py
浏览文件 @
358982a9
...
@@ -113,6 +113,24 @@ def bool_or(x, y):
...
@@ -113,6 +113,24 @@ def bool_or(x, y):
"""Implement `bool_or`."""
"""Implement `bool_or`."""
return
x
or
y
return
x
or
y
def
vm_compare
(
*
args
):
"""Implement `vm_compare` for tensor."""
obj_str
=
args
[
-
1
]
if
obj_str
==
"shape"
:
fn
=
getattr
(
args
[
0
].
asnumpy
(),
obj_str
)
return
fn
if
len
(
args
)
==
2
:
fn
=
getattr
(
args
[
0
].
asnumpy
(),
obj_str
)
return
Tensor
(
fn
())
if
isinstance
(
args
[
0
],
Tensor
):
fn
=
getattr
(
args
[
0
].
asnumpy
(),
obj_str
)
y
=
args
[
1
].
asnumpy
()
if
isinstance
(
args
[
1
],
Tensor
)
else
args
[
1
]
else
:
obj_str
=
"__r"
+
obj_str
[
2
:]
fn
=
getattr
(
args
[
1
].
asnumpy
(),
obj_str
)
y
=
args
[
0
]
return
Tensor
(
np
.
array
(
fn
(
y
)))
def
make_list
(
*
xs
):
def
make_list
(
*
xs
):
"""Implement `make_list`."""
"""Implement `make_list`."""
...
...
mindspore/ccsrc/pipeline/parse/data_converter.cc
浏览文件 @
358982a9
...
@@ -41,6 +41,35 @@ using TensorPtr = mindspore::tensor::TensorPtr;
...
@@ -41,6 +41,35 @@ using TensorPtr = mindspore::tensor::TensorPtr;
using
MetaTensor
=
mindspore
::
tensor
::
MetaTensor
;
using
MetaTensor
=
mindspore
::
tensor
::
MetaTensor
;
using
MetaTensorPtr
=
mindspore
::
tensor
::
MetaTensorPtr
;
using
MetaTensorPtr
=
mindspore
::
tensor
::
MetaTensorPtr
;
FuncGraphPtr
ConvertToBpropCut
(
const
py
::
object
&
obj
)
{
std
::
vector
<
std
::
string
>
results
=
data_converter
::
GetObjKey
(
obj
);
std
::
string
obj_key
=
results
[
0
];
py
::
function
bprop_func
=
py
::
getattr
(
obj
,
CUSTOM_BPROP_NAME
);
auto
bprop_graph
=
std
::
make_shared
<
FuncGraph
>
();
std
::
vector
<
AnfNodePtr
>
outputs
;
auto
fake_bprop
=
std
::
make_shared
<
PrimitivePy
>
(
"bprop_cut"
,
py
::
object
());
fake_bprop
->
set_hook
(
bprop_func
);
(
void
)
fake_bprop
->
AddAttr
(
CUSTOM_BPROP_NAME
,
MakeValue
(
true
));
outputs
.
push_back
(
NewValueNode
(
fake_bprop
));
py
::
object
code_obj
=
py
::
getattr
(
bprop_func
,
"__code__"
);
size_t
inputs_num
=
py
::
cast
<
int
>
(
py
::
getattr
(
code_obj
,
"co_argcount"
))
-
3
;
for
(
size_t
i
=
0
;
i
<
inputs_num
;
++
i
)
{
auto
param
=
bprop_graph
->
add_parameter
();
outputs
.
push_back
(
param
);
}
auto
p1
=
bprop_graph
->
add_parameter
();
auto
p2
=
bprop_graph
->
add_parameter
();
outputs
.
push_back
(
p1
);
outputs
.
push_back
(
p2
);
bprop_graph
->
set_output
(
bprop_graph
->
NewCNode
(
outputs
));
data_converter
::
SetObjGraphValue
(
obj_key
,
bprop_graph
);
return
bprop_graph
;
}
namespace
{
namespace
{
bool
ConvertTuple
(
const
py
::
object
&
obj
,
ValuePtr
*
const
data
,
bool
use_signature
)
{
bool
ConvertTuple
(
const
py
::
object
&
obj
,
ValuePtr
*
const
data
,
bool
use_signature
)
{
MS_LOG
(
DEBUG
)
<<
"Converting python tuple"
;
MS_LOG
(
DEBUG
)
<<
"Converting python tuple"
;
...
@@ -231,35 +260,6 @@ bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
...
@@ -231,35 +260,6 @@ bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
return
true
;
return
true
;
}
}
FuncGraphPtr
ConvertToBpropCut
(
py
::
object
obj
)
{
std
::
vector
<
std
::
string
>
results
=
data_converter
::
GetObjKey
(
obj
);
std
::
string
obj_key
=
results
[
0
];
py
::
function
bprop_func
=
py
::
getattr
(
obj
,
"bprop"
);
FuncGraphPtr
bprop_graph
=
std
::
make_shared
<
FuncGraph
>
();
std
::
vector
<
AnfNodePtr
>
outputs
;
auto
fake_bprop
=
std
::
make_shared
<
PrimitivePy
>
(
"bprop_cut"
,
py
::
object
());
fake_bprop
->
set_hook
(
bprop_func
);
(
void
)
fake_bprop
->
AddAttr
(
"bprop"
,
MakeValue
(
true
));
outputs
.
push_back
(
NewValueNode
(
fake_bprop
));
py
::
object
code_obj
=
py
::
getattr
(
bprop_func
,
"__code__"
);
size_t
inputs_num
=
py
::
cast
<
int
>
(
py
::
getattr
(
code_obj
,
"co_argcount"
))
-
3
;
for
(
size_t
i
=
0
;
i
<
inputs_num
;
++
i
)
{
auto
param
=
bprop_graph
->
add_parameter
();
outputs
.
push_back
(
param
);
}
auto
p1
=
bprop_graph
->
add_parameter
();
auto
p2
=
bprop_graph
->
add_parameter
();
outputs
.
push_back
(
p1
);
outputs
.
push_back
(
p2
);
bprop_graph
->
set_output
(
bprop_graph
->
NewCNode
(
outputs
));
data_converter
::
SetObjGraphValue
(
obj_key
,
bprop_graph
);
return
bprop_graph
;
}
bool
ConvertCellObjToFuncGraph
(
py
::
object
obj
,
ValuePtr
*
const
data
)
{
bool
ConvertCellObjToFuncGraph
(
py
::
object
obj
,
ValuePtr
*
const
data
)
{
FuncGraphPtr
func_graph
=
ConvertToFuncGraph
(
obj
);
FuncGraphPtr
func_graph
=
ConvertToFuncGraph
(
obj
);
if
(
func_graph
==
nullptr
)
{
if
(
func_graph
==
nullptr
)
{
...
@@ -267,7 +267,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
...
@@ -267,7 +267,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
return
false
;
return
false
;
}
}
// if the cell object has specified bprop, it has user-defined bprop function parse and record it
// if the cell object has specified bprop, it has user-defined bprop function parse and record it
if
(
py
::
hasattr
(
obj
,
"bprop"
))
{
if
(
py
::
hasattr
(
obj
,
CUSTOM_BPROP_NAME
))
{
FuncGraphPtr
bprop_graph
=
nullptr
;
FuncGraphPtr
bprop_graph
=
nullptr
;
bool
enable_bprop_debug
=
py
::
cast
<
bool
>
(
py
::
getattr
(
obj
,
"bprop_debug"
));
bool
enable_bprop_debug
=
py
::
cast
<
bool
>
(
py
::
getattr
(
obj
,
"bprop_debug"
));
if
(
enable_bprop_debug
)
{
if
(
enable_bprop_debug
)
{
...
@@ -276,7 +276,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
...
@@ -276,7 +276,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
bprop_graph
=
ConvertToFuncGraph
(
obj
,
PYTHON_MOD_GET_BPROP_METHOD
);
bprop_graph
=
ConvertToFuncGraph
(
obj
,
PYTHON_MOD_GET_BPROP_METHOD
);
}
}
if
(
bprop_graph
!=
nullptr
)
{
if
(
bprop_graph
!=
nullptr
)
{
(
void
)
func_graph
->
transforms
().
insert
(
std
::
make_pair
(
"bprop"
,
FuncGraphTransform
(
bprop_graph
)));
(
void
)
func_graph
->
transforms
().
insert
(
std
::
make_pair
(
CUSTOM_BPROP_NAME
,
FuncGraphTransform
(
bprop_graph
)));
(
void
)
bprop_graph
->
transforms
().
insert
(
std
::
make_pair
(
"primal"
,
FuncGraphTransform
(
func_graph
)));
(
void
)
bprop_graph
->
transforms
().
insert
(
std
::
make_pair
(
"primal"
,
FuncGraphTransform
(
func_graph
)));
func_graph
->
set_flags
(
FUNC_GRAPH_FLAG_DEFER_INLINE
,
true
);
func_graph
->
set_flags
(
FUNC_GRAPH_FLAG_DEFER_INLINE
,
true
);
}
}
...
...
mindspore/ccsrc/pipeline/parse/data_converter.h
浏览文件 @
358982a9
...
@@ -51,6 +51,7 @@ void ClearObjectCache();
...
@@ -51,6 +51,7 @@ void ClearObjectCache();
}
// namespace data_converter
}
// namespace data_converter
ClassPtr
ParseDataClass
(
const
py
::
object
&
cls_obj
);
ClassPtr
ParseDataClass
(
const
py
::
object
&
cls_obj
);
FuncGraphPtr
ConvertToBpropCut
(
const
py
::
object
&
obj
);
void
CleanDataClassToClassMap
();
void
CleanDataClassToClassMap
();
...
...
mindspore/ccsrc/pipeline/parse/parse_base.h
浏览文件 @
358982a9
...
@@ -109,6 +109,7 @@ const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags";
...
@@ -109,6 +109,7 @@ const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags";
// define the parse constant
// define the parse constant
const
int
MAX_COMPARISON_OPS_SUPPORTED
=
1
;
const
int
MAX_COMPARISON_OPS_SUPPORTED
=
1
;
const
char
CUSTOM_BPROP_NAME
[]
=
"bprop"
;
// define the Namespace name
// define the Namespace name
const
char
RESOLVE_NAMESPACE_NAME_AST
[]
=
"Ast"
;
// for ast type namespace
const
char
RESOLVE_NAMESPACE_NAME_AST
[]
=
"Ast"
;
// for ast type namespace
...
...
mindspore/ccsrc/pynative/base.h
浏览文件 @
358982a9
...
@@ -45,7 +45,7 @@ enum PynativeStatusCode {
...
@@ -45,7 +45,7 @@ enum PynativeStatusCode {
PYNATIVE_UNKNOWN_STATE
=
0XFF
PYNATIVE_UNKNOWN_STATE
=
0XFF
};
};
enum
RunOpArgsEnum
{
PY_PRIM
=
0
,
PY_NAME
,
PY_INPUTS
,
PY_
INPUT_MASK
,
PY_
ARGS_NUM
};
enum
RunOpArgsEnum
{
PY_PRIM
=
0
,
PY_NAME
,
PY_INPUTS
,
PY_ARGS_NUM
};
struct
OpExecInfo
{
struct
OpExecInfo
{
PrimitivePyPtr
py_primitive
;
PrimitivePyPtr
py_primitive
;
...
...
mindspore/ccsrc/pynative/pynative_execute.cc
浏览文件 @
358982a9
...
@@ -110,9 +110,15 @@ py::object GetTupleObj(const py::object &obj) {
...
@@ -110,9 +110,15 @@ py::object GetTupleObj(const py::object &obj) {
return
obj_tuple
;
return
obj_tuple
;
}
}
void
ConvertInputs
(
const
PrimitivePyPtr
&
prim
,
const
py
::
list
&
args
,
py
::
tuple
*
out_args
)
{
py
::
tuple
ConvertInputs
(
const
PrimitivePyPtr
&
prim
,
const
py
::
list
&
args
,
py
::
tuple
*
out_args
)
{
auto
&
py_args
=
*
out_args
;
auto
&
py_args
=
*
out_args
;
py
::
tuple
input_mask
(
args
.
size
());
for
(
size_t
i
=
0
;
i
<
args
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
args
.
size
();
++
i
)
{
if
(
py
::
hasattr
(
args
[
i
],
"__parameter__"
))
{
input_mask
[
i
]
=
true
;
}
else
{
input_mask
[
i
]
=
false
;
}
py_args
[
i
]
=
GetTupleObj
(
args
[
i
]);
py_args
[
i
]
=
GetTupleObj
(
args
[
i
]);
}
}
auto
signature
=
prim
->
signatures
();
auto
signature
=
prim
->
signatures
();
...
@@ -121,7 +127,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *
...
@@ -121,7 +127,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *
[](
const
Signature
&
sig
)
{
return
sig
.
dtype
;
});
[](
const
Signature
&
sig
)
{
return
sig
.
dtype
;
});
int
empty_dtype_count
=
std
::
count
(
dtypes
.
begin
(),
dtypes
.
end
(),
SignatureEnumDType
::
kDTypeEmptyDefaultValue
);
int
empty_dtype_count
=
std
::
count
(
dtypes
.
begin
(),
dtypes
.
end
(),
SignatureEnumDType
::
kDTypeEmptyDefaultValue
);
if
(
dtypes
.
size
()
==
0
||
static_cast
<
int
>
(
dtypes
.
size
())
==
empty_dtype_count
)
{
if
(
dtypes
.
size
()
==
0
||
static_cast
<
int
>
(
dtypes
.
size
())
==
empty_dtype_count
)
{
return
;
return
input_mask
;
}
}
std
::
map
<
SignatureEnumDType
,
std
::
vector
<
size_t
>>
type_indexs
;
std
::
map
<
SignatureEnumDType
,
std
::
vector
<
size_t
>>
type_indexs
;
for
(
size_t
i
=
0
;
i
<
dtypes
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
dtypes
.
size
();
++
i
)
{
...
@@ -160,6 +166,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *
...
@@ -160,6 +166,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *
continue
;
continue
;
}
}
}
}
return
input_mask
;
}
}
void
PynativeInfer
(
const
PrimitivePyPtr
&
prim
,
const
py
::
list
&
py_args
,
OpExecInfo
*
const
op_exec_info
)
{
void
PynativeInfer
(
const
PrimitivePyPtr
&
prim
,
const
py
::
list
&
py_args
,
OpExecInfo
*
const
op_exec_info
)
{
...
@@ -167,7 +174,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
...
@@ -167,7 +174,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
AbstractBasePtrList
args_spec_list
;
AbstractBasePtrList
args_spec_list
;
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
ValuePtr
input_value
=
PyAttrValue
(
py_args
[
i
]);
ValuePtr
input_value
=
PyAttrValue
(
py_args
[
i
]);
if
(
input_value
->
isa
<
tensor
::
Tensor
>
())
{
if
(
!
py
::
hasattr
(
prim
->
GetPyObj
(),
"const_value"
)
&&
input_value
->
isa
<
tensor
::
Tensor
>
())
{
args_spec_list
.
emplace_back
(
abstract
::
FromValueInside
(
input_value
,
true
));
args_spec_list
.
emplace_back
(
abstract
::
FromValueInside
(
input_value
,
true
));
}
else
{
}
else
{
args_spec_list
.
emplace_back
(
abstract
::
FromValueInside
(
input_value
,
false
));
args_spec_list
.
emplace_back
(
abstract
::
FromValueInside
(
input_value
,
false
));
...
@@ -179,7 +186,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
...
@@ -179,7 +186,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
OpExecInfoPtr
GenerateOpExecInfo
(
const
py
::
args
&
args
)
{
OpExecInfoPtr
GenerateOpExecInfo
(
const
py
::
args
&
args
)
{
if
(
args
.
size
()
!=
PY_ARGS_NUM
)
{
if
(
args
.
size
()
!=
PY_ARGS_NUM
)
{
MS_LOG
(
ERROR
)
<<
"
Four
args are needed by RunOp"
;
MS_LOG
(
ERROR
)
<<
"
Three
args are needed by RunOp"
;
return
nullptr
;
return
nullptr
;
}
}
auto
op_exec_info
=
std
::
make_shared
<
OpExecInfo
>
();
auto
op_exec_info
=
std
::
make_shared
<
OpExecInfo
>
();
...
@@ -195,14 +202,13 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
...
@@ -195,14 +202,13 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
size_t
input_num
=
a
.
size
();
size_t
input_num
=
a
.
size
();
op_exec_info
->
op_inputs
=
py
::
tuple
(
input_num
);
op_exec_info
->
op_inputs
=
py
::
tuple
(
input_num
);
ConvertInputs
(
prim
,
args
[
PY_INPUTS
],
&
op_exec_info
->
op_inputs
);
op_exec_info
->
inputs_mask
=
ConvertInputs
(
prim
,
args
[
PY_INPUTS
],
&
op_exec_info
->
op_inputs
);
// use python infer method
// use python infer method
if
(
ignore_infer_prim
.
find
(
op_exec_info
->
op_name
)
==
ignore_infer_prim
.
end
())
{
if
(
ignore_infer_prim
.
find
(
op_exec_info
->
op_name
)
==
ignore_infer_prim
.
end
())
{
PynativeInfer
(
prim
,
op_exec_info
->
op_inputs
,
op_exec_info
.
get
());
PynativeInfer
(
prim
,
op_exec_info
->
op_inputs
,
op_exec_info
.
get
());
}
}
op_exec_info
->
py_primitive
=
prim
;
op_exec_info
->
py_primitive
=
prim
;
op_exec_info
->
op_attrs
=
py
::
getattr
(
args
[
PY_PRIM
],
"attrs"
);
op_exec_info
->
op_attrs
=
py
::
getattr
(
args
[
PY_PRIM
],
"attrs"
);
op_exec_info
->
inputs_mask
=
args
[
PY_INPUT_MASK
];
if
(
op_exec_info
->
op_inputs
.
size
()
!=
op_exec_info
->
inputs_mask
.
size
())
{
if
(
op_exec_info
->
op_inputs
.
size
()
!=
op_exec_info
->
inputs_mask
.
size
())
{
MS_LOG
(
ERROR
)
<<
"Op:"
<<
op_exec_info
->
op_name
<<
" inputs size not equal op_mask"
;
MS_LOG
(
ERROR
)
<<
"Op:"
<<
op_exec_info
->
op_name
<<
" inputs size not equal op_mask"
;
return
nullptr
;
return
nullptr
;
...
@@ -488,14 +494,14 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn
...
@@ -488,14 +494,14 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn
return
result
;
return
result
;
}
}
AnfNodePtr
PynativeExecutor
::
MakeCNode
(
const
py
::
args
&
args
,
const
py
::
tuple
&
out
)
{
AnfNodePtr
PynativeExecutor
::
MakeCNode
(
const
OpExecInfoPtr
&
op_exec_info
,
const
py
::
args
&
args
,
const
py
::
tuple
&
out
)
{
if
(
!
grad_flag_
||
graph_info_map_
.
size
()
==
0
)
{
if
(
!
grad_flag_
||
graph_info_map_
.
size
()
==
0
)
{
return
nullptr
;
return
nullptr
;
}
}
std
::
vector
<
AnfNodePtr
>
inputs
;
std
::
vector
<
AnfNodePtr
>
inputs
;
auto
prim
=
py
::
cast
<
PrimitivePyPtr
>
(
args
[
PY_PRIM
])
;
auto
prim
=
op_exec_info
->
py_primitive
;
inputs
.
push_back
(
NewValueNode
(
prim
));
inputs
.
push_back
(
NewValueNode
(
prim
));
py
::
tuple
op_masks
=
args
[
PY_INPUT_MASK
]
;
py
::
tuple
op_masks
=
op_exec_info
->
inputs_mask
;
py
::
list
op_args
=
args
[
PY_INPUTS
];
py
::
list
op_args
=
args
[
PY_INPUTS
];
AbstractBasePtrList
args_spec_list
;
AbstractBasePtrList
args_spec_list
;
for
(
size_t
i
=
0
;
i
<
op_args
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
op_args
.
size
();
i
++
)
{
...
@@ -584,7 +590,7 @@ py::tuple RunOp(const py::args &args) {
...
@@ -584,7 +590,7 @@ py::tuple RunOp(const py::args &args) {
return
err_ret
;
return
err_ret
;
}
}
auto
node
=
PynativeExecutor
::
GetInstance
()
->
MakeCNode
(
args
,
result
);
auto
node
=
PynativeExecutor
::
GetInstance
()
->
MakeCNode
(
op_exec_info
,
args
,
result
);
if
(
node
!=
nullptr
)
{
if
(
node
!=
nullptr
)
{
node
->
set_abstract
(
op_exec_info
->
abstract
);
node
->
set_abstract
(
op_exec_info
->
abstract
);
MS_LOG
(
DEBUG
)
<<
"RunOp MakeCnode,new node is: "
<<
node
->
DebugString
();
MS_LOG
(
DEBUG
)
<<
"RunOp MakeCnode,new node is: "
<<
node
->
DebugString
();
...
@@ -705,7 +711,7 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
...
@@ -705,7 +711,7 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
}
}
cell_graph_map_
[
cell_id
]
=
curr_g_
;
cell_graph_map_
[
cell_id
]
=
curr_g_
;
auto
out_id
=
GetId
(
out
);
auto
out_id
=
GetId
(
out
);
if
(
!
graph_info_map_
[
curr_g_
].
obj_node_map
.
count
(
out_id
))
{
if
(
!
graph_info_map_
[
curr_g_
].
obj_node_map
.
count
(
out_id
)
&&
!
graph_info_map_
[
curr_g_
].
param_map
.
count
(
out_id
)
)
{
// cell construct return x, y
// cell construct return x, y
if
(
py
::
isinstance
<
py
::
tuple
>
(
out
))
{
if
(
py
::
isinstance
<
py
::
tuple
>
(
out
))
{
std
::
vector
<
AnfNodePtr
>
args
;
std
::
vector
<
AnfNodePtr
>
args
;
...
@@ -727,12 +733,26 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
...
@@ -727,12 +733,26 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
}
}
}
}
auto
output_node
=
GetObjNode
(
out
);
AnfNodePtr
output_node
;
if
(
graph_info_map_
[
curr_g_
].
param_map
.
count
(
out_id
))
{
output_node
=
graph_info_map_
[
curr_g_
].
param_map
[
out_id
];
}
else
{
output_node
=
GetObjNode
(
out
);
}
curr_g_
->
set_output
(
output_node
);
curr_g_
->
set_output
(
output_node
);
std
::
vector
<
AnfNodePtr
>
inputs
;
std
::
vector
<
AnfNodePtr
>
inputs
;
inputs
.
push_back
(
NewValueNode
(
curr_g_
));
inputs
.
push_back
(
NewValueNode
(
curr_g_
));
MS_LOG
(
DEBUG
)
<<
"Current graph"
<<
curr_g_
->
output
()
->
DebugString
();
MS_LOG
(
DEBUG
)
<<
"Current graph"
<<
curr_g_
->
output
()
->
DebugString
();
resource_
->
manager
()
->
AddFuncGraph
(
curr_g_
);
resource_
->
manager
()
->
AddFuncGraph
(
curr_g_
);
// custom bprop debug
if
(
py
::
hasattr
(
cell
,
parse
::
CUSTOM_BPROP_NAME
))
{
MS_LOG
(
DEBUG
)
<<
"Use cell custom bprop function."
;
FuncGraphPtr
bprop_graph
=
parse
::
ConvertToBpropCut
(
cell
);
if
(
bprop_graph
!=
nullptr
)
{
(
void
)
curr_g_
->
transforms
().
insert
(
std
::
make_pair
(
parse
::
CUSTOM_BPROP_NAME
,
FuncGraphTransform
(
bprop_graph
)));
(
void
)
bprop_graph
->
transforms
().
insert
(
std
::
make_pair
(
"primal"
,
FuncGraphTransform
(
curr_g_
)));
}
}
auto
newfg
=
ad
::
Grad
(
curr_g_
,
resource_
,
curr_g_
==
top_g_
);
auto
newfg
=
ad
::
Grad
(
curr_g_
,
resource_
,
curr_g_
==
top_g_
);
if
(
curr_g_
!=
top_g_
)
{
if
(
curr_g_
!=
top_g_
)
{
Popp
();
Popp
();
...
...
mindspore/ccsrc/pynative/pynative_execute.h
浏览文件 @
358982a9
...
@@ -44,7 +44,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
...
@@ -44,7 +44,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
py
::
tuple
RunOp
(
const
py
::
args
&
args
);
py
::
tuple
RunOp
(
const
py
::
args
&
args
);
void
ConvertInputs
(
const
PrimitivePyPtr
&
prim
,
const
py
::
list
&
py_args
,
py
::
tuple
*
out_args
);
py
::
tuple
ConvertInputs
(
const
PrimitivePyPtr
&
prim
,
const
py
::
list
&
py_args
,
py
::
tuple
*
out_args
);
void
ClearPyNativeSession
();
void
ClearPyNativeSession
();
...
@@ -83,7 +83,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
...
@@ -83,7 +83,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void
set_obj_node_map
(
FuncGraphPtr
g
,
const
std
::
string
obj
,
AnfNodePtr
node
,
int
index
)
{
void
set_obj_node_map
(
FuncGraphPtr
g
,
const
std
::
string
obj
,
AnfNodePtr
node
,
int
index
)
{
graph_info_map_
[
g
].
obj_node_map
[
obj
]
=
std
::
make_pair
(
node
,
index
);
graph_info_map_
[
g
].
obj_node_map
[
obj
]
=
std
::
make_pair
(
node
,
index
);
}
}
AnfNodePtr
MakeCNode
(
const
py
::
args
&
args
,
const
py
::
tuple
&
out
);
AnfNodePtr
MakeCNode
(
const
OpExecInfoPtr
&
op_exec_info
,
const
py
::
args
&
args
,
const
py
::
tuple
&
out
);
py
::
object
Run
(
const
py
::
tuple
&
args
,
const
py
::
object
&
phase
);
py
::
object
Run
(
const
py
::
tuple
&
args
,
const
py
::
object
&
phase
);
void
Pushp
();
void
Pushp
();
...
...
mindspore/common/_register_for_tensor.py
浏览文件 @
358982a9
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
"""Registry the relation."""
"""Registry the relation."""
from
collections
import
UserDict
from
collections
import
UserDict
from
..
import
context
class
Registry
(
UserDict
):
class
Registry
(
UserDict
):
...
@@ -27,9 +28,16 @@ class Registry(UserDict):
...
@@ -27,9 +28,16 @@ class Registry(UserDict):
def
get
(
self
,
obj_str
):
def
get
(
self
,
obj_str
):
"""Get the value by str."""
"""Get the value by str."""
if
isinstance
(
obj_str
,
str
):
if
not
isinstance
(
obj_str
,
str
):
raise
TypeError
(
"key for tensor registry must be string."
)
if
context
.
get_context
(
"enable_ge"
):
def
wrap
(
*
args
):
new_args
=
list
(
args
)
new_args
.
append
(
obj_str
)
return
self
[
"vm_compare"
](
*
new_args
)
obj
=
wrap
else
:
obj
=
self
[
obj_str
]
obj
=
self
[
obj_str
]
return
obj
return
obj
tensor_operator_registry
=
Registry
()
tensor_operator_registry
=
Registry
()
mindspore/common/tensor.py
浏览文件 @
358982a9
...
@@ -19,7 +19,6 @@ from .._c_expression import Tensor as Tensor_
...
@@ -19,7 +19,6 @@ from .._c_expression import Tensor as Tensor_
from
.._c_expression
import
MetaTensor
from
.._c_expression
import
MetaTensor
from
.._checkparam
import
check_type
,
check_typename
from
.._checkparam
import
check_type
,
check_typename
from
.
import
dtype
as
mstype
from
.
import
dtype
as
mstype
from
..
import
context
from
._register_for_tensor
import
tensor_operator_registry
from
._register_for_tensor
import
tensor_operator_registry
__all__
=
[
'Tensor'
,
'MetaTensor'
]
__all__
=
[
'Tensor'
,
'MetaTensor'
]
...
@@ -76,17 +75,19 @@ class Tensor(Tensor_):
...
@@ -76,17 +75,19 @@ class Tensor(Tensor_):
return
out
return
out
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
if
not
isinstance
(
other
,
Tensor
):
if
not
isinstance
(
other
,
(
int
,
float
,
Tensor
)
):
return
False
return
False
# The GE backend don't support single `Equal` operator execution.
# bool type is not supported for `Equal` operator in backend.
# bool type is not supported for `Equal` operator in backend.
if
context
.
get_context
(
"enable_ge"
)
or
self
.
dtype
==
mstype
.
bool_
or
other
.
dtype
==
mstype
.
bool_
:
if
self
.
dtype
==
mstype
.
bool_
or
(
isinstance
(
other
,
Tensor
)
and
other
.
dtype
==
mstype
.
bool_
)
:
return
Tensor
(
np
.
array
(
self
.
asnumpy
()
==
other
.
asnumpy
()))
return
Tensor
(
np
.
array
(
self
.
asnumpy
()
==
other
.
asnumpy
()))
return
tensor_operator_registry
.
get
(
'__eq__'
)(
self
,
other
)
return
tensor_operator_registry
.
get
(
'__eq__'
)(
self
,
other
)
def
__ne__
(
self
,
other
):
def
__ne__
(
self
,
other
):
if
not
isinstance
(
other
,
Tensor
):
if
not
isinstance
(
other
,
(
int
,
float
,
Tensor
)
):
return
True
return
True
# bool type is not supported for `NotEqual` operator in backend.
if
self
.
dtype
==
mstype
.
bool_
or
(
isinstance
(
other
,
Tensor
)
and
other
.
dtype
==
mstype
.
bool_
):
return
Tensor
(
np
.
array
(
self
.
asnumpy
()
!=
other
.
asnumpy
()))
return
tensor_operator_registry
.
get
(
'__ne__'
)(
self
,
other
)
return
tensor_operator_registry
.
get
(
'__ne__'
)(
self
,
other
)
def
__hash__
(
self
):
def
__hash__
(
self
):
...
@@ -105,7 +106,7 @@ class Tensor(Tensor_):
...
@@ -105,7 +106,7 @@ class Tensor(Tensor_):
return
out
return
out
def
__radd__
(
self
,
other
):
def
__radd__
(
self
,
other
):
out
=
tensor_operator_registry
.
get
(
'__add__'
)(
other
,
self
)
out
=
tensor_operator_registry
.
get
(
'__add__'
)(
self
,
other
)
return
out
return
out
def
__imul__
(
self
,
other
):
def
__imul__
(
self
,
other
):
...
@@ -113,15 +114,15 @@ class Tensor(Tensor_):
...
@@ -113,15 +114,15 @@ class Tensor(Tensor_):
return
out
return
out
def
__rmul__
(
self
,
other
):
def
__rmul__
(
self
,
other
):
out
=
tensor_operator_registry
.
get
(
'__mul__'
)(
other
,
self
)
out
=
tensor_operator_registry
.
get
(
'__mul__'
)(
self
,
other
)
return
out
return
out
def
__truediv__
(
self
,
other
):
def
__truediv__
(
self
,
other
):
out
=
tensor_operator_registry
.
get
(
'__div__'
)(
self
,
other
)
out
=
tensor_operator_registry
.
get
(
'__
true
div__'
)(
self
,
other
)
return
out
return
out
def
__rtruediv__
(
self
,
other
):
def
__rtruediv__
(
self
,
other
):
out
=
tensor_operator_registry
.
get
(
'__div__'
)(
other
,
self
)
out
=
tensor_operator_registry
.
get
(
'__
true
div__'
)(
other
,
self
)
return
out
return
out
def
__sub__
(
self
,
other
):
def
__sub__
(
self
,
other
):
...
@@ -160,7 +161,7 @@ class Tensor(Tensor_):
...
@@ -160,7 +161,7 @@ class Tensor(Tensor_):
return
out
return
out
def
__len__
(
self
):
def
__len__
(
self
):
out
=
tensor_operator_registry
.
get
(
'
__shape__
'
)(
self
)
out
=
tensor_operator_registry
.
get
(
'
shape
'
)(
self
)
if
not
out
:
if
not
out
:
return
1
return
1
return
out
[
0
]
return
out
[
0
]
...
...
mindspore/nn/cell.py
浏览文件 @
358982a9
...
@@ -819,4 +819,4 @@ class Cell:
...
@@ -819,4 +819,4 @@ class Cell:
"""
"""
self
.
_backward_hook
=
HookBackward
(
fn
,
self
.
cls_name
+
"("
+
str
(
id
(
self
))
+
")"
)
self
.
_backward_hook
=
HookBackward
(
fn
,
self
.
cls_name
+
"("
+
str
(
id
(
self
))
+
")"
)
self
.
_
enable_hook
=
True
self
.
enable_hook
=
True
mindspore/nn/layer/container.py
浏览文件 @
358982a9
...
@@ -140,6 +140,11 @@ class SequentialCell(Cell):
...
@@ -140,6 +140,11 @@ class SequentialCell(Cell):
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
_cells
)
return
len
(
self
.
_cells
)
def
set_grad
(
self
,
flag
=
True
):
self
.
requires_grad
=
flag
for
cell
in
self
.
_cells
.
values
():
cell
.
set_grad
(
flag
)
def
construct
(
self
,
input_data
):
def
construct
(
self
,
input_data
):
for
cell
in
self
.
cell_list
:
for
cell
in
self
.
cell_list
:
input_data
=
cell
(
input_data
)
input_data
=
cell
(
input_data
)
...
@@ -246,5 +251,10 @@ class CellList(_CellListBase, Cell):
...
@@ -246,5 +251,10 @@ class CellList(_CellListBase, Cell):
self
.
_cells
[
str
(
len
(
self
))]
=
cell
self
.
_cells
[
str
(
len
(
self
))]
=
cell
return
self
return
self
def
set_grad
(
self
,
flag
=
True
):
self
.
requires_grad
=
flag
for
cell
in
self
.
_cells
.
values
():
cell
.
set_grad
(
flag
)
def
construct
(
self
,
*
inputs
):
def
construct
(
self
,
*
inputs
):
raise
NotImplementedError
raise
NotImplementedError
mindspore/ops/composite/base.py
浏览文件 @
358982a9
...
@@ -112,7 +112,7 @@ class GradOperation(GradOperation_):
...
@@ -112,7 +112,7 @@ class GradOperation(GradOperation_):
grad_
=
GradOperation
(
'grad'
,
self
.
get_all
,
self
.
get_by_list
,
self
.
sens_param
)
grad_
=
GradOperation
(
'grad'
,
self
.
get_all
,
self
.
get_by_list
,
self
.
sens_param
)
if
self
.
grad_fn
is
None
or
self
.
fn
!=
fn
:
if
self
.
grad_fn
is
None
or
self
.
fn
!=
fn
:
if
self
.
get_by_list
:
if
self
.
get_by_list
:
if
context
.
get_context
(
"mode"
)
==
context
.
GRAPH_MODE
or
fn
.
bprop_debug
:
if
context
.
get_context
(
"mode"
)
==
context
.
GRAPH_MODE
:
@
ms_function
(
obj
=
fn
)
@
ms_function
(
obj
=
fn
)
def
after_grad
(
*
args
):
def
after_grad
(
*
args
):
return
grad_
(
fn
,
weights
)(
*
args
)
return
grad_
(
fn
,
weights
)(
*
args
)
...
...
mindspore/ops/functional.py
浏览文件 @
358982a9
...
@@ -21,6 +21,7 @@ from mindspore.common._register_for_tensor import tensor_operator_registry
...
@@ -21,6 +21,7 @@ from mindspore.common._register_for_tensor import tensor_operator_registry
from
.primitive
import
Primitive
from
.primitive
import
Primitive
from
.
import
operations
as
P
from
.
import
operations
as
P
from
.operations
import
_grad_ops
from
.operations
import
_grad_ops
from
.._extends
import
builtin_operations
as
BP
typeof
=
Primitive
(
'typeof'
)
typeof
=
Primitive
(
'typeof'
)
hastype
=
Primitive
(
'hastype'
)
hastype
=
Primitive
(
'hastype'
)
...
@@ -155,7 +156,7 @@ stop_gradient = Primitive("stop_gradient")
...
@@ -155,7 +156,7 @@ stop_gradient = Primitive("stop_gradient")
tensor_operator_registry
.
register
(
'__add__'
,
tensor_add
)
tensor_operator_registry
.
register
(
'__add__'
,
tensor_add
)
tensor_operator_registry
.
register
(
'__sub__'
,
tensor_sub
)
tensor_operator_registry
.
register
(
'__sub__'
,
tensor_sub
)
tensor_operator_registry
.
register
(
'__mul__'
,
tensor_mul
)
tensor_operator_registry
.
register
(
'__mul__'
,
tensor_mul
)
tensor_operator_registry
.
register
(
'__div__'
,
tensor_div
)
tensor_operator_registry
.
register
(
'__
true
div__'
,
tensor_div
)
#ms cannot support Tensor(True) compare
#ms cannot support Tensor(True) compare
tensor_operator_registry
.
register
(
'__eq__'
,
equal
)
tensor_operator_registry
.
register
(
'__eq__'
,
equal
)
tensor_operator_registry
.
register
(
'__ne__'
,
not_equal
)
tensor_operator_registry
.
register
(
'__ne__'
,
not_equal
)
...
@@ -164,4 +165,6 @@ tensor_operator_registry.register('__lt__', tensor_lt)
...
@@ -164,4 +165,6 @@ tensor_operator_registry.register('__lt__', tensor_lt)
tensor_operator_registry
.
register
(
'__le__'
,
tensor_le
)
tensor_operator_registry
.
register
(
'__le__'
,
tensor_le
)
tensor_operator_registry
.
register
(
'__gt__'
,
tensor_gt
)
tensor_operator_registry
.
register
(
'__gt__'
,
tensor_gt
)
tensor_operator_registry
.
register
(
'__ge__'
,
tensor_ge
)
tensor_operator_registry
.
register
(
'__ge__'
,
tensor_ge
)
tensor_operator_registry
.
register
(
'__shape__'
,
shape
)
tensor_operator_registry
.
register
(
'shape'
,
shape
)
#support GE backend for no compare operators
tensor_operator_registry
.
register
(
'vm_compare'
,
BP
.
vm_compare
)
mindspore/ops/operations/array_ops.py
浏览文件 @
358982a9
...
@@ -933,6 +933,8 @@ class TupleToArray(PrimitiveWithInfer):
...
@@ -933,6 +933,8 @@ class TupleToArray(PrimitiveWithInfer):
args
=
list
()
args
=
list
()
if
isinstance
(
x
,
range
):
if
isinstance
(
x
,
range
):
args
.
append
(
tuple
(
x
))
args
.
append
(
tuple
(
x
))
else
:
args
.
append
(
x
)
return
_run_op
(
self
,
self
.
name
,
args
)
return
_run_op
(
self
,
self
.
name
,
args
)
...
...
mindspore/ops/primitive.py
浏览文件 @
358982a9
...
@@ -341,13 +341,7 @@ def constexpr(fn=None, get_instance=True, name=None):
...
@@ -341,13 +341,7 @@ def constexpr(fn=None, get_instance=True, name=None):
@
_wrap_func
@
_wrap_func
def
_run_op
(
obj
,
op_name
,
args
):
def
_run_op
(
obj
,
op_name
,
args
):
"""Single op execution function supported by ge in PyNative mode."""
"""Single op execution function supported by ge in PyNative mode."""
op_mask
=
[
0
]
*
len
(
args
)
output
=
real_run_op
(
obj
,
op_name
,
args
)
op_inputs
=
[]
for
i
,
arg
in
enumerate
(
args
):
if
hasattr
(
arg
,
'__parameter__'
):
op_mask
[
i
]
=
1
op_inputs
.
append
(
arg
)
output
=
real_run_op
(
obj
,
op_name
,
args
,
tuple
(
op_mask
))
if
not
output
:
if
not
output
:
raise
RuntimeError
(
"Pynative run op %s failed!"
%
op_name
)
raise
RuntimeError
(
"Pynative run op %s failed!"
%
op_name
)
if
len
(
output
)
==
1
:
if
len
(
output
)
==
1
:
...
...
tests/ut/cpp/pynative/pynative_execute_test.cc
浏览文件 @
358982a9
...
@@ -63,8 +63,7 @@ OpExecInfoPtr ConstructOpExecInfo() {
...
@@ -63,8 +63,7 @@ OpExecInfoPtr ConstructOpExecInfo() {
auto
conv_obj
=
prim
::
GetPythonOps
(
"conv2d_prim"
,
"gtest_input.pynative"
);
auto
conv_obj
=
prim
::
GetPythonOps
(
"conv2d_prim"
,
"gtest_input.pynative"
);
py
::
none
py_none
;
py
::
none
py_none
;
py
::
tuple
op_mask
=
py
::
make_tuple
(
0
,
1
);
return
GenerateOpExecInfo
(
py
::
make_tuple
(
conv_obj
,
op_name
,
op_inputs
));
return
GenerateOpExecInfo
(
py
::
make_tuple
(
conv_obj
,
op_name
,
op_inputs
,
op_mask
));
}
}
TEST_F
(
TestPynativeExecute
,
TestRunOpInVM
)
{
TEST_F
(
TestPynativeExecute
,
TestRunOpInVM
)
{
...
@@ -79,7 +78,7 @@ TEST_F(TestPynativeExecute, TestRunOp) {
...
@@ -79,7 +78,7 @@ TEST_F(TestPynativeExecute, TestRunOp) {
py
::
none
py_none
;
py
::
none
py_none
;
auto
op_exec_info_ptr
=
ConstructOpExecInfo
();
auto
op_exec_info_ptr
=
ConstructOpExecInfo
();
py
::
tuple
outputs
=
pynative
::
RunOp
(
py
::
make_tuple
(
op_exec_info_ptr
->
py_primitive
,
op_exec_info_ptr
->
op_name
,
py
::
tuple
outputs
=
pynative
::
RunOp
(
py
::
make_tuple
(
op_exec_info_ptr
->
py_primitive
,
op_exec_info_ptr
->
op_name
,
op_exec_info_ptr
->
op_inputs
,
op_exec_info_ptr
->
inputs_mask
));
op_exec_info_ptr
->
op_inputs
));
if
(
outputs
.
size
()
==
0
)
{
if
(
outputs
.
size
()
==
0
)
{
FAIL
();
FAIL
();
}
else
{
}
else
{
...
...
tests/ut/python/ir/test_tensor.py
浏览文件 @
358982a9
...
@@ -452,5 +452,5 @@ def test_tensor_operation():
...
@@ -452,5 +452,5 @@ def test_tensor_operation():
assert
np
.
all
(
res
.
asnumpy
()
==
np
.
ones
((
3
,
3
))
*
2
)
assert
np
.
all
(
res
.
asnumpy
()
==
np
.
ones
((
3
,
3
))
*
2
)
res
=
8
/
x
res
=
8
/
x
assert
np
.
all
(
res
.
asnumpy
()
==
np
.
ones
((
3
,
3
))
*
2
)
assert
np
.
all
(
res
.
asnumpy
()
==
np
.
ones
((
3
,
3
))
*
2
)
with
pytest
.
raises
(
Typ
eError
):
with
pytest
.
raises
(
Valu
eError
):
res
=
x
*
(
2
,
3
)
res
=
x
*
(
2
,
3
)
tests/ut/python/pynative_mode/test_hook.py
浏览文件 @
358982a9
...
@@ -8,6 +8,9 @@ from mindspore.nn import WithLossCell, Momentum
...
@@ -8,6 +8,9 @@ from mindspore.nn import WithLossCell, Momentum
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
composite
as
C
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"GPU"
)
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"GPU"
)
cell_hook_done
=
False
var_hook_done
=
False
cell_bprop_done
=
False
def
conv
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
):
def
conv
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
):
...
@@ -32,15 +35,35 @@ def weight_variable():
...
@@ -32,15 +35,35 @@ def weight_variable():
def
cell_hook_function
(
cell_id
,
grad_input
,
grad_output
):
def
cell_hook_function
(
cell_id
,
grad_input
,
grad_output
):
print
(
cell_id
)
print
(
cell_id
)
global
cell_hook_done
cell_hook_done
=
True
assert
(
grad_output
[
0
].
asnumpy
().
shape
==
(
32
,
6
,
14
,
14
))
assert
(
grad_output
[
0
].
asnumpy
().
shape
==
(
32
,
6
,
14
,
14
))
assert
(
grad_input
[
0
].
asnumpy
().
shape
==
(
32
,
16
,
10
,
10
))
assert
(
grad_input
[
0
].
asnumpy
().
shape
==
(
32
,
16
,
10
,
10
))
def
var_hook_function
(
grad_out
):
def
var_hook_function
(
grad_out
):
print
(
"grad:"
,
grad_out
)
print
(
"grad:"
,
grad_out
)
global
var_hook_done
var_hook_done
=
True
assert
(
grad_out
[
0
].
asnumpy
().
shape
==
(
32
,
120
))
assert
(
grad_out
[
0
].
asnumpy
().
shape
==
(
32
,
120
))
class
Block
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Block
,
self
).
__init__
()
self
.
relu
=
nn
.
ReLU
()
def
construct
(
self
,
x
):
x
=
self
.
relu
(
x
)
return
x
def
bprop
(
self
,
x
,
out
,
dout
):
global
cell_bprop_done
cell_bprop_done
=
True
grad
=
out
.
asnumpy
()
*
dout
.
asnumpy
()
grad
=
Tensor
(
grad
)
return
(
grad
,)
class
LeNet5
(
nn
.
Cell
):
class
LeNet5
(
nn
.
Cell
):
"""
"""
Lenet network
Lenet network
...
@@ -59,6 +82,7 @@ class LeNet5(nn.Cell):
...
@@ -59,6 +82,7 @@ class LeNet5(nn.Cell):
self
.
conv1
=
conv
(
1
,
6
,
5
)
self
.
conv1
=
conv
(
1
,
6
,
5
)
self
.
conv2
=
conv
(
6
,
16
,
5
)
self
.
conv2
=
conv
(
6
,
16
,
5
)
self
.
conv2
.
register_backward_hook
(
cell_hook_function
)
self
.
conv2
.
register_backward_hook
(
cell_hook_function
)
self
.
block
=
Block
()
self
.
fc1
=
fc_with_initialize
(
16
*
5
*
5
,
120
)
self
.
fc1
=
fc_with_initialize
(
16
*
5
*
5
,
120
)
self
.
fc2
=
fc_with_initialize
(
120
,
84
)
self
.
fc2
=
fc_with_initialize
(
120
,
84
)
self
.
fc3
=
fc_with_initialize
(
84
,
self
.
num_class
)
self
.
fc3
=
fc_with_initialize
(
84
,
self
.
num_class
)
...
@@ -72,7 +96,7 @@ class LeNet5(nn.Cell):
...
@@ -72,7 +96,7 @@ class LeNet5(nn.Cell):
x
=
self
.
relu
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
block
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
reshape
(
x
,
(
self
.
batch_size
,
-
1
))
x
=
self
.
reshape
(
x
,
(
self
.
batch_size
,
-
1
))
x
=
self
.
fc1
(
x
)
x
=
self
.
fc1
(
x
)
...
@@ -110,6 +134,9 @@ def test_hook():
...
@@ -110,6 +134,9 @@ def test_hook():
loss_output
=
criterion
(
output
,
label
)
loss_output
=
criterion
(
output
,
label
)
grads
=
train_network
(
input_data
,
label
)
grads
=
train_network
(
input_data
,
label
)
success
=
optimizer
(
grads
)
success
=
optimizer
(
grads
)
assert
cell_hook_done
assert
var_hook_done
assert
cell_bprop_done
print
(
loss_output
.
asnumpy
().
shape
)
print
(
loss_output
.
asnumpy
().
shape
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录