From 45de931276ab634b12709a927ec97dc52f378ff7 Mon Sep 17 00:00:00 2001 From: wuhuanzhou Date: Mon, 11 Oct 2021 20:00:34 +0800 Subject: [PATCH] [cherry-pick]fix hasattr(paddle.fluid.ir.PassDesc.OP, '__name__') error (#36294) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 对于__getattr__重载后不满足条件的参数,全部抛出AttributeError异常,达到与未重载版本一致。 (cherry picked from PR #36229) --- python/paddle/fluid/ir.py | 10 ++++++---- .../fluid/tests/unittests/ir/test_ir_generate_pass.py | 3 +++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/ir.py b/python/paddle/fluid/ir.py index 17b7ea1122..7e2d3df1ce 100644 --- a/python/paddle/fluid/ir.py +++ b/python/paddle/fluid/ir.py @@ -230,9 +230,6 @@ class PassDesc(object): self._type = type def __getattr__(self, name): - if self._type is not None: - raise AttributeError( - "type object 'OpHelper' has no attribute '{}'".format(name)) op = PassDesc.OpHelper(name) op.Init() return op @@ -261,7 +258,12 @@ class PassDesc(object): self._op_idx = len(block.ops) self._op_desc = block.desc.append_op() self._op_desc.set_type(self._type) - self._op_proto = OpProtoHolder.instance().get_op_proto(self._type) + self._op_proto = OpProtoHolder.instance().op_proto_map.get( + self._type) + if self._op_proto is None: + raise AttributeError( + "type object 'OpHelper' has no attribute '{}'".format( + self._type)) block.ops.append(self) def Attr(self, name): diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py index c8b9d5e573..851ae21c38 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py @@ -123,6 +123,9 @@ class TestGeneratePass(unittest.TestCase): op_dicts[op.type] = [op] return op_dicts + def test_has_attr(self): + self.assertFalse(hasattr(ir.PassDesc.OP, '__name__')) + def test_generate_fc_fuse(self): def _check_fc_fuse_pass(pass_desc, with_relu): pattern_op_dicts = self.convert_ops_to_op_dicts( -- GitLab