提交 b03ff6ea 编写于 作者: J jiangjiajun

more op support for tensorflow

上级 8a93b96d
...@@ -25,16 +25,18 @@ import sys ...@@ -25,16 +25,18 @@ import sys
class TFGraphNode(GraphNode): class TFGraphNode(GraphNode):
def __init__(self, layer, layer_name=None): def __init__(self, layer, layer_name=None):
if layer_name is None: if layer_name is None:
super(TFGraphNode, self).__init__(layer, super(TFGraphNode,
layer.name.replace('/', '_').replace('-', '_')) self).__init__(layer,
layer.name.replace('/', '_').replace('-', '_'))
else: else:
super(TFGraphNode, self).__init__(layer, super(TFGraphNode,
layer_name.replace('/', '_').replace('-', '_')) self).__init__(layer,
layer_name.replace('/', '_').replace('-', '_'))
self.layer_type = layer.op self.layer_type = layer.op
self.fluid_code = FluidCode() self.fluid_code = FluidCode()
self.dtype_map = {1: "float32", 3: "int32", 9: "int64"} self.dtype_map = {1: "float32", 3: "int32", 4: "int8", 9: "int64"}
@property @property
def out_shapes(self): def out_shapes(self):
...@@ -89,7 +91,8 @@ class TFGraph(Graph): ...@@ -89,7 +91,8 @@ class TFGraph(Graph):
def build(self): def build(self):
for layer in self.model.node: for layer in self.model.node:
self.node_map[layer.name.replace('/', '_').replace('-', '_')] = TFGraphNode(layer) self.node_map[layer.name.replace('/', '_').replace(
'-', '_')] = TFGraphNode(layer)
for layer_name, node in self.node_map.items(): for layer_name, node in self.node_map.items():
for in_node in node.layer.input: for in_node in node.layer.input:
...@@ -164,9 +167,11 @@ def check_input_shape(graph_def): ...@@ -164,9 +167,11 @@ def check_input_shape(graph_def):
graph_node = TFGraphNode(layer) graph_node = TFGraphNode(layer)
dtype = graph_node.dtype dtype = graph_node.dtype
if not graph_node.get_attr("shape"): if not graph_node.get_attr("shape"):
sys.stderr.write("\nUnknown shape for input tensor[tensor name: \"{}\"]\n".format( sys.stderr.write(
layer.name)) "\nUnknown shape for input tensor[tensor name: \"{}\"]\n".
shape = input("Please define shape of input here(e.g. None,224,224,3): ") format(layer.name))
shape = input(
"Please define shape of input here(e.g. None,224,224,3): ")
shape = [ shape = [
None if dim == "None" else int(dim) None if dim == "None" else int(dim)
for dim in shape.strip().split(',') for dim in shape.strip().split(',')
...@@ -186,6 +191,7 @@ class TFDecoder(object): ...@@ -186,6 +191,7 @@ class TFDecoder(object):
graph_def = tf.GraphDef() graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read()) graph_def.ParseFromString(f.read())
input_map = check_input_shape(graph_def) input_map = check_input_shape(graph_def)
self._fix_output_shape(graph_def)
sess.graph.as_default() sess.graph.as_default()
tf.import_graph_def(graph_def, name='', input_map=input_map) tf.import_graph_def(graph_def, name='', input_map=input_map)
...@@ -193,3 +199,9 @@ class TFDecoder(object): ...@@ -193,3 +199,9 @@ class TFDecoder(object):
self.tf_graph = TFGraph(sess.graph._as_graph_def(add_shapes=True)[0]) self.tf_graph = TFGraph(sess.graph._as_graph_def(add_shapes=True)[0])
self.tf_graph.build() self.tf_graph.build()
def _fix_output_shape(self, graph):
for i in range(len(graph.node)):
node = graph.node[i]
if node.op == "swish_f32":
graph.node[i].attr['_disable_call_shape_inference'].b = False
...@@ -54,7 +54,8 @@ class TFOpMapper(OpMapper): ...@@ -54,7 +54,8 @@ class TFOpMapper(OpMapper):
attr = { attr = {
'dtype': string(dtype), 'dtype': string(dtype),
'shape': shape, 'shape': shape,
'name': string(node.layer_name) 'name': string(node.layer_name),
'append_batch_size': False
} }
node.fluid_code.add_layer("data", node.fluid_code.add_layer("data",
inputs=None, inputs=None,
...@@ -429,7 +430,10 @@ class TFOpMapper(OpMapper): ...@@ -429,7 +430,10 @@ class TFOpMapper(OpMapper):
def Sigmoid(self, node): def Sigmoid(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("sigmoid", inputs=input, output=node, param_attr=None) node.fluid_code.add_layer("sigmoid",
inputs=input,
output=node,
param_attr=None)
def Maximum(self, node): def Maximum(self, node):
x = self.graph.get_node(node.layer.input[0], copy=True) x = self.graph.get_node(node.layer.input[0], copy=True)
...@@ -448,20 +452,35 @@ class TFOpMapper(OpMapper): ...@@ -448,20 +452,35 @@ class TFOpMapper(OpMapper):
assert dim.layer_type == "Const" assert dim.layer_type == "Const"
self.omit_nodes.append(num_sections.layer_name) self.omit_nodes.append(num_sections.layer_name)
self.omit_nodes.append(dim.layer_name) self.omit_nodes.append(dim.layer_name)
attr = {"num_or_sections":num_sections.value.tolist(), "dim":dim.value} attr = {
node.fluid_code.add_layer("split", inputs=input, output=node, param_attr=attr) "num_or_sections": num_sections.value.tolist(),
"dim": dim.value
}
node.fluid_code.add_layer("split",
inputs=input,
output=node,
param_attr=attr)
def Exp(self, node): def Exp(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("exp", inputs=input, output=node, param_attr=None) node.fluid_code.add_layer("exp",
inputs=input,
output=node,
param_attr=None)
def ConcatV2(self, node): def ConcatV2(self, node):
inputs = [self.graph.get_node(name, copy=True) for name in node.layer.input[:-1]] inputs = [
self.graph.get_node(name, copy=True)
for name in node.layer.input[:-1]
]
axis = self.graph.get_node(node.layer.input[-1], copy=True) axis = self.graph.get_node(node.layer.input[-1], copy=True)
assert axis.layer_type == "Const" assert axis.layer_type == "Const"
self.omit_nodes.append(axis.layer_name) self.omit_nodes.append(axis.layer_name)
attr = {"axis": axis.value} attr = {"axis": axis.value}
node.fluid_code.add_layer("concat", inputs=inputs, output=node, param_attr=attr) node.fluid_code.add_layer("concat",
inputs=inputs,
output=node,
param_attr=attr)
def Tile(self, node): def Tile(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
...@@ -469,11 +488,19 @@ class TFOpMapper(OpMapper): ...@@ -469,11 +488,19 @@ class TFOpMapper(OpMapper):
assert expand_times.layer_type == "Const" assert expand_times.layer_type == "Const"
self.omit_nodes.append(expand_times.layer_name) self.omit_nodes.append(expand_times.layer_name)
attr = {"expand_times": expand_times.value.tolist()} attr = {"expand_times": expand_times.value.tolist()}
node.fluid_code.add_layer("expand", inputs=input, output=node, param_attr=attr) node.fluid_code.add_layer("expand",
inputs=input,
output=node,
param_attr=attr)
def Pack(self, node): def Pack(self, node):
inputs = [self.graph.get_node(name, copy=True) for name in node.layer.input[:-1]] inputs = [
node.fluid_code.add_layer("stack", inputs=inputs, output=node, param_attr=None) self.graph.get_node(name, copy=True) for name in node.layer.input
]
node.fluid_code.add_layer("stack",
inputs=inputs,
output=node,
param_attr=None)
def Pad(self, node): def Pad(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
...@@ -481,7 +508,10 @@ class TFOpMapper(OpMapper): ...@@ -481,7 +508,10 @@ class TFOpMapper(OpMapper):
assert paddings.layer_type == "Const", "Padding should be Const" assert paddings.layer_type == "Const", "Padding should be Const"
self.omit_nodes.append(paddings.layer_name) self.omit_nodes.append(paddings.layer_name)
attr = {"paddings": paddings.value.tolist()} attr = {"paddings": paddings.value.tolist()}
node.fluid_code.add_layer("pad", inputs=input, output=node, param_attr=attr) node.fluid_code.add_layer("pad",
inputs=input,
output=node,
param_attr=attr)
# def ResizeNearestNeighbor(self, node): # def ResizeNearestNeighbor(self, node):
# pass # pass
...@@ -499,9 +529,13 @@ class TFOpMapper(OpMapper): ...@@ -499,9 +529,13 @@ class TFOpMapper(OpMapper):
if delta.layer_type == "Const": if delta.layer_type == "Const":
self.omit_nodes.append(delta.layer_name) self.omit_nodes.append(delta.layer_name)
delta = delta.value delta = delta.value
inputs = {"start": start, "end":limit, "step":delta} inputs = {"start": start, "end": limit, "step": delta}
attr = {"dtype": string(node.dtype)} attr = {"dtype": string(node.dtype)}
node.fluid_code.append("range", inputs=inputs, output=node, param_attr=None) node.fluid_code.append("range",
inputs=inputs,
output=node,
param_attr=None)
# def Fill(self, node): # def Fill(self, node):
# shape = self.graph.get_node(node.layer # shape = self.graph.get_node(node.layer
...@@ -524,3 +558,75 @@ class TFOpMapper(OpMapper): ...@@ -524,3 +558,75 @@ class TFOpMapper(OpMapper):
output=node, output=node,
param_attr=None) param_attr=None)
def Rsqrt(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("rsqrt",
inputs=input,
output=node,
param_attr=None)
def swish_f32(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("sigmoid",
inputs=input,
output=node,
param_attr=None)
inputs = {"x": input, "y": node}
node.fluid_code.add_layer("elementwise_mul",
inputs=inputs,
output=node,
param_attr=None)
def Mean(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
keep_dims = node.get_attr("keep_dims")
attr = {"dim": reduce_idx.value.tolist(), "keep_dim": keep_dims}
node.fluid_code.add_layer("reduce_mean",
inputs=input,
output=node,
param_attr=attr)
def MatMul(self, node):
x = self.graph.get_node(node.layer.input[0], copy=True)
y = self.graph.get_node(node.layer.input[1], copy=True)
transpose_a = node.get_attr('transpose_a')
transpose_b = node.get_attr('transpose_b')
inputs = {"x": x, "y": y}
attr = {"transpose_x": transpose_a, "transpose_y": transpose_b}
node.fluid_code.add_layer("matmul",
inputs=inputs,
output=node,
param_attr=attr)
def ArgMax(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
axis = self.graph.get_node(node.layer.input[1], copy=True)
assert axis.layer_type == "Const", "ArgMax only support Const parameter"
self.omit_nodes.append(axis.layer_name)
attr = {"axis": axis.value}
node.fluid_code.add_layer("argmax",
inputs=input,
output=node,
param_attr=attr)
def StridedSlice(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
begin = self.graph.get_node(node.layer.input[1], copy=True)
end = self.graph.get_node(node.layer.input[2], copy=True)
strides = self.graph.get_node(node.layer.input[3], copy=True)
assert begin.layer_type == "Const"
assert end.layer_type == "Const"
assert strides.layer_type == "Const"
self.omit_nodes.append(begin.layer_name)
self.omit_nodes.append(end.layer_name)
self.omit_nodes.append(strides.layer_name)
strides = strides.value.tolist()
assert len(set(strides)) == 1 and strides[0] == 1
attr = {"starts": begin.value.tolist(), "ends": end.value.tolist()}
node.fluid_code.add_layer("slice",
inputs=input,
output=node,
param_attr=attr)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册