提交 0b191615 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1427 fix check bprop attr error

Merge pull request !1427 from panyifeng/fix_check_bprop_attr_error
......@@ -32,6 +32,7 @@
#include "operator/composite/composite.h"
#include "utils/symbolic.h"
#include "utils/primitive_utils.h"
#include "utils/context/ms_context.h"
#include "debug/info.h"
#include "debug/trace.h"
......@@ -181,10 +182,19 @@ void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bp
}
void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool check_bprop_flag = context->check_bprop_flag();
// Skip checking if check_bprop not set
if (!check_bprop_flag) {
return;
}
// bprop_fg has been checked in caller
auto check_bprop = prim::GetPythonOps("check_bprop", "mindspore.ops.functional")->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(check_bprop);
check_bprop->set_attr("prim_to_check", std::make_shared<StringImm>(prim_to_check));
auto check_bprop_class = prim::GetPythonOps("CheckBprop", "mindspore.ops.operations.other_ops");
MS_EXCEPTION_IF_NULL(check_bprop_class);
auto check_bprop =
bprop_fg->NewCNode({NewValueNode(check_bprop_class), NewValueNode(std::make_shared<StringImm>(prim_to_check))});
std::vector<AnfNodePtr> inputs;
inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
......@@ -192,7 +202,7 @@ void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check
AnfNodePtr params = bprop_fg->NewCNode(inputs);
inputs.clear();
inputs.push_back(NewValueNode(check_bprop));
inputs.push_back(check_bprop);
inputs.push_back(bprop_fg->output());
inputs.push_back(params);
AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs);
......
......@@ -141,7 +141,9 @@ PYBIND11_MODULE(_c_expression, m) {
.def("get_enable_profiling", &mindspore::MsContext::enable_profiling, "Get whether to open profiling.")
.def("set_enable_profiling", &mindspore::MsContext::set_enable_profiling, "Set whether to open profiling.")
.def("get_profiling_options", &mindspore::MsContext::profiling_options, "Get options to profiling.")
.def("set_profiling_options", &mindspore::MsContext::set_profiling_options, "Set options to profiling.");
.def("set_profiling_options", &mindspore::MsContext::set_profiling_options, "Set options to profiling.")
.def("get_check_bprop_flag", &mindspore::MsContext::check_bprop_flag, "Get whether to check bprop.")
.def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.");
(void)py::class_<ParallelContext, std::shared_ptr<ParallelContext>>(m, "AutoParallelContext")
.def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.")
......
......@@ -140,6 +140,8 @@ class MsContext {
void set_profiling_options(const std::string &options) { profiling_options_ = options; }
std::string profiling_options() const { return profiling_options_; }
bool check_bprop_flag() const { return check_bprop_flag_; }
void set_check_bprop_flag(bool check_bprop_flag) { check_bprop_flag_ = check_bprop_flag; }
private:
MsContext(const std::string &backend_policy, const std::string &target);
......@@ -179,6 +181,7 @@ class MsContext {
std::thread tdt_print_;
bool profiling_mode_;
std::string profiling_options_;
bool check_bprop_flag_;
};
} // namespace mindspore
......
......@@ -324,6 +324,13 @@ class _Context:
thread_info = self._thread_local_info
thread_info.debug_runtime = enable
@property
def check_bprop(self):
return self._context_handle.get_check_bprop_flag()
@check_bprop.setter
def check_bprop(self, check_bprop_flag):
self._context_handle.set_check_bprop_flag(check_bprop_flag)
def check_input_format(x):
import re
......@@ -449,7 +456,8 @@ def reset_auto_parallel_context():
@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool,
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool)
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
check_bprop=bool)
def set_context(**kwargs):
"""
Sets context for running environment.
......@@ -500,6 +508,7 @@ def set_context(**kwargs):
The profiling can choose training_trace, task_trace, training_trace and task_trace combination and
separated by colons; single operator can choose op_trace, op_trace cannot be combined with
training_trace and task_trace. Default: "training_trace".
check_bprop (bool): Whether to check bprop. Default: False.
Raises:
ValueError: If input key is not an attribute in context.
......
......@@ -323,8 +323,9 @@ class CheckBprop(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self):
def __init__(self, prim_to_check=""):
"""init CheckBprop"""
self.prim_to_check = prim_to_check
def infer_shape(self, xshapes, yshapes):
tips = f'Bprop of {self.prim_to_check}'
......
......@@ -353,6 +353,7 @@ class MulAddWithWrongOutputNum(nn.Cell):
def test_grad_mul_add_with_wrong_output_num():
context.set_context(check_bprop=True)
mul_add = MulAddWithWrongOutputNum()
with pytest.raises(TypeError):
C.grad_all(mul_add)(1, 2)
......@@ -370,6 +371,7 @@ class MulAddWithWrongOutputType(nn.Cell):
def test_grad_mul_add_with_wrong_output_type():
context.set_context(check_bprop=True)
mul_add = MulAddWithWrongOutputType()
with pytest.raises(TypeError):
C.grad_all(mul_add)(1, Tensor(np.ones([2, 2])))
......@@ -388,6 +390,7 @@ class MulAddWithWrongOutputShape(nn.Cell):
def test_grad_mul_add_with_wrong_output_shape():
context.set_context(check_bprop=True)
mul_add = MulAddWithWrongOutputShape()
with pytest.raises(TypeError):
C.grad_all(mul_add)(1, Tensor(np.ones([2, 2])))
......@@ -893,6 +893,7 @@ def test_grad_if_defer_inline():
def test_bprop_with_wrong_output_num():
context.set_context(check_bprop=True)
class BpropWithWrongOutputNum(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
......@@ -926,8 +927,8 @@ def test_bprop_with_wrong_output_num():
with pytest.raises(TypeError):
C.grad_all(BpropWithWrongOutputNumCell())(1, 2)
def test_bprop_with_wrong_output_type():
context.set_context(check_bprop=True)
class BpropWithWrongOutputType(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
......@@ -963,6 +964,7 @@ def test_bprop_with_wrong_output_type():
def test_bprop_with_wrong_output_shape():
context.set_context(check_bprop=True)
class BpropWithWrongOutputShape(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册