提交 0bf809c9 编写于 作者: Z Zhen Wang

add set_attr for IrOpNode. test=develop

上级 7c8f7df2
...@@ -87,7 +87,7 @@ nv_test(transform_test SRCS transform_test.cu DEPS memory place device_context) ...@@ -87,7 +87,7 @@ nv_test(transform_test SRCS transform_test.cu DEPS memory place device_context)
cc_library(timer SRCS timer.cc) cc_library(timer SRCS timer.cc)
cc_test(timer_test SRCS timer_test.cc DEPS timer) cc_test(timer_test SRCS timer_test.cc DEPS timer)
cc_library(device_tracer SRCS device_tracer.cc DEPS boost profiler_proto framework_proto device_context ${GPU_CTX_DEPS}) cc_library(device_tracer SRCS device_tracer.cc DEPS boost profiler_proto framework_proto ${GPU_CTX_DEPS})
if(WITH_GPU) if(WITH_GPU)
nv_library(profiler SRCS profiler.cc profiler.cu DEPS device_context device_tracer) nv_library(profiler SRCS profiler.cc profiler.cu DEPS device_context device_tracer)
else() else()
......
...@@ -1633,7 +1633,7 @@ class IrNode(object): ...@@ -1633,7 +1633,7 @@ class IrNode(object):
""" """
self.node.clear_inputs() self.node.clear_inputs()
def inputs_remove_by_id(self, node_id): def remove_input_by_id(self, node_id):
""" """
Remove a node from inputs by the given node id. Remove a node from inputs by the given node id.
...@@ -1876,6 +1876,34 @@ class IrOpNode(IrNode): ...@@ -1876,6 +1876,34 @@ class IrOpNode(IrNode):
"The node operator description cannot be None." "The node operator description cannot be None."
return self.node.op().set_type(new_type) return self.node.op().set_type(new_type)
def set_attr(self, name, val):
"""
Set the value of attribute by attribute's name.
Args:
name(str): the attribute name.
val(bool|int|str|float|list): the value of the attribute.
"""
self._update_desc_attr(name, val)
def _update_desc_attr(self, name, val):
"""
Update the value of the op desc's attribute by attribute's name.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
desc = self.node.op()
if isinstance(val, Block):
desc.set_block_attr(name, val.desc)
elif isinstance(val, list) and val and \
all(isinstance(v, Block) for v in val):
desc.set_blocks_attr(name, [v.desc for v in val])
elif isinstance(val, core.BlockDesc) or \
isinstance(val, core.ProgramDesc):
desc.set_serialized_attr(name, val.serialize_to_string())
else:
desc._set_attr(name, val)
@property @property
def inputs(self): def inputs(self):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册