提交 8a93b96d 编写于 作者: J jiangjiajun

more op mapper

上级 8e27ee0e
......@@ -110,7 +110,6 @@ class Graph(object):
del self.node_map[input].inputs[idx]
del self.node_map[node_name]
print("remove topo", node_name)
idx = self.topo_sort.index(node_name)
del self.topo_sort[idx]
......
......@@ -26,10 +26,10 @@ class TFGraphNode(GraphNode):
def __init__(self, layer, layer_name=None):
if layer_name is None:
super(TFGraphNode, self).__init__(layer,
layer.name.replace('/', '_'))
layer.name.replace('/', '_').replace('-', '_'))
else:
super(TFGraphNode, self).__init__(layer,
layer_name.replace('/', '_'))
layer_name.replace('/', '_').replace('-', '_'))
self.layer_type = layer.op
self.fluid_code = FluidCode()
......@@ -89,11 +89,11 @@ class TFGraph(Graph):
def build(self):
for layer in self.model.node:
self.node_map[layer.name.replace('/', '_')] = TFGraphNode(layer)
self.node_map[layer.name.replace('/', '_').replace('-', '_')] = TFGraphNode(layer)
for layer_name, node in self.node_map.items():
for in_node in node.layer.input:
in_node = in_node.replace('/', '_')
in_node = in_node.replace('/', '_').replace('-', '_')
if in_node not in self.node_map:
if in_node.strip().split(':')[0] in self.node_map:
self.connect(in_node.strip().split(':')[0], layer_name)
......@@ -112,7 +112,7 @@ class TFGraph(Graph):
def get_node(self, node_name, copy=False):
items = node_name.strip().split(':')
items[0] = items[0].replace('/', '_')
items[0] = items[0].replace('/', '_').replace('-', '_')
if items[0] in self.identity_map:
items[0] = self.identity_map[items[0]]
new_node_name = ":".join(items)
......@@ -163,11 +163,10 @@ def check_input_shape(graph_def):
continue
graph_node = TFGraphNode(layer)
dtype = graph_node.dtype
# print("shape:", graph_node.out_shapes)
if not graph_node.get_attr("shape"):
sys.stderr.write("Unknown shape for input tensor[{}]\n".format(
sys.stderr.write("\nUnknown shape for input tensor[tensor name: \"{}\"]\n".format(
layer.name))
shape = input("Please define shape of input here: ")
shape = input("Please define shape of input here(e.g. None,224,224,3): ")
shape = [
None if dim == "None" else int(dim)
for dim in shape.strip().split(',')
......
......@@ -350,6 +350,7 @@ class TFOpMapper(OpMapper):
param = self.graph.get_node(node.layer.input[1], copy=True)
if param.layer_type == "Const":
attr = {"shape": param.value.tolist()}
self.omit_nodes.append(param.layer_name)
else:
# Here is a trick method to solove tensor parameter in tensorflow
assert len(param.out_shapes[0]
......@@ -425,3 +426,101 @@ class TFOpMapper(OpMapper):
inputs=input,
output=node,
param_attr=None)
def Sigmoid(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("sigmoid", inputs=input, output=node, param_attr=None)
def Maximum(self, node):
x = self.graph.get_node(node.layer.input[0], copy=True)
y = self.graph.get_node(node.layer.input[1], copy=True)
inputs = {"x": x, "y": y}
node.fluid_code.add_layer("elementwise_max",
inputs=inputs,
output=node,
param_attr=None)
def SplitV(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
num_sections = self.graph.get_node(node.layer.input[1], copy=True)
dim = self.graph.get_node(node.layer.input[2], copy=True)
assert num_sections.layer_type == "Const"
assert dim.layer_type == "Const"
self.omit_nodes.append(num_sections.layer_name)
self.omit_nodes.append(dim.layer_name)
attr = {"num_or_sections":num_sections.value.tolist(), "dim":dim.value}
node.fluid_code.add_layer("split", inputs=input, output=node, param_attr=attr)
def Exp(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("exp", inputs=input, output=node, param_attr=None)
def ConcatV2(self, node):
inputs = [self.graph.get_node(name, copy=True) for name in node.layer.input[:-1]]
axis = self.graph.get_node(node.layer.input[-1], copy=True)
assert axis.layer_type == "Const"
self.omit_nodes.append(axis.layer_name)
attr = {"axis": axis.value}
node.fluid_code.add_layer("concat", inputs=inputs, output=node, param_attr=attr)
def Tile(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
expand_times = self.graph.get_node(node.layer.input[1], copy=True)
assert expand_times.layer_type == "Const"
self.omit_nodes.append(expand_times.layer_name)
attr = {"expand_times": expand_times.value.tolist()}
node.fluid_code.add_layer("expand", inputs=input, output=node, param_attr=attr)
def Pack(self, node):
inputs = [self.graph.get_node(name, copy=True) for name in node.layer.input[:-1]]
node.fluid_code.add_layer("stack", inputs=inputs, output=node, param_attr=None)
def Pad(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
paddings = self.graph.get_node(Node.layer.input[1], copy=True)
assert paddings.layer_type == "Const", "Padding should be Const"
self.omit_nodes.append(paddings.layer_name)
attr = {"paddings": paddings.value.tolist()}
node.fluid_code.add_layer("pad", inputs=input, output=node, param_attr=attr)
# def ResizeNearestNeighbor(self, node):
# pass
def Range(self, node):
start = self.graph.get_node(node.layer.input[0], copy=True)
limit = self.graph.get_node(node.layer.input[1], copy=True)
delta = self.graph.get_node(node.layer.input[2], copy=True)
if start.layer_type == "Const":
self.omit_nodes.append(start.layer_name)
start = start.value
if limit.layer_type == "Const":
self.omit_nodes.append(limit.layer_name)
limit = limit.value
if delta.layer_type == "Const":
self.omit_nodes.append(delta.layer_name)
delta = delta.value
inputs = {"start": start, "end":limit, "step":delta}
attr = {"dtype": string(node.dtype)}
node.fluid_code.append("range", inputs=inputs, output=node, param_attr=None)
# def Fill(self, node):
# shape = self.graph.get_node(node.layer
def Mul(self, node):
x = self.graph.get_node(node.layer.input[0], copy=True)
y = self.graph.get_node(node.layer.input[1], copy=True)
inputs = {"x": x, "y": y}
node.fluid_code.add_layer("elementwise_mul",
inputs=inputs,
output=node,
param_attr=None)
def Sub(self, node):
x = self.graph.get_node(node.layer.input[0], copy=True)
y = self.graph.get_node(node.layer.input[1], copy=True)
inputs = {"x": x, "y": y}
node.fluid_code.add_layer("elementwise_sub",
inputs=inputs,
output=node,
param_attr=None)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册