提交 d1516a29 编写于 作者: J jiangjiajun

add merge prelu

上级 9ecd568e
......@@ -107,6 +107,7 @@ def tf2paddle(model_path,
optimizer.merge_activation()
optimizer.merge_bias()
optimizer.merge_batch_norm()
optimizer.merge_prelu()
else:
mapper = TFOpMapperNHWC(model)
optimizer = TFOptimizer(mapper)
......
......@@ -16,6 +16,7 @@
from x2paddle.op_mapper.tf_op_mapper import TFOpMapper
from x2paddle.core.fluid_code import Layer
from x2paddle.core.util import *
import numpy
import copy as cp
......@@ -418,6 +419,40 @@ class TFOptimizer(object):
is_batch_norm = False
continue
if len(in_nodes0[0].outputs) != 1:
is_batch_norm = False
continue
if len(in_nodes0[1].outputs) != 1:
is_batch_norm = False
continue
if len(in_nodes1[1].outputs) != 2:
is_batch_norm = False
continue
if len(in_nodes2[0].outputs) != 1:
is_batch_norm = False
continue
if len(in_nodes2[1].outputs) != 1:
is_batch_norm = False
continue
if len(in_nodes3[0].outputs) != 1:
is_batch_norm = False
continue
if len(in_nodes3[1].outputs) != 1:
is_batch_norm = False
continue
if len(in_nodes4[0].outputs) != 1:
is_batch_norm = False
continue
if len(in_nodes5.outputs) != 1:
is_batch_norm = False
continue
if len(in_nodes6[0].outputs) != 1:
is_batch_norm = False
continue
if len(in_nodes6[1].outputs) != 1:
is_batch_norm = False
continue
conv_shape = in_nodes1[0].out_shapes[0]
if conv_shape[3] < 0:
is_batch_norm = False
......@@ -466,8 +501,9 @@ class TFOptimizer(object):
"act": act
}
node.fluid_code.add_layer("batch_norm",
inputs=cp.copy(in_nodes1[0]),
node.fluid_code.add_layer(
"batch_norm",
inputs=in_nodes1[0].fluid_code.layers[-1].output,
output=node,
param_attr=attr)
......@@ -478,3 +514,149 @@ class TFOptimizer(object):
del self.graph.node_map[in_nodes3[0].layer_name]
del self.graph.node_map[in_nodes4[0].layer_name]
del self.graph.node_map[in_nodes5.layer_name]
def merge_prelu(self):
for i, name in enumerate(self.graph.topo_sort):
node = self.graph.get_node(name)
if node is None:
continue
is_prelu = True
if node.layer_type == "Add":
in_nodes0 = [
self.graph.get_node(in_name) for in_name in node.inputs
]
if in_nodes0[0].layer_type != "Relu" or in_nodes0[
1].layer_type != "Mul":
is_prelu = False
continue
if len(in_nodes0[0].outputs) != 1 or len(
in_nodes0[1].outputs) != 1:
is_prelu = False
continue
in_nodes1 = self.graph.get_node(in_nodes0[0].inputs[0])
in_nodes2 = [
self.graph.get_node(in_name)
for in_name in in_nodes0[1].inputs
]
if in_nodes2[1].layer_type != "Const" or numpy.fabs(
in_nodes2[1].value - 0.5) > 1e-06:
is_prelu = False
continue
if in_nodes2[0].layer_type != "Mul":
is_prelu = False
continue
if len(in_nodes2[1].outputs) != 1 or len(
in_nodes2[0].outputs) != 1:
is_prelu = False
continue
in_nodes3 = [
self.graph.get_node(in_name)
for in_name in in_nodes2[0].inputs
]
if in_nodes3[0].layer_type != "Const" or in_nodes3[
1].layer_type != "Sub":
is_prelu = False
continue
if len(in_nodes3[0].outputs) != 1 or len(
in_nodes3[1].outputs) != 1:
is_prelu = False
continue
in_nodes4 = [
self.graph.get_node(in_name)
for in_name in in_nodes3[1].inputs
]
if in_nodes4[0].layer_name != in_nodes1.layer_name or in_nodes4[
1].layer_type != "Abs":
is_prelu = False
continue
if len(in_nodes4[1].outputs) != 1:
is_prelu = False
continue
in_nodes5 = self.graph.get_node(in_nodes4[1].inputs[0])
if in_nodes5.layer_name != in_nodes1.layer_name:
is_prelu = False
continue
if len(in_nodes0[0].outputs) != 1:
is_prelu = false
continue
if len(in_nodes0[1].outputs) != 1:
is_prelu = False
continue
if len(in_nodes1.outputs) < 3:
is_prelu = False
continue
if len(in_nodes2[0].outputs) != 1:
is_prelu = false
continue
if len(in_nodes2[1].outputs) != 1:
is_prelu = False
continue
if len(in_nodes3[0].outputs) != 1:
is_prelu = False
continue
if len(in_nodes3[1].outputs) != 1:
is_prelu = false
continue
if len(in_nodes4[1].outputs) != 1:
is_prelu = False
continue
mode = None
in_shape = in_nodes1.out_shapes[0]
if in_shape == list(in_nodes3[0].value.shape):
mode = "element"
elif len(in_nodes3[0].value.shape) == 0:
mode = "all"
elif len(in_nodes3[0].value.shape
) == 1 and in_nodes3[0].value.shape[0] == 1:
mode = "all"
elif len(in_shape) == 4 and len(
in_nodes3[0].value.shape
) == 1 and in_nodes3[0].value.shape[0] == in_shape[-1]:
mode = "channel"
weight = self.op_mapper.weights[in_nodes3[0].layer_name]
weight = numpy.expand_dims(weight, 0)
weight = numpy.expand_dims(weight, 2)
weight = numpy.expand_dims(weight, 3)
self.op_mapper.weights[in_nodes3[0].layer_name] = weight
in_nodes3[0].fluid_code.layers[0].param_attr["shape"] = [
1, in_shape[-1], 1, 1
]
else:
is_prelu = False
continue
if is_prelu:
index = in_nodes1.outputs.index(in_nodes0[0].layer_name)
del in_nodes1.outputs[index]
index = in_nodes1.outputs.index(in_nodes3[1].layer_name)
del in_nodes1.outputs[index]
index = in_nodes1.outputs.index(in_nodes4[1].layer_name)
del in_nodes1.outputs[index]
node.layer_type = "Prelu"
node.inputs = [in_nodes1.layer_name]
node.outputs = node.outputs
act = node.fluid_code.layers[-1].param_attr.get("act", None)
node.fluid_code.clear()
attr = {
"mode": string(mode),
"param_attr": string(in_nodes3[0].layer_name)
}
node.fluid_code.add_layer(
"prelu",
inputs=in_nodes1.fluid_code.layers[-1].output,
output=node,
param_attr=attr)
del self.graph.node_map[in_nodes0[0].layer_name]
del self.graph.node_map[in_nodes0[1].layer_name]
del self.graph.node_map[in_nodes2[0].layer_name]
del self.graph.node_map[in_nodes2[1].layer_name]
del self.graph.node_map[in_nodes3[1].layer_name]
del self.graph.node_map[in_nodes4[1].layer_name]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册