提交 0c09c458 编写于 作者: J jiangjiajun

fix some bugs

上级 1b472393
...@@ -108,6 +108,7 @@ def tf2paddle(model_path, ...@@ -108,6 +108,7 @@ def tf2paddle(model_path,
mapper = TFOpMapperNHWC(model) mapper = TFOpMapperNHWC(model)
optimizer = TFOptimizer(mapper) optimizer = TFOptimizer(mapper)
optimizer.delete_redundance_code() optimizer.delete_redundance_code()
optimizer.strip_graph()
mapper.save_inference_model(save_dir) mapper.save_inference_model(save_dir)
......
...@@ -85,7 +85,7 @@ class Graph(object): ...@@ -85,7 +85,7 @@ class Graph(object):
node.index = int(idx) node.index = int(idx)
return node return node
else: else:
raise Exception("Graph doesn't have node [%s]." % name) return None
else: else:
if copy: if copy:
node = cp.copy(self.node_map[name]) node = cp.copy(self.node_map[name])
......
...@@ -148,6 +148,8 @@ class OpMapper(object): ...@@ -148,6 +148,8 @@ class OpMapper(object):
for i in range(len(self.graph.topo_sort)): for i in range(len(self.graph.topo_sort)):
node_name = self.graph.topo_sort[i] node_name = self.graph.topo_sort[i]
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
if node is None:
continue
if len(node.fluid_code.layers) == 0: if len(node.fluid_code.layers) == 0:
continue continue
self.add_codes(node.fluid_code.gen_codes(), 1) self.add_codes(node.fluid_code.gen_codes(), 1)
......
...@@ -129,6 +129,8 @@ class TFGraph(Graph): ...@@ -129,6 +129,8 @@ class TFGraph(Graph):
items[0] = self.identity_map[items[0]] items[0] = self.identity_map[items[0]]
new_node_name = ":".join(items) new_node_name = ":".join(items)
node = super(TFGraph, self).get_node(new_node_name, copy) node = super(TFGraph, self).get_node(new_node_name, copy)
if node is None:
return None
if len(items) == 1 and node.layer_type in self.multi_out_ops: if len(items) == 1 and node.layer_type in self.multi_out_ops:
node.index = 0 node.index = 0
return node return node
......
...@@ -113,6 +113,15 @@ class TFOpMapper(OpMapper): ...@@ -113,6 +113,15 @@ class TFOpMapper(OpMapper):
print("========== {} ==========".format(op)) print("========== {} ==========".format(op))
sys.exit(-1) sys.exit(-1)
def add_omit_nodes(self, in_node_name, out_node_name):
in_node = self.graph.get_node(in_node_name)
out_node = self.graph.get_node(out_node_name)
index = in_node.outputs.index(out_node_name)
del in_node.outputs[index]
index = out_node.inputs.index(in_node_name)
del out_node.inputs[index]
self.omit_nodes.append(in_node.layer_name)
def directly_map(self, node): def directly_map(self, node):
assert node.layer_type in self.directly_map_ops assert node.layer_type in self.directly_map_ops
op_info = self.directly_map_ops[node.layer_type] op_info = self.directly_map_ops[node.layer_type]
...@@ -363,10 +372,7 @@ class TFOpMapper(OpMapper): ...@@ -363,10 +372,7 @@ class TFOpMapper(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True) kernel = self.graph.get_node(node.layer.input[1], copy=True)
assert kernel.layer_type == "Const", "Kernel of Conv2D should be Const" assert kernel.layer_type == "Const", "Kernel of Conv2D should be Const"
self.omit_nodes.append(kernel.layer_name) self.add_omit_nodes(kernel.layer_name, node.layer_name)
node.fluid_code.add_note("#{} : {}".format(node.layer.name,
node.layer_name))
in_shape = input.out_shapes[0] in_shape = input.out_shapes[0]
if in_shape.count(-1) > 2: if in_shape.count(-1) > 2:
...@@ -444,11 +450,10 @@ class TFOpMapper(OpMapper): ...@@ -444,11 +450,10 @@ class TFOpMapper(OpMapper):
assert beta.layer_type == "Const" assert beta.layer_type == "Const"
assert moving_mean.layer_type == "Const" assert moving_mean.layer_type == "Const"
assert moving_var.layer_type == "Const" assert moving_var.layer_type == "Const"
self.omit_nodes.append(gamma.layer_name) self.add_omit_nodes(gamma.layer_name, node.layer_name)
self.omit_nodes.append(beta.layer_name) self.add_omit_nodes(beta.layer_name, node.layer_name)
self.omit_nodes.append(moving_mean.layer_name) self.add_omit_nodes(moving_mean.layer_name, node.layer_name)
self.omit_nodes.append(moving_var.layer_name) self.add_omit_nodes(moving_var.layer_name, node.layer_name)
if channel_first: if channel_first:
self.data_format_propagation(node) self.data_format_propagation(node)
...@@ -470,10 +475,7 @@ class TFOpMapper(OpMapper): ...@@ -470,10 +475,7 @@ class TFOpMapper(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True) kernel = self.graph.get_node(node.layer.input[1], copy=True)
assert kernel.layer_type == "Const", "Kernel of DepthwiseConv2DNative should be Const" assert kernel.layer_type == "Const", "Kernel of DepthwiseConv2DNative should be Const"
self.omit_nodes.append(kernel.layer_name) self.add_omit_nodes(kernel.layer_name, node.layer_name)
node.fluid_code.add_note("#{} : {}".format(node.layer.name,
node.layer_name))
in_shape = input.out_shapes[0] in_shape = input.out_shapes[0]
if in_shape.count(-1) > 2: if in_shape.count(-1) > 2:
...@@ -533,13 +535,17 @@ class TFOpMapper(OpMapper): ...@@ -533,13 +535,17 @@ class TFOpMapper(OpMapper):
param = self.graph.get_node(node.layer.input[1], copy=True) param = self.graph.get_node(node.layer.input[1], copy=True)
if param.layer_type == "Const": if param.layer_type == "Const":
attr = {"shape": param.value.tolist()} attr = {"shape": param.value.tolist()}
self.omit_nodes.append(param.layer_name) self.add_omit_nodes(param.layer_name, node.layer_name)
else: else:
# Here is a trick method to solove tensor parameter in tensorflow # Here is a trick method to solove tensor parameter in tensorflow
shape = self.decoder.infer_shape_tensor(param, node.out_shapes[0]) shape = self.decoder.infer_shape_tensor(param, node.out_shapes[0])
if shape.count(-1) <= 1: if shape.count(-1) <= 1:
attr = {"shape": shape} attr = {"shape": shape}
self.omit_nodes.append(param.layer_name) self.add_omit_nodes(param.layer_name, node.layer_name)
elif shape.count(-1) == 2 and shape[0] == -1:
shape[0] = 0
attr = {"shape": shape}
self.add_omit_nodes(param.layer_name, node.layer_name)
else: else:
assert len(param.out_shapes[0] assert len(param.out_shapes[0]
) == 1, "Unexpected situation of shape parameter" ) == 1, "Unexpected situation of shape parameter"
...@@ -559,7 +565,25 @@ class TFOpMapper(OpMapper): ...@@ -559,7 +565,25 @@ class TFOpMapper(OpMapper):
new_param = new_param.strip(", ") + "]" new_param = new_param.strip(", ") + "]"
attr = {"shape": new_param} attr = {"shape": new_param}
if len(attr["shape"]) == 4 and node.tf_data_format == "NHWC": if len(attr["shape"]) == 4 and node.tf_data_format == "NHWC":
attr["shape"] = [attr["shape"][i] for i in [0, 3, 1, 2]] input_shape = self.decoder.infer_tensor(input).shape
if input_shape[1] == attr["shape"][1]:
attr["shape"] = [attr["shape"][i] for i in [0, 3, 1, 2]]
else:
perm = {"perm": [0, 2, 3, 1]}
node.fluid_code.add_layer("transpose",
inputs=input,
output=node,
param_attr=perm)
node.fluid_code.add_layer("reshape",
inputs=node,
output=node,
param_attr=attr)
perm = {"perm": [0, 3, 1, 2]}
node.fluid_code.add_layer("transpose",
inputs=node,
output=node,
param_attr=perm)
return
node.fluid_code.add_layer("reshape", node.fluid_code.add_layer("reshape",
inputs=input, inputs=input,
output=node, output=node,
...@@ -607,8 +631,8 @@ class TFOpMapper(OpMapper): ...@@ -607,8 +631,8 @@ class TFOpMapper(OpMapper):
dim = self.graph.get_node(node.layer.input[2], copy=True) dim = self.graph.get_node(node.layer.input[2], copy=True)
assert num_sections.layer_type == "Const" assert num_sections.layer_type == "Const"
assert dim.layer_type == "Const" assert dim.layer_type == "Const"
self.omit_nodes.append(num_sections.layer_name) self.add_omit_nodes(num_sections.layer_name, node.layer_name)
self.omit_nodes.append(dim.layer_name) self.add_omit_nodes(dim.layer_name, node.layer_name)
dim = dim.value dim = dim.value
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4: if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
dim = nhwc_dim_to_nchw(input, dim) dim = nhwc_dim_to_nchw(input, dim)
...@@ -628,7 +652,7 @@ class TFOpMapper(OpMapper): ...@@ -628,7 +652,7 @@ class TFOpMapper(OpMapper):
] ]
axis = self.graph.get_node(node.layer.input[-1], copy=True) axis = self.graph.get_node(node.layer.input[-1], copy=True)
assert axis.layer_type == "Const" assert axis.layer_type == "Const"
self.omit_nodes.append(axis.layer_name) self.add_omit_nodes(axis.layer_name, node.layer_name)
axis = axis.value axis = axis.value
if inputs[0].tf_data_format == "NHWC" and len( if inputs[0].tf_data_format == "NHWC" and len(
inputs[0].out_shapes[0]) == 4: inputs[0].out_shapes[0]) == 4:
...@@ -642,7 +666,7 @@ class TFOpMapper(OpMapper): ...@@ -642,7 +666,7 @@ class TFOpMapper(OpMapper):
def Tile(self, node): def Tile(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
expand_times = self.graph.get_node(node.layer.input[1], copy=True) expand_times = self.graph.get_node(node.layer.input[1], copy=True)
self.omit_nodes.append(expand_times.layer_name) self.add_omit_nodes(expand_times.layer_name, node.layer_name)
if expand_times.layer_type == "Const": if expand_times.layer_type == "Const":
expand_times = expand_times.value.tolist() expand_times = expand_times.value.tolist()
else: else:
...@@ -687,7 +711,7 @@ class TFOpMapper(OpMapper): ...@@ -687,7 +711,7 @@ class TFOpMapper(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
paddings = self.graph.get_node(node.layer.input[1], copy=True) paddings = self.graph.get_node(node.layer.input[1], copy=True)
assert paddings.layer_type == "Const", "Padding should be Const" assert paddings.layer_type == "Const", "Padding should be Const"
self.omit_nodes.append(paddings.layer_name) self.add_omit_nodes(paddings.layer_name, node.layer_name)
paddings = paddings.value.flatten().tolist() paddings = paddings.value.flatten().tolist()
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4: if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
paddings = [paddings[i] for i in [0, 1, 6, 7, 2, 3, 4, 5]] paddings = [paddings[i] for i in [0, 1, 6, 7, 2, 3, 4, 5]]
...@@ -719,9 +743,9 @@ class TFOpMapper(OpMapper): ...@@ -719,9 +743,9 @@ class TFOpMapper(OpMapper):
delta = delta.value delta = delta.value
else: else:
delta = self.decoder.infer_tensor(delta) delta = self.decoder.infer_tensor(delta)
self.omit_nodes.append(start.layer_name) self.add_omit_nodes(start.layer_name, node.layer_name)
self.omit_nodes.append(limit.layer_name) self.add_omit_nodes(limit.layer_name, node.layer_name)
limit = self.decoder.infer_tensor(limit) self.add_omit_nodes(delta.layer_name, node.layer_name)
inputs = {"start": start, "end": limit, "step": delta} inputs = {"start": start, "end": limit, "step": delta}
attr = {"dtype": string(node.dtype)} attr = {"dtype": string(node.dtype)}
...@@ -773,7 +797,7 @@ class TFOpMapper(OpMapper): ...@@ -773,7 +797,7 @@ class TFOpMapper(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
axis = self.graph.get_node(node.layer.input[1], copy=True) axis = self.graph.get_node(node.layer.input[1], copy=True)
assert axis.layer_type == "Const", "ArgMax only support Const parameter" assert axis.layer_type == "Const", "ArgMax only support Const parameter"
self.omit_nodes.append(axis.layer_name) self.add_omit_nodes(axis.layer_name, node.layer_name)
axis = axis.value axis = axis.value
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4: if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
axis = nhwc_dim_to_nchw(input, axis) axis = nhwc_dim_to_nchw(input, axis)
...@@ -791,9 +815,9 @@ class TFOpMapper(OpMapper): ...@@ -791,9 +815,9 @@ class TFOpMapper(OpMapper):
assert begin.layer_type == "Const" assert begin.layer_type == "Const"
assert end.layer_type == "Const" assert end.layer_type == "Const"
assert strides.layer_type == "Const" assert strides.layer_type == "Const"
self.omit_nodes.append(begin.layer_name) self.add_omit_nodes(begin.layer_name, node.layer_name)
self.omit_nodes.append(end.layer_name) self.add_omit_nodes(end.layer_name, node.layer_name)
self.omit_nodes.append(strides.layer_name) self.add_omit_nodes(strides.layer_name, node.layer_name)
strides = strides.value.tolist() strides = strides.value.tolist()
assert len(set(strides)) == 1 and strides[0] == 1 assert len(set(strides)) == 1 and strides[0] == 1
...@@ -821,10 +845,8 @@ class TFOpMapper(OpMapper): ...@@ -821,10 +845,8 @@ class TFOpMapper(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
begin = self.graph.get_node(node.layer.input[1], copy=True) begin = self.graph.get_node(node.layer.input[1], copy=True)
size = self.graph.get_node(node.layer.input[2], copy=True) size = self.graph.get_node(node.layer.input[2], copy=True)
# assert begin.layer_type == "Const" self.add_omit_nodes(begin.layer_name, node.layer_name)
# assert size.layer_type == "Const" self.add_omit_nodes(size.layer_name, node.layer_name)
self.omit_nodes.append(begin.layer_name)
self.omit_nodes.append(size.layer_name)
if begin.layer_type == "Const": if begin.layer_type == "Const":
begin = begin.value.tolist() begin = begin.value.tolist()
else: else:
...@@ -848,11 +870,7 @@ class TFOpMapper(OpMapper): ...@@ -848,11 +870,7 @@ class TFOpMapper(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True) kernel = self.graph.get_node(node.layer.input[1], copy=True)
assert kernel.layer_type == "Const", "Kernel of Conv2DBackpropInput should be Const" assert kernel.layer_type == "Const", "Kernel of Conv2DBackpropInput should be Const"
self.omit_nodes.append(kernel.layer_name) self.add_omit_nodes(kernel.layer_name, node.layer_name)
node.fluid_code.add_note("#{} : {}".format(node.layer.name,
node.layer_name))
in_shape = input.out_shapes[0] in_shape = input.out_shapes[0]
if in_shape.count(-1) > 2: if in_shape.count(-1) > 2:
in_shape = self.decoder.infer_tensor(input).shape in_shape = self.decoder.infer_tensor(input).shape
...@@ -958,8 +976,7 @@ class TFOpMapper(OpMapper): ...@@ -958,8 +976,7 @@ class TFOpMapper(OpMapper):
def Split(self, node): def Split(self, node):
dim = self.graph.get_node(node.layer.input[0], copy=True) dim = self.graph.get_node(node.layer.input[0], copy=True)
input = self.graph.get_node(node.layer.input[1], copy=True) input = self.graph.get_node(node.layer.input[1], copy=True)
assert dim.layer_type == "Const" self.add_omit_nodes(dim.layer_name, node.layer_name)
self.omit_nodes.append(dim.layer_name)
num_split = node.get_attr('num_split') num_split = node.get_attr('num_split')
dim = dim.value dim = dim.value
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4: if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
...@@ -986,6 +1003,8 @@ class TFOpMapper(OpMapper): ...@@ -986,6 +1003,8 @@ class TFOpMapper(OpMapper):
def Softmax(self, node): def Softmax(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
axis = node.get_attr("axis") axis = node.get_attr("axis")
if axis is None:
axis = -1 + len(input.out_shapes[0])
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4: if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
axis = nhwc_dim_to_nchw(input, axis) axis = nhwc_dim_to_nchw(input, axis)
attr = {"axis": axis} attr = {"axis": axis}
...@@ -997,7 +1016,7 @@ class TFOpMapper(OpMapper): ...@@ -997,7 +1016,7 @@ class TFOpMapper(OpMapper):
def ResizeNearestNeighbor(self, node): def ResizeNearestNeighbor(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
resize_shape = self.graph.get_node(node.layer.input[1], copy=True) resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
self.omit_nodes.append(resize_shape.layer_name) self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
if resize_shape.layer_type == "Const": if resize_shape.layer_type == "Const":
resize_shape = resize_shape.value.tolist() resize_shape = resize_shape.value.tolist()
else: else:
...@@ -1012,7 +1031,7 @@ class TFOpMapper(OpMapper): ...@@ -1012,7 +1031,7 @@ class TFOpMapper(OpMapper):
def ResizeBilinear(self, node): def ResizeBilinear(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
resize_shape = self.graph.get_node(node.layer.input[1], copy=True) resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
self.omit_nodes.append(resize_shape.layer_name) self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
if resize_shape.layer_type == "Const": if resize_shape.layer_type == "Const":
resize_shape = resize_shape.value.tolist() resize_shape = resize_shape.value.tolist()
else: else:
...@@ -1031,7 +1050,7 @@ class TFOpMapper(OpMapper): ...@@ -1031,7 +1050,7 @@ class TFOpMapper(OpMapper):
def ResizeNearestNeighbor(self, node): def ResizeNearestNeighbor(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
resize_shape = self.graph.get_node(node.layer.input[1], copy=True) resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
self.omit_nodes.append(resize_shape.layer_name) self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
if resize_shape.layer_type == "Const": if resize_shape.layer_type == "Const":
resize_shape = resize_shape.value.tolist() resize_shape = resize_shape.value.tolist()
else: else:
...@@ -1047,7 +1066,7 @@ class TFOpMapper(OpMapper): ...@@ -1047,7 +1066,7 @@ class TFOpMapper(OpMapper):
def ResizeBilinear(self, node): def ResizeBilinear(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
resize_shape = self.graph.get_node(node.layer.input[1], copy=True) resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
self.omit_nodes.append(resize_shape.layer_name) self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
if resize_shape.layer_type == "Const": if resize_shape.layer_type == "Const":
resize_shape = resize_shape.value.tolist() resize_shape = resize_shape.value.tolist()
else: else:
......
...@@ -95,6 +95,15 @@ class TFOpMapperNHWC(OpMapper): ...@@ -95,6 +95,15 @@ class TFOpMapperNHWC(OpMapper):
print("========== {} ============".format(op)) print("========== {} ============".format(op))
sys.exit(-1) sys.exit(-1)
def add_omit_nodes(self, in_node_name, out_node_name):
in_node = self.graph.get_node(in_node_name)
out_node = self.graph.get_node(out_node_name)
index = in_node.outputs.index(out_node_name)
del in_node.outputs[index]
index = out_node.inputs.index(in_node_name)
del out_node.inputs[index]
self.omit_nodes.append(in_node.layer_name)
def directly_map(self, node): def directly_map(self, node):
assert node.layer_type in self.directly_map_ops assert node.layer_type in self.directly_map_ops
op_info = self.directly_map_ops[node.layer_type] op_info = self.directly_map_ops[node.layer_type]
...@@ -289,10 +298,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -289,10 +298,7 @@ class TFOpMapperNHWC(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True) kernel = self.graph.get_node(node.layer.input[1], copy=True)
assert kernel.layer_type == "Const", "Kernel of Conv2D should be Const" assert kernel.layer_type == "Const", "Kernel of Conv2D should be Const"
self.omit_nodes.append(kernel.layer_name) self.add_omit_nodes(kernel.layer_name, node.layer_name)
node.fluid_code.add_note("#{} : {}".format(node.layer.name,
node.layer_name))
in_shape = input.out_shapes[0] in_shape = input.out_shapes[0]
if in_shape.count(-1) > 2: if in_shape.count(-1) > 2:
...@@ -376,10 +382,10 @@ class TFOpMapperNHWC(OpMapper): ...@@ -376,10 +382,10 @@ class TFOpMapperNHWC(OpMapper):
assert beta.layer_type == "Const" assert beta.layer_type == "Const"
assert moving_mean.layer_type == "Const" assert moving_mean.layer_type == "Const"
assert moving_var.layer_type == "Const" assert moving_var.layer_type == "Const"
self.omit_nodes.append(gamma.layer_name) self.add_omit_nodes(gamma.layer_name, node.layer_name)
self.omit_nodes.append(beta.layer_name) self.add_omit_nodes(beta.layer_name, node.layer_name)
self.omit_nodes.append(moving_mean.layer_name) self.add_omit_nodes(moving_mean.layer_name, node.layer_name)
self.omit_nodes.append(moving_var.layer_name) self.add_omit_nodes(moving_var.layer_name, node.layer_name)
if not channel_first: if not channel_first:
attr = {"perm": [0, 3, 1, 2]} attr = {"perm": [0, 3, 1, 2]}
...@@ -414,7 +420,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -414,7 +420,7 @@ class TFOpMapperNHWC(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True) kernel = self.graph.get_node(node.layer.input[1], copy=True)
assert kernel.layer_type == "Const", "Kernel of DepthwiseConv2DNative should be Const" assert kernel.layer_type == "Const", "Kernel of DepthwiseConv2DNative should be Const"
self.omit_nodes.append(kernel.layer_name) self.add_omit_nodes(kernel.layer_name, node.layer_name)
in_shape = input.out_shapes[0] in_shape = input.out_shapes[0]
if in_shape.count(-1) > 2: if in_shape.count(-1) > 2:
...@@ -485,13 +491,13 @@ class TFOpMapperNHWC(OpMapper): ...@@ -485,13 +491,13 @@ class TFOpMapperNHWC(OpMapper):
param = self.graph.get_node(node.layer.input[1], copy=True) param = self.graph.get_node(node.layer.input[1], copy=True)
if param.layer_type == "Const": if param.layer_type == "Const":
attr = {"shape": param.value.tolist()} attr = {"shape": param.value.tolist()}
self.omit_nodes.append(param.layer_name) self.add_omit_nodes(param.layer_name, node.layer_name)
else: else:
# Here is a trick method to solove tensor parameter in tensorflow # Here is a trick method to solove tensor parameter in tensorflow
shape = self.decoder.infer_shape_tensor(param, node.out_shapes[0]) shape = self.decoder.infer_shape_tensor(param, node.out_shapes[0])
if shape.count(-1) <= 1: if shape.count(-1) <= 1:
attr = {"shape": shape} attr = {"shape": shape}
self.omit_nodes.append(param.layer_name) self.add_omit_nodes(param.layer_name, node.layer_name)
else: else:
assert len(param.out_shapes[0] assert len(param.out_shapes[0]
) == 1, "Unexpected situation of shape parameter" ) == 1, "Unexpected situation of shape parameter"
...@@ -568,8 +574,8 @@ class TFOpMapperNHWC(OpMapper): ...@@ -568,8 +574,8 @@ class TFOpMapperNHWC(OpMapper):
dim = self.graph.get_node(node.layer.input[2], copy=True) dim = self.graph.get_node(node.layer.input[2], copy=True)
assert num_sections.layer_type == "Const" assert num_sections.layer_type == "Const"
assert dim.layer_type == "Const" assert dim.layer_type == "Const"
self.omit_nodes.append(num_sections.layer_name) self.add_omit_nodes(num_sections.layer_name, node.layer_name)
self.omit_nodes.append(dim.layer_name) self.add_omit_nodes(dim.layer_name, node.layer_name)
dim = dim.value dim = dim.value
attr = { attr = {
"num_or_sections": num_sections.value.tolist(), "num_or_sections": num_sections.value.tolist(),
...@@ -587,7 +593,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -587,7 +593,7 @@ class TFOpMapperNHWC(OpMapper):
] ]
axis = self.graph.get_node(node.layer.input[-1], copy=True) axis = self.graph.get_node(node.layer.input[-1], copy=True)
assert axis.layer_type == "Const" assert axis.layer_type == "Const"
self.omit_nodes.append(axis.layer_name) self.add_omit_nodes(axis.layer_name, node.layer_name)
axis = axis.value axis = axis.value
if axis < 0: if axis < 0:
axis += len(inputs[0].out_shapes[0]) axis += len(inputs[0].out_shapes[0])
...@@ -601,7 +607,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -601,7 +607,7 @@ class TFOpMapperNHWC(OpMapper):
def Tile(self, node): def Tile(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
expand_times = self.graph.get_node(node.layer.input[1], copy=True) expand_times = self.graph.get_node(node.layer.input[1], copy=True)
self.omit_nodes.append(expand_times.layer_name) self.add_omit_nodes(expand_times.layer_name, node.layer_name)
if expand_times.layer_type == "Const": if expand_times.layer_type == "Const":
expand_times = expand_times.value.tolist() expand_times = expand_times.value.tolist()
else: else:
...@@ -630,7 +636,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -630,7 +636,7 @@ class TFOpMapperNHWC(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
paddings = self.graph.get_node(node.layer.input[1], copy=True) paddings = self.graph.get_node(node.layer.input[1], copy=True)
assert paddings.layer_type == "Const", "Padding should be Const" assert paddings.layer_type == "Const", "Padding should be Const"
self.omit_nodes.append(paddings.layer_name) self.add_omit_nodes(paddings.layer_name, node.layer_name)
paddings = paddings.value.flatten().tolist() paddings = paddings.value.flatten().tolist()
data_format = input.tf_data_format data_format = input.tf_data_format
...@@ -674,9 +680,9 @@ class TFOpMapperNHWC(OpMapper): ...@@ -674,9 +680,9 @@ class TFOpMapperNHWC(OpMapper):
start = self.graph.get_node(node.layer.input[0], copy=True) start = self.graph.get_node(node.layer.input[0], copy=True)
limit = self.graph.get_node(node.layer.input[1], copy=True) limit = self.graph.get_node(node.layer.input[1], copy=True)
delta = self.graph.get_node(node.layer.input[2], copy=True) delta = self.graph.get_node(node.layer.input[2], copy=True)
self.omit_nodes.append(start.layer_name) self.add_omit_nodes(start.layer_name, node.layer_name)
self.omit_nodes.append(limit.layer_name) self.add_omit_nodes(limit.layer_name, node.layer_name)
self.omit_nodes.append(delta.layer_name) self.add_omit_nodes(delta.layer_name, node.layer_name)
if start.layer_type == "Const": if start.layer_type == "Const":
start = start.value start = start.value
else: else:
...@@ -741,7 +747,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -741,7 +747,7 @@ class TFOpMapperNHWC(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
axis = self.graph.get_node(node.layer.input[1], copy=True) axis = self.graph.get_node(node.layer.input[1], copy=True)
assert axis.layer_type == "Const", "ArgMax only support Const parameter" assert axis.layer_type == "Const", "ArgMax only support Const parameter"
self.omit_nodes.append(axis.layer_name) self.add_omit_nodes(axis.layer_name, node.layer_name)
axis = axis.value axis = axis.value
attr = {"axis": axis} attr = {"axis": axis}
node.fluid_code.add_layer("argmax", node.fluid_code.add_layer("argmax",
...@@ -757,9 +763,9 @@ class TFOpMapperNHWC(OpMapper): ...@@ -757,9 +763,9 @@ class TFOpMapperNHWC(OpMapper):
assert begin.layer_type == "Const" assert begin.layer_type == "Const"
assert end.layer_type == "Const" assert end.layer_type == "Const"
assert strides.layer_type == "Const" assert strides.layer_type == "Const"
self.omit_nodes.append(begin.layer_name) self.add_omit_nodes(begin.layer_name, node.layer_name)
self.omit_nodes.append(end.layer_name) self.add_omit_nodes(end.layer_name, node.layer_name)
self.omit_nodes.append(strides.layer_name) self.add_omit_nodes(strides.layer_name, node.layer_name)
strides = strides.value.tolist() strides = strides.value.tolist()
assert len(set(strides)) == 1 and strides[ assert len(set(strides)) == 1 and strides[
0] == 1, "Only support strides be 1 in StridedSlice OP" 0] == 1, "Only support strides be 1 in StridedSlice OP"
...@@ -840,8 +846,8 @@ class TFOpMapperNHWC(OpMapper): ...@@ -840,8 +846,8 @@ class TFOpMapperNHWC(OpMapper):
size = self.graph.get_node(node.layer.input[2], copy=True) size = self.graph.get_node(node.layer.input[2], copy=True)
# assert begin.layer_type == "Const" # assert begin.layer_type == "Const"
# assert size.layer_type == "Const" # assert size.layer_type == "Const"
self.omit_nodes.append(begin.layer_name) self.add_omit_nodes(begin.layer_name, node.layer_name)
self.omit_nodes.append(size.layer_name) self.add_omit_nodes(size.layer_name, node.layer_name)
if begin.layer_type == "Const": if begin.layer_type == "Const":
begin = begin.value.tolist() begin = begin.value.tolist()
else: else:
...@@ -861,7 +867,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -861,7 +867,7 @@ class TFOpMapperNHWC(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True) kernel = self.graph.get_node(node.layer.input[1], copy=True)
assert kernel.layer_type == "Const", "Kernel of Conv2DBackpropInput should be Const" assert kernel.layer_type == "Const", "Kernel of Conv2DBackpropInput should be Const"
self.omit_nodes.append(kernel.layer_name) self.add_omit_nodes(kernel.layer_name, node.layer_name)
in_shape = input.out_shapes[0] in_shape = input.out_shapes[0]
if in_shape.count(-1) > 2: if in_shape.count(-1) > 2:
...@@ -973,7 +979,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -973,7 +979,7 @@ class TFOpMapperNHWC(OpMapper):
dim = self.graph.get_node(node.layer.input[0], copy=True) dim = self.graph.get_node(node.layer.input[0], copy=True)
input = self.graph.get_node(node.layer.input[1], copy=True) input = self.graph.get_node(node.layer.input[1], copy=True)
assert dim.layer_type == "Const" assert dim.layer_type == "Const"
self.omit_nodes.append(dim.layer_name) self.add_omit_nodes(dim.layer_name, node.layer_name)
num_split = node.get_attr('num_split') num_split = node.get_attr('num_split')
dim = dim.value dim = dim.value
...@@ -1004,7 +1010,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -1004,7 +1010,7 @@ class TFOpMapperNHWC(OpMapper):
def ResizeNearestNeighbor(self, node): def ResizeNearestNeighbor(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
resize_shape = self.graph.get_node(node.layer.input[1], copy=True) resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
self.omit_nodes.append(resize_shape.layer_name) self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
if resize_shape.layer_type == "Const": if resize_shape.layer_type == "Const":
resize_shape = resize_shape.value.tolist() resize_shape = resize_shape.value.tolist()
else: else:
...@@ -1030,7 +1036,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -1030,7 +1036,7 @@ class TFOpMapperNHWC(OpMapper):
def ResizeBilinear(self, node): def ResizeBilinear(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
resize_shape = self.graph.get_node(node.layer.input[1], copy=True) resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
self.omit_nodes.append(resize_shape.layer_name) self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
if resize_shape.layer_type == "Const": if resize_shape.layer_type == "Const":
resize_shape = resize_shape.value.tolist() resize_shape = resize_shape.value.tolist()
else: else:
......
...@@ -37,18 +37,82 @@ class TFOptimizer(object): ...@@ -37,18 +37,82 @@ class TFOptimizer(object):
self.graph = op_mapper.graph self.graph = op_mapper.graph
def delete_redundance_code(self): def delete_redundance_code(self):
# print("==========omit_nodes============")
# for node_name in set(self.op_mapper.omit_nodes):
# node = self.graph.get_node(node_name)
# print(node.layer_name, self.op_mapper.omit_nodes.count(node.layer_name), len(node.outputs), node.outputs)
# print("================================")
for node_name in self.graph.topo_sort: for node_name in self.graph.topo_sort:
if node_name in self.op_mapper.omit_nodes: if node_name in self.op_mapper.omit_nodes:
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
if node is None:
continue
omit_freq = self.op_mapper.omit_nodes.count(node_name) omit_freq = self.op_mapper.omit_nodes.count(node_name)
if len(node.outputs) <= omit_freq: if len(node.outputs) <= omit_freq:
node.fluid_code.clear() node.fluid_code.clear()
# remove node from graph
input_names = node.inputs
output_names = node.outputs
for in_name in input_names:
in_node = self.graph.get_node(in_name)
index = in_node.outputs.index(node_name)
del in_node.outputs[index]
for out_name in output_names:
out_node = self.graph.get_node(out_name)
index = out_node.inputs.index(node_name)
del out_node.inputs[index]
del self.graph.node_map[node_name]
def strip_graph(self):
# print("=============")
# for i, node_name in enumerate(self.graph.topo_sort):
# node = self.graph.get_node(node_name)
# if node is None:
# continue
# print(node.layer_name, node.inputs)
# print("================")
visited_nodes = set()
def visit(node_name):
if node_name in visited_nodes:
return
visited_nodes.add(node_name)
input_names = self.graph.get_node(node_name).inputs
for in_name in input_names:
visit(in_name)
for node_name in self.graph.output_nodes:
visit(node_name)
# print("=============visited nodes++++++++++++")
# for name in visited_nodes:
# print(name)
# print("===================================")
for i, node_name in enumerate(self.graph.topo_sort):
if node_name not in visited_nodes:
node = self.graph.get_node(node_name)
if node is None:
continue
input_names = node.inputs
output_names = node.outputs
for in_name in input_names:
in_node = self.graph.get_node(in_name)
index = in_node.outputs.index(node_name)
del in_node.outputs[index]
for out_name in output_names:
out_node = self.graph.get_node(out_name)
index = out_node.inputs.index(node_name)
del out_node.inputs[index]
del self.graph.node_map[node_name]
# TODO activation merge # TODO activation merge
def merge_activation(self): def merge_activation(self):
act_nodes = list() act_nodes = list()
for node_name in self.graph.topo_sort: for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
if node is None:
continue
if node.layer_type in self.activation_ops: if node.layer_type in self.activation_ops:
act_nodes.append(node_name) act_nodes.append(node_name)
...@@ -75,6 +139,8 @@ class TFOptimizer(object): ...@@ -75,6 +139,8 @@ class TFOptimizer(object):
def merge_bias(self): def merge_bias(self):
for node_name in self.graph.topo_sort: for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
if node is None:
continue
if node.layer_type == "BiasAdd": if node.layer_type == "BiasAdd":
input = self.graph.get_node(node.inputs[0]) input = self.graph.get_node(node.inputs[0])
if input.layer_type not in self.layers_with_bias: if input.layer_type not in self.layers_with_bias:
...@@ -105,3 +171,27 @@ class TFOptimizer(object): ...@@ -105,3 +171,27 @@ class TFOptimizer(object):
'act'] = node.fluid_code.layers[-1].param_attr[ 'act'] = node.fluid_code.layers[-1].param_attr[
'act'] 'act']
node.fluid_code.clear() node.fluid_code.clear()
# def remove_transpose(self):
# optimize_ops = ['Conv2D', 'MaxPool', 'FusedBatchNorm', 'DepthwiseConv2dNative', 'AvgPool', 'Pad', 'Conv2DBackpropInput', 'ResizeNearestNeighbor', 'ResizeBilinear']
# for node_name in self.graph.topo_sort:
# node = self.graph.get_node(node_name)
# if node.layer_type not in optimize_ops:
# continue
# if node.fluid_code.layers[-1].op != "transpose" or node.fluid_code.layers[-1].param_attr["perm"] != [0, 2, 3, 1]:
# continue
# output_names = node.outputs
# can_be_removed = True
# for out_name in outputs_names:
# out_node = self.graph.get_node(out_name)
# if out_node.fluid_code.layers[0].op != "transpose" or out_node.fluid_code.layers[-1].param_attr["perm"] != [0, 3, 1, 2]:
# can_be_removed = False
# break
# if can_be_removed and len(output_names) > 0:
# last_out = node.fluid_code.layers[-1].inputs
# del node.fluid_code.layers[-1]
# for out_name in outputs_names:
# out_node = self.graph.get_node(out_name)
# del out_node.fluid_code.layers[0]
# out_node.fluid_code.layers[0].inputs = last_out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册