提交 0b191615 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1427 fix check bprop attr error

Merge pull request !1427 from panyifeng/fix_check_bprop_attr_error
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "operator/composite/composite.h" #include "operator/composite/composite.h"
#include "utils/symbolic.h" #include "utils/symbolic.h"
#include "utils/primitive_utils.h" #include "utils/primitive_utils.h"
#include "utils/context/ms_context.h"
#include "debug/info.h" #include "debug/info.h"
#include "debug/trace.h" #include "debug/trace.h"
...@@ -181,10 +182,19 @@ void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bp ...@@ -181,10 +182,19 @@ void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bp
} }
void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) { void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool check_bprop_flag = context->check_bprop_flag();
// Skip checking if check_bprop not set
if (!check_bprop_flag) {
return;
}
// bprop_fg has been checked in caller // bprop_fg has been checked in caller
auto check_bprop = prim::GetPythonOps("check_bprop", "mindspore.ops.functional")->cast<PrimitivePtr>(); auto check_bprop_class = prim::GetPythonOps("CheckBprop", "mindspore.ops.operations.other_ops");
MS_EXCEPTION_IF_NULL(check_bprop); MS_EXCEPTION_IF_NULL(check_bprop_class);
check_bprop->set_attr("prim_to_check", std::make_shared<StringImm>(prim_to_check)); auto check_bprop =
bprop_fg->NewCNode({NewValueNode(check_bprop_class), NewValueNode(std::make_shared<StringImm>(prim_to_check))});
std::vector<AnfNodePtr> inputs; std::vector<AnfNodePtr> inputs;
inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
...@@ -192,7 +202,7 @@ void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check ...@@ -192,7 +202,7 @@ void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check
AnfNodePtr params = bprop_fg->NewCNode(inputs); AnfNodePtr params = bprop_fg->NewCNode(inputs);
inputs.clear(); inputs.clear();
inputs.push_back(NewValueNode(check_bprop)); inputs.push_back(check_bprop);
inputs.push_back(bprop_fg->output()); inputs.push_back(bprop_fg->output());
inputs.push_back(params); inputs.push_back(params);
AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs); AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs);
......
...@@ -141,7 +141,9 @@ PYBIND11_MODULE(_c_expression, m) { ...@@ -141,7 +141,9 @@ PYBIND11_MODULE(_c_expression, m) {
.def("get_enable_profiling", &mindspore::MsContext::enable_profiling, "Get whether to open profiling.") .def("get_enable_profiling", &mindspore::MsContext::enable_profiling, "Get whether to open profiling.")
.def("set_enable_profiling", &mindspore::MsContext::set_enable_profiling, "Set whether to open profiling.") .def("set_enable_profiling", &mindspore::MsContext::set_enable_profiling, "Set whether to open profiling.")
.def("get_profiling_options", &mindspore::MsContext::profiling_options, "Get options to profiling.") .def("get_profiling_options", &mindspore::MsContext::profiling_options, "Get options to profiling.")
.def("set_profiling_options", &mindspore::MsContext::set_profiling_options, "Set options to profiling."); .def("set_profiling_options", &mindspore::MsContext::set_profiling_options, "Set options to profiling.")
.def("get_check_bprop_flag", &mindspore::MsContext::check_bprop_flag, "Get whether to check bprop.")
.def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.");
(void)py::class_<ParallelContext, std::shared_ptr<ParallelContext>>(m, "AutoParallelContext") (void)py::class_<ParallelContext, std::shared_ptr<ParallelContext>>(m, "AutoParallelContext")
.def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.") .def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.")
......
...@@ -140,6 +140,8 @@ class MsContext { ...@@ -140,6 +140,8 @@ class MsContext {
void set_profiling_options(const std::string &options) { profiling_options_ = options; } void set_profiling_options(const std::string &options) { profiling_options_ = options; }
std::string profiling_options() const { return profiling_options_; } std::string profiling_options() const { return profiling_options_; }
bool check_bprop_flag() const { return check_bprop_flag_; }
void set_check_bprop_flag(bool check_bprop_flag) { check_bprop_flag_ = check_bprop_flag; }
private: private:
MsContext(const std::string &backend_policy, const std::string &target); MsContext(const std::string &backend_policy, const std::string &target);
...@@ -179,6 +181,7 @@ class MsContext { ...@@ -179,6 +181,7 @@ class MsContext {
std::thread tdt_print_; std::thread tdt_print_;
bool profiling_mode_; bool profiling_mode_;
std::string profiling_options_; std::string profiling_options_;
bool check_bprop_flag_;
}; };
} // namespace mindspore } // namespace mindspore
......
...@@ -324,6 +324,13 @@ class _Context: ...@@ -324,6 +324,13 @@ class _Context:
thread_info = self._thread_local_info thread_info = self._thread_local_info
thread_info.debug_runtime = enable thread_info.debug_runtime = enable
@property
def check_bprop(self):
return self._context_handle.get_check_bprop_flag()
@check_bprop.setter
def check_bprop(self, check_bprop_flag):
self._context_handle.set_check_bprop_flag(check_bprop_flag)
def check_input_format(x): def check_input_format(x):
import re import re
...@@ -449,7 +456,8 @@ def reset_auto_parallel_context(): ...@@ -449,7 +456,8 @@ def reset_auto_parallel_context():
@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool, @args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool, save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool,
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool) enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
check_bprop=bool)
def set_context(**kwargs): def set_context(**kwargs):
""" """
Sets context for running environment. Sets context for running environment.
...@@ -500,6 +508,7 @@ def set_context(**kwargs): ...@@ -500,6 +508,7 @@ def set_context(**kwargs):
The profiling can choose training_trace, task_trace, training_trace and task_trace combination and The profiling can choose training_trace, task_trace, training_trace and task_trace combination and
separated by colons; single operator can choose op_trace, op_trace cannot be combined with separated by colons; single operator can choose op_trace, op_trace cannot be combined with
training_trace and task_trace. Default: "training_trace". training_trace and task_trace. Default: "training_trace".
check_bprop (bool): Whether to check bprop. Default: False.
Raises: Raises:
ValueError: If input key is not an attribute in context. ValueError: If input key is not an attribute in context.
......
...@@ -323,8 +323,9 @@ class CheckBprop(PrimitiveWithInfer): ...@@ -323,8 +323,9 @@ class CheckBprop(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self, prim_to_check=""):
"""init CheckBprop""" """init CheckBprop"""
self.prim_to_check = prim_to_check
def infer_shape(self, xshapes, yshapes): def infer_shape(self, xshapes, yshapes):
tips = f'Bprop of {self.prim_to_check}' tips = f'Bprop of {self.prim_to_check}'
......
...@@ -353,6 +353,7 @@ class MulAddWithWrongOutputNum(nn.Cell): ...@@ -353,6 +353,7 @@ class MulAddWithWrongOutputNum(nn.Cell):
def test_grad_mul_add_with_wrong_output_num(): def test_grad_mul_add_with_wrong_output_num():
context.set_context(check_bprop=True)
mul_add = MulAddWithWrongOutputNum() mul_add = MulAddWithWrongOutputNum()
with pytest.raises(TypeError): with pytest.raises(TypeError):
C.grad_all(mul_add)(1, 2) C.grad_all(mul_add)(1, 2)
...@@ -370,6 +371,7 @@ class MulAddWithWrongOutputType(nn.Cell): ...@@ -370,6 +371,7 @@ class MulAddWithWrongOutputType(nn.Cell):
def test_grad_mul_add_with_wrong_output_type(): def test_grad_mul_add_with_wrong_output_type():
context.set_context(check_bprop=True)
mul_add = MulAddWithWrongOutputType() mul_add = MulAddWithWrongOutputType()
with pytest.raises(TypeError): with pytest.raises(TypeError):
C.grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) C.grad_all(mul_add)(1, Tensor(np.ones([2, 2])))
...@@ -388,6 +390,7 @@ class MulAddWithWrongOutputShape(nn.Cell): ...@@ -388,6 +390,7 @@ class MulAddWithWrongOutputShape(nn.Cell):
def test_grad_mul_add_with_wrong_output_shape(): def test_grad_mul_add_with_wrong_output_shape():
context.set_context(check_bprop=True)
mul_add = MulAddWithWrongOutputShape() mul_add = MulAddWithWrongOutputShape()
with pytest.raises(TypeError): with pytest.raises(TypeError):
C.grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) C.grad_all(mul_add)(1, Tensor(np.ones([2, 2])))
...@@ -893,6 +893,7 @@ def test_grad_if_defer_inline(): ...@@ -893,6 +893,7 @@ def test_grad_if_defer_inline():
def test_bprop_with_wrong_output_num(): def test_bprop_with_wrong_output_num():
context.set_context(check_bprop=True)
class BpropWithWrongOutputNum(PrimitiveWithInfer): class BpropWithWrongOutputNum(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
...@@ -926,8 +927,8 @@ def test_bprop_with_wrong_output_num(): ...@@ -926,8 +927,8 @@ def test_bprop_with_wrong_output_num():
with pytest.raises(TypeError): with pytest.raises(TypeError):
C.grad_all(BpropWithWrongOutputNumCell())(1, 2) C.grad_all(BpropWithWrongOutputNumCell())(1, 2)
def test_bprop_with_wrong_output_type(): def test_bprop_with_wrong_output_type():
context.set_context(check_bprop=True)
class BpropWithWrongOutputType(PrimitiveWithInfer): class BpropWithWrongOutputType(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
...@@ -963,6 +964,7 @@ def test_bprop_with_wrong_output_type(): ...@@ -963,6 +964,7 @@ def test_bprop_with_wrong_output_type():
def test_bprop_with_wrong_output_shape(): def test_bprop_with_wrong_output_shape():
context.set_context(check_bprop=True)
class BpropWithWrongOutputShape(PrimitiveWithInfer): class BpropWithWrongOutputShape(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册