提交 c5f36434 编写于 作者: S SunAhong1993

fix the bn optimizer

上级 eb48eac9
......@@ -25,26 +25,19 @@ class BatchNorm2dFuser(FuseBase):
def build_pattern(self):
""" 描述需要替换的batchnorm2d图结构。
batchnorm2d层模式python实现代码示例:
x2209 = 1
x2212 = 'Exception'
x2213 = 4
x2214 = x2207.shape
x2214 = fluid.layers.shape(x2207)
x2214 = len(x2214)
x2215 = x2214 != x2213
if x2215 :
raise RaiseException(x2212)
x2218 = False
if x2218 :
x2220 = self.x2220
x2221 = x2220 + x2209
self.x2220 = x2221
x2227 = False
if x2227 :
x2230 = x2207.shape
x2230 = fluid.layers.shape(x2207.shape)
x2231 = 'Exception'
x2233 = 0
x2234 = 2
x2235 = 1
x2236 = x2230[x2233]
x2237 = len(x2230)
x2238 = x2237 - x2234
......@@ -63,43 +56,45 @@ class BatchNorm2dFuser(FuseBase):
def gen_name(id):
return "x" + str(id)
# self.pattern.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(0)], value=1)
# self.pattern.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(1)], value=0.1)
# self.pattern.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(2)], value=0.001)
# self.pattern.add_layer(
# "prim.constant",
# inputs={},
# outputs=[gen_name(3)],
# value="Exception")
# self.pattern.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(4)], value=4)
self.pattern.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(0)], value=1)
self.pattern.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(1)], value=0.1)
self.pattern.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(2)], value=0.001)
self.pattern.add_layer(
"prim.constant",
inputs={},
outputs=[gen_name(3)],
value="Exception")
self.pattern.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(4)], value=4)
self.pattern.add_layer(
"prim.shape", inputs={'input': "bn-input-0"},
"fluid.layers.shape",
inputs={'input': "bn-input-0"},
outputs=[gen_name(5)])
self.pattern.add_layer(
"prim.len", inputs={'input': gen_name(5)}, outputs=[gen_name(5)])
self.pattern.add_layer(
"prim.ne",
inputs={"x": gen_name(5),
"y": gen_name(4)},
"y": "bn-input-9"},
outputs=[gen_name(6)])
self.pattern.add_layer("prim.if", {'input': gen_name(6)}, [gen_name(7)])
if_layer1 = self.pattern.layers[list(self.pattern.layers.keys())[-1]]
pattern_block0 = PaddleGraph(if_layer1, graph_type="dygraph")
pattern_block0.add_layer(
"prim.exception",
inputs={"input": gen_name(3)},
inputs={"input": "bn-input-1"},
outputs=[gen_name(8)])
if_layer1.inputs["input-0"] = gen_name(3)
if_layer1.inputs["input-0"] = "bn-input-1"
if_layer1.add_block(pattern_block0)
pattern_block1 = PaddleGraph(if_layer1, graph_type="dygraph")
if_layer1.add_block(pattern_block1)
self.pattern.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(9)], value=False)
self.pattern.add_layer("prim.if", {'input': gen_name(9)},
# self.pattern.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(9)], value=False)
self.pattern.add_layer("prim.if", {'input': "bn-input-2"},
[gen_name(10)])
if_layer2 = self.pattern.layers[list(self.pattern.layers.keys())[-1]]
pattern_block0 = PaddleGraph(if_layer2, graph_type="dygraph")
......@@ -111,52 +106,52 @@ class BatchNorm2dFuser(FuseBase):
pattern_block0.add_layer(
"prim.add",
inputs={"x": gen_name(11),
"y": gen_name(0)},
"y": "bn-input-3"},
outputs=[gen_name(12)])
pattern_block0.add_layer(
"prim.set_attr",
inputs={"input": gen_name(12)},
outputs=["self." + gen_name(11)])
if_layer2.inputs["input-0"] = gen_name(0)
if_layer2.inputs["input-0"] = "bn-input-3"
if_layer2.add_block(pattern_block0)
pattern_block1 = PaddleGraph(if_layer2, graph_type="dygraph")
if_layer2.add_block(pattern_block1)
self.pattern.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(13)], value=True)
self.pattern.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(14)], value=False)
self.pattern.add_layer("prim.if", {'input': gen_name(14)},
# self.pattern.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(13)], value=True)
# self.pattern.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(14)], value=False)
self.pattern.add_layer("prim.if", {'input': "bn-input-4"},
[gen_name(15)])
if_layer3 = self.pattern.layers[list(self.pattern.layers.keys())[-1]]
pattern_block0 = PaddleGraph(if_layer3, graph_type="dygraph")
pattern_block0.add_layer(
"prim.shape",
"fluid.layers.shape",
inputs={'input': "bn-input-0"},
outputs=[gen_name(16)])
pattern_block0.add_layer(
"prim.constant",
inputs={},
outputs=[gen_name(17)],
value="Exception")
pattern_block0.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(18)], value=True)
pattern_block0.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(19)], value=0)
pattern_block0.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(20)], value=2)
pattern_block0.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(21)], value=1)
# pattern_block0.add_layer(
# "prim.constant",
# inputs={},
# outputs=[gen_name(17)],
# value="Exception")
# pattern_block0.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(18)], value=True)
# pattern_block0.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(19)], value=0)
# pattern_block0.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(20)], value=2)
# pattern_block0.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(21)], value=1)
pattern_block0.add_layer(
"prim.getitem",
inputs={"list": gen_name(16),
"index": gen_name(19)},
"index": "bn-input-6"},
outputs=[gen_name(22)])
pattern_block0.add_layer(
"prim.len", inputs={"input": gen_name(16)}, outputs=[gen_name(23)])
pattern_block0.add_layer(
"prim.sub",
inputs={"x": gen_name(23),
"y": gen_name(20)},
"y": "bn-input-7"},
outputs=[gen_name(24)])
pattern_block0.add_layer(
"prim.equal",
......@@ -172,7 +167,7 @@ class BatchNorm2dFuser(FuseBase):
pattern_block0_block0.add_layer(
"prim.add",
inputs={"x": gen_name(27),
"y": gen_name(20)},
"y": "bn-input-7"},
outputs=[gen_name(28)])
pattern_block0_block0.add_layer(
"prim.getitem",
......@@ -188,14 +183,14 @@ class BatchNorm2dFuser(FuseBase):
"prim.equal",
inputs={"input": gen_name(30)},
outputs=[gen_name(26)])
loop_layer.inputs["input-1"] = gen_name(20)
loop_layer.inputs["input-1"] = "bn-input-7"
loop_layer.inputs["input-2"] = gen_name(16)
loop_layer.inputs["input-3"] = gen_name(25)
loop_layer.add_block(pattern_block0_block0)
pattern_block0.add_layer(
"prim.eq",
inputs={"x": gen_name(26),
"y": gen_name(21)},
"y": "bn-input-8"},
outputs=[gen_name(31)])
pattern_block0.add_layer(
"prim.if", inputs={"input": gen_name(31)}, outputs=[gen_name(32)])
......@@ -204,16 +199,21 @@ class BatchNorm2dFuser(FuseBase):
pattern_block0_block0 = PaddleGraph(if_layer31, graph_type="dygraph")
pattern_block0_block0.add_layer(
"prim.exception",
inputs={"input": gen_name(17)},
inputs={"input": "bn-input-5"},
outputs=[gen_name(33)])
if_layer31.inputs["input-0"] = gen_name(17)
if_layer31.inputs["input-0"] = "bn-input-5"
if_layer31.add_block(pattern_block0_block0)
pattern_block0_block1 = PaddleGraph(if_layer31, graph_type="dygraph")
if_layer31.add_block(pattern_block0_block1)
if_layer3.inputs["input-0"] = "bn-input-0"
if_layer3.add_block(pattern_block0)
pattern_block1 = PaddleGraph(if_layer3, graph_type="dygraph")
if_layer3.add_block(pattern_block1)
if_layer3.inputs["input-0"] = "bn-input-5"
if_layer3.inputs["input-1"] = "bn-input-6"
if_layer3.inputs["input-2"] = "bn-input-7"
if_layer3.inputs["input-3"] = "bn-input-7"
if_layer3.inputs["input-4"] = "bn-input-8"
if_layer3.inputs["input-5"] = "bn-input-0"
self.pattern.add_layer(
"paddle.nn.BatchNorm",
inputs={"input": "bn-input-0"},
......@@ -222,7 +222,18 @@ class BatchNorm2dFuser(FuseBase):
num_channels=160,
momentum=0.1,
epsilon=0.001)
self.pattern.build(inputs={"input-0": "bn-input-0"})
self.pattern.build(inputs={
"input-0": "bn-input-0",
"input-1": "bn-input-1",
"input-2": "bn-input-2",
"input-3": "bn-input-3",
"input-4": "bn-input-4",
"input-5": "bn-input-5",
"input-6": "bn-input-6",
"input-7": "bn-input-7",
"input-8": "bn-input-8",
"input-9": "bn-input-9"
})
def insert_new_layer(self, graph, parameters, matches):
new_layer = self.gen_new_layer(parameters, matches)
......
......@@ -87,7 +87,7 @@ class PatternMatcher(object):
if not set(pattern_layer.outputs).issubset(
pattern.outputs):
# 若pattern当前layer的输出是pattern的输出,则是正确的
print("5--")
# print("5--")
return False
# 当为控制流时的处理
if layer.kernel == "prim.if" or layer.kernel == "prim.loop":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册