Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0b191615
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0b191615
编写于
5月 26, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 26, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1427 fix check bprop attr error
Merge pull request !1427 from panyifeng/fix_check_bprop_attr_error
上级
963f7ee5
6a57eeb9
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
38 addition
and
8 deletion
+38
-8
mindspore/ccsrc/optimizer/ad/kprim.cc
mindspore/ccsrc/optimizer/ad/kprim.cc
+14
-4
mindspore/ccsrc/pipeline/init.cc
mindspore/ccsrc/pipeline/init.cc
+3
-1
mindspore/ccsrc/utils/context/ms_context.h
mindspore/ccsrc/utils/context/ms_context.h
+3
-0
mindspore/context.py
mindspore/context.py
+10
-1
mindspore/ops/operations/other_ops.py
mindspore/ops/operations/other_ops.py
+2
-1
tests/ut/python/pynative_mode/test_cell_bprop.py
tests/ut/python/pynative_mode/test_cell_bprop.py
+3
-0
tests/ut/python/pynative_mode/test_framstruct.py
tests/ut/python/pynative_mode/test_framstruct.py
+3
-1
未找到文件。
mindspore/ccsrc/optimizer/ad/kprim.cc
浏览文件 @
0b191615
...
...
@@ -32,6 +32,7 @@
#include "operator/composite/composite.h"
#include "utils/symbolic.h"
#include "utils/primitive_utils.h"
#include "utils/context/ms_context.h"
#include "debug/info.h"
#include "debug/trace.h"
...
...
@@ -181,10 +182,19 @@ void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bp
}
void
KPrim
::
CheckBprop
(
const
FuncGraphPtr
&
bprop_fg
,
const
string
&
prim_to_check
)
{
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
bool
check_bprop_flag
=
context
->
check_bprop_flag
();
// Skip checking if check_bprop not set
if
(
!
check_bprop_flag
)
{
return
;
}
// bprop_fg has been checked in caller
auto
check_bprop
=
prim
::
GetPythonOps
(
"check_bprop"
,
"mindspore.ops.functional"
)
->
cast
<
PrimitivePtr
>
();
MS_EXCEPTION_IF_NULL
(
check_bprop
);
check_bprop
->
set_attr
(
"prim_to_check"
,
std
::
make_shared
<
StringImm
>
(
prim_to_check
));
auto
check_bprop_class
=
prim
::
GetPythonOps
(
"CheckBprop"
,
"mindspore.ops.operations.other_ops"
);
MS_EXCEPTION_IF_NULL
(
check_bprop_class
);
auto
check_bprop
=
bprop_fg
->
NewCNode
({
NewValueNode
(
check_bprop_class
),
NewValueNode
(
std
::
make_shared
<
StringImm
>
(
prim_to_check
))});
std
::
vector
<
AnfNodePtr
>
inputs
;
inputs
.
emplace_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
...
...
@@ -192,7 +202,7 @@ void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check
AnfNodePtr
params
=
bprop_fg
->
NewCNode
(
inputs
);
inputs
.
clear
();
inputs
.
push_back
(
NewValueNode
(
check_bprop
)
);
inputs
.
push_back
(
check_bprop
);
inputs
.
push_back
(
bprop_fg
->
output
());
inputs
.
push_back
(
params
);
AnfNodePtr
bprop_out
=
bprop_fg
->
NewCNode
(
inputs
);
...
...
mindspore/ccsrc/pipeline/init.cc
浏览文件 @
0b191615
...
...
@@ -141,7 +141,9 @@ PYBIND11_MODULE(_c_expression, m) {
.
def
(
"get_enable_profiling"
,
&
mindspore
::
MsContext
::
enable_profiling
,
"Get whether to open profiling."
)
.
def
(
"set_enable_profiling"
,
&
mindspore
::
MsContext
::
set_enable_profiling
,
"Set whether to open profiling."
)
.
def
(
"get_profiling_options"
,
&
mindspore
::
MsContext
::
profiling_options
,
"Get options to profiling."
)
.
def
(
"set_profiling_options"
,
&
mindspore
::
MsContext
::
set_profiling_options
,
"Set options to profiling."
);
.
def
(
"set_profiling_options"
,
&
mindspore
::
MsContext
::
set_profiling_options
,
"Set options to profiling."
)
.
def
(
"get_check_bprop_flag"
,
&
mindspore
::
MsContext
::
check_bprop_flag
,
"Get whether to check bprop."
)
.
def
(
"set_check_bprop_flag"
,
&
mindspore
::
MsContext
::
set_check_bprop_flag
,
"Set whether to check bprop."
);
(
void
)
py
::
class_
<
ParallelContext
,
std
::
shared_ptr
<
ParallelContext
>>
(
m
,
"AutoParallelContext"
)
.
def_static
(
"get_instance"
,
&
ParallelContext
::
GetInstance
,
"Get auto parallel context instance."
)
...
...
mindspore/ccsrc/utils/context/ms_context.h
浏览文件 @
0b191615
...
...
@@ -140,6 +140,8 @@ class MsContext {
void
set_profiling_options
(
const
std
::
string
&
options
)
{
profiling_options_
=
options
;
}
std
::
string
profiling_options
()
const
{
return
profiling_options_
;
}
bool
check_bprop_flag
()
const
{
return
check_bprop_flag_
;
}
void
set_check_bprop_flag
(
bool
check_bprop_flag
)
{
check_bprop_flag_
=
check_bprop_flag
;
}
private:
MsContext
(
const
std
::
string
&
backend_policy
,
const
std
::
string
&
target
);
...
...
@@ -179,6 +181,7 @@ class MsContext {
std
::
thread
tdt_print_
;
bool
profiling_mode_
;
std
::
string
profiling_options_
;
bool
check_bprop_flag_
;
};
}
// namespace mindspore
...
...
mindspore/context.py
浏览文件 @
0b191615
...
...
@@ -324,6 +324,13 @@ class _Context:
thread_info
=
self
.
_thread_local_info
thread_info
.
debug_runtime
=
enable
@
property
def
check_bprop
(
self
):
return
self
.
_context_handle
.
get_check_bprop_flag
()
@
check_bprop
.
setter
def
check_bprop
(
self
,
check_bprop_flag
):
self
.
_context_handle
.
set_check_bprop_flag
(
check_bprop_flag
)
def
check_input_format
(
x
):
import
re
...
...
@@ -449,7 +456,8 @@ def reset_auto_parallel_context():
@
args_type_check
(
mode
=
int
,
precompile_only
=
bool
,
device_target
=
str
,
device_id
=
int
,
save_graphs
=
bool
,
save_graphs_path
=
str
,
save_ms_model
=
bool
,
save_ms_model_path
=
str
,
enable_dump
=
bool
,
save_dump_path
=
str
,
enable_reduce_precision
=
bool
,
variable_memory_max_size
=
str
,
enable_profiling
=
bool
,
profiling_options
=
str
,
enable_auto_mixed_precision
=
bool
)
enable_profiling
=
bool
,
profiling_options
=
str
,
enable_auto_mixed_precision
=
bool
,
check_bprop
=
bool
)
def
set_context
(
**
kwargs
):
"""
Sets context for running environment.
...
...
@@ -500,6 +508,7 @@ def set_context(**kwargs):
The profiling can choose training_trace, task_trace, training_trace and task_trace combination and
separated by colons; single operator can choose op_trace, op_trace cannot be combined with
training_trace and task_trace. Default: "training_trace".
check_bprop (bool): Whether to check bprop. Default: False.
Raises:
ValueError: If input key is not an attribute in context.
...
...
mindspore/ops/operations/other_ops.py
浏览文件 @
0b191615
...
...
@@ -323,8 +323,9 @@ class CheckBprop(PrimitiveWithInfer):
"""
@
prim_attr_register
def
__init__
(
self
):
def
__init__
(
self
,
prim_to_check
=
""
):
"""init CheckBprop"""
self
.
prim_to_check
=
prim_to_check
def
infer_shape
(
self
,
xshapes
,
yshapes
):
tips
=
f
'Bprop of
{
self
.
prim_to_check
}
'
...
...
tests/ut/python/pynative_mode/test_cell_bprop.py
浏览文件 @
0b191615
...
...
@@ -353,6 +353,7 @@ class MulAddWithWrongOutputNum(nn.Cell):
def
test_grad_mul_add_with_wrong_output_num
():
context
.
set_context
(
check_bprop
=
True
)
mul_add
=
MulAddWithWrongOutputNum
()
with
pytest
.
raises
(
TypeError
):
C
.
grad_all
(
mul_add
)(
1
,
2
)
...
...
@@ -370,6 +371,7 @@ class MulAddWithWrongOutputType(nn.Cell):
def
test_grad_mul_add_with_wrong_output_type
():
context
.
set_context
(
check_bprop
=
True
)
mul_add
=
MulAddWithWrongOutputType
()
with
pytest
.
raises
(
TypeError
):
C
.
grad_all
(
mul_add
)(
1
,
Tensor
(
np
.
ones
([
2
,
2
])))
...
...
@@ -388,6 +390,7 @@ class MulAddWithWrongOutputShape(nn.Cell):
def
test_grad_mul_add_with_wrong_output_shape
():
context
.
set_context
(
check_bprop
=
True
)
mul_add
=
MulAddWithWrongOutputShape
()
with
pytest
.
raises
(
TypeError
):
C
.
grad_all
(
mul_add
)(
1
,
Tensor
(
np
.
ones
([
2
,
2
])))
tests/ut/python/pynative_mode/test_framstruct.py
浏览文件 @
0b191615
...
...
@@ -893,6 +893,7 @@ def test_grad_if_defer_inline():
def
test_bprop_with_wrong_output_num
():
context
.
set_context
(
check_bprop
=
True
)
class
BpropWithWrongOutputNum
(
PrimitiveWithInfer
):
@
prim_attr_register
def
__init__
(
self
):
...
...
@@ -926,8 +927,8 @@ def test_bprop_with_wrong_output_num():
with
pytest
.
raises
(
TypeError
):
C
.
grad_all
(
BpropWithWrongOutputNumCell
())(
1
,
2
)
def
test_bprop_with_wrong_output_type
():
context
.
set_context
(
check_bprop
=
True
)
class
BpropWithWrongOutputType
(
PrimitiveWithInfer
):
@
prim_attr_register
def
__init__
(
self
):
...
...
@@ -963,6 +964,7 @@ def test_bprop_with_wrong_output_type():
def
test_bprop_with_wrong_output_shape
():
context
.
set_context
(
check_bprop
=
True
)
class
BpropWithWrongOutputShape
(
PrimitiveWithInfer
):
@
prim_attr_register
def
__init__
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录