未验证 提交 31140398 编写于 作者: J Jason 提交者: GitHub

Merge pull request #150 from jiangjiajun/develop

support new situation
......@@ -104,6 +104,7 @@ def tf2paddle(model_path,
# neccesary optimization
optimizer.delete_redundance_code()
# optimizer below is experimental
optimizer.optimize_elementwise_op()
optimizer.merge_activation()
optimizer.merge_bias()
optimizer.optimize_sub_graph()
......
......@@ -60,6 +60,15 @@ class TFGraphNode(GraphNode):
raise Exception("Dtype[{}] not in dtype_map".format(dtype))
return self.dtype_map[dtype]
@property
def raw_dtype(self):
keys = ['dtype', 'Tidx', 'T', 'DstT']
for k in keys:
dtype = self.layer.attr[k].type
if dtype > 0:
break
return dtype
@property
def value(self):
assert self.layer_type == "Const", "Only Const node has value."
......@@ -120,6 +129,7 @@ class TFGraph(Graph):
# tensorflow graph optimize
self._remove_isolated_node()
self._remove_identity_node()
self._remove_cast_node()
def get_node(self, node_name, copy=False):
items = node_name.strip().split(':')
......@@ -190,6 +200,27 @@ class TFGraph(Graph):
idx = self.output_nodes.index(node_name)
self.output_nodes[idx] = input_node.layer_name
def _remove_cast_node(self):
cast_node = list()
for node_name, node in self.node_map.items():
if node.layer_type == "Cast":
input = self.get_node(node.inputs[0])
if input.layer_type != "Placeholder" or len(input.outputs) != 1:
continue
cast_node.append(node_name)
for node_name in cast_node:
node = self.get_node(node_name)
input_node = self.get_node(node.inputs[0])
input_node.layer.attr["dtype"].type = node.raw_dtype
self.remove_node(node_name)
self.identity_map[node_name] = input_node.layer_name
if node_name in self.output_nodes:
idx = self.output_nodes.index(node_name)
self.output_nodes[idx] = input_node.layer_name
def data_format_propagation(self, node):
current_node = self.node_map[node.layer_name]
current_node = node.tf_data_format
......
......@@ -170,7 +170,28 @@ class TFOpMapper(OpMapper):
x_shape = y.out_shapes[0]
y_shape = x.out_shapes[0]
else:
raise Exception("Unexpected situation happend")
if len(x_shape) == 1 and len(y_shape) == 4 and x_shape[
0] == y_shape[-1] and y_shape.count(-1) < 1:
shape = [1, x_shape[0], 1, 1]
attr = {"shape": shape}
node.fluid_code.add_layer("reshape",
inputs=x_input,
output="reshape_x",
param_attr=attr)
if y_shape[0] != 1:
attr = {"expand_times": [y_shape[0], 1, 1, 1]}
node.fluid_code.add_layer("expand",
inputs="reshape_x",
output="reshape_x",
param_attr=attr)
inputs = {"x": "reshape_x", "y": y_input}
node.fluid_code.add_layer(op_type,
inputs=inputs,
output=node,
param_attr=None)
return
else:
raise Exception("Unexpected situation happend")
if len(x_shape) == 4 and len(y_shape) == 1:
if x_input.tf_data_format == "NHWC":
......
......@@ -16,6 +16,7 @@
from x2paddle.op_mapper.tf_op_mapper import TFOpMapper
from x2paddle.core.fluid_code import Layer
from x2paddle.core.util import *
import six
import numpy
import copy as cp
......@@ -104,6 +105,59 @@ class TFOptimizer(object):
del out_node.inputs[index]
del self.graph.node_map[node_name]
def optimize_elementwise_op(self):
elementwise_ops = [
'Sub', 'Add', 'RealDiv', 'Maximum', 'Mul', 'FloorDiv',
'GreaterEqual'
]
revertable_ops = ['Add', 'Mul']
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
if node is None:
continue
if node.layer_type in elementwise_ops:
if len(node.fluid_code.layers) != 2:
continue
if node.fluid_code.layers[0].op != "expand":
continue
expand_out = node.fluid_code.layers[0].output
expand_in = node.fluid_code.layers[0].inputs
expand_times = node.fluid_code.layers[0].param_attr[
"expand_times"]
x = node.fluid_code.layers[1].inputs["x"]
y = node.fluid_code.layers[1].inputs["y"]
if isinstance(
x,
six.string_types) and node.layer_type in revertable_ops:
node.fluid_code.layers[1].inputs["y"] = x
node.fluid_code.layers[1].inputs["x"] = y
x = node.fluid_code.layers[1].inputs["x"]
y = expand_in
elif isinstance(y, six.string_types):
y = expand_in
else:
continue
x_shape = x.out_shapes[0]
y_shape = y.out_shapes[0]
if len(x_shape) != len(y_shape):
continue
if len(x_shape) == 4:
x_shape = [x_shape[i] for i in [0, 3, 1, 2]]
y_shape = [y_shape[i] for i in [0, 3, 1, 2]]
continue_flag = True
for i in range(len(x_shape)):
if y_shape[-1 * (i + 1)] == 1 and continue_flag:
expand_times[-1 * (i + 1)] = 1
else:
continue_flag = False
if expand_times.count(1) == len(expand_times):
node.fluid_code.layers[1].inputs["y"] = expand_in
del node.fluid_code.layers[0]
def merge_activation(self):
act_nodes = list()
for node_name in self.graph.topo_sort:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册