From 0bf809c9b39ac729d1fc1fcdc3feee73eb1028ba Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Mon, 25 Feb 2019 15:37:00 +0800 Subject: [PATCH] add set_attr for IrOpNode. test=develop --- paddle/fluid/platform/CMakeLists.txt | 2 +- python/paddle/fluid/framework.py | 30 +++++++++++++++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index b7e84031e..5833fee35 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -87,7 +87,7 @@ nv_test(transform_test SRCS transform_test.cu DEPS memory place device_context) cc_library(timer SRCS timer.cc) cc_test(timer_test SRCS timer_test.cc DEPS timer) -cc_library(device_tracer SRCS device_tracer.cc DEPS boost profiler_proto framework_proto device_context ${GPU_CTX_DEPS}) +cc_library(device_tracer SRCS device_tracer.cc DEPS boost profiler_proto framework_proto ${GPU_CTX_DEPS}) if(WITH_GPU) nv_library(profiler SRCS profiler.cc profiler.cu DEPS device_context device_tracer) else() diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 8c62d2f28..b6babf5d0 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1633,7 +1633,7 @@ class IrNode(object): """ self.node.clear_inputs() - def inputs_remove_by_id(self, node_id): + def remove_input_by_id(self, node_id): """ Remove a node from inputs by the given node id. @@ -1876,6 +1876,34 @@ class IrOpNode(IrNode): "The node operator description cannot be None." return self.node.op().set_type(new_type) + def set_attr(self, name, val): + """ + Set the value of attribute by attribute's name. + + Args: + name(str): the attribute name. + val(bool|int|str|float|list): the value of the attribute. + """ + self._update_desc_attr(name, val) + + def _update_desc_attr(self, name, val): + """ + Update the value of the op desc's attribute by attribute's name. + """ + assert self.node.op() is not None, \ + "The node operator description cannot be None." + desc = self.node.op() + if isinstance(val, Block): + desc.set_block_attr(name, val.desc) + elif isinstance(val, list) and val and \ + all(isinstance(v, Block) for v in val): + desc.set_blocks_attr(name, [v.desc for v in val]) + elif isinstance(val, core.BlockDesc) or \ + isinstance(val, core.ProgramDesc): + desc.set_serialized_attr(name, val.serialize_to_string()) + else: + desc._set_attr(name, val) + @property def inputs(self): """ -- GitLab