提交 11303142 编写于 作者: L lvliang

pynative-cell-hook-grad-abnormal

上级 5157063c
......@@ -170,6 +170,9 @@ bool GetPriorOp(const AnfNodePtr &x_node, CNodePtr *prior_op, bool *single_outpu
MS_EXCEPTION_IF_NULL(output_idx);
AnfNodePtr input1 = x_cnode->input(1);
MS_EXCEPTION_IF_NULL(input1);
if (!input1->isa<CNode>()) {
return false;
}
*prior_op = input1->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(*prior_op);
AnfNodePtr input2 = x_cnode->input(2);
......
......@@ -762,5 +762,5 @@ class Cell:
Args:
fn (function): Specifies the hook function with grad as input.
"""
self._backward_hook = HookBackward(fn, str(id(self)))
self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")")
self._enable_hook = True
......@@ -14,7 +14,7 @@
# ============================================================================
"""debug_ops"""
from types import FunctionType
from types import FunctionType, MethodType
from ..._checkparam import Validator as validator
from ...common import dtype as mstype
from ..primitive import prim_attr_register, PrimitiveWithInfer
......@@ -279,7 +279,7 @@ class HookBackward(PrimitiveWithInfer):
super(HookBackward, self).__init__(self.__class__.__name__)
self.add_prim_attr("cell_id", cell_id)
self.init_attrs["cell_id"] = cell_id
if not isinstance(hook_fn, FunctionType):
if not isinstance(hook_fn, (FunctionType, MethodType)):
raise TypeError("Hook function should be python function type.")
self.register_hook(hook_fn)
self.cell_id = cell_id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册