diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 4c1c8443a641cde40c392f1c647bc78d6cd3c13c..eee64fc050908d5ee6ac4c8e46be004377cb98e8 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -559,19 +559,8 @@ class Operator(object): self.attrs[attr_name] is None): continue attr_val = self.attrs[attr_name] - if isinstance(attr_val, Block): - self.desc.set_block_attr(attr_name, - self.attrs[attr_name].desc) - elif isinstance(attr_val, list) and attr_val and \ - all(isinstance(v, Block) for v in attr_val): - self.desc.set_blocks_attr(attr_name, - [v.desc for v in attr_val]) - elif isinstance(attr_val, core.BlockDesc) or \ - isinstance(attr_val, core.ProgramDesc): - self.desc.set_serialized_attr( - attr_name, attr_val.serialize_to_string()) - else: - self.desc.set_attr(attr_name, attr_val) + self._update_desc_attr(attr_name, attr_val) + self.desc.check_attrs() if self.has_kernel(type): self.desc.infer_var_type(self.block.desc) @@ -718,6 +707,19 @@ class Operator(object): ValueError: If the type of value doesn't match with desc.attr_type(name). """ self.attrs[name] = val + self._update_desc_attr(name, val) + + def _update_desc_attr(self, name, val): + """ + Update the value of desc's attribute by attribute's name. + + Args: + name(str): the attribute name. + val(bool|int|str|float|list): the value of the attribute. + + Raises: + ValueError: If the type of value doesn't match with desc.attr_type(name). + """ if isinstance(val, Block): self.desc.set_block_attr(name, val.desc) elif isinstance(val, list) and val and all(