提交 702fcbbe 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1467 Pynative can not add cell hook

Merge pull request !1467 from JoyLvliang/r0.3
...@@ -170,6 +170,9 @@ bool GetPriorOp(const AnfNodePtr &x_node, CNodePtr *prior_op, bool *single_outpu ...@@ -170,6 +170,9 @@ bool GetPriorOp(const AnfNodePtr &x_node, CNodePtr *prior_op, bool *single_outpu
MS_EXCEPTION_IF_NULL(output_idx); MS_EXCEPTION_IF_NULL(output_idx);
AnfNodePtr input1 = x_cnode->input(1); AnfNodePtr input1 = x_cnode->input(1);
MS_EXCEPTION_IF_NULL(input1); MS_EXCEPTION_IF_NULL(input1);
if (!input1->isa<CNode>()) {
return false;
}
*prior_op = input1->cast<CNodePtr>(); *prior_op = input1->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(*prior_op); MS_EXCEPTION_IF_NULL(*prior_op);
AnfNodePtr input2 = x_cnode->input(2); AnfNodePtr input2 = x_cnode->input(2);
......
...@@ -762,5 +762,5 @@ class Cell: ...@@ -762,5 +762,5 @@ class Cell:
Args: Args:
fn (function): Specifies the hook function with grad as input. fn (function): Specifies the hook function with grad as input.
""" """
self._backward_hook = HookBackward(fn, str(id(self))) self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")")
self._enable_hook = True self._enable_hook = True
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# ============================================================================ # ============================================================================
"""debug_ops""" """debug_ops"""
from types import FunctionType from types import FunctionType, MethodType
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
from ...common import dtype as mstype from ...common import dtype as mstype
from ..primitive import prim_attr_register, PrimitiveWithInfer from ..primitive import prim_attr_register, PrimitiveWithInfer
...@@ -279,7 +279,7 @@ class HookBackward(PrimitiveWithInfer): ...@@ -279,7 +279,7 @@ class HookBackward(PrimitiveWithInfer):
super(HookBackward, self).__init__(self.__class__.__name__) super(HookBackward, self).__init__(self.__class__.__name__)
self.add_prim_attr("cell_id", cell_id) self.add_prim_attr("cell_id", cell_id)
self.init_attrs["cell_id"] = cell_id self.init_attrs["cell_id"] = cell_id
if not isinstance(hook_fn, FunctionType): if not isinstance(hook_fn, (FunctionType, MethodType)):
raise TypeError("Hook function should be python function type.") raise TypeError("Hook function should be python function type.")
self.register_hook(hook_fn) self.register_hook(hook_fn)
self.cell_id = cell_id self.cell_id = cell_id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册