diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 409eb020d39601b3a4ee81e880fafcf236c2d0ce..5f62e5d3ea5ce484b2a40695a74feaa37d30cda4 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -3138,7 +3138,7 @@ class Operator(object): Returns: core.AttrType: the attribute type. """ - return self.desc.attr_type(name) + return self.desc.attr_type(name, True) def _set_attr(self, name, val): """ @@ -3290,6 +3290,41 @@ class Operator(object): return self.desc._blocks_attr_ids(name) + def _var_attr(self, name): + """ + Get the Variable attribute by name. + + Args: + name(str): the attribute name. + + Returns: + Variable: the Variable attribute. + """ + attr_type = self.desc.attr_type(name, True) + assert attr_type == core.AttrType.VAR, "Required type attr({}) is Variable, but received {}".format( + name, attr_type) + attr_var_name = self.desc.attr(name, True).name() + return self.block._var_recursive(attr_var_name) + + def _vars_attr(self, name): + """ + Get the Variables attribute by name. + + Args: + name(str): the attribute name. + + Returns: + Variables: the Variables attribute. + """ + attr_type = self.desc.attr_type(name, True) + assert attr_type == core.AttrType.VARS, "Required type attr({}) is list[Variable], but received {}".format( + name, attr_type) + attr_vars = [ + self.block._var_recursive(var.name()) + for var in self.desc.attr(name, True) + ] + return attr_vars + def all_attrs(self): """ Get the attribute dict. @@ -3300,16 +3335,17 @@ class Operator(object): attr_names = self.attr_names attr_map = {} for n in attr_names: - attr_type = self.desc.attr_type(n) + attr_type = self.desc.attr_type(n, True) if attr_type == core.AttrType.BLOCK: attr_map[n] = self._block_attr(n) - continue - - if attr_type == core.AttrType.BLOCKS: + elif attr_type == core.AttrType.BLOCKS: attr_map[n] = self._blocks_attr(n) - continue - - attr_map[n] = self.attr(n) + elif attr_type == core.AttrType.VAR: + attr_map[n] = self._var_attr(n) + elif attr_type == core.AttrType.VARS: + attr_map[n] = self._vars_attr(n) + else: + attr_map[n] = self.attr(n) return attr_map diff --git a/python/paddle/fluid/tests/unittests/test_attribute_var.py b/python/paddle/fluid/tests/unittests/test_attribute_var.py index cabbfb826b53b6b464962fb74a126420b4075481..6e8e3c6675087c7df463b304bd3834a461689e05 100644 --- a/python/paddle/fluid/tests/unittests/test_attribute_var.py +++ b/python/paddle/fluid/tests/unittests/test_attribute_var.py @@ -96,6 +96,10 @@ class TestDropout(UnittestBase): infer_out = self.infer_prog() self.assertEqual(infer_out.shape, (10, 10)) + self.assertEqual( + main_prog.block(0).ops[4].all_attrs()['dropout_prob'].name, + p.name) + class TestTileTensorList(UnittestBase): diff --git a/python/paddle/fluid/tests/unittests/test_reverse_op.py b/python/paddle/fluid/tests/unittests/test_reverse_op.py index e2260082fc9686d1e0b731df586b130266dc3723..f090cf1c8de11e48b57a5c1e3efa8dd51326c77b 100644 --- a/python/paddle/fluid/tests/unittests/test_reverse_op.py +++ b/python/paddle/fluid/tests/unittests/test_reverse_op.py @@ -258,6 +258,12 @@ class TestReverseAxisListTensor(TestReverseAxisTensor): # axes is a List[Variable] axes = [paddle.assign([0]), paddle.assign([2])] out = paddle.fluid.layers.reverse(x, axes) + + # check attrs + axis_attrs = paddle.static.default_main_program().block( + 0).ops[-1].all_attrs()["axis"] + self.assertTrue(axis_attrs[0].name, axes[0].name) + self.assertTrue(axis_attrs[1].name, axes[1].name) return out