提交 b482fc35 编写于 作者: J jiangjiajun

optimize cast op

上级 843d740b
......@@ -60,6 +60,15 @@ class TFGraphNode(GraphNode):
raise Exception("Dtype[{}] not in dtype_map".format(dtype))
return self.dtype_map[dtype]
@property
def raw_dtype(self):
keys = ['dtype', 'Tidx', 'T', 'DstT']
for k in keys:
dtype = self.layer.attr[k].type
if dtype > 0:
break
return dtype
@property
def value(self):
assert self.layer_type == "Const", "Only Const node has value."
......@@ -120,6 +129,7 @@ class TFGraph(Graph):
# tensorflow graph optimize
self._remove_isolated_node()
self._remove_identity_node()
self._remove_cast_node()
def get_node(self, node_name, copy=False):
items = node_name.strip().split(':')
......@@ -190,6 +200,27 @@ class TFGraph(Graph):
idx = self.output_nodes.index(node_name)
self.output_nodes[idx] = input_node.layer_name
def _remove_cast_node(self):
cast_node = list()
for node_name, node in self.node_map.items():
if node.layer_type == "Cast":
input = self.get_node(node.inputs[0])
if input.layer_type != "Placeholder" or len(input.outputs) != 1:
continue
cast_node.append(node_name)
for node_name in cast_node:
node = self.get_node(node_name)
input_node = self.get_node(node.inputs[0])
input_node.layer.attr["dtype"].type = node.raw_dtype
self.remove_node(node_name)
self.identity_map[node_name] = input_node.layer_name
if node_name in self.output_nodes:
idx = self.output_nodes.index(node_name)
self.output_nodes[idx] = input_node.layer_name
def data_format_propagation(self, node):
current_node = self.node_map[node.layer_name]
current_node = node.tf_data_format
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册