From 840922e501da3b832d69b88ccc07b528ddbc34e3 Mon Sep 17 00:00:00 2001 From: kingfo Date: Thu, 14 May 2020 17:02:27 +0800 Subject: [PATCH] add backward hook function in pynative mode --- mindspore/_extends/parse/parser.py | 5 +- mindspore/ccsrc/ir/primitive.cc | 1 + mindspore/ccsrc/ir/primitive.h | 3 - mindspore/ccsrc/ir/primitive_base.h | 7 + mindspore/ccsrc/operator/ops.cc | 2 + mindspore/ccsrc/operator/ops.h | 2 + mindspore/ccsrc/operator/prim_nn.cc | 10 ++ mindspore/ccsrc/optimizer/ad/dfunctor.cc | 1 + mindspore/ccsrc/optimizer/ad/dfunctor.h | 1 + mindspore/ccsrc/optimizer/ad/kprim.cc | 52 ++++++- mindspore/ccsrc/optimizer/irpass.cc | 7 +- .../optimizer/irpass/special_op_eliminate.h | 4 +- .../ccsrc/pipeline/parse/data_converter.cc | 38 ++++- .../ccsrc/pipeline/static_analysis/prim.cc | 1 + .../ccsrc/pipeline/static_analysis/prim.h | 2 + mindspore/ccsrc/vm/backend.cc | 3 + mindspore/ccsrc/vm/transform.cc | 14 +- mindspore/ccsrc/vm/vm.cc | 56 +++++++- mindspore/ccsrc/vm/vm.h | 2 + mindspore/nn/cell.py | 36 +++++ mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/debug_ops.py | 60 ++++++++ tests/ut/python/pynative_mode/test_hook.py | 133 ++++++++++++++++++ 23 files changed, 422 insertions(+), 21 deletions(-) create mode 100644 tests/ut/python/pynative_mode/test_hook.py diff --git a/mindspore/_extends/parse/parser.py b/mindspore/_extends/parse/parser.py index 34a3a6c59..462565fd7 100644 --- a/mindspore/_extends/parse/parser.py +++ b/mindspore/_extends/parse/parser.py @@ -102,7 +102,10 @@ def get_parse_method_of_class(obj, parse_method=None): method_name = parse_method else: if isinstance(obj, nn.Cell): - method_name = "construct" + if obj.enable_hook: + method_name = "_hook_construct" + else: + method_name = "construct" if method_name is not None: if hasattr(obj, method_name): method = getattr(obj, method_name) diff --git a/mindspore/ccsrc/ir/primitive.cc b/mindspore/ccsrc/ir/primitive.cc index 4fd000ca5..7f6080828 100644 --- a/mindspore/ccsrc/ir/primitive.cc +++ b/mindspore/ccsrc/ir/primitive.cc @@ -115,6 +115,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.") + .def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.") .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name."); })); } // namespace mindspore diff --git a/mindspore/ccsrc/ir/primitive.h b/mindspore/ccsrc/ir/primitive.h index 1959b3199..1dd867fd1 100644 --- a/mindspore/ccsrc/ir/primitive.h +++ b/mindspore/ccsrc/ir/primitive.h @@ -23,7 +23,6 @@ #include #include -#include "pybind11/pybind11.h" #include "pipeline/static_analysis/abstract_value.h" #include "utils/misc.h" #include "utils/log_adapter.h" @@ -31,8 +30,6 @@ #include "ir/signature.h" #include "parallel/ops_info/operator_info.h" -namespace py = pybind11; - namespace mindspore { class PrimitivePy : public Primitive { public: diff --git a/mindspore/ccsrc/ir/primitive_base.h b/mindspore/ccsrc/ir/primitive_base.h index 28135c8e7..1b347a307 100644 --- a/mindspore/ccsrc/ir/primitive_base.h +++ b/mindspore/ccsrc/ir/primitive_base.h @@ -24,6 +24,9 @@ #include #include "ir/dtype/type.h" +#include "pybind11/pybind11.h" + +namespace py = pybind11; namespace mindspore { // Supported meta type @@ -73,6 +76,9 @@ class Primitive : public Named { return iter == attrs_.cend() ? nullptr : iter->second; } + void set_hook(const py::function &hook) { hook_ = hook; } + py::function hook() const { return hook_; } + const std::unordered_map &attrs() const { return attrs_; } // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. @@ -103,6 +109,7 @@ class Primitive : public Named { private: std::string instance_name_; + py::function hook_; bool is_base_; bool has_signature_; PrimType prim_type_; diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index ddd698f6e..459b9650c 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -211,6 +211,7 @@ const PrimitivePtr kPrimRelu = std::make_shared("ReLU"); const PrimitivePtr kPrimReluV2 = std::make_shared("ReLUV2"); const PrimitivePtr kPrimZerosLikeTensor = std::make_shared("zeros_like_tensor"); const PrimitivePtr kPrimFakeBprop = std::make_shared("fake_bprop"); +const PrimitivePtr kPrimBpropCut = std::make_shared("bprop_cut"); // Other miscellaneous const PrimitivePtr kPrimIdentity = std::make_shared("identity"); @@ -224,6 +225,7 @@ const PrimitivePtr kPrimGetRefKey = std::make_shared("get_ref_key"); const PrimitivePtr kPrimGetRefValue = std::make_shared("get_ref_value"); const PrimitivePtr kPrimGetRefOrigin = std::make_shared("get_ref_origin"); const PrimitivePtr kPrimInsertGradientOf = std::make_shared("InsertGradientOf"); +const PrimitivePtr kPrimHookBackward = std::make_shared("HookBackward"); const PrimitivePtr kPrimPrintShapeType = std::make_shared("PrintShapeType"); const PrimitivePtr kPrimSameTypeShape = std::make_shared("SameTypeShape"); const PrimitivePtr kPrimCheckBprop = std::make_shared("CheckBprop"); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index 1125feee7..19fac40c3 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -216,6 +216,7 @@ extern const PrimitivePtr kPrimReluV2; extern const PrimitivePtr kPrimActivation; extern const PrimitivePtr kPrimZerosLikeTensor; extern const PrimitivePtr kPrimFakeBprop; +extern const PrimitivePtr kPrimBpropCut; // Other Miscellaneous extern const PrimitivePtr kPrimIdentity; @@ -230,6 +231,7 @@ extern const PrimitivePtr kPrimGetRefKey; extern const PrimitivePtr kPrimGetRefValue; extern const PrimitivePtr kPrimGetRefOrigin; extern const PrimitivePtr kPrimInsertGradientOf; +extern const PrimitivePtr kPrimHookBackward; extern const PrimitivePtr kPrimPrintShapeType; extern const PrimitivePtr kPrimPrint; extern const PrimitivePtr kPrimSameTypeShape; diff --git a/mindspore/ccsrc/operator/prim_nn.cc b/mindspore/ccsrc/operator/prim_nn.cc index d90c09256..d057fd925 100644 --- a/mindspore/ccsrc/operator/prim_nn.cc +++ b/mindspore/ccsrc/operator/prim_nn.cc @@ -285,6 +285,16 @@ AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr return args_spec_list[0]->Broaden(); } +AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor. + AbstractBasePtrList args_list; + for (size_t i = 0; i < args_spec_list.size() - 2; i++) { + args_list.push_back(args_spec_list[i]->Broaden()); + } + return std::make_shared(args_list); +} + AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: three tensors(x, gamma, beta). diff --git a/mindspore/ccsrc/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/optimizer/ad/dfunctor.cc index 6f648b572..06975b0e5 100644 --- a/mindspore/ccsrc/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/optimizer/ad/dfunctor.cc @@ -32,6 +32,7 @@ #include "operator/ops.h" #include "operator/composite/composite.h" #include "utils/symbolic.h" +#include "utils/context/ms_context.h" #include "./common.h" namespace mindspore { diff --git a/mindspore/ccsrc/optimizer/ad/dfunctor.h b/mindspore/ccsrc/optimizer/ad/dfunctor.h index 7da90e7c3..d11926b37 100644 --- a/mindspore/ccsrc/optimizer/ad/dfunctor.h +++ b/mindspore/ccsrc/optimizer/ad/dfunctor.h @@ -125,6 +125,7 @@ class KPrim { FuncGraphPtr GetBprop(const PrimitivePtr &prim); FuncGraphPtr GetFprop(const PrimitivePtr &prim); FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); + FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); // Given a bprop rule, do the K mapping. template FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g); diff --git a/mindspore/ccsrc/optimizer/ad/kprim.cc b/mindspore/ccsrc/optimizer/ad/kprim.cc index 6b68d1564..72f25f4b3 100644 --- a/mindspore/ccsrc/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/optimizer/ad/kprim.cc @@ -115,10 +115,15 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R } bool is_faked_bprop = false; - auto bprop_fg = GetBprop(prim); - if (bprop_fg == nullptr) { - bprop_fg = FakeBprop(value_node, resources); - is_faked_bprop = true; + FuncGraphPtr bprop_fg = nullptr; + if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == "HookBackward") { + bprop_fg = BpropCut(value_node, resources); + } else { + bprop_fg = GetBprop(prim); + if (bprop_fg == nullptr) { + bprop_fg = FakeBprop(value_node, resources); + is_faked_bprop = true; + } } auto expanded_fg = BpropToK(prim, bprop_fg); @@ -206,6 +211,45 @@ FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) { return expanded_fg; } +FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { + auto prim = GetValueNode(value_node); + MS_EXCEPTION_IF_NULL(prim); + auto &node_users = resources->manager()->node_users(); + + auto &users = node_users[value_node]; + auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair &user) -> bool { + return IsPrimitiveCNode(user.first, prim); + }); + if (cnode == users.end()) { + MS_LOG(EXCEPTION) << "Fail to find cnode."; + } + auto inputs_num = cnode->first->cast()->size() - 1; + + auto func_graph = std::make_shared(); + std::vector outputs; + + auto bprop_cut = std::make_shared("bprop_cut"); + bprop_cut->set_hook(prim->hook()); + auto cell_id = GetValue(prim->GetAttr("cell_id")); + if (cell_id != "") { + (void)bprop_cut->AddAttr("cell_hook", MakeValue(true)); + (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id)); + } + + outputs.push_back(NewValueNode(bprop_cut)); + for (size_t i = 0; i < inputs_num; ++i) { + auto param = func_graph->add_parameter(); + outputs.push_back(param); + } + auto p1 = func_graph->add_parameter(); + auto p2 = func_graph->add_parameter(); + outputs.push_back(p1); + outputs.push_back(p2); + + func_graph->set_output(func_graph->NewCNode(outputs)); + return func_graph; +} + FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { auto prim = value_node->value()->cast(); MS_EXCEPTION_IF_NULL(prim); diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index d2f7d6035..2ac0bc21c 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -49,9 +49,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); - special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", - {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, - prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); + special_op_eliminate_ = + MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", + {prim::kPrimInsertGradientOf, prim::kPrimHookBackward, prim::kPrimPrintShapeType, + prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor); adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN); diff --git a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h index 30a94c9bf..aa23441bb 100644 --- a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h @@ -35,11 +35,13 @@ class SpecialOpEliminater { public: SpecialOpEliminater() : insert_gradient_of_(prim::kPrimInsertGradientOf), + hook_backward_(prim::kPrimHookBackward), print_shape_type_(prim::kPrimPrintShapeType), get_ref_value_(prim::kPrimGetRefValue), mirror_(prim::kPrimMirror), virtual_div_(prim::kPrimVirtualDiv) { eliminaters_.emplace_back(insert_gradient_of_); + eliminaters_.emplace_back(hook_backward_); eliminaters_.emplace_back(print_shape_type_); eliminaters_.emplace_back(get_ref_value_); eliminaters_.emplace_back(mirror_); @@ -59,7 +61,7 @@ class SpecialOpEliminater { } private: - PrimEliminater insert_gradient_of_, print_shape_type_, get_ref_value_, mirror_, virtual_div_; + PrimEliminater insert_gradient_of_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, virtual_div_; std::vector eliminaters_{}; }; diff --git a/mindspore/ccsrc/pipeline/parse/data_converter.cc b/mindspore/ccsrc/pipeline/parse/data_converter.cc index 01c732518..5dbb8bc45 100644 --- a/mindspore/ccsrc/pipeline/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/parse/data_converter.cc @@ -30,6 +30,7 @@ #include "operator/composite/composite.h" #include "ir/func_graph_cloner.h" #include "utils/symbolic.h" +#include "utils/context/ms_context.h" #include "debug/trace.h" namespace mindspore { @@ -207,6 +208,35 @@ bool ConvertTensor(const py::object &obj, ValuePtr *const data) { return true; } +FuncGraphPtr ConvertToBpropCut(py::object obj) { + std::vector results = data_converter::GetObjKey(obj); + std::string obj_key = results[0]; + py::function bprop_func = py::getattr(obj, "bprop"); + + FuncGraphPtr bprop_graph = std::make_shared(); + std::vector outputs; + + auto fake_bprop = std::make_shared("bprop_cut"); + fake_bprop->set_hook(bprop_func); + (void)fake_bprop->AddAttr("bprop", MakeValue(true)); + outputs.push_back(NewValueNode(fake_bprop)); + + py::object code_obj = py::getattr(bprop_func, "__code__"); + size_t inputs_num = py::cast(py::getattr(code_obj, "co_argcount")) - 3; + for (size_t i = 0; i < inputs_num; ++i) { + auto param = bprop_graph->add_parameter(); + outputs.push_back(param); + } + auto p1 = bprop_graph->add_parameter(); + auto p2 = bprop_graph->add_parameter(); + outputs.push_back(p1); + outputs.push_back(p2); + + bprop_graph->set_output(bprop_graph->NewCNode(outputs)); + data_converter::SetObjGraphValue(obj_key, bprop_graph); + return bprop_graph; +} + bool ConvertOtherObj(py::object obj, ValuePtr *const data) { auto obj_type = data_converter::GetObjType(obj); MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " "; @@ -238,7 +268,13 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) { } // if the cell object has specified bprop, it has user-defined bprop function parse and record it if (py::hasattr(obj, "bprop")) { - FuncGraphPtr bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); + FuncGraphPtr bprop_graph = nullptr; + bool enable_bprop_debug = py::cast(py::getattr(obj, "bprop_debug")); + if (enable_bprop_debug) { + bprop_graph = ConvertToBpropCut(obj); + } else { + bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); + } if (bprop_graph != nullptr) { (void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph))); (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index d37657ce4..e5abfb33c 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -108,6 +108,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimRelu, {InferImplRelu, true}}, {prim::kPrimZerosLikeTensor, {InferImplZerosLikeTensor, true}}, {prim::kPrimFakeBprop, {InferImplFakeBprop, false}}, + {prim::kPrimBpropCut, {InferImplBpropCut, true}}, {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.h b/mindspore/ccsrc/pipeline/static_analysis/prim.h index 3969e76bf..783691101 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.h @@ -210,6 +210,8 @@ AbstractBasePtr InferImplZerosLikeTensor(const AnalysisEnginePtr &, const Primit const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index a9f526418..b447bb822 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -64,6 +64,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) { result.outputs = outputs; result.graph_id = kInvalidGraphId; auto graph_id = sess_->CompileGraph(lst, outputs); + if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { + sess_->BuildGraph(graph_id); + } if (MsContext::GetInstance()->precompile_only()) { MS_LOG(INFO) << "PrecompileOnly, stop run graph"; return result; diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index e8b47a4bc..604661420 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -40,9 +40,10 @@ using MapPrimTypeFuncGraph = std::map; using TypedPrimitiveAbstractClosurePtr = std::shared_ptr; std::vector nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, - prim::kPrimMakeTuple}; + prim::kPrimMakeTuple, prim::kPrimBpropCut}; const std::vector &GetMsNonlinearOps() { - static const std::vector ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch}; + static const std::vector ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, + prim::kPrimBpropCut}; return ms_nonlinear_ops; } @@ -646,8 +647,13 @@ BackendPtr CreateBackend() { auto backend = std::make_shared(name, target, device_id); std::string device_target = MsContext::GetInstance()->device_target(); if (device_target == kAscendDevice) { - backend->set_is_multi_graph_sink(true); - context_ptr->set_is_multi_graph_sink(true); + if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { + backend->set_is_multi_graph_sink(false); + context_ptr->set_is_multi_graph_sink(false); + } else { + backend->set_is_multi_graph_sink(true); + context_ptr->set_is_multi_graph_sink(true); + } } return backend; } diff --git a/mindspore/ccsrc/vm/vm.cc b/mindspore/ccsrc/vm/vm.cc index 3a34eba18..d7784457d 100644 --- a/mindspore/ccsrc/vm/vm.cc +++ b/mindspore/ccsrc/vm/vm.cc @@ -587,15 +587,65 @@ void FinalVM::InstPushPrim(const VectorRef &args) { VectorRef tuple; auto prim = utils::cast(args[0]); for (size_t i = 1; i < args.size(); ++i) { - auto index = utils::cast(args[1]); + auto index = utils::cast(args[i]); tuple.push_back(Ref(index)); } - auto outs = RunOperation(prim, tuple); - Push(outs); + if (prim->name() == "bprop_cut") { + auto outs = RunHook(prim, tuple); + Push(outs); + } else { + auto outs = RunOperation(prim, tuple); + Push(outs); + } MS_LOG(DEBUG) << "End"; } +BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { + py::tuple py_args = py::tuple(args.size()); + MS_LOG(DEBUG) << "input for operation:"; + size_t i = 0; + for (auto &arg : args) { + py_args[i] = BaseRefToPyData(arg); + MS_LOG(DEBUG) << "arg: " << i << ":"; + i++; + } + py::object obj; + bool is_bprop = prim->HasAttr("bprop"); + if (is_bprop) { + py::function fn_bprop = prim->hook(); + obj = fn_bprop(*py_args); + return obj; + } + bool is_cell = prim->HasAttr("cell_hook"); + if (is_cell) { + std::string cell_id = GetValue(prim->GetAttr("cell_id")); + if (_hook_grad.find(cell_id) != _hook_grad.end()) { + py::tuple hook_args = py::tuple(3); + hook_args[0] = cell_id; + hook_args[1] = _hook_grad[cell_id]; + hook_args[2] = py_args[2]; + py::function fn_hook = prim->hook(); + obj = fn_hook(*hook_args); + if (py::isinstance(obj)) { + obj = py_args[2]; + } + _hook_grad.erase(cell_id); + } else { + _hook_grad[cell_id] = py_args[2]; + obj = py_args[2]; + } + } else { + py::function fn_hook = prim->hook(); + obj = fn_hook(py_args[2]); + if (py::isinstance(obj)) { + obj = py_args[2]; + } + } + obj = py::make_tuple(obj); + return obj; +} + } // namespace compile } // namespace mindspore diff --git a/mindspore/ccsrc/vm/vm.h b/mindspore/ccsrc/vm/vm.h index a9832ab5e..c72737a1d 100644 --- a/mindspore/ccsrc/vm/vm.h +++ b/mindspore/ccsrc/vm/vm.h @@ -115,6 +115,7 @@ class FinalVM { void InstPushPrim(const VectorRef &args); void InstSwitchReturn(const VectorRef &args); void set_insts(const InstSet &value) { insts_ = value; } + BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &args); protected: BaseRef Ref(int i); @@ -156,6 +157,7 @@ class FinalVM { {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }}, {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }}, }; + std::map _hook_grad; }; using FinalVMPtr = std::shared_ptr; diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 9953c2431..0937ac3f7 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -24,6 +24,7 @@ from .._checkparam import _check_str_by_regular from ..common.parameter import Parameter, ParameterTuple from .._c_expression import init_backend from ..ops.primitive import Primitive +from ..ops.operations import HookBackward from ..parallel._tensor import _load_tensor_by_layout from ..common.tensor import Tensor @@ -75,6 +76,9 @@ class Cell: self._parallel_inputs_run = None if flags: self.add_flags(**flags) + self._backward_hook = None + self._enable_hook = False + self._bprop_debug = False @property def create_time(self): @@ -91,6 +95,16 @@ class Cell: """ return self._param_prefix + @property + def bprop_debug(self): + return self._bprop_debug + + @bprop_debug.setter + def bprop_debug(self, value): + if not isinstance(value, bool): + raise TypeError("'bprop debug' value must be bool type.") + self._bprop_debug = value + def update_cell_prefix(self): """ Update the all child cells' self.param_prefix. @@ -728,3 +742,25 @@ class Cell: self._auto_parallel_mode = True self.add_flags(auto_parallel=True) self._get_construct_inputs_number_and_name() + + def _hook_construct(self, inputs): + """Hook construct method to replace original construct method when hook function enabled.""" + inputs = self._backward_hook(inputs) + inputs = self.construct(inputs) + outputs = self._backward_hook(inputs) + return outputs + + @property + def enable_hook(self): + """Whether the cell register hook function""" + return self._enable_hook + + def register_backward_hook(self, fn): + """ + Set the cell backward hook function. + + Args: + fn (function): Specifies the hook function with grad as input. + """ + self._backward_hook = HookBackward(fn, str(id(self))) + self._enable_hook = True diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 5151c0b36..40fdee594 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, _MirrorOperator, ReduceOp, _VirtualDataset, _VirtualDiv, _GetTensorSlice) -from .debug_ops import (ImageSummary, InsertGradientOf, ScalarSummary, +from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, TensorSummary, HistogramSummary, Print) from .control_ops import ControlDepend, GeSwitch, Merge from .inner_ops import ScalarCast @@ -155,6 +155,7 @@ __all__ = [ 'HistogramSummary', "Print", 'InsertGradientOf', + 'HookBackward', 'InvertPermutation', 'Shape', 'DropoutDoMask', diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index 6887c778e..aae2d7dcf 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -14,6 +14,7 @@ # ============================================================================ """debug_ops""" +from types import FunctionType from ..._checkparam import Validator as validator from ...common import dtype as mstype from ..primitive import Primitive, prim_attr_register, PrimitiveWithInfer @@ -193,6 +194,65 @@ class InsertGradientOf(PrimitiveWithInfer): return x_type +class HookBackward(PrimitiveWithInfer): + """ + Used as tag to hook gradient in intermediate variables. + + Note: + The hook function should have one input of gradient of the variable. + hook function will be executed in python environment, while callback + of InsertGradientOf will be parsed and added to the graph. + + Args: + hook_fn (Function): Python function. hook function. + + Inputs: + - **inputs** (Tensor) - The variable to hook. + + Examples: + >>> def hook_fn(grad_out): + >>> print(grad_out) + >>> + >>> hook = P.HookBackward(hook_fn) + >>> + >>> def hook_test(x, y): + >>> z = x * y + >>> z = hook(z) + >>> z = z * y + >>> return z + >>> + >>> def backward(x, y): + >>> return C.grad_all(hook_test)(x, y) + >>> + >>> backward(1, 2) + """ + + def __init__(self, hook_fn, cell_id=""): + super(HookBackward, self).__init__(self.__class__.__name__) + self.add_prim_attr("cell_id", cell_id) + self.init_attrs["cell_id"] = cell_id + if not isinstance(hook_fn, FunctionType): + raise TypeError("Hook function should be python function type.") + self.register_hook(hook_fn) + self.cell_id = cell_id + + def __call__(self, *inputs): + """run in PyNative mode.""" + if len(inputs) == 1: + return inputs[0] + return inputs + + def infer_shape(self, *inputs_shape): + if len(inputs_shape) == 1: + return inputs_shape[0] + return inputs_shape + + def infer_dtype(self, *inputs_type): + if len(inputs_type) == 1: + return inputs_type[0] + return inputs_type + + class Print(PrimitiveWithInfer): """ Output tensor or string to stdout. diff --git a/tests/ut/python/pynative_mode/test_hook.py b/tests/ut/python/pynative_mode/test_hook.py new file mode 100644 index 000000000..63d712876 --- /dev/null +++ b/tests/ut/python/pynative_mode/test_hook.py @@ -0,0 +1,133 @@ +import numpy as np +import mindspore.nn as nn +import mindspore.ops.operations as P +from mindspore import context +from mindspore.ops import composite as C +from mindspore.common import dtype as mstype +from mindspore import context, Tensor, ParameterTuple +from mindspore.common.initializer import TruncatedNormal +from mindspore.nn import Dense, WithLossCell, SoftmaxCrossEntropyWithLogits, Momentum + +context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + """weight initial for conv layer""" + weight = weight_variable() + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + weight_init=weight, has_bias=False, pad_mode="valid") + +def fc_with_initialize(input_channels, out_channels): + """weight initial for fc layer""" + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + +def weight_variable(): + """weight initial""" + return TruncatedNormal(0.02) + +def cell_hook_function(cell_id, grad_input, grad_output): + print(cell_id) + assert(grad_output.asnumpy().shape == (32, 6, 14, 14)) + assert(grad_input.asnumpy().shape == (32, 16, 10, 10)) + + +def var_hook_function(grad_out): + print("grad:", grad_out) + assert(grad_out.asnumpy().shape == (32, 120)) + + +class LeNet5(nn.Cell): + """ + Lenet network + Args: + num_class (int): Num classes. Default: 10. + + Returns: + Tensor, output tensor + + Examples: + >>> LeNet(num_class=10) + """ + def __init__(self, num_class=10): + super(LeNet5, self).__init__() + self.num_class = num_class + self.batch_size = 32 + self.conv1 = conv(1, 6, 5) + self.conv2 = conv(6, 16, 5) + self.conv2.register_backward_hook(cell_hook_function) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, self.num_class) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.reshape = P.Reshape() + self.hook = P.HookBackward(var_hook_function) + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.reshape(x, (self.batch_size, -1)) + x = self.fc1(x) + x = self.hook(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x + + +class GradWrap(nn.Cell): + """ GradWrap definition """ + def __init__(self, network): + super(GradWrap, self).__init__(auto_prefix=False) + self.network = network + self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) + + def construct(self, x, label): + weights = self.weights + return C.GradOperation('get_by_list', get_by_list=True)(self.network, weights)(x, label) + +def test_hook(): + net = LeNet5() + optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9) + criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=False) + net_with_criterion = WithLossCell(net, criterion) + train_network = GradWrap(net_with_criterion) + train_network.set_train() + + input_data = Tensor(np.ones([net.batch_size, 1, 32, 32]).astype(np.float32) * 0.01) + label = Tensor(np.ones([net.batch_size, net.num_class]).astype(np.float32)) + output = net(Tensor(input_data)) + loss_output = criterion(output, label) + grads = train_network(input_data, label) + success = optimizer(grads) + print(loss_output.asnumpy().shape) + + + + +class MulAdd(nn.Cell): + def __init__(self): + super(MulAdd, self).__init__() + + def construct(self, x, y): + return 2 * x + y + + def bprop(self, x, y, out, dout): + assert(x == 1) + assert(y == 2) + assert(out == 4) + assert(dout == 1) + return 3 * dout, 2 * y + +def test_custom_bprop(): + mul_add = MulAdd() + mul_add.bprop_debug = True + assert C.grad_all(mul_add)(1, 2) == (3, 4) -- GitLab