From e221a6002bfd93a9f56792f4632bd590890c1fc9 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Tue, 30 Aug 2022 14:30:17 +0800 Subject: [PATCH] [OpAtttr]Add attribute var interface for Operator class (#45525) * [OpAtttr]Add attribute var interface for Operator class * fix unittest * fix unittest --- python/paddle/fluid/framework.py | 52 ++++++++++++++++--- .../tests/unittests/test_attribute_var.py | 4 ++ .../fluid/tests/unittests/test_reverse_op.py | 6 +++ 3 files changed, 54 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 409eb020d39..5f62e5d3ea5 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -3138,7 +3138,7 @@ class Operator(object): Returns: core.AttrType: the attribute type. """ - return self.desc.attr_type(name) + return self.desc.attr_type(name, True) def _set_attr(self, name, val): """ @@ -3290,6 +3290,41 @@ class Operator(object): return self.desc._blocks_attr_ids(name) + def _var_attr(self, name): + """ + Get the Variable attribute by name. + + Args: + name(str): the attribute name. + + Returns: + Variable: the Variable attribute. + """ + attr_type = self.desc.attr_type(name, True) + assert attr_type == core.AttrType.VAR, "Required type attr({}) is Variable, but received {}".format( + name, attr_type) + attr_var_name = self.desc.attr(name, True).name() + return self.block._var_recursive(attr_var_name) + + def _vars_attr(self, name): + """ + Get the Variables attribute by name. + + Args: + name(str): the attribute name. + + Returns: + Variables: the Variables attribute. + """ + attr_type = self.desc.attr_type(name, True) + assert attr_type == core.AttrType.VARS, "Required type attr({}) is list[Variable], but received {}".format( + name, attr_type) + attr_vars = [ + self.block._var_recursive(var.name()) + for var in self.desc.attr(name, True) + ] + return attr_vars + def all_attrs(self): """ Get the attribute dict. @@ -3300,16 +3335,17 @@ class Operator(object): attr_names = self.attr_names attr_map = {} for n in attr_names: - attr_type = self.desc.attr_type(n) + attr_type = self.desc.attr_type(n, True) if attr_type == core.AttrType.BLOCK: attr_map[n] = self._block_attr(n) - continue - - if attr_type == core.AttrType.BLOCKS: + elif attr_type == core.AttrType.BLOCKS: attr_map[n] = self._blocks_attr(n) - continue - - attr_map[n] = self.attr(n) + elif attr_type == core.AttrType.VAR: + attr_map[n] = self._var_attr(n) + elif attr_type == core.AttrType.VARS: + attr_map[n] = self._vars_attr(n) + else: + attr_map[n] = self.attr(n) return attr_map diff --git a/python/paddle/fluid/tests/unittests/test_attribute_var.py b/python/paddle/fluid/tests/unittests/test_attribute_var.py index cabbfb826b5..6e8e3c66750 100644 --- a/python/paddle/fluid/tests/unittests/test_attribute_var.py +++ b/python/paddle/fluid/tests/unittests/test_attribute_var.py @@ -96,6 +96,10 @@ class TestDropout(UnittestBase): infer_out = self.infer_prog() self.assertEqual(infer_out.shape, (10, 10)) + self.assertEqual( + main_prog.block(0).ops[4].all_attrs()['dropout_prob'].name, + p.name) + class TestTileTensorList(UnittestBase): diff --git a/python/paddle/fluid/tests/unittests/test_reverse_op.py b/python/paddle/fluid/tests/unittests/test_reverse_op.py index e2260082fc9..f090cf1c8de 100644 --- a/python/paddle/fluid/tests/unittests/test_reverse_op.py +++ b/python/paddle/fluid/tests/unittests/test_reverse_op.py @@ -258,6 +258,12 @@ class TestReverseAxisListTensor(TestReverseAxisTensor): # axes is a List[Variable] axes = [paddle.assign([0]), paddle.assign([2])] out = paddle.fluid.layers.reverse(x, axes) + + # check attrs + axis_attrs = paddle.static.default_main_program().block( + 0).ops[-1].all_attrs()["axis"] + self.assertTrue(axis_attrs[0].name, axes[0].name) + self.assertTrue(axis_attrs[1].name, axes[1].name) return out -- GitLab