未验证 提交 2be69d05 编写于 作者: Y YuanRisheng 提交者: GitHub

[Save/Load]Fix backward op's error when use jit.load (#50744)

* perfect translated layer

* perfect code according comment
上级 31e465e1
...@@ -2773,6 +2773,9 @@ class OpProtoHolder: ...@@ -2773,6 +2773,9 @@ class OpProtoHolder:
return custom_op_names return custom_op_names
def has_op_proto(self, type):
return type in self.op_proto_map
@staticmethod @staticmethod
def generated_op_attr_names(): def generated_op_attr_names():
return { return {
......
...@@ -563,6 +563,11 @@ class _ProgramHolder: ...@@ -563,6 +563,11 @@ class _ProgramHolder:
op.desc.set_output("ReserveSpace", [reserve_space.name]) op.desc.set_output("ReserveSpace", [reserve_space.name])
continue continue
# There are some situations that users will add backward op in Forward
# function of Layer. And because backward op doesn't have proto. So, we
# should skip it when we meet it.
if not OpProtoHolder.instance().has_op_proto(op.type):
continue
proto = OpProtoHolder.instance().get_op_proto(op.type) proto = OpProtoHolder.instance().get_op_proto(op.type)
has_create_intermediate_out = False has_create_intermediate_out = False
for output_proto in proto.outputs: for output_proto in proto.outputs:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册