未验证 提交 e3b4b14d 编写于 作者: S SunAhong1993 提交者: GitHub

Merge pull request #2 from PaddlePaddle/develop

Develop
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from x2paddle.core.graph import GraphNode from x2paddle.core.graph import GraphNode
import collections
class Layer(object): class Layer(object):
...@@ -36,25 +37,34 @@ class Layer(object): ...@@ -36,25 +37,34 @@ class Layer(object):
if isinstance(self.inputs, list): if isinstance(self.inputs, list):
in_list = "[" in_list = "["
for input in self.inputs: for input in self.inputs:
assert isinstance( if isinstance(input, GraphNode):
input, GraphNode), "Type of input should be GraphNode"
if hasattr(input, "index"): if hasattr(input, "index"):
in_list += (input.layer_name + "[{}]".format(input.index) + in_list += (input.layer_name +
", ") "[{}]".format(input.index) + ", ")
else: else:
in_list += (input.layer_name + ", ") in_list += (input.layer_name + ", ")
elif isinstance(input, str):
in_list += (input + ", ")
else:
raise Exception(
"Element of inputs should GraphNode or String")
in_list = in_list.strip(", ") + "], " in_list = in_list.strip(", ") + "], "
layer_code += in_list layer_code += in_list
elif isinstance(self.inputs, dict): elif isinstance(self.inputs, dict):
for key, input in self.inputs.items(): inputs = collections.OrderedDict(self.inputs)
assert isinstance( for key, input in inputs.items():
input, GraphNode), "Type of input should be GraphNode" if isinstance(input, GraphNode):
if hasattr(input, "index"): if hasattr(input, "index"):
layer_code = layer_code + key + "={}, ".format( layer_code = layer_code + key + "={}, ".format(
input.layer_name + "[{}]".format(input.index)) input.layer_name + "[{}]".format(input.index))
else: else:
layer_code = layer_code + key + "={}, ".format( layer_code = layer_code + key + "={}, ".format(
input.layer_name) input.layer_name)
elif isinstance(input, str):
layer_code = layer_code + key + "={}, ".format(input)
else:
raise Exception(
"Element of inputs should GraphNode or String")
elif isinstance(self.inputs, GraphNode): elif isinstance(self.inputs, GraphNode):
if hasattr(self.inputs, "index"): if hasattr(self.inputs, "index"):
layer_code += (self.inputs.layer_name + layer_code += (self.inputs.layer_name +
...@@ -66,7 +76,8 @@ class Layer(object): ...@@ -66,7 +76,8 @@ class Layer(object):
else: else:
raise Exception("Unknown type of inputs.") raise Exception("Unknown type of inputs.")
for key, value in self.param_attr.items(): param_attr = collections.OrderedDict(self.param_attr)
for key, value in param_attr.items():
layer_code = layer_code + key + "={}, ".format(value) layer_code = layer_code + key + "={}, ".format(value)
layer_code = layer_code.strip(", ") layer_code = layer_code.strip(", ")
...@@ -97,7 +108,8 @@ class Layer(object): ...@@ -97,7 +108,8 @@ class Layer(object):
else: else:
raise Exception("Unknown type of inputs.") raise Exception("Unknown type of inputs.")
for key, value in self.param_attr.items(): param_attr = collections.OrderedDict(self.param_attr)
for key, value in param_attr.items():
layer_code = layer_code + key + "={}, ".format(value) layer_code = layer_code + key + "={}, ".format(value)
layer_code = layer_code.strip(", ") layer_code = layer_code.strip(", ")
......
...@@ -176,8 +176,9 @@ class TFDecoder(object): ...@@ -176,8 +176,9 @@ class TFDecoder(object):
self.sess.graph.as_default() self.sess.graph.as_default()
tf.import_graph_def(graph_def, name='', input_map=input_map) tf.import_graph_def(graph_def, name='', input_map=input_map)
for node in graph_def.node:
print(node.name, node.op, node.input) # for node in graph_def.node:
# print(node.name, node.op, node.input)
self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.global_variables_initializer())
......
...@@ -19,6 +19,31 @@ import numpy ...@@ -19,6 +19,31 @@ import numpy
class TFOpMapper(OpMapper): class TFOpMapper(OpMapper):
directly_map_ops = {
'Relu': ['relu'],
'Relu6': ['relu6'],
'Shape': ['shape'],
'Abs': ['abs'],
'Sigmoid': ['sigmoid'],
'Exp': ['exp'],
'Rsqrt': ['rsqrt'],
'Squeeze': ['squeeze', {
'squeeze_dims': 'axes'
}],
'Softmax': ['softmax', {
'axis': 'axis'
}],
}
elementwise_ops = {
'Add': 'elementwise_add',
'RealDiv': 'elementwise_div',
'BiasAdd': 'elementwise_add',
'Sub': 'elementwise_sub',
'Maximum': 'elementwise_max',
'Mul': 'elementwise_mul'
}
def __init__(self, decoder): def __init__(self, decoder):
super(TFOpMapper, self).__init__() super(TFOpMapper, self).__init__()
self.decoder = decoder self.decoder = decoder
...@@ -30,15 +55,20 @@ class TFOpMapper(OpMapper): ...@@ -30,15 +55,20 @@ class TFOpMapper(OpMapper):
print("Total nodes: {}".format(len(self.graph.topo_sort))) print("Total nodes: {}".format(len(self.graph.topo_sort)))
# check if ops in model are all supported # check if ops in model are all supported
if not self.op_checker(): # TODO
raise Exception("Model are not supported yet.")
for node_name in self.graph.topo_sort: for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
op = node.layer_type op = node.layer_type
if hasattr(self, op): if op in self.directly_map_ops:
self.directly_map(node)
elif op in self.elementwise_ops:
self.elementwise_map(node)
elif hasattr(self, op):
func = getattr(self, op) func = getattr(self, op)
func(node) func(node)
else:
raise Exception("OP: [{}] not support yet".format(op))
for i in range(len(self.graph.topo_sort)): for i in range(len(self.graph.topo_sort)):
node_name = self.graph.topo_sort[i] node_name = self.graph.topo_sort[i]
...@@ -47,7 +77,24 @@ class TFOpMapper(OpMapper): ...@@ -47,7 +77,24 @@ class TFOpMapper(OpMapper):
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
self.net_code += node.fluid_code.gen_codes() self.net_code += node.fluid_code.gen_codes()
def elementwise_operator(self, node, op_type): def directly_map(self, node):
assert node.layer_type in self.directly_map_ops
op_info = self.directly_map_ops[node.layer_type]
input = self.graph.get_node(node.layer.input[0], copy=True)
attr = dict()
for param in op_info[1:]:
tf_param_name = list(param.keys())[0]
pd_param_name = list(param.values())[0]
tf_param = node.get_attr(tf_param_name)
attr[pd_param_name] = tf_param
node.fluid_code.add_layer(op_info[0],
inputs=input,
output=node,
param_attr=attr)
def elementwise_map(self, node):
assert node.layer_type in self.elementwise_ops
op_type = self.elementwise_ops[node.layer_type]
x = self.graph.get_node(node.layer.input[0], copy=True) x = self.graph.get_node(node.layer.input[0], copy=True)
y = self.graph.get_node(node.layer.input[1], copy=True) y = self.graph.get_node(node.layer.input[1], copy=True)
x_shape = x.out_shapes[0] x_shape = x.out_shapes[0]
...@@ -161,41 +208,6 @@ class TFOpMapper(OpMapper): ...@@ -161,41 +208,6 @@ class TFOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
def RealDiv(self, node):
self.elementwise_operator(node, "elementwise_div")
def Relu(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("relu",
inputs=input,
output=node,
param_attr=None)
def Squeeze(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
squeeze_dims = node.get_attr('squeeze_dims')
attr = {'axes': squeeze_dims}
node.fluid_code.add_layer("squeeze",
inputs=input,
output=node,
param_attr=attr)
def BiasAdd(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
bias = self.graph.get_node(node.layer.input[1], copy=True)
inputs = {'x': input, 'y': bias}
node.fluid_code.add_layer("elementwise_add",
inputs=inputs,
output=node,
param_attr=None)
def Identity(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("assign",
inputs=input,
output=node,
param_attr=None)
def MaxPool(self, node): def MaxPool(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
...@@ -314,13 +326,6 @@ class TFOpMapper(OpMapper): ...@@ -314,13 +326,6 @@ class TFOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
def Relu6(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("relu6",
inputs=input,
output=node,
param_attr=None)
def FusedBatchNorm(self, node): def FusedBatchNorm(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
gamma = self.graph.get_node(node.layer.input[1], copy=True) gamma = self.graph.get_node(node.layer.input[1], copy=True)
...@@ -433,13 +438,6 @@ class TFOpMapper(OpMapper): ...@@ -433,13 +438,6 @@ class TFOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
def Shape(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("shape",
inputs=input,
output=node,
param_attr=None)
def Reshape(self, node): def Reshape(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
param = self.graph.get_node(node.layer.input[1], copy=True) param = self.graph.get_node(node.layer.input[1], copy=True)
...@@ -474,27 +472,6 @@ class TFOpMapper(OpMapper): ...@@ -474,27 +472,6 @@ class TFOpMapper(OpMapper):
inputs=input, inputs=input,
output=node, output=node,
param_attr=attr) param_attr=attr)
# temporary shape inference fix
# if param.layer_type == "Pack":
# shape_slices = list()
# for i in range(len(param.layer.input)):
# slice = self.graph.get_node(param.layer.input[i], copy=True)
# if slice.layer_type == "Const":
# shape_slices.append(slice.value.tolist())
# else:
# shape_slices.append(0)
# if shape_slices.count(-1) == 0:
# shape_slices[shape_slices.index(0)] = -1
# attr = {"shape": shape_slices}
# node.fluid_code.add_layer("reshape",
# inputs=node,
# output=node,
# param_attr=attr)
def Add(self, node):
self.elementwise_operator(node, "elementwise_add")
def AvgPool(self, node): def AvgPool(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
...@@ -542,23 +519,6 @@ class TFOpMapper(OpMapper): ...@@ -542,23 +519,6 @@ class TFOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
def Softmax(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("softmax",
inputs=input,
output=node,
param_attr=None)
def Sigmoid(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("sigmoid",
inputs=input,
output=node,
param_attr=None)
def Maximum(self, node):
self.elementwise_operator(node, "elementwise_max")
def SplitV(self, node): def SplitV(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
num_sections = self.graph.get_node(node.layer.input[1], copy=True) num_sections = self.graph.get_node(node.layer.input[1], copy=True)
...@@ -576,13 +536,6 @@ class TFOpMapper(OpMapper): ...@@ -576,13 +536,6 @@ class TFOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
def Exp(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("exp",
inputs=input,
output=node,
param_attr=None)
def ConcatV2(self, node): def ConcatV2(self, node):
inputs = [ inputs = [
self.graph.get_node(name, copy=True) self.graph.get_node(name, copy=True)
...@@ -649,19 +602,6 @@ class TFOpMapper(OpMapper): ...@@ -649,19 +602,6 @@ class TFOpMapper(OpMapper):
output=node, output=node,
param_attr=None) param_attr=None)
def Mul(self, node):
self.elementwise_operator(node, "elementwise_mul")
def Sub(self, node):
self.elementwise_operator(node, "elementwise_sub")
def Rsqrt(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("rsqrt",
inputs=input,
output=node,
param_attr=None)
def swish_f32(self, node): def swish_f32(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("sigmoid", node.fluid_code.add_layer("sigmoid",
...@@ -765,13 +705,6 @@ class TFOpMapper(OpMapper): ...@@ -765,13 +705,6 @@ class TFOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
def Abs(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("abs",
inputs=input,
output=node,
param_attr=None)
def Conv2DBackpropInput(self, node): def Conv2DBackpropInput(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True) kernel = self.graph.get_node(node.layer.input[1], copy=True)
......
...@@ -15,56 +15,8 @@ ...@@ -15,56 +15,8 @@
# TODO useless node remove # TODO useless node remove
from x2paddle.decoder.tf_decoder import TFGraph from x2paddle.decoder.tf_decoder import TFGraph
# TODO bn merge
class TFGraphOptimizer(object): # TODO activation merge
def __init__(self):
print("Doint Nothing")
def remove_isolated_node(self, graph): # TODO biasadd merge
# delete isolated nodes
isolated_nodes = list()
for node_name in graph.node_map.keys():
if len(graph.get_node(node_name).inputs) == 0 or len(
graph.get_node(node_name).outputs) == 0:
isolated_nodes.append(node_name)
graph.remove_node(node_name)
def remove_identity_node(self, graph):
identity_node = list()
for node_name, node in graph.node_map.items():
if node.layer_type == "Identity":
identity_node.append(node_name)
for node_name in identity_node:
node = graph.get_node(node_name)
# Remind: Only 1 input for Identity node
input_node = graph.get_node(node.inputs[0])
# remove identity node from graph
idx = input_node.outputs.index(node_name)
del input_node.outputs[idx]
output_names = node.outputs
for output_name in output_names:
output_node = graph.get_node(output_name)
idx = output_node.inputs.index(node_name)
output_node.inputs[idx] = input_node.layer_name
idx = graph.topo_sort.index(node_name)
del graph.topo_sort[idx]
def run(self, graph):
self.remove_isolated_node(graph)
self.remove_identity_node(graph)
# TODO identity node remove
# TODO subgraph optimize
# TODO compute optimize
# activation merge
# biasadd merge
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册