提交 c5f36434 编写于 作者: S SunAhong1993

fix the bn optimizer

上级 eb48eac9
...@@ -25,26 +25,19 @@ class BatchNorm2dFuser(FuseBase): ...@@ -25,26 +25,19 @@ class BatchNorm2dFuser(FuseBase):
def build_pattern(self): def build_pattern(self):
""" 描述需要替换的batchnorm2d图结构。 """ 描述需要替换的batchnorm2d图结构。
batchnorm2d层模式python实现代码示例: batchnorm2d层模式python实现代码示例:
x2209 = 1 x2214 = fluid.layers.shape(x2207)
x2212 = 'Exception'
x2213 = 4
x2214 = x2207.shape
x2214 = len(x2214) x2214 = len(x2214)
x2215 = x2214 != x2213 x2215 = x2214 != x2213
if x2215 : if x2215 :
raise RaiseException(x2212) raise RaiseException(x2212)
x2218 = False
if x2218 : if x2218 :
x2220 = self.x2220 x2220 = self.x2220
x2221 = x2220 + x2209 x2221 = x2220 + x2209
self.x2220 = x2221 self.x2220 = x2221
x2227 = False x2227 = False
if x2227 : if x2227 :
x2230 = x2207.shape x2230 = fluid.layers.shape(x2207.shape)
x2231 = 'Exception' x2231 = 'Exception'
x2233 = 0
x2234 = 2
x2235 = 1
x2236 = x2230[x2233] x2236 = x2230[x2233]
x2237 = len(x2230) x2237 = len(x2230)
x2238 = x2237 - x2234 x2238 = x2237 - x2234
...@@ -63,43 +56,45 @@ class BatchNorm2dFuser(FuseBase): ...@@ -63,43 +56,45 @@ class BatchNorm2dFuser(FuseBase):
def gen_name(id): def gen_name(id):
return "x" + str(id) return "x" + str(id)
# self.pattern.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(0)], value=1)
# self.pattern.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(1)], value=0.1)
# self.pattern.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(2)], value=0.001)
# self.pattern.add_layer(
# "prim.constant",
# inputs={},
# outputs=[gen_name(3)],
# value="Exception")
# self.pattern.add_layer(
# "prim.constant", inputs={}, outputs=[gen_name(4)], value=4)
self.pattern.add_layer( self.pattern.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(0)], value=1) "fluid.layers.shape",
self.pattern.add_layer( inputs={'input': "bn-input-0"},
"prim.constant", inputs={}, outputs=[gen_name(1)], value=0.1)
self.pattern.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(2)], value=0.001)
self.pattern.add_layer(
"prim.constant",
inputs={},
outputs=[gen_name(3)],
value="Exception")
self.pattern.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(4)], value=4)
self.pattern.add_layer(
"prim.shape", inputs={'input': "bn-input-0"},
outputs=[gen_name(5)]) outputs=[gen_name(5)])
self.pattern.add_layer( self.pattern.add_layer(
"prim.len", inputs={'input': gen_name(5)}, outputs=[gen_name(5)]) "prim.len", inputs={'input': gen_name(5)}, outputs=[gen_name(5)])
self.pattern.add_layer( self.pattern.add_layer(
"prim.ne", "prim.ne",
inputs={"x": gen_name(5), inputs={"x": gen_name(5),
"y": gen_name(4)}, "y": "bn-input-9"},
outputs=[gen_name(6)]) outputs=[gen_name(6)])
self.pattern.add_layer("prim.if", {'input': gen_name(6)}, [gen_name(7)]) self.pattern.add_layer("prim.if", {'input': gen_name(6)}, [gen_name(7)])
if_layer1 = self.pattern.layers[list(self.pattern.layers.keys())[-1]] if_layer1 = self.pattern.layers[list(self.pattern.layers.keys())[-1]]
pattern_block0 = PaddleGraph(if_layer1, graph_type="dygraph") pattern_block0 = PaddleGraph(if_layer1, graph_type="dygraph")
pattern_block0.add_layer( pattern_block0.add_layer(
"prim.exception", "prim.exception",
inputs={"input": gen_name(3)}, inputs={"input": "bn-input-1"},
outputs=[gen_name(8)]) outputs=[gen_name(8)])
if_layer1.inputs["input-0"] = gen_name(3) if_layer1.inputs["input-0"] = "bn-input-1"
if_layer1.add_block(pattern_block0) if_layer1.add_block(pattern_block0)
pattern_block1 = PaddleGraph(if_layer1, graph_type="dygraph") pattern_block1 = PaddleGraph(if_layer1, graph_type="dygraph")
if_layer1.add_block(pattern_block1) if_layer1.add_block(pattern_block1)
self.pattern.add_layer( # self.pattern.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(9)], value=False) # "prim.constant", inputs={}, outputs=[gen_name(9)], value=False)
self.pattern.add_layer("prim.if", {'input': gen_name(9)}, self.pattern.add_layer("prim.if", {'input': "bn-input-2"},
[gen_name(10)]) [gen_name(10)])
if_layer2 = self.pattern.layers[list(self.pattern.layers.keys())[-1]] if_layer2 = self.pattern.layers[list(self.pattern.layers.keys())[-1]]
pattern_block0 = PaddleGraph(if_layer2, graph_type="dygraph") pattern_block0 = PaddleGraph(if_layer2, graph_type="dygraph")
...@@ -111,52 +106,52 @@ class BatchNorm2dFuser(FuseBase): ...@@ -111,52 +106,52 @@ class BatchNorm2dFuser(FuseBase):
pattern_block0.add_layer( pattern_block0.add_layer(
"prim.add", "prim.add",
inputs={"x": gen_name(11), inputs={"x": gen_name(11),
"y": gen_name(0)}, "y": "bn-input-3"},
outputs=[gen_name(12)]) outputs=[gen_name(12)])
pattern_block0.add_layer( pattern_block0.add_layer(
"prim.set_attr", "prim.set_attr",
inputs={"input": gen_name(12)}, inputs={"input": gen_name(12)},
outputs=["self." + gen_name(11)]) outputs=["self." + gen_name(11)])
if_layer2.inputs["input-0"] = gen_name(0) if_layer2.inputs["input-0"] = "bn-input-3"
if_layer2.add_block(pattern_block0) if_layer2.add_block(pattern_block0)
pattern_block1 = PaddleGraph(if_layer2, graph_type="dygraph") pattern_block1 = PaddleGraph(if_layer2, graph_type="dygraph")
if_layer2.add_block(pattern_block1) if_layer2.add_block(pattern_block1)
self.pattern.add_layer( # self.pattern.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(13)], value=True) # "prim.constant", inputs={}, outputs=[gen_name(13)], value=True)
self.pattern.add_layer( # self.pattern.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(14)], value=False) # "prim.constant", inputs={}, outputs=[gen_name(14)], value=False)
self.pattern.add_layer("prim.if", {'input': gen_name(14)}, self.pattern.add_layer("prim.if", {'input': "bn-input-4"},
[gen_name(15)]) [gen_name(15)])
if_layer3 = self.pattern.layers[list(self.pattern.layers.keys())[-1]] if_layer3 = self.pattern.layers[list(self.pattern.layers.keys())[-1]]
pattern_block0 = PaddleGraph(if_layer3, graph_type="dygraph") pattern_block0 = PaddleGraph(if_layer3, graph_type="dygraph")
pattern_block0.add_layer( pattern_block0.add_layer(
"prim.shape", "fluid.layers.shape",
inputs={'input': "bn-input-0"}, inputs={'input': "bn-input-0"},
outputs=[gen_name(16)]) outputs=[gen_name(16)])
pattern_block0.add_layer( # pattern_block0.add_layer(
"prim.constant", # "prim.constant",
inputs={}, # inputs={},
outputs=[gen_name(17)], # outputs=[gen_name(17)],
value="Exception") # value="Exception")
pattern_block0.add_layer( # pattern_block0.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(18)], value=True) # "prim.constant", inputs={}, outputs=[gen_name(18)], value=True)
pattern_block0.add_layer( # pattern_block0.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(19)], value=0) # "prim.constant", inputs={}, outputs=[gen_name(19)], value=0)
pattern_block0.add_layer( # pattern_block0.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(20)], value=2) # "prim.constant", inputs={}, outputs=[gen_name(20)], value=2)
pattern_block0.add_layer( # pattern_block0.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(21)], value=1) # "prim.constant", inputs={}, outputs=[gen_name(21)], value=1)
pattern_block0.add_layer( pattern_block0.add_layer(
"prim.getitem", "prim.getitem",
inputs={"list": gen_name(16), inputs={"list": gen_name(16),
"index": gen_name(19)}, "index": "bn-input-6"},
outputs=[gen_name(22)]) outputs=[gen_name(22)])
pattern_block0.add_layer( pattern_block0.add_layer(
"prim.len", inputs={"input": gen_name(16)}, outputs=[gen_name(23)]) "prim.len", inputs={"input": gen_name(16)}, outputs=[gen_name(23)])
pattern_block0.add_layer( pattern_block0.add_layer(
"prim.sub", "prim.sub",
inputs={"x": gen_name(23), inputs={"x": gen_name(23),
"y": gen_name(20)}, "y": "bn-input-7"},
outputs=[gen_name(24)]) outputs=[gen_name(24)])
pattern_block0.add_layer( pattern_block0.add_layer(
"prim.equal", "prim.equal",
...@@ -172,7 +167,7 @@ class BatchNorm2dFuser(FuseBase): ...@@ -172,7 +167,7 @@ class BatchNorm2dFuser(FuseBase):
pattern_block0_block0.add_layer( pattern_block0_block0.add_layer(
"prim.add", "prim.add",
inputs={"x": gen_name(27), inputs={"x": gen_name(27),
"y": gen_name(20)}, "y": "bn-input-7"},
outputs=[gen_name(28)]) outputs=[gen_name(28)])
pattern_block0_block0.add_layer( pattern_block0_block0.add_layer(
"prim.getitem", "prim.getitem",
...@@ -188,14 +183,14 @@ class BatchNorm2dFuser(FuseBase): ...@@ -188,14 +183,14 @@ class BatchNorm2dFuser(FuseBase):
"prim.equal", "prim.equal",
inputs={"input": gen_name(30)}, inputs={"input": gen_name(30)},
outputs=[gen_name(26)]) outputs=[gen_name(26)])
loop_layer.inputs["input-1"] = gen_name(20) loop_layer.inputs["input-1"] = "bn-input-7"
loop_layer.inputs["input-2"] = gen_name(16) loop_layer.inputs["input-2"] = gen_name(16)
loop_layer.inputs["input-3"] = gen_name(25) loop_layer.inputs["input-3"] = gen_name(25)
loop_layer.add_block(pattern_block0_block0) loop_layer.add_block(pattern_block0_block0)
pattern_block0.add_layer( pattern_block0.add_layer(
"prim.eq", "prim.eq",
inputs={"x": gen_name(26), inputs={"x": gen_name(26),
"y": gen_name(21)}, "y": "bn-input-8"},
outputs=[gen_name(31)]) outputs=[gen_name(31)])
pattern_block0.add_layer( pattern_block0.add_layer(
"prim.if", inputs={"input": gen_name(31)}, outputs=[gen_name(32)]) "prim.if", inputs={"input": gen_name(31)}, outputs=[gen_name(32)])
...@@ -204,16 +199,21 @@ class BatchNorm2dFuser(FuseBase): ...@@ -204,16 +199,21 @@ class BatchNorm2dFuser(FuseBase):
pattern_block0_block0 = PaddleGraph(if_layer31, graph_type="dygraph") pattern_block0_block0 = PaddleGraph(if_layer31, graph_type="dygraph")
pattern_block0_block0.add_layer( pattern_block0_block0.add_layer(
"prim.exception", "prim.exception",
inputs={"input": gen_name(17)}, inputs={"input": "bn-input-5"},
outputs=[gen_name(33)]) outputs=[gen_name(33)])
if_layer31.inputs["input-0"] = gen_name(17) if_layer31.inputs["input-0"] = "bn-input-5"
if_layer31.add_block(pattern_block0_block0) if_layer31.add_block(pattern_block0_block0)
pattern_block0_block1 = PaddleGraph(if_layer31, graph_type="dygraph") pattern_block0_block1 = PaddleGraph(if_layer31, graph_type="dygraph")
if_layer31.add_block(pattern_block0_block1) if_layer31.add_block(pattern_block0_block1)
if_layer3.inputs["input-0"] = "bn-input-0"
if_layer3.add_block(pattern_block0) if_layer3.add_block(pattern_block0)
pattern_block1 = PaddleGraph(if_layer3, graph_type="dygraph") pattern_block1 = PaddleGraph(if_layer3, graph_type="dygraph")
if_layer3.add_block(pattern_block1) if_layer3.add_block(pattern_block1)
if_layer3.inputs["input-0"] = "bn-input-5"
if_layer3.inputs["input-1"] = "bn-input-6"
if_layer3.inputs["input-2"] = "bn-input-7"
if_layer3.inputs["input-3"] = "bn-input-7"
if_layer3.inputs["input-4"] = "bn-input-8"
if_layer3.inputs["input-5"] = "bn-input-0"
self.pattern.add_layer( self.pattern.add_layer(
"paddle.nn.BatchNorm", "paddle.nn.BatchNorm",
inputs={"input": "bn-input-0"}, inputs={"input": "bn-input-0"},
...@@ -222,7 +222,18 @@ class BatchNorm2dFuser(FuseBase): ...@@ -222,7 +222,18 @@ class BatchNorm2dFuser(FuseBase):
num_channels=160, num_channels=160,
momentum=0.1, momentum=0.1,
epsilon=0.001) epsilon=0.001)
self.pattern.build(inputs={"input-0": "bn-input-0"}) self.pattern.build(inputs={
"input-0": "bn-input-0",
"input-1": "bn-input-1",
"input-2": "bn-input-2",
"input-3": "bn-input-3",
"input-4": "bn-input-4",
"input-5": "bn-input-5",
"input-6": "bn-input-6",
"input-7": "bn-input-7",
"input-8": "bn-input-8",
"input-9": "bn-input-9"
})
def insert_new_layer(self, graph, parameters, matches): def insert_new_layer(self, graph, parameters, matches):
new_layer = self.gen_new_layer(parameters, matches) new_layer = self.gen_new_layer(parameters, matches)
......
...@@ -87,7 +87,7 @@ class PatternMatcher(object): ...@@ -87,7 +87,7 @@ class PatternMatcher(object):
if not set(pattern_layer.outputs).issubset( if not set(pattern_layer.outputs).issubset(
pattern.outputs): pattern.outputs):
# 若pattern当前layer的输出是pattern的输出,则是正确的 # 若pattern当前layer的输出是pattern的输出,则是正确的
print("5--") # print("5--")
return False return False
# 当为控制流时的处理 # 当为控制流时的处理
if layer.kernel == "prim.if" or layer.kernel == "prim.loop": if layer.kernel == "prim.if" or layer.kernel == "prim.loop":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册