提交 0c09c458 编写于 作者: J jiangjiajun

fix some bugs

上级 1b472393
......@@ -108,6 +108,7 @@ def tf2paddle(model_path,
mapper = TFOpMapperNHWC(model)
optimizer = TFOptimizer(mapper)
optimizer.delete_redundance_code()
optimizer.strip_graph()
mapper.save_inference_model(save_dir)
......
......@@ -85,7 +85,7 @@ class Graph(object):
node.index = int(idx)
return node
else:
raise Exception("Graph doesn't have node [%s]." % name)
return None
else:
if copy:
node = cp.copy(self.node_map[name])
......
......@@ -148,6 +148,8 @@ class OpMapper(object):
for i in range(len(self.graph.topo_sort)):
node_name = self.graph.topo_sort[i]
node = self.graph.get_node(node_name)
if node is None:
continue
if len(node.fluid_code.layers) == 0:
continue
self.add_codes(node.fluid_code.gen_codes(), 1)
......
......@@ -129,6 +129,8 @@ class TFGraph(Graph):
items[0] = self.identity_map[items[0]]
new_node_name = ":".join(items)
node = super(TFGraph, self).get_node(new_node_name, copy)
if node is None:
return None
if len(items) == 1 and node.layer_type in self.multi_out_ops:
node.index = 0
return node
......
......@@ -113,6 +113,15 @@ class TFOpMapper(OpMapper):
print("========== {} ==========".format(op))
sys.exit(-1)
def add_omit_nodes(self, in_node_name, out_node_name):
in_node = self.graph.get_node(in_node_name)
out_node = self.graph.get_node(out_node_name)
index = in_node.outputs.index(out_node_name)
del in_node.outputs[index]
index = out_node.inputs.index(in_node_name)
del out_node.inputs[index]
self.omit_nodes.append(in_node.layer_name)
def directly_map(self, node):
assert node.layer_type in self.directly_map_ops
op_info = self.directly_map_ops[node.layer_type]
......@@ -363,10 +372,7 @@ class TFOpMapper(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True)
assert kernel.layer_type == "Const", "Kernel of Conv2D should be Const"
self.omit_nodes.append(kernel.layer_name)
node.fluid_code.add_note("#{} : {}".format(node.layer.name,
node.layer_name))
self.add_omit_nodes(kernel.layer_name, node.layer_name)
in_shape = input.out_shapes[0]
if in_shape.count(-1) > 2:
......@@ -444,11 +450,10 @@ class TFOpMapper(OpMapper):
assert beta.layer_type == "Const"
assert moving_mean.layer_type == "Const"
assert moving_var.layer_type == "Const"
self.omit_nodes.append(gamma.layer_name)
self.omit_nodes.append(beta.layer_name)
self.omit_nodes.append(moving_mean.layer_name)
self.omit_nodes.append(moving_var.layer_name)
self.add_omit_nodes(gamma.layer_name, node.layer_name)
self.add_omit_nodes(beta.layer_name, node.layer_name)
self.add_omit_nodes(moving_mean.layer_name, node.layer_name)
self.add_omit_nodes(moving_var.layer_name, node.layer_name)
if channel_first:
self.data_format_propagation(node)
......@@ -470,10 +475,7 @@ class TFOpMapper(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True)
assert kernel.layer_type == "Const", "Kernel of DepthwiseConv2DNative should be Const"
self.omit_nodes.append(kernel.layer_name)
node.fluid_code.add_note("#{} : {}".format(node.layer.name,
node.layer_name))
self.add_omit_nodes(kernel.layer_name, node.layer_name)
in_shape = input.out_shapes[0]
if in_shape.count(-1) > 2:
......@@ -533,13 +535,17 @@ class TFOpMapper(OpMapper):
param = self.graph.get_node(node.layer.input[1], copy=True)
if param.layer_type == "Const":
attr = {"shape": param.value.tolist()}
self.omit_nodes.append(param.layer_name)
self.add_omit_nodes(param.layer_name, node.layer_name)
else:
# Here is a trick method to solove tensor parameter in tensorflow
shape = self.decoder.infer_shape_tensor(param, node.out_shapes[0])
if shape.count(-1) <= 1:
attr = {"shape": shape}
self.omit_nodes.append(param.layer_name)
self.add_omit_nodes(param.layer_name, node.layer_name)
elif shape.count(-1) == 2 and shape[0] == -1:
shape[0] = 0
attr = {"shape": shape}
self.add_omit_nodes(param.layer_name, node.layer_name)
else:
assert len(param.out_shapes[0]
) == 1, "Unexpected situation of shape parameter"
......@@ -559,7 +565,25 @@ class TFOpMapper(OpMapper):
new_param = new_param.strip(", ") + "]"
attr = {"shape": new_param}
if len(attr["shape"]) == 4 and node.tf_data_format == "NHWC":
attr["shape"] = [attr["shape"][i] for i in [0, 3, 1, 2]]
input_shape = self.decoder.infer_tensor(input).shape
if input_shape[1] == attr["shape"][1]:
attr["shape"] = [attr["shape"][i] for i in [0, 3, 1, 2]]
else:
perm = {"perm": [0, 2, 3, 1]}
node.fluid_code.add_layer("transpose",
inputs=input,
output=node,
param_attr=perm)
node.fluid_code.add_layer("reshape",
inputs=node,
output=node,
param_attr=attr)
perm = {"perm": [0, 3, 1, 2]}
node.fluid_code.add_layer("transpose",
inputs=node,
output=node,
param_attr=perm)
return
node.fluid_code.add_layer("reshape",
inputs=input,
output=node,
......@@ -607,8 +631,8 @@ class TFOpMapper(OpMapper):
dim = self.graph.get_node(node.layer.input[2], copy=True)
assert num_sections.layer_type == "Const"
assert dim.layer_type == "Const"
self.omit_nodes.append(num_sections.layer_name)
self.omit_nodes.append(dim.layer_name)
self.add_omit_nodes(num_sections.layer_name, node.layer_name)
self.add_omit_nodes(dim.layer_name, node.layer_name)
dim = dim.value
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
dim = nhwc_dim_to_nchw(input, dim)
......@@ -628,7 +652,7 @@ class TFOpMapper(OpMapper):
]
axis = self.graph.get_node(node.layer.input[-1], copy=True)
assert axis.layer_type == "Const"
self.omit_nodes.append(axis.layer_name)
self.add_omit_nodes(axis.layer_name, node.layer_name)
axis = axis.value
if inputs[0].tf_data_format == "NHWC" and len(
inputs[0].out_shapes[0]) == 4:
......@@ -642,7 +666,7 @@ class TFOpMapper(OpMapper):
def Tile(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
expand_times = self.graph.get_node(node.layer.input[1], copy=True)
self.omit_nodes.append(expand_times.layer_name)
self.add_omit_nodes(expand_times.layer_name, node.layer_name)
if expand_times.layer_type == "Const":
expand_times = expand_times.value.tolist()
else:
......@@ -687,7 +711,7 @@ class TFOpMapper(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True)
paddings = self.graph.get_node(node.layer.input[1], copy=True)
assert paddings.layer_type == "Const", "Padding should be Const"
self.omit_nodes.append(paddings.layer_name)
self.add_omit_nodes(paddings.layer_name, node.layer_name)
paddings = paddings.value.flatten().tolist()
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
paddings = [paddings[i] for i in [0, 1, 6, 7, 2, 3, 4, 5]]
......@@ -719,9 +743,9 @@ class TFOpMapper(OpMapper):
delta = delta.value
else:
delta = self.decoder.infer_tensor(delta)
self.omit_nodes.append(start.layer_name)
self.omit_nodes.append(limit.layer_name)
limit = self.decoder.infer_tensor(limit)
self.add_omit_nodes(start.layer_name, node.layer_name)
self.add_omit_nodes(limit.layer_name, node.layer_name)
self.add_omit_nodes(delta.layer_name, node.layer_name)
inputs = {"start": start, "end": limit, "step": delta}
attr = {"dtype": string(node.dtype)}
......@@ -773,7 +797,7 @@ class TFOpMapper(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True)
axis = self.graph.get_node(node.layer.input[1], copy=True)
assert axis.layer_type == "Const", "ArgMax only support Const parameter"
self.omit_nodes.append(axis.layer_name)
self.add_omit_nodes(axis.layer_name, node.layer_name)
axis = axis.value
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
axis = nhwc_dim_to_nchw(input, axis)
......@@ -791,9 +815,9 @@ class TFOpMapper(OpMapper):
assert begin.layer_type == "Const"
assert end.layer_type == "Const"
assert strides.layer_type == "Const"
self.omit_nodes.append(begin.layer_name)
self.omit_nodes.append(end.layer_name)
self.omit_nodes.append(strides.layer_name)
self.add_omit_nodes(begin.layer_name, node.layer_name)
self.add_omit_nodes(end.layer_name, node.layer_name)
self.add_omit_nodes(strides.layer_name, node.layer_name)
strides = strides.value.tolist()
assert len(set(strides)) == 1 and strides[0] == 1
......@@ -821,10 +845,8 @@ class TFOpMapper(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True)
begin = self.graph.get_node(node.layer.input[1], copy=True)
size = self.graph.get_node(node.layer.input[2], copy=True)
# assert begin.layer_type == "Const"
# assert size.layer_type == "Const"
self.omit_nodes.append(begin.layer_name)
self.omit_nodes.append(size.layer_name)
self.add_omit_nodes(begin.layer_name, node.layer_name)
self.add_omit_nodes(size.layer_name, node.layer_name)
if begin.layer_type == "Const":
begin = begin.value.tolist()
else:
......@@ -848,11 +870,7 @@ class TFOpMapper(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True)
assert kernel.layer_type == "Const", "Kernel of Conv2DBackpropInput should be Const"
self.omit_nodes.append(kernel.layer_name)
node.fluid_code.add_note("#{} : {}".format(node.layer.name,
node.layer_name))
self.add_omit_nodes(kernel.layer_name, node.layer_name)
in_shape = input.out_shapes[0]
if in_shape.count(-1) > 2:
in_shape = self.decoder.infer_tensor(input).shape
......@@ -958,8 +976,7 @@ class TFOpMapper(OpMapper):
def Split(self, node):
dim = self.graph.get_node(node.layer.input[0], copy=True)
input = self.graph.get_node(node.layer.input[1], copy=True)
assert dim.layer_type == "Const"
self.omit_nodes.append(dim.layer_name)
self.add_omit_nodes(dim.layer_name, node.layer_name)
num_split = node.get_attr('num_split')
dim = dim.value
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
......@@ -986,6 +1003,8 @@ class TFOpMapper(OpMapper):
def Softmax(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
axis = node.get_attr("axis")
if axis is None:
axis = -1 + len(input.out_shapes[0])
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
axis = nhwc_dim_to_nchw(input, axis)
attr = {"axis": axis}
......@@ -997,7 +1016,7 @@ class TFOpMapper(OpMapper):
def ResizeNearestNeighbor(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
self.omit_nodes.append(resize_shape.layer_name)
self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
if resize_shape.layer_type == "Const":
resize_shape = resize_shape.value.tolist()
else:
......@@ -1012,7 +1031,7 @@ class TFOpMapper(OpMapper):
def ResizeBilinear(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
self.omit_nodes.append(resize_shape.layer_name)
self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
if resize_shape.layer_type == "Const":
resize_shape = resize_shape.value.tolist()
else:
......@@ -1031,7 +1050,7 @@ class TFOpMapper(OpMapper):
def ResizeNearestNeighbor(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
self.omit_nodes.append(resize_shape.layer_name)
self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
if resize_shape.layer_type == "Const":
resize_shape = resize_shape.value.tolist()
else:
......@@ -1047,7 +1066,7 @@ class TFOpMapper(OpMapper):
def ResizeBilinear(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
self.omit_nodes.append(resize_shape.layer_name)
self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
if resize_shape.layer_type == "Const":
resize_shape = resize_shape.value.tolist()
else:
......
......@@ -95,6 +95,15 @@ class TFOpMapperNHWC(OpMapper):
print("========== {} ============".format(op))
sys.exit(-1)
def add_omit_nodes(self, in_node_name, out_node_name):
in_node = self.graph.get_node(in_node_name)
out_node = self.graph.get_node(out_node_name)
index = in_node.outputs.index(out_node_name)
del in_node.outputs[index]
index = out_node.inputs.index(in_node_name)
del out_node.inputs[index]
self.omit_nodes.append(in_node.layer_name)
def directly_map(self, node):
assert node.layer_type in self.directly_map_ops
op_info = self.directly_map_ops[node.layer_type]
......@@ -289,10 +298,7 @@ class TFOpMapperNHWC(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True)
assert kernel.layer_type == "Const", "Kernel of Conv2D should be Const"
self.omit_nodes.append(kernel.layer_name)
node.fluid_code.add_note("#{} : {}".format(node.layer.name,
node.layer_name))
self.add_omit_nodes(kernel.layer_name, node.layer_name)
in_shape = input.out_shapes[0]
if in_shape.count(-1) > 2:
......@@ -376,10 +382,10 @@ class TFOpMapperNHWC(OpMapper):
assert beta.layer_type == "Const"
assert moving_mean.layer_type == "Const"
assert moving_var.layer_type == "Const"
self.omit_nodes.append(gamma.layer_name)
self.omit_nodes.append(beta.layer_name)
self.omit_nodes.append(moving_mean.layer_name)
self.omit_nodes.append(moving_var.layer_name)
self.add_omit_nodes(gamma.layer_name, node.layer_name)
self.add_omit_nodes(beta.layer_name, node.layer_name)
self.add_omit_nodes(moving_mean.layer_name, node.layer_name)
self.add_omit_nodes(moving_var.layer_name, node.layer_name)
if not channel_first:
attr = {"perm": [0, 3, 1, 2]}
......@@ -414,7 +420,7 @@ class TFOpMapperNHWC(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True)
assert kernel.layer_type == "Const", "Kernel of DepthwiseConv2DNative should be Const"
self.omit_nodes.append(kernel.layer_name)
self.add_omit_nodes(kernel.layer_name, node.layer_name)
in_shape = input.out_shapes[0]
if in_shape.count(-1) > 2:
......@@ -485,13 +491,13 @@ class TFOpMapperNHWC(OpMapper):
param = self.graph.get_node(node.layer.input[1], copy=True)
if param.layer_type == "Const":
attr = {"shape": param.value.tolist()}
self.omit_nodes.append(param.layer_name)
self.add_omit_nodes(param.layer_name, node.layer_name)
else:
# Here is a trick method to solove tensor parameter in tensorflow
shape = self.decoder.infer_shape_tensor(param, node.out_shapes[0])
if shape.count(-1) <= 1:
attr = {"shape": shape}
self.omit_nodes.append(param.layer_name)
self.add_omit_nodes(param.layer_name, node.layer_name)
else:
assert len(param.out_shapes[0]
) == 1, "Unexpected situation of shape parameter"
......@@ -568,8 +574,8 @@ class TFOpMapperNHWC(OpMapper):
dim = self.graph.get_node(node.layer.input[2], copy=True)
assert num_sections.layer_type == "Const"
assert dim.layer_type == "Const"
self.omit_nodes.append(num_sections.layer_name)
self.omit_nodes.append(dim.layer_name)
self.add_omit_nodes(num_sections.layer_name, node.layer_name)
self.add_omit_nodes(dim.layer_name, node.layer_name)
dim = dim.value
attr = {
"num_or_sections": num_sections.value.tolist(),
......@@ -587,7 +593,7 @@ class TFOpMapperNHWC(OpMapper):
]
axis = self.graph.get_node(node.layer.input[-1], copy=True)
assert axis.layer_type == "Const"
self.omit_nodes.append(axis.layer_name)
self.add_omit_nodes(axis.layer_name, node.layer_name)
axis = axis.value
if axis < 0:
axis += len(inputs[0].out_shapes[0])
......@@ -601,7 +607,7 @@ class TFOpMapperNHWC(OpMapper):
def Tile(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
expand_times = self.graph.get_node(node.layer.input[1], copy=True)
self.omit_nodes.append(expand_times.layer_name)
self.add_omit_nodes(expand_times.layer_name, node.layer_name)
if expand_times.layer_type == "Const":
expand_times = expand_times.value.tolist()
else:
......@@ -630,7 +636,7 @@ class TFOpMapperNHWC(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True)
paddings = self.graph.get_node(node.layer.input[1], copy=True)
assert paddings.layer_type == "Const", "Padding should be Const"
self.omit_nodes.append(paddings.layer_name)
self.add_omit_nodes(paddings.layer_name, node.layer_name)
paddings = paddings.value.flatten().tolist()
data_format = input.tf_data_format
......@@ -674,9 +680,9 @@ class TFOpMapperNHWC(OpMapper):
start = self.graph.get_node(node.layer.input[0], copy=True)
limit = self.graph.get_node(node.layer.input[1], copy=True)
delta = self.graph.get_node(node.layer.input[2], copy=True)
self.omit_nodes.append(start.layer_name)
self.omit_nodes.append(limit.layer_name)
self.omit_nodes.append(delta.layer_name)
self.add_omit_nodes(start.layer_name, node.layer_name)
self.add_omit_nodes(limit.layer_name, node.layer_name)
self.add_omit_nodes(delta.layer_name, node.layer_name)
if start.layer_type == "Const":
start = start.value
else:
......@@ -741,7 +747,7 @@ class TFOpMapperNHWC(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True)
axis = self.graph.get_node(node.layer.input[1], copy=True)
assert axis.layer_type == "Const", "ArgMax only support Const parameter"
self.omit_nodes.append(axis.layer_name)
self.add_omit_nodes(axis.layer_name, node.layer_name)
axis = axis.value
attr = {"axis": axis}
node.fluid_code.add_layer("argmax",
......@@ -757,9 +763,9 @@ class TFOpMapperNHWC(OpMapper):
assert begin.layer_type == "Const"
assert end.layer_type == "Const"
assert strides.layer_type == "Const"
self.omit_nodes.append(begin.layer_name)
self.omit_nodes.append(end.layer_name)
self.omit_nodes.append(strides.layer_name)
self.add_omit_nodes(begin.layer_name, node.layer_name)
self.add_omit_nodes(end.layer_name, node.layer_name)
self.add_omit_nodes(strides.layer_name, node.layer_name)
strides = strides.value.tolist()
assert len(set(strides)) == 1 and strides[
0] == 1, "Only support strides be 1 in StridedSlice OP"
......@@ -840,8 +846,8 @@ class TFOpMapperNHWC(OpMapper):
size = self.graph.get_node(node.layer.input[2], copy=True)
# assert begin.layer_type == "Const"
# assert size.layer_type == "Const"
self.omit_nodes.append(begin.layer_name)
self.omit_nodes.append(size.layer_name)
self.add_omit_nodes(begin.layer_name, node.layer_name)
self.add_omit_nodes(size.layer_name, node.layer_name)
if begin.layer_type == "Const":
begin = begin.value.tolist()
else:
......@@ -861,7 +867,7 @@ class TFOpMapperNHWC(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True)
assert kernel.layer_type == "Const", "Kernel of Conv2DBackpropInput should be Const"
self.omit_nodes.append(kernel.layer_name)
self.add_omit_nodes(kernel.layer_name, node.layer_name)
in_shape = input.out_shapes[0]
if in_shape.count(-1) > 2:
......@@ -973,7 +979,7 @@ class TFOpMapperNHWC(OpMapper):
dim = self.graph.get_node(node.layer.input[0], copy=True)
input = self.graph.get_node(node.layer.input[1], copy=True)
assert dim.layer_type == "Const"
self.omit_nodes.append(dim.layer_name)
self.add_omit_nodes(dim.layer_name, node.layer_name)
num_split = node.get_attr('num_split')
dim = dim.value
......@@ -1004,7 +1010,7 @@ class TFOpMapperNHWC(OpMapper):
def ResizeNearestNeighbor(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
self.omit_nodes.append(resize_shape.layer_name)
self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
if resize_shape.layer_type == "Const":
resize_shape = resize_shape.value.tolist()
else:
......@@ -1030,7 +1036,7 @@ class TFOpMapperNHWC(OpMapper):
def ResizeBilinear(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
self.omit_nodes.append(resize_shape.layer_name)
self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
if resize_shape.layer_type == "Const":
resize_shape = resize_shape.value.tolist()
else:
......
......@@ -37,18 +37,82 @@ class TFOptimizer(object):
self.graph = op_mapper.graph
def delete_redundance_code(self):
# print("==========omit_nodes============")
# for node_name in set(self.op_mapper.omit_nodes):
# node = self.graph.get_node(node_name)
# print(node.layer_name, self.op_mapper.omit_nodes.count(node.layer_name), len(node.outputs), node.outputs)
# print("================================")
for node_name in self.graph.topo_sort:
if node_name in self.op_mapper.omit_nodes:
node = self.graph.get_node(node_name)
if node is None:
continue
omit_freq = self.op_mapper.omit_nodes.count(node_name)
if len(node.outputs) <= omit_freq:
node.fluid_code.clear()
# remove node from graph
input_names = node.inputs
output_names = node.outputs
for in_name in input_names:
in_node = self.graph.get_node(in_name)
index = in_node.outputs.index(node_name)
del in_node.outputs[index]
for out_name in output_names:
out_node = self.graph.get_node(out_name)
index = out_node.inputs.index(node_name)
del out_node.inputs[index]
del self.graph.node_map[node_name]
def strip_graph(self):
# print("=============")
# for i, node_name in enumerate(self.graph.topo_sort):
# node = self.graph.get_node(node_name)
# if node is None:
# continue
# print(node.layer_name, node.inputs)
# print("================")
visited_nodes = set()
def visit(node_name):
if node_name in visited_nodes:
return
visited_nodes.add(node_name)
input_names = self.graph.get_node(node_name).inputs
for in_name in input_names:
visit(in_name)
for node_name in self.graph.output_nodes:
visit(node_name)
# print("=============visited nodes++++++++++++")
# for name in visited_nodes:
# print(name)
# print("===================================")
for i, node_name in enumerate(self.graph.topo_sort):
if node_name not in visited_nodes:
node = self.graph.get_node(node_name)
if node is None:
continue
input_names = node.inputs
output_names = node.outputs
for in_name in input_names:
in_node = self.graph.get_node(in_name)
index = in_node.outputs.index(node_name)
del in_node.outputs[index]
for out_name in output_names:
out_node = self.graph.get_node(out_name)
index = out_node.inputs.index(node_name)
del out_node.inputs[index]
del self.graph.node_map[node_name]
# TODO activation merge
def merge_activation(self):
act_nodes = list()
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
if node is None:
continue
if node.layer_type in self.activation_ops:
act_nodes.append(node_name)
......@@ -75,6 +139,8 @@ class TFOptimizer(object):
def merge_bias(self):
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
if node is None:
continue
if node.layer_type == "BiasAdd":
input = self.graph.get_node(node.inputs[0])
if input.layer_type not in self.layers_with_bias:
......@@ -105,3 +171,27 @@ class TFOptimizer(object):
'act'] = node.fluid_code.layers[-1].param_attr[
'act']
node.fluid_code.clear()
# def remove_transpose(self):
# optimize_ops = ['Conv2D', 'MaxPool', 'FusedBatchNorm', 'DepthwiseConv2dNative', 'AvgPool', 'Pad', 'Conv2DBackpropInput', 'ResizeNearestNeighbor', 'ResizeBilinear']
# for node_name in self.graph.topo_sort:
# node = self.graph.get_node(node_name)
# if node.layer_type not in optimize_ops:
# continue
# if node.fluid_code.layers[-1].op != "transpose" or node.fluid_code.layers[-1].param_attr["perm"] != [0, 2, 3, 1]:
# continue
# output_names = node.outputs
# can_be_removed = True
# for out_name in outputs_names:
# out_node = self.graph.get_node(out_name)
# if out_node.fluid_code.layers[0].op != "transpose" or out_node.fluid_code.layers[-1].param_attr["perm"] != [0, 3, 1, 2]:
# can_be_removed = False
# break
# if can_be_removed and len(output_names) > 0:
# last_out = node.fluid_code.layers[-1].inputs
# del node.fluid_code.layers[-1]
# for out_name in outputs_names:
# out_node = self.graph.get_node(out_name)
# del out_node.fluid_code.layers[0]
# out_node.fluid_code.layers[0].inputs = last_out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册