提交 b434b7db 编写于 作者: M mamingjie-China

remove infer when converting TF model

上级 58e1668e
......@@ -236,26 +236,18 @@ class TFOptimizer(object):
def remove_transpose(self):
graph_copy = cp.deepcopy(self.graph)
nhwc_insensitive_ops = [
'Relu', 'Relu6', 'Abs', 'Sigmoid', 'Exp', 'Rsqrt', 'swish_f32',
'LeakyRelu', 'Cast', 'Tanh'
]
elementwise_ops = [
'Sub', 'Add', 'RealDiv', 'Maximum', 'Mul', 'FloorDiv',
'GreaterEqual'
]
optimize_ops = [
'Conv2D', 'MaxPool', 'FusedBatchNorm', 'DepthwiseConv2dNative',
'AvgPool', 'Pad', 'Conv2DBackpropInput', 'ResizeNearestNeighbor',
'ResizeBilinear', "Placeholder"
'GreateerEqual'
]
can_be_optimized_ops = [
'Conv2D', 'MaxPool', 'FusedBatchNorm', 'DepthwiseConv2dNative',
'AvgPool', 'Pad', 'Conv2DBackpropInput', 'ResizeNearestNeighbor',
'ResizeBilinear', "Placeholder", 'Relu', 'Relu6', 'Abs', 'Sigmoid',
'Exp', 'Rsqrt', 'swish_f32', 'LeakyRelu', 'Cast', 'Tanh'
'Placeholder', 'Relu', 'Relu6', 'Abs', 'Sigmoid', 'Exp', 'Rsqrt',
'swish_f32', 'LeakyRelu', 'Cast', 'Tanh'
]
# These ops may have one more Variable input
can_be_optimized_special_ops = ['ResizeBilinear']
for node_name in self.graph.topo_sort:
node = graph_copy.get_node(node_name)
if node is None:
......@@ -278,9 +270,10 @@ class TFOptimizer(object):
0].param_attr["perm"] != [0, 3, 1, 2]:
can_be_removed = False
break
elif out_node.layer_type in elementwise_ops:
elif out_node.layer_type in elementwise_ops or out_node.layer_type in can_be_optimized_special_ops:
can_be_removed = False
break
if can_be_removed and len(node.fluid_code.layers) > 1:
true_node = self.graph.get_node(node_name)
if true_node.layer_type == "Placeholder":
......@@ -298,6 +291,7 @@ class TFOptimizer(object):
-2].output = true_node.fluid_code.layers[-1].output
node.removed = True
del true_node.fluid_code.layers[-1]
for out_name in output_names:
out_node = self.graph.get_node(out_name)
out_node.fluid_code.layers[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册