未验证 提交 e221a600 编写于 作者: A Aurelius84 提交者: GitHub

[OpAtttr]Add attribute var interface for Operator class (#45525)

* [OpAtttr]Add attribute var interface for Operator class

* fix unittest

* fix unittest
上级 fe321f9a
......@@ -3138,7 +3138,7 @@ class Operator(object):
Returns:
core.AttrType: the attribute type.
"""
return self.desc.attr_type(name)
return self.desc.attr_type(name, True)
def _set_attr(self, name, val):
"""
......@@ -3290,6 +3290,41 @@ class Operator(object):
return self.desc._blocks_attr_ids(name)
def _var_attr(self, name):
"""
Get the Variable attribute by name.
Args:
name(str): the attribute name.
Returns:
Variable: the Variable attribute.
"""
attr_type = self.desc.attr_type(name, True)
assert attr_type == core.AttrType.VAR, "Required type attr({}) is Variable, but received {}".format(
name, attr_type)
attr_var_name = self.desc.attr(name, True).name()
return self.block._var_recursive(attr_var_name)
def _vars_attr(self, name):
"""
Get the Variables attribute by name.
Args:
name(str): the attribute name.
Returns:
Variables: the Variables attribute.
"""
attr_type = self.desc.attr_type(name, True)
assert attr_type == core.AttrType.VARS, "Required type attr({}) is list[Variable], but received {}".format(
name, attr_type)
attr_vars = [
self.block._var_recursive(var.name())
for var in self.desc.attr(name, True)
]
return attr_vars
def all_attrs(self):
"""
Get the attribute dict.
......@@ -3300,16 +3335,17 @@ class Operator(object):
attr_names = self.attr_names
attr_map = {}
for n in attr_names:
attr_type = self.desc.attr_type(n)
attr_type = self.desc.attr_type(n, True)
if attr_type == core.AttrType.BLOCK:
attr_map[n] = self._block_attr(n)
continue
if attr_type == core.AttrType.BLOCKS:
elif attr_type == core.AttrType.BLOCKS:
attr_map[n] = self._blocks_attr(n)
continue
attr_map[n] = self.attr(n)
elif attr_type == core.AttrType.VAR:
attr_map[n] = self._var_attr(n)
elif attr_type == core.AttrType.VARS:
attr_map[n] = self._vars_attr(n)
else:
attr_map[n] = self.attr(n)
return attr_map
......
......@@ -96,6 +96,10 @@ class TestDropout(UnittestBase):
infer_out = self.infer_prog()
self.assertEqual(infer_out.shape, (10, 10))
self.assertEqual(
main_prog.block(0).ops[4].all_attrs()['dropout_prob'].name,
p.name)
class TestTileTensorList(UnittestBase):
......
......@@ -258,6 +258,12 @@ class TestReverseAxisListTensor(TestReverseAxisTensor):
# axes is a List[Variable]
axes = [paddle.assign([0]), paddle.assign([2])]
out = paddle.fluid.layers.reverse(x, axes)
# check attrs
axis_attrs = paddle.static.default_main_program().block(
0).ops[-1].all_attrs()["axis"]
self.assertTrue(axis_attrs[0].name, axes[0].name)
self.assertTrue(axis_attrs[1].name, axes[1].name)
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册