未验证 提交 8fecf767 编写于 作者: J Jason 提交者: GitHub

Merge pull request #68 from jiangjiajun/develop

add optimizer for tf2fluid
1. 增加逻辑,处理NHWC格式模型
2. bias和激活函数合入前置layer
...@@ -67,10 +67,17 @@ def tf2paddle(model_path, save_dir): ...@@ -67,10 +67,17 @@ def tf2paddle(model_path, save_dir):
from x2paddle.decoder.tf_decoder import TFDecoder from x2paddle.decoder.tf_decoder import TFDecoder
from x2paddle.op_mapper.tf_op_mapper import TFOpMapper from x2paddle.op_mapper.tf_op_mapper import TFOpMapper
from x2paddle.optimizer.tf_optimizer import TFOptimizer
print("Now translating model from tensorflow to paddle.") print("Now translating model from tensorflow to paddle.")
model = TFDecoder(model_path) model = TFDecoder(model_path)
mapper = TFOpMapper(model) mapper = TFOpMapper(model)
optimizer = TFOptimizer(mapper)
# neccesary optimization
optimizer.delete_redundance_code()
# optimizer below is experimental
optimizer.merge_activation()
optimizer.merge_bias()
mapper.save_inference_model(save_dir) mapper.save_inference_model(save_dir)
......
...@@ -99,29 +99,6 @@ class Graph(object): ...@@ -99,29 +99,6 @@ class Graph(object):
self.node_map[dst].inputs.append(src) self.node_map[dst].inputs.append(src)
self.node_map[src].outputs.append(dst) self.node_map[src].outputs.append(dst)
def remove_node(self, node_name):
if node_name not in self.node_map:
raise Exception("Node[{}] not in graph".format(node_name))
inputs = self.node_map[node_name].inputs
outputs = self.node_map[node_name].outputs
for input in inputs:
idx = self.node_map[input].outputs.index(node_name)
del self.node_map[input].outputs[idx]
for output in outputs:
idx = self.node_map[input].inputs.index(node_name)
del self.node_map[input].inputs[idx]
del self.node_map[node_name]
idx = self.topo_sort.index(node_name)
del self.topo_sort[idx]
if node_name in self.input_nodes:
idx = self.input_nodes.index(node_name)
del self.input_nodes[idx]
if node_name in self.output_nodes:
idx = self.output_nodes.index(node_name)
del self.output_nodes[idx]
def print(self): def print(self):
for i, tmp in enumerate(self.topo_sort): for i, tmp in enumerate(self.topo_sort):
print(tmp, self.node_map[tmp].layer_type, self.node_map[tmp].inputs, print(tmp, self.node_map[tmp].layer_type, self.node_map[tmp].inputs,
......
...@@ -116,7 +116,7 @@ class OpMapper(object): ...@@ -116,7 +116,7 @@ class OpMapper(object):
feeded_var_names=input_names, feeded_var_names=input_names,
target_vars=outputs, target_vars=outputs,
executor=exe, executor=exe,
params_filename="__params__") params_filename=None)
except: except:
raise Exception( raise Exception(
"Paddle code was saved in {}/model.py, but seems there's wrong exist, please check model.py manually." "Paddle code was saved in {}/model.py, but seems there's wrong exist, please check model.py manually."
...@@ -142,9 +142,9 @@ class OpMapper(object): ...@@ -142,9 +142,9 @@ class OpMapper(object):
self.add_codes("\ndef x2paddle_net():", 0) self.add_codes("\ndef x2paddle_net():", 0)
for i in range(len(self.graph.topo_sort)): for i in range(len(self.graph.topo_sort)):
node_name = self.graph.topo_sort[i] node_name = self.graph.topo_sort[i]
if hasattr(self, "omit_nodes") and node_name in self.omit_nodes:
continue
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
if len(node.fluid_code.layers) == 0:
continue
self.add_codes(node.fluid_code.gen_codes(), 1) self.add_codes(node.fluid_code.gen_codes(), 1)
self.add_codes("", 0) self.add_codes("", 0)
......
...@@ -24,7 +24,7 @@ import sys ...@@ -24,7 +24,7 @@ import sys
class TFGraphNode(GraphNode): class TFGraphNode(GraphNode):
def __init__(self, layer, layer_name=None): def __init__(self, layer, layer_name=None, data_format="NHWC"):
if layer_name is None: if layer_name is None:
super(TFGraphNode, super(TFGraphNode,
self).__init__(layer, self).__init__(layer,
...@@ -35,6 +35,8 @@ class TFGraphNode(GraphNode): ...@@ -35,6 +35,8 @@ class TFGraphNode(GraphNode):
layer_name.replace('/', '_').replace('-', '_')) layer_name.replace('/', '_').replace('-', '_'))
self.layer_type = layer.op self.layer_type = layer.op
self.tf_data_format = data_format
self.pd_data_format = "NCHW"
self.fluid_code = FluidCode() self.fluid_code = FluidCode()
self.dtype_map = {1: "float32", 3: "int32", 4: "int8", 9: "int64"} self.dtype_map = {1: "float32", 3: "int32", 4: "int8", 9: "int64"}
...@@ -86,15 +88,16 @@ class TFGraphNode(GraphNode): ...@@ -86,15 +88,16 @@ class TFGraphNode(GraphNode):
class TFGraph(Graph): class TFGraph(Graph):
def __init__(self, model): def __init__(self, model, data_format="NHWC"):
super(TFGraph, self).__init__(model) super(TFGraph, self).__init__(model)
self.identity_map = dict() self.identity_map = dict()
self.multi_out_ops = ['Split', 'SplitV'] self.multi_out_ops = ['Split', 'SplitV']
self.tf_data_format = data_format
def build(self): def build(self):
for layer in self.model.node: for layer in self.model.node:
self.node_map[layer.name.replace('/', '_').replace( self.node_map[layer.name.replace('/', '_').replace(
'-', '_')] = TFGraphNode(layer) '-', '_')] = TFGraphNode(layer, data_format=self.tf_data_format)
for layer_name, node in self.node_map.items(): for layer_name, node in self.node_map.items():
for in_node in node.layer.input: for in_node in node.layer.input:
...@@ -126,6 +129,26 @@ class TFGraph(Graph): ...@@ -126,6 +129,26 @@ class TFGraph(Graph):
node.index = 0 node.index = 0
return node return node
def remove_node(self, node_name):
if node_name not in self.node_map:
raise Exception("Node[{}] not in graph".format(node_name))
inputs = self.node_map[node_name].inputs
outputs = self.node_map[node_name].outputs
assert len(inputs) == 1
input_node = self.node_map[inputs[0]]
idx = input_node.outputs.index(node_name)
del input_node.outputs[idx]
for output in outputs:
node = self.node_map[output]
idx = node.inputs.index(node_name)
node.inputs[idx] = inputs[0]
input_node.outputs.append(output)
del self.node_map[node_name]
idx = self.topo_sort.index(node_name)
del self.topo_sort[idx]
def _remove_isolated_node(self): def _remove_isolated_node(self):
# delete isolated nodes # delete isolated nodes
isolated_nodes = list() isolated_nodes = list()
...@@ -135,7 +158,15 @@ class TFGraph(Graph): ...@@ -135,7 +158,15 @@ class TFGraph(Graph):
isolated_nodes.append(node_name) isolated_nodes.append(node_name)
for node_name in isolated_nodes: for node_name in isolated_nodes:
self.remove_node(node_name) del self.node_map[node_name]
if node_name in self.input_nodes:
idx = self.input_nodes.index(node_name)
del self.input_nodes[idx]
if node_name in self.output_nodes:
idx = self.output_nodes.index(node_name)
del self.output_nodes[idx]
idx = self.topo_sort.index(node_name)
del self.topo_sort[idx]
def _remove_identity_node(self): def _remove_identity_node(self):
identity_node = list() identity_node = list()
...@@ -145,30 +176,47 @@ class TFGraph(Graph): ...@@ -145,30 +176,47 @@ class TFGraph(Graph):
for node_name in identity_node: for node_name in identity_node:
node = self.get_node(node_name) node = self.get_node(node_name)
# Remind: Only 1 input for Identity node
input_node = self.get_node(node.inputs[0]) input_node = self.get_node(node.inputs[0])
self.remove_node(node_name)
# remove identity node from graph
self.identity_map[node_name] = input_node.layer_name self.identity_map[node_name] = input_node.layer_name
idx = input_node.outputs.index(node_name)
del input_node.outputs[idx]
output_names = node.outputs # node = self.get_node(node_name)
for output_name in output_names: # # Remind: Only 1 input for Identity node
output_node = self.get_node(output_name) # input_node = self.get_node(node.inputs[0])
idx = output_node.inputs.index(node_name) #
output_node.inputs[idx] = input_node.layer_name # # remove identity node from graph
# self.identity_map[node_name] = input_node.layer_name
idx = self.topo_sort.index(node_name) # idx = input_node.outputs.index(node_name)
del self.topo_sort[idx] # del input_node.outputs[idx]
#
# output_names = node.outputs
# for output_name in output_names:
# output_node = self.get_node(output_name)
# idx = output_node.inputs.index(node_name)
# output_node.inputs[idx] = input_node.layer_name
#
# idx = self.topo_sort.index(node_name)
# del self.topo_sort[idx]
if node_name in self.output_nodes: if node_name in self.output_nodes:
idx = self.output_nodes.index(node_name) idx = self.output_nodes.index(node_name)
self.output_nodes[idx] = input_node.layer_name self.output_nodes[idx] = input_node.layer_name
def data_format_propagation(self, node):
current_node = self.node_map[node.layer_name]
current_node = node.tf_data_format
outputs = current_node.outputs
if len(outputs) == 0:
return
for out in outputs:
next_node = self.node_map[out]
next_node.tf_data_format = node.tf_data_format
self.data_format_propagation(next_node)
class TFDecoder(object): class TFDecoder(object):
def __init__(self, pb_model): def __init__(self, pb_model, data_format="NHWC"):
self.sess = tf.Session() self.sess = tf.Session()
self.input_info = dict() self.input_info = dict()
with gfile.FastGFile(pb_model, 'rb') as f: with gfile.FastGFile(pb_model, 'rb') as f:
...@@ -186,7 +234,7 @@ class TFDecoder(object): ...@@ -186,7 +234,7 @@ class TFDecoder(object):
self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.global_variables_initializer())
self.tf_graph = TFGraph( self.tf_graph = TFGraph(
self.sess.graph._as_graph_def(add_shapes=True)[0]) self.sess.graph._as_graph_def(add_shapes=True)[0], data_format)
self.tf_graph.build() self.tf_graph.build()
def _fix_output_shape(self, graph): def _fix_output_shape(self, graph):
......
此差异已折叠。
...@@ -13,10 +13,95 @@ ...@@ -13,10 +13,95 @@
# limitations under the License. # limitations under the License.
# TODO useless node remove # TODO useless node remove
from x2paddle.decoder.tf_decoder import TFGraph from x2paddle.op_mapper.tf_op_mapper import TFOpMapper
from x2paddle.core.util import *
# TODO bn merge
# TODO activation merge class TFOptimizer(object):
activation_ops = {
'Relu': 'relu',
'Sigmoid': 'sigmoid',
'Relu6': 'relu6',
'swish_f32': 'swish'
}
layers_with_act = [
'Conv2D', 'BiasAdd', 'DepthwiseConv2dNative', 'Conv2DBackpropInput',
'FusedBatchNorm'
]
layers_with_bias = [
'Conv2D', 'DepthwiseConv2dNative', 'Conv2DBackpropInput'
]
# TODO biasadd merge def __init__(self, op_mapper):
self.op_mapper = op_mapper
self.graph = op_mapper.graph
def delete_redundance_code(self):
for node_name in self.graph.topo_sort:
if node_name in self.op_mapper.omit_nodes:
node = self.graph.get_node(node_name)
omit_freq = self.op_mapper.omit_nodes.count(node_name)
if len(node.outputs) <= omit_freq:
node.fluid_code.clear()
# TODO activation merge
def merge_activation(self):
act_nodes = list()
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
if node.layer_type in self.activation_ops:
act_nodes.append(node_name)
for act_node_name in act_nodes:
node = self.graph.get_node(act_node_name)
input = self.graph.get_node(node.inputs[0])
if input.layer_type not in self.layers_with_act:
continue
if len(input.fluid_code.layers) == 0:
continue
if 'act' in input.fluid_code.layers[
-1].param_attr and input.fluid_code.layers[-1].param_attr[
'act'] is not None:
continue
if len(input.outputs) != 1:
continue
input.fluid_code.layers[-1].param_attr['act'] = string(
self.activation_ops[node.layer_type])
input.fluid_code.layers[-1].output = node.fluid_code.layers[
0].output
self.graph.remove_node(act_node_name)
# TODO bias merge
def merge_bias(self):
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
if node.layer_type == "BiasAdd":
input = self.graph.get_node(node.inputs[0])
if input.layer_type not in self.layers_with_bias:
continue
if len(input.outputs) != 1:
continue
if len(input.fluid_code.layers) == 0:
continue
bias_with_act = False
if 'act' in node.fluid_code.layers[-1].param_attr:
bias_with_act = True
layer_with_act = False
if 'act' in input.fluid_code.layers[
-1].param_attr and input.fluid_code.layers[
-1].param_attr['act'] is not None:
layer_with_act = True
if bias_with_act and layer_with_act:
continue
if not input.fluid_code.layers[-1].param_attr['bias_attr']:
bias_name = node.inputs[1]
input.fluid_code.layers[-1].param_attr[
'bias_attr'] = string(bias_name)
input.fluid_code.layers[-1].output = node.fluid_code.layers[
0].output
if bias_with_act:
input.fluid_code.layers[-1].param_attr[
'act'] = node.fluid_code.layers[-1].param_attr[
'act']
node.fluid_code.clear()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册