提交 843d740b 编写于 作者: J jiangjiajun

optimize for elementwise op

上级 fa6ada39
......@@ -104,6 +104,7 @@ def tf2paddle(model_path,
# neccesary optimization
optimizer.delete_redundance_code()
# optimizer below is experimental
optimizer.optimize_elementwise_op()
optimizer.merge_activation()
optimizer.merge_bias()
optimizer.optimize_sub_graph()
......
......@@ -16,6 +16,7 @@
from x2paddle.op_mapper.tf_op_mapper import TFOpMapper
from x2paddle.core.fluid_code import Layer
from x2paddle.core.util import *
import six
import numpy
import copy as cp
......@@ -104,6 +105,59 @@ class TFOptimizer(object):
del out_node.inputs[index]
del self.graph.node_map[node_name]
def optimize_elementwise_op(self):
elementwise_ops = [
'Sub', 'Add', 'RealDiv', 'Maximum', 'Mul', 'FloorDiv',
'GreaterEqual'
]
revertable_ops = ['Add', 'Mul']
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
if node is None:
continue
if node.layer_type in elementwise_ops:
if len(node.fluid_code.layers) != 2:
continue
if node.fluid_code.layers[0].op != "expand":
continue
expand_out = node.fluid_code.layers[0].output
expand_in = node.fluid_code.layers[0].inputs
expand_times = node.fluid_code.layers[0].param_attr[
"expand_times"]
x = node.fluid_code.layers[1].inputs["x"]
y = node.fluid_code.layers[1].inputs["y"]
if isinstance(
x,
six.string_types) and node.layer_type in revertable_ops:
node.fluid_code.layers[1].inputs["y"] = x
node.fluid_code.layers[1].inputs["x"] = y
x = node.fluid_code.layers[1].inputs["x"]
y = expand_in
elif isinstance(y, six.string_types):
y = expand_in
else:
continue
x_shape = x.out_shapes[0]
y_shape = y.out_shapes[0]
if len(x_shape) != len(y_shape):
continue
if len(x_shape) == 4:
x_shape = [x_shape[i] for i in [0, 3, 1, 2]]
y_shape = [y_shape[i] for i in [0, 3, 1, 2]]
continue_flag = True
for i in range(len(x_shape)):
if y_shape[-1 * (i + 1)] == 1 and continue_flag:
expand_times[-1 * (i + 1)] = 1
else:
continue_flag = False
if expand_times.count(1) == len(expand_times):
node.fluid_code.layers[1].inputs["y"] = expand_in
del node.fluid_code.layers[0]
def merge_activation(self):
act_nodes = list()
for node_name in self.graph.topo_sort:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册