diff --git a/x2paddle/optimizer/fusion/batchnorm2d_fuser.py b/x2paddle/optimizer/fusion/batchnorm2d_fuser.py index f22db98216b0aaad69cb0679eb517e79a9ff9d8d..1a21e260208f9c262fe7ab83d3f3b034949f7c25 100644 --- a/x2paddle/optimizer/fusion/batchnorm2d_fuser.py +++ b/x2paddle/optimizer/fusion/batchnorm2d_fuser.py @@ -25,26 +25,19 @@ class BatchNorm2dFuser(FuseBase): def build_pattern(self): """ 描述需要替换的batchnorm2d图结构。 batchnorm2d层模式python实现代码示例: - x2209 = 1 - x2212 = 'Exception' - x2213 = 4 - x2214 = x2207.shape + x2214 = fluid.layers.shape(x2207) x2214 = len(x2214) x2215 = x2214 != x2213 if x2215 : raise RaiseException(x2212) - x2218 = False if x2218 : x2220 = self.x2220 x2221 = x2220 + x2209 self.x2220 = x2221 x2227 = False if x2227 : - x2230 = x2207.shape + x2230 = fluid.layers.shape(x2207.shape) x2231 = 'Exception' - x2233 = 0 - x2234 = 2 - x2235 = 1 x2236 = x2230[x2233] x2237 = len(x2230) x2238 = x2237 - x2234 @@ -63,43 +56,45 @@ class BatchNorm2dFuser(FuseBase): def gen_name(id): return "x" + str(id) +# self.pattern.add_layer( +# "prim.constant", inputs={}, outputs=[gen_name(0)], value=1) +# self.pattern.add_layer( +# "prim.constant", inputs={}, outputs=[gen_name(1)], value=0.1) +# self.pattern.add_layer( +# "prim.constant", inputs={}, outputs=[gen_name(2)], value=0.001) +# self.pattern.add_layer( +# "prim.constant", +# inputs={}, +# outputs=[gen_name(3)], +# value="Exception") +# self.pattern.add_layer( +# "prim.constant", inputs={}, outputs=[gen_name(4)], value=4) + self.pattern.add_layer( - "prim.constant", inputs={}, outputs=[gen_name(0)], value=1) - self.pattern.add_layer( - "prim.constant", inputs={}, outputs=[gen_name(1)], value=0.1) - self.pattern.add_layer( - "prim.constant", inputs={}, outputs=[gen_name(2)], value=0.001) - self.pattern.add_layer( - "prim.constant", - inputs={}, - outputs=[gen_name(3)], - value="Exception") - self.pattern.add_layer( - "prim.constant", inputs={}, outputs=[gen_name(4)], value=4) - self.pattern.add_layer( - "prim.shape", inputs={'input': "bn-input-0"}, + "fluid.layers.shape", + inputs={'input': "bn-input-0"}, outputs=[gen_name(5)]) self.pattern.add_layer( "prim.len", inputs={'input': gen_name(5)}, outputs=[gen_name(5)]) self.pattern.add_layer( "prim.ne", inputs={"x": gen_name(5), - "y": gen_name(4)}, + "y": "bn-input-9"}, outputs=[gen_name(6)]) self.pattern.add_layer("prim.if", {'input': gen_name(6)}, [gen_name(7)]) if_layer1 = self.pattern.layers[list(self.pattern.layers.keys())[-1]] pattern_block0 = PaddleGraph(if_layer1, graph_type="dygraph") pattern_block0.add_layer( "prim.exception", - inputs={"input": gen_name(3)}, + inputs={"input": "bn-input-1"}, outputs=[gen_name(8)]) - if_layer1.inputs["input-0"] = gen_name(3) + if_layer1.inputs["input-0"] = "bn-input-1" if_layer1.add_block(pattern_block0) pattern_block1 = PaddleGraph(if_layer1, graph_type="dygraph") if_layer1.add_block(pattern_block1) - self.pattern.add_layer( - "prim.constant", inputs={}, outputs=[gen_name(9)], value=False) - self.pattern.add_layer("prim.if", {'input': gen_name(9)}, + # self.pattern.add_layer( + # "prim.constant", inputs={}, outputs=[gen_name(9)], value=False) + self.pattern.add_layer("prim.if", {'input': "bn-input-2"}, [gen_name(10)]) if_layer2 = self.pattern.layers[list(self.pattern.layers.keys())[-1]] pattern_block0 = PaddleGraph(if_layer2, graph_type="dygraph") @@ -111,52 +106,52 @@ class BatchNorm2dFuser(FuseBase): pattern_block0.add_layer( "prim.add", inputs={"x": gen_name(11), - "y": gen_name(0)}, + "y": "bn-input-3"}, outputs=[gen_name(12)]) pattern_block0.add_layer( "prim.set_attr", inputs={"input": gen_name(12)}, outputs=["self." + gen_name(11)]) - if_layer2.inputs["input-0"] = gen_name(0) + if_layer2.inputs["input-0"] = "bn-input-3" if_layer2.add_block(pattern_block0) pattern_block1 = PaddleGraph(if_layer2, graph_type="dygraph") if_layer2.add_block(pattern_block1) - self.pattern.add_layer( - "prim.constant", inputs={}, outputs=[gen_name(13)], value=True) - self.pattern.add_layer( - "prim.constant", inputs={}, outputs=[gen_name(14)], value=False) - self.pattern.add_layer("prim.if", {'input': gen_name(14)}, + # self.pattern.add_layer( + # "prim.constant", inputs={}, outputs=[gen_name(13)], value=True) + # self.pattern.add_layer( + # "prim.constant", inputs={}, outputs=[gen_name(14)], value=False) + self.pattern.add_layer("prim.if", {'input': "bn-input-4"}, [gen_name(15)]) if_layer3 = self.pattern.layers[list(self.pattern.layers.keys())[-1]] pattern_block0 = PaddleGraph(if_layer3, graph_type="dygraph") pattern_block0.add_layer( - "prim.shape", + "fluid.layers.shape", inputs={'input': "bn-input-0"}, outputs=[gen_name(16)]) - pattern_block0.add_layer( - "prim.constant", - inputs={}, - outputs=[gen_name(17)], - value="Exception") - pattern_block0.add_layer( - "prim.constant", inputs={}, outputs=[gen_name(18)], value=True) - pattern_block0.add_layer( - "prim.constant", inputs={}, outputs=[gen_name(19)], value=0) - pattern_block0.add_layer( - "prim.constant", inputs={}, outputs=[gen_name(20)], value=2) - pattern_block0.add_layer( - "prim.constant", inputs={}, outputs=[gen_name(21)], value=1) + # pattern_block0.add_layer( + # "prim.constant", + # inputs={}, + # outputs=[gen_name(17)], + # value="Exception") + # pattern_block0.add_layer( + # "prim.constant", inputs={}, outputs=[gen_name(18)], value=True) + # pattern_block0.add_layer( + # "prim.constant", inputs={}, outputs=[gen_name(19)], value=0) + # pattern_block0.add_layer( + # "prim.constant", inputs={}, outputs=[gen_name(20)], value=2) + # pattern_block0.add_layer( + # "prim.constant", inputs={}, outputs=[gen_name(21)], value=1) pattern_block0.add_layer( "prim.getitem", inputs={"list": gen_name(16), - "index": gen_name(19)}, + "index": "bn-input-6"}, outputs=[gen_name(22)]) pattern_block0.add_layer( "prim.len", inputs={"input": gen_name(16)}, outputs=[gen_name(23)]) pattern_block0.add_layer( "prim.sub", inputs={"x": gen_name(23), - "y": gen_name(20)}, + "y": "bn-input-7"}, outputs=[gen_name(24)]) pattern_block0.add_layer( "prim.equal", @@ -172,7 +167,7 @@ class BatchNorm2dFuser(FuseBase): pattern_block0_block0.add_layer( "prim.add", inputs={"x": gen_name(27), - "y": gen_name(20)}, + "y": "bn-input-7"}, outputs=[gen_name(28)]) pattern_block0_block0.add_layer( "prim.getitem", @@ -188,14 +183,14 @@ class BatchNorm2dFuser(FuseBase): "prim.equal", inputs={"input": gen_name(30)}, outputs=[gen_name(26)]) - loop_layer.inputs["input-1"] = gen_name(20) + loop_layer.inputs["input-1"] = "bn-input-7" loop_layer.inputs["input-2"] = gen_name(16) loop_layer.inputs["input-3"] = gen_name(25) loop_layer.add_block(pattern_block0_block0) pattern_block0.add_layer( "prim.eq", inputs={"x": gen_name(26), - "y": gen_name(21)}, + "y": "bn-input-8"}, outputs=[gen_name(31)]) pattern_block0.add_layer( "prim.if", inputs={"input": gen_name(31)}, outputs=[gen_name(32)]) @@ -204,16 +199,21 @@ class BatchNorm2dFuser(FuseBase): pattern_block0_block0 = PaddleGraph(if_layer31, graph_type="dygraph") pattern_block0_block0.add_layer( "prim.exception", - inputs={"input": gen_name(17)}, + inputs={"input": "bn-input-5"}, outputs=[gen_name(33)]) - if_layer31.inputs["input-0"] = gen_name(17) + if_layer31.inputs["input-0"] = "bn-input-5" if_layer31.add_block(pattern_block0_block0) pattern_block0_block1 = PaddleGraph(if_layer31, graph_type="dygraph") if_layer31.add_block(pattern_block0_block1) - if_layer3.inputs["input-0"] = "bn-input-0" if_layer3.add_block(pattern_block0) pattern_block1 = PaddleGraph(if_layer3, graph_type="dygraph") if_layer3.add_block(pattern_block1) + if_layer3.inputs["input-0"] = "bn-input-5" + if_layer3.inputs["input-1"] = "bn-input-6" + if_layer3.inputs["input-2"] = "bn-input-7" + if_layer3.inputs["input-3"] = "bn-input-7" + if_layer3.inputs["input-4"] = "bn-input-8" + if_layer3.inputs["input-5"] = "bn-input-0" self.pattern.add_layer( "paddle.nn.BatchNorm", inputs={"input": "bn-input-0"}, @@ -222,7 +222,18 @@ class BatchNorm2dFuser(FuseBase): num_channels=160, momentum=0.1, epsilon=0.001) - self.pattern.build(inputs={"input-0": "bn-input-0"}) + self.pattern.build(inputs={ + "input-0": "bn-input-0", + "input-1": "bn-input-1", + "input-2": "bn-input-2", + "input-3": "bn-input-3", + "input-4": "bn-input-4", + "input-5": "bn-input-5", + "input-6": "bn-input-6", + "input-7": "bn-input-7", + "input-8": "bn-input-8", + "input-9": "bn-input-9" + }) def insert_new_layer(self, graph, parameters, matches): new_layer = self.gen_new_layer(parameters, matches) diff --git a/x2paddle/optimizer/pattern_matcher.py b/x2paddle/optimizer/pattern_matcher.py index 8a1d57d01427aecc799a122856927ce615107a09..6b7a046eeb94c867483eb4f81e864fa0cc6a5a76 100644 --- a/x2paddle/optimizer/pattern_matcher.py +++ b/x2paddle/optimizer/pattern_matcher.py @@ -87,7 +87,7 @@ class PatternMatcher(object): if not set(pattern_layer.outputs).issubset( pattern.outputs): # 若pattern当前layer的输出是pattern的输出,则是正确的 - print("5--") + # print("5--") return False # 当为控制流时的处理 if layer.kernel == "prim.if" or layer.kernel == "prim.loop":