未验证 提交 853731b8 编写于 作者: J Jason 提交者: GitHub

Merge pull request #294 from mamingjie-China/develop

remove infer when converting TF model
...@@ -236,26 +236,18 @@ class TFOptimizer(object): ...@@ -236,26 +236,18 @@ class TFOptimizer(object):
def remove_transpose(self): def remove_transpose(self):
graph_copy = cp.deepcopy(self.graph) graph_copy = cp.deepcopy(self.graph)
nhwc_insensitive_ops = [
'Relu', 'Relu6', 'Abs', 'Sigmoid', 'Exp', 'Rsqrt', 'swish_f32',
'LeakyRelu', 'Cast', 'Tanh'
]
elementwise_ops = [ elementwise_ops = [
'Sub', 'Add', 'RealDiv', 'Maximum', 'Mul', 'FloorDiv', 'Sub', 'Add', 'RealDiv', 'Maximum', 'Mul', 'FloorDiv',
'GreaterEqual' 'GreateerEqual'
]
optimize_ops = [
'Conv2D', 'MaxPool', 'FusedBatchNorm', 'DepthwiseConv2dNative',
'AvgPool', 'Pad', 'Conv2DBackpropInput', 'ResizeNearestNeighbor',
'ResizeBilinear', "Placeholder"
] ]
can_be_optimized_ops = [ can_be_optimized_ops = [
'Conv2D', 'MaxPool', 'FusedBatchNorm', 'DepthwiseConv2dNative', 'Conv2D', 'MaxPool', 'FusedBatchNorm', 'DepthwiseConv2dNative',
'AvgPool', 'Pad', 'Conv2DBackpropInput', 'ResizeNearestNeighbor', 'AvgPool', 'Pad', 'Conv2DBackpropInput', 'ResizeNearestNeighbor',
'ResizeBilinear', "Placeholder", 'Relu', 'Relu6', 'Abs', 'Sigmoid', 'Placeholder', 'Relu', 'Relu6', 'Abs', 'Sigmoid', 'Exp', 'Rsqrt',
'Exp', 'Rsqrt', 'swish_f32', 'LeakyRelu', 'Cast', 'Tanh' 'swish_f32', 'LeakyRelu', 'Cast', 'Tanh'
] ]
# These ops may have one more Variable input
can_be_optimized_special_ops = ['ResizeBilinear']
for node_name in self.graph.topo_sort: for node_name in self.graph.topo_sort:
node = graph_copy.get_node(node_name) node = graph_copy.get_node(node_name)
if node is None: if node is None:
...@@ -278,9 +270,10 @@ class TFOptimizer(object): ...@@ -278,9 +270,10 @@ class TFOptimizer(object):
0].param_attr["perm"] != [0, 3, 1, 2]: 0].param_attr["perm"] != [0, 3, 1, 2]:
can_be_removed = False can_be_removed = False
break break
elif out_node.layer_type in elementwise_ops: elif out_node.layer_type in elementwise_ops or out_node.layer_type in can_be_optimized_special_ops:
can_be_removed = False can_be_removed = False
break break
if can_be_removed and len(node.fluid_code.layers) > 1: if can_be_removed and len(node.fluid_code.layers) > 1:
true_node = self.graph.get_node(node_name) true_node = self.graph.get_node(node_name)
if true_node.layer_type == "Placeholder": if true_node.layer_type == "Placeholder":
...@@ -298,6 +291,7 @@ class TFOptimizer(object): ...@@ -298,6 +291,7 @@ class TFOptimizer(object):
-2].output = true_node.fluid_code.layers[-1].output -2].output = true_node.fluid_code.layers[-1].output
node.removed = True node.removed = True
del true_node.fluid_code.layers[-1] del true_node.fluid_code.layers[-1]
for out_name in output_names: for out_name in output_names:
out_node = self.graph.get_node(out_name) out_node = self.graph.get_node(out_name)
out_node.fluid_code.layers[ out_node.fluid_code.layers[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册