diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc index bc34ae141333f7532c29184aeb973077f1c6b24b..dc47757e5d147c0d99d4c7b33751f1f3a19b16e2 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc @@ -170,6 +170,9 @@ bool GetPriorOp(const AnfNodePtr &x_node, CNodePtr *prior_op, bool *single_outpu MS_EXCEPTION_IF_NULL(output_idx); AnfNodePtr input1 = x_cnode->input(1); MS_EXCEPTION_IF_NULL(input1); + if (!input1->isa<CNode>()) { + return false; + } *prior_op = input1->cast<CNodePtr>(); MS_EXCEPTION_IF_NULL(*prior_op); AnfNodePtr input2 = x_cnode->input(2); diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 0937ac3f7be2c246e83805202d4f394c771c679c..e0563a05fab608d5c276840d622f4dcec8bb8d59 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -762,5 +762,5 @@ class Cell: Args: fn (function): Specifies the hook function with grad as input. """ - self._backward_hook = HookBackward(fn, str(id(self))) + self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")") self._enable_hook = True diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index 0a9a51e714298e028d5e175031235255db724bc4..957e7a4666f8bb5e3f648cb0e0c8c3b6827c2d58 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -14,7 +14,7 @@ # ============================================================================ """debug_ops""" -from types import FunctionType +from types import FunctionType, MethodType from ..._checkparam import Validator as validator from ...common import dtype as mstype from ..primitive import prim_attr_register, PrimitiveWithInfer @@ -279,7 +279,7 @@ class HookBackward(PrimitiveWithInfer): super(HookBackward, self).__init__(self.__class__.__name__) self.add_prim_attr("cell_id", cell_id) self.init_attrs["cell_id"] = cell_id - if not isinstance(hook_fn, FunctionType): + if not isinstance(hook_fn, (FunctionType, MethodType)): raise TypeError("Hook function should be python function type.") self.register_hook(hook_fn) self.cell_id = cell_id