From 5f0f476f80d6db5c06f6c56826a53b0e3ac7b930 Mon Sep 17 00:00:00 2001 From: SunAhong1993 Date: Wed, 19 Aug 2020 16:29:19 +0800 Subject: [PATCH] modify optimizer --- x2paddle/optimizer/fusion/__init__.py | 16 ++ .../__pycache__/__init__.cpython-37.pyc | Bin 0 -> 244 bytes .../__pycache__/fc_fuse_pass.cpython-37.pyc | Bin 0 -> 834 bytes .../__pycache__/fc_fuser.cpython-37.pyc | Bin 0 -> 4691 bytes x2paddle/optimizer/fusion/fc_fuse_pass.py | 33 ++++ .../{linear_pass.py => fusion/fc_fuser.py} | 98 +++++------ x2paddle/optimizer/optimizer.py | 35 +--- x2paddle/optimizer/pass_.py | 44 +++++ x2paddle/optimizer/pass_manager.py | 42 +++++ x2paddle/optimizer/passes.py | 109 ------------- x2paddle/optimizer/pattern_matcher.py | 154 ++++++++++++++++++ 11 files changed, 341 insertions(+), 190 deletions(-) create mode 100644 x2paddle/optimizer/fusion/__init__.py create mode 100644 x2paddle/optimizer/fusion/__pycache__/__init__.cpython-37.pyc create mode 100644 x2paddle/optimizer/fusion/__pycache__/fc_fuse_pass.cpython-37.pyc create mode 100644 x2paddle/optimizer/fusion/__pycache__/fc_fuser.cpython-37.pyc create mode 100644 x2paddle/optimizer/fusion/fc_fuse_pass.py rename x2paddle/optimizer/{linear_pass.py => fusion/fc_fuser.py} (68%) create mode 100644 x2paddle/optimizer/pass_.py create mode 100644 x2paddle/optimizer/pass_manager.py delete mode 100644 x2paddle/optimizer/passes.py create mode 100644 x2paddle/optimizer/pattern_matcher.py diff --git a/x2paddle/optimizer/fusion/__init__.py b/x2paddle/optimizer/fusion/__init__.py new file mode 100644 index 0000000..c21061b --- /dev/null +++ b/x2paddle/optimizer/fusion/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .fc_fuser import FcFuser +from .fc_fuse_pass import FcFusePass diff --git a/x2paddle/optimizer/fusion/__pycache__/__init__.cpython-37.pyc b/x2paddle/optimizer/fusion/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19ed62a203f5bb9789043e4c2e655dde2c3983d0 GIT binary patch literal 244 zcmZ?b<>g`k0=4rt@ynSQ7#@Q-FaYE;H~?|643J1+NMX!j$YqRTWCXF9a+q?NqnLqg z=3oX*mY0k`C7O)4*xi!dN{dsAfJ`n3Ga#|J*iVxsiX$yKJ`E&Q1Tyv(50o2U0F*Cc z1`4fYC}IUtVB(jXer{fgesO7DWnx}hVqUs|p?-u>Kw?TtPO5>0fpM9UeuYs1n5&;( wP?DLOS(RF(4>UeAKTkhCJ~J<~BtBlRpz;=nO>TZlX-=vgKhWr6kj*@d0Ci73L;wH) literal 0 HcmV?d00001 diff --git a/x2paddle/optimizer/fusion/__pycache__/fc_fuse_pass.cpython-37.pyc b/x2paddle/optimizer/fusion/__pycache__/fc_fuse_pass.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..894ffb39a66ec60cf7c25a055374837276151507 GIT binary patch literal 834 zcmZuv!H&}~5Vf5+38jUV79sWo*Ib%XE*w^9wMdJ^p{m3M`C?hITOuWCuv0{<<<$L< z-6NmFHCImj0w-o1x)iC7W-{Y(=Dj!Z4h zr6bWRG>%;4P$)hWMSYLBG z%cvV-toi0G>1~PW%(yQx!&P_GdV$IBN)>YeF4E14oD^OX- zCe=--P)y7l4HC{9a>k6_#F(D&1R5sCk}XAYq{PyW(pUE+S{I0wo1Bml=I0j?8A}q&aC@iOg8pk+;=Y#X66ObeAuZ+r1i@gN0Xi2)e%}&tZZ29m6gt&p((+X@Ni2P3V|2><+0Xf0%C8vh6y6J{Dc zKNj;p$?T|&{0F$u{Z*P5{?j(dS7>vd(yXDW3uzYSLwW-*_UWiivnw`Rc28g7Bw=9k|t(sj_@JnIs7s2p(Rq@Sm|@#BHpI( zD)~(lJ=%LEX5l~oNMRRt;Y(IUnpx=H!xk++dGO%7ufG^P`slkq+nFRPKmG0f zX5;b0ubzJSt0$k`|K_8Qt#bAQ_%9zn`QzT>zdh(|MeFV0@t;3?`so+ne)+qn{d-S; z|JQH7`sHgQWTM%uUAk7jUS8Sk!f-Tl$0u~kbEUamLe=)0z%X8daUZx!4Os>?gU|{9lB>XuM^RH840*` z{k__JdGYU>$qx;$Ur!g%@8e0kN4oa+$j z_?>lh4wAck5p03IoD=E#q5HpCf=u=QSz?(ikz>(vqE^>WEje7>l>-+;A#;m&^Kn52 zPThO$0QY_`^QTCc)8K=G=gYP7If}#QJ(Zy@3bS6W<;+TR2b(rWwfsC+cqnz$<;4-l z(djuX5ZeTpy;il+4#LO@B2j=0h(|?D3=y1RVz#>x zqBKnOJC5IV2c&S%5_Do4C%QoFS| z<-%^Xhxrt-r0S`nj_Xb`SI`WoWEzsz22>^A8cfgxIk;C8wt_PDfZY77?)(e^lp*cn zuv4qZLc^Pt$;8ZDo)n3N$ZNtIiQ#lQZom>F&2@?9dtsDlAGli~Cs!m!v(2GxVj$wY zU`=pLxvV03K1g&CQk4QHjrn0T^`h5IMbzq3nOs{SFYX)H#;8m6qdC;tkm5jaZGY5H z<^p>TT7-W`Bb_X$hCGXUm7an;h4v&K1&@W%X?aFI1-J%?g7-1ev(PK;gAonE35zH- z@S*XGDGAHM{H#0;ZqvZsk4-udl(-Rh!Y7;(dIY1lg+Iye9X)UhH^@iyqkyK(@)n-( zH8iojA)yeJ`f{ws%AT^V;f5jZ7bsq26sstTfTP7a3M-|r$2ye>xtoWPr2EN5vIvt& zktEy?v*?+J41R%{^Jo&a)9zFZ5d?`YwrFA^Y#Qrs$SD&g#b8KUFO9nP^Z9Gq3vlO% z>=DjGP>15u2t$w|ExC7MXpH^HBYV)Mp_Cx(F@+q3B&zgfD%N2YC05A>umTmb3R8qV z2&#^txZDA12wZ&x)tE6S(M9hUM1yi4H0BFt3!b!%qCdv`hR%v}QjALstUZK)KOhq; z5-ar$TqjVnZCf$ixM8y56-kk7GsW!}*yJ7+o>Ra>xxOMr#XidXxG*OP_@NOTWjg82 zrkhw*tfMrvVgr|n!zinZzfnGwwFJxYR z485c%a+=M&R$`?G0QGx?^lEx{8>x)&1ead%?soM<*IQeUdM3K7o)h+F^ZEN7aix9s zV5pS%IHpF&1#)T>A^LQ2w(u&+pe&sfM;Sd)f=*(Fx4W*3+beAXuoBsb0*_xoQ!l3J z4e^%fqMKOhg6O17CbAW!gH&^Ih9*VQU3B?%;xmz5#R-#NCH53(B8wW5g!xia2ERv$xkc%Tv7|Bfa5EM9&}sy z2=E3l>Su-@4636K2GwEaPB%Q~R58?#ju0Dd?p8azjkM4bhl+4gGLuK>MxHMOj&&T{ VH7Ts5YkvI@sY&NA 0: - layer.blocks[j], is_update_block = self.run(block) - if is_update_block: - break - if i + 1 == len(graph.layers): - return graph, is_update_graph + self.passes = ["fc_fuse_pass"] def optimize(self, graph): - # 开始优化 - for _pass, matcher in self.passes.items(): - self.current_pass = _pass - self.current_matcher = matcher - graph, _ = self.run(graph) - print("{} done!".format(_pass.__class__.__name__)) + for pass_name in self.passes: + pass_ = PassManager.lookup(pass_name)() + pass_.apply(graph) + print("{} done!".format(pass_name)) return graph diff --git a/x2paddle/optimizer/pass_.py b/x2paddle/optimizer/pass_.py new file mode 100644 index 0000000..da74986 --- /dev/null +++ b/x2paddle/optimizer/pass_.py @@ -0,0 +1,44 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class Kind(Enum): + Program = 1 + Code = 2 + + +class Pass(object): + name = "pass" + + def __init__(self, kind): + self.kind = kind + + def apply(self, graph): + raise NotImplementedError("The apply function must be implemented!") + + @classmethod + def get_name(cls): + return cls.name + + +class ProgramPass(Pass): + def __init__(self): + super(ProgramPass, self).__init__(Kind.Program) + + +class CodePass(Pass): + def __init__(self): + super(CodePass, self).__init__(Kind.Code) diff --git a/x2paddle/optimizer/pass_manager.py b/x2paddle/optimizer/pass_manager.py new file mode 100644 index 0000000..8653f62 --- /dev/null +++ b/x2paddle/optimizer/pass_manager.py @@ -0,0 +1,42 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class PassManager(object): + """ pass管理器。 + """ + # pass_map存储name与其对应的pass + pass_map = dict() + + def __init__(self): + pass + + @staticmethod + def add_new_pass(name, pass_): + if name not in PassManager.pass_map: + PassManager.pass_map[name] = pass_ + + @staticmethod + def clear(): + PassManager.passes = list() + + @staticmethod + def lookup(name): + return PassManager.pass_map[name] + + +def pass_register(cls): + name = cls.get_name() + PassManager.add_new_pass(name, cls) + return cls diff --git a/x2paddle/optimizer/passes.py b/x2paddle/optimizer/passes.py deleted file mode 100644 index 2987f79..0000000 --- a/x2paddle/optimizer/passes.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License" -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from x2paddle.core.program import PaddleGraph - - -class Pass(object): - def __init__(self): - self.pattern = PaddleGraph() - self.build_pattern() - - -class Matcher(object): - def __init__(self): - pass - - -class PyTorchMatcher(Matcher): - def __init__(self): - super(PyTorchMatcher, self).__init__() - - def match_pattern(self, pattern, graph, start_index): - pattern_index = 0 - pattern_global_layers = pattern.get_global_layers() - subgraph_global_layers = dict() - graph_layers = dict(list(graph.layers.items())[start_index:]) - for layer_id, layer in graph_layers.items(): - pattern_layer = pattern.layers[list(pattern.layers.keys())[ - pattern_index]] - if layer.kernel == pattern_layer.kernel: - subgraph_global_layers[layer_id] = layer - pattern_layer_id = pattern_layer.id - if layer.kernel == "prim.constant": - if layer.attrs["value"] != pattern_layer.attrs["value"]: - return False - elif layer.kernel == "fluid.layers.addmm": - if layer.attrs["beta"] != pattern_layer.attrs["beta"]: - return False - if layer.attrs["alpha"] != pattern_layer.attrs["alpha"]: - return False - - if layer_id in graph.edges_in: - if pattern_layer_id not in pattern.edges_in: - return False - else: - if len(graph.edges_in[layer_id]) != len( - pattern.edges_in[pattern_layer_id]): - return False - layer_in = graph.edges_in[layer_id] - pattern_layer_in = pattern.edges_in[pattern_layer_id] - for i in range(len(layer_in)): - layer_id_in = layer_in[i] - pattern_layer_id_in = pattern_layer_in[i] - if pattern_layer_id_in != -1: - pattern_global_layers_id = list( - pattern_global_layers.keys()) - subgraph_global_layers_id = list( - subgraph_global_layers.keys()) - if pattern_global_layers_id.index(pattern_layer_id_in) == \ - subgraph_global_layers_id.index(layer_id_in): - # 判断pattern输入在pattern_global_layers_id的索引 - # 和graph输入在subgraph_global_layers_id的索引一致 - continue - return False - - if layer_id in graph.edges_out: - if pattern_layer_id not in pattern.edges_out: - if not set(pattern_layer.outputs).issubset( - pattern.outputs): - # 若pattern当前layer的输出是pattern的输出,则是正确的 - return False - else: - if len(graph.edges_out[layer_id]) != len( - pattern.edges_out[pattern_layer_id]): - # 如果在每个节点edges_in相同的情况下,edges_out数目相同则说明无节点在subgraph外被用到 - if not set(pattern_layer.outputs).issubset( - pattern.outputs): - # 若pattern当前layer的输出是pattern的输出,则是正确的 - return False - - if layer.kernel == "prim.if": - res = self.match_pattern(pattern_layer.blocks[0], - layer.blocks[0], 0) - if res: - subgraph_global_layers.update(res) - else: - return False - res = self.match_pattern(pattern_layer.blocks[1], - layer.blocks[1], 0) - if res: - subgraph_global_layers.update(res) - else: - return False - pattern_index += 1 - if pattern_index == len(pattern.layers): - return subgraph_global_layers - else: - return False diff --git a/x2paddle/optimizer/pattern_matcher.py b/x2paddle/optimizer/pattern_matcher.py new file mode 100644 index 0000000..a07b398 --- /dev/null +++ b/x2paddle/optimizer/pattern_matcher.py @@ -0,0 +1,154 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from x2paddle.core.program import PaddleGraph + + +class PatternMatcher(object): + def __init__(self, pattern): + self.pattern = pattern + self.subgraphs = list() + + def operate(self, graph): + self.detect_patterns(graph) + self.remove_overlapped_match() + return self.subgraphs + + def detect_patterns(self, graph): + """ 找到与模式匹配的子图, + 并将子图的id以拓扑排序存放到subgraph_id2layers。 + """ + + def get_subgraph(pattern, graph, start_index): + pattern_index = 0 + pattern_id2layers = pattern.get_global_layers() + pattern_ids = list(pattern_id2layers.keys()) + subgraph_id2layers = dict() + graph_layers = dict(list(graph.layers.items())[start_index:]) + for layer_id, layer in graph_layers.items(): + pattern_layer = pattern.layers[list(pattern.layers.keys())[ + pattern_index]] + if layer.kernel == pattern_layer.kernel: + subgraph_id2layers[layer_id] = layer + pattern_layer_id = pattern_layer.id + # 判断输入连接是否一致 + if layer_id in graph.edges_in: + if pattern_layer_id not in pattern.edges_in: + return False + else: + if len(graph.edges_in[layer_id]) != len( + pattern.edges_in[pattern_layer_id]): + return False + layer_in = graph.edges_in[layer_id] + pattern_layer_in = pattern.edges_in[pattern_layer_id] + for i in range(len(layer_in)): + layer_id_in = layer_in[i] + pattern_layer_id_in = pattern_layer_in[i] + if pattern_layer_id_in != -1: + subgraph_ids = list(subgraph_id2layers.keys()) + if pattern_ids.index(pattern_layer_id_in) == \ + subgraph_ids.index(layer_id_in): + # 判断pattern输入在pattern_ids的索引 + # 和graph输入在subgraph_ids的索引一致 + continue + return False + # 判断subgraph中的节点是否被外部图使用到(如若被使用到则无效) + if layer_id in graph.edges_out: + if pattern_layer_id not in pattern.edges_out: + if not set(pattern_layer.outputs).issubset( + pattern.outputs): + # 若pattern当前layer的输出是pattern的输出,则是正确的 + return False + else: + if len(graph.edges_out[layer_id]) != len( + pattern.edges_out[pattern_layer_id]): + # 如果在每个节点edges_in相同的情况下,edges_out数目相同则说明无节点在subgraph外被用到 + if not set(pattern_layer.outputs).issubset( + pattern.outputs): + # 若pattern当前layer的输出是pattern的输出,则是正确的 + return False + # 当为控制流时的处理 + if layer.kernel == "prim.if": + match_info = get_subgraph(pattern_layer.blocks[0], + layer.blocks[0], 0) + if match_info: + subgraph_id2layers.update(match_info) + else: + return False + match_info = get_subgraph(pattern_layer.blocks[1], + layer.blocks[1], 0) + if match_info: + subgraph_id2layers.update(match_info) + else: + return False + pattern_index += 1 + if pattern_index == len(pattern.layers): + return subgraph_id2layers + else: + return False + + for i, (layer_id, layer) in enumerate(graph.layers.items()): + match_info = get_subgraph(self.pattern, graph, i) + if match_info: + self.subgraphs.append(match_info) + for j, block in enumerate(layer.blocks): + if len(block.layers) > 0: + self.detect_patterns(layer.blocks[j]) + + def remove_overlapped_match(self): + """ 如果2个子图有重叠,只取前一个子图。 + """ + match_ids = [] + for i, subgraph in enumerate(self.subgraphs): + is_overlapped = False + for id in subgraph.keys(): + if id in match_ids: + self.subgraphs.pop(i) + is_overlapped = True + break + if not is_overlapped: + match_ids.extend(list(subgraph.keys())) + + +class FuseBase(object): + def __init__(self): + self.pattern = PaddleGraph() + + def operate(self, graph): + self.build_pattern() + self.perform_pattern_matcher(graph) + for subgraph in self.subgraphs: + self.insert_new_layer(graph, subgraph) + self.delete_inter_layer(graph) + graph.build() + + def perform_pattern_matcher(self, graph): + """ 执行模式匹配,找到匹配的子图。 + """ + pattern_matcher = PatternMatcher(self.pattern) + self.subgraphs = pattern_matcher.operate(graph) + + def delete_inter_layer(self, graph): + """ 删除不需要的中间layer及其对应参数。 + """ + for subgraph in self.subgraphs: + for layer_id, layer in subgraph.items(): + if layer.kernel == "fluid.dygraph.base.to_variable" and \ + layer.attrs["value"].startswith("params["): + param_name = layer.attrs["value"][8:-2] + if param_name in graph.parameters: + graph.parameters.pop(param_name) + if layer_id in graph.layers: + # layer_id可能是属于子图的,此时删除父layer,即删除整个子图 + graph.layers.pop(layer_id) -- GitLab