diff --git a/x2paddle/optimizer/fusion/__init__.py b/x2paddle/optimizer/fusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c21061b26753483d2cb9d0c648e250af2a3cd58c --- /dev/null +++ b/x2paddle/optimizer/fusion/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .fc_fuser import FcFuser +from .fc_fuse_pass import FcFusePass diff --git a/x2paddle/optimizer/fusion/__pycache__/__init__.cpython-37.pyc b/x2paddle/optimizer/fusion/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19ed62a203f5bb9789043e4c2e655dde2c3983d0 Binary files /dev/null and b/x2paddle/optimizer/fusion/__pycache__/__init__.cpython-37.pyc differ diff --git a/x2paddle/optimizer/fusion/__pycache__/fc_fuse_pass.cpython-37.pyc b/x2paddle/optimizer/fusion/__pycache__/fc_fuse_pass.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..894ffb39a66ec60cf7c25a055374837276151507 Binary files /dev/null and b/x2paddle/optimizer/fusion/__pycache__/fc_fuse_pass.cpython-37.pyc differ diff --git a/x2paddle/optimizer/fusion/__pycache__/fc_fuser.cpython-37.pyc b/x2paddle/optimizer/fusion/__pycache__/fc_fuser.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e28f1b43cad2b756cdc6471834fb9c0423546149 Binary files /dev/null and b/x2paddle/optimizer/fusion/__pycache__/fc_fuser.cpython-37.pyc differ diff --git a/x2paddle/optimizer/fusion/fc_fuse_pass.py b/x2paddle/optimizer/fusion/fc_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..35d11e4bbd8a5853a6444a65fa732527053bf958 --- /dev/null +++ b/x2paddle/optimizer/fusion/fc_fuse_pass.py @@ -0,0 +1,33 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from x2paddle.optimizer.pass_ import ProgramPass +from x2paddle.optimizer.fusion import FcFuser +from x2paddle.optimizer.pass_manager import pass_register + + +@pass_register +class FcFusePass(ProgramPass): + name = "fc_fuse_pass" + + def __init__(self): + ProgramPass.__init__(self) + + def apply(self, graph): + fuser = FcFuser() + fuser.operate(graph) + + +# 用于注册 +fc_fuse_pass = FcFusePass() diff --git a/x2paddle/optimizer/linear_pass.py b/x2paddle/optimizer/fusion/fc_fuser.py similarity index 68% rename from x2paddle/optimizer/linear_pass.py rename to x2paddle/optimizer/fusion/fc_fuser.py index e7a4a76a207ec76c3361381e4ff7c70c1fd238c9..91b51b8d704c12933413d4c90f7d263f2867e7b3 100644 --- a/x2paddle/optimizer/linear_pass.py +++ b/x2paddle/optimizer/fusion/fc_fuser.py @@ -13,17 +13,18 @@ # limitations under the License. import numpy as np +from x2paddle.optimizer.pattern_matcher import FuseBase +from x2paddle.core.program import PaddleGraph, PaddleLayer from x2paddle.core.util import * -from x2paddle.core.program import PaddleLayer, PaddleGraph -from x2paddle.optimizer.passes import Pass, Matcher, PyTorchMatcher -class LinearPass(Pass): +class FcFuser(FuseBase): def __init__(self): - super(LinearPass, self).__init__() + self.linear_index = 0 + super(FcFuser, self).__init__() def build_pattern(self): - """ 构造fc层的模式。 + """ 描述需要替换的fc图结构。 fc层模式python实现代码示例: x149 = 2 x151 = x146.shape @@ -68,8 +69,8 @@ class LinearPass(Pass): outputs=[gen_name(3)]) self.pattern.add_layer("prim.if", {'input': gen_name(3)}, [gen_name(4)]) self.pattern.outputs.append(gen_name(4)) - if_layer_a = self.pattern.layers[list(self.pattern.layers.keys())[-1]] - pattern_block0 = PaddleGraph(if_layer_a) + if_layer1 = self.pattern.layers[list(self.pattern.layers.keys())[-1]] + pattern_block0 = PaddleGraph(if_layer1) pattern_block0.add_layer( "fluid.dygraph.base.to_variable", inputs={}, @@ -93,12 +94,12 @@ class LinearPass(Pass): outputs=[gen_name(8)], beta=1, alpha=1) - if_layer_a.inputs["input-0"] = "fc-input-0" + if_layer1.inputs["input-0"] = "fc-input-0" self.pattern.inputs.append("fc-input-0") pattern_block0.add_layer( "prim.equal", inputs={'input': gen_name(8)}, outputs=[gen_name(4)]) - if_layer_a.add_block(pattern_block0) - pattern_block1 = PaddleGraph(if_layer_a) + if_layer1.add_block(pattern_block0) + pattern_block1 = PaddleGraph(if_layer1) pattern_block1.add_layer( "fluid.dygraph.base.to_variable", inputs={}, @@ -114,84 +115,75 @@ class LinearPass(Pass): inputs={"x": "fc-input-0", "y": gen_name(6)}, outputs=[gen_name(9)]) - if_layer_a.inputs["input-1"] = "fc-input-0" + if_layer1.inputs["input-1"] = "fc-input-0" pattern_block1.add_layer( "prim.constant", inputs={}, outputs=[gen_name(10)], value=True) pattern_block1.add_layer("prim.if", {'input': gen_name(10)}, [gen_name(11)]) - if_layer_b = pattern_block1.layers[list(pattern_block1.layers.keys())[ + if_layer2 = pattern_block1.layers[list(pattern_block1.layers.keys())[ -1]] - pattern_block1_block0 = PaddleGraph(if_layer_b) + pattern_block1_block0 = PaddleGraph(if_layer2) pattern_block1_block0.add_layer( "fluid.dygraph.base.to_variable", inputs={}, outputs=[gen_name(12)], value="params[{}]".format(string(gen_name(12)))) pattern_block1_block0.add_layer( - "prim.add", + "prim.add_", inputs={"x": gen_name(9), "y": gen_name(12)}, outputs=[gen_name(13)], alpha=1) - if_layer_b.inputs["input-0"] = gen_name(9) + if_layer2.inputs["input-0"] = gen_name(9) pattern_block1_block0.add_layer( "prim.equal", inputs={'input': gen_name(13)}, outputs=[gen_name(11)]) - if_layer_b.add_block(pattern_block1_block0) - pattern_block1_block1 = PaddleGraph(if_layer_b) + if_layer2.add_block(pattern_block1_block0) + pattern_block1_block1 = PaddleGraph(if_layer2) pattern_block1_block1.add_layer( "prim.equal", inputs={'input': gen_name(9)}, outputs=[gen_name(11)]) - if_layer_b.inputs["input-1"] = gen_name(9) + if_layer2.inputs["input-1"] = gen_name(9) pattern_block1.add_layer( "prim.equal", inputs={'input': gen_name(11)}, outputs=[gen_name(4)]) - if_layer_b.add_block(pattern_block1_block1) - if_layer_a.add_block(pattern_block1) + if_layer2.add_block(pattern_block1_block1) + if_layer1.add_block(pattern_block1) self.pattern.build( inputs={"input-0": "fc-input-0", "input-1": "fc-input-0"}) + def insert_new_layer(self, graph, matches): + parameters = graph.parameters + new_layer = self.gen_new_layer(parameters, matches) + new_layer_id = list(matches.keys())[0] + graph.layers[new_layer_id] = new_layer + matches.pop(new_layer_id) -class LinearMatcher(PyTorchMatcher): - def __init__(self): - self.linear_index = 0 - super(LinearMatcher, self).__init__() - - def replace_layer(self, graph, subgraph_global_layers): - subgraph_global_layers_id = list(subgraph_global_layers.keys()) - layer = subgraph_global_layers[subgraph_global_layers_id[2]] + def gen_new_layer(self, parameters, matches): + layers_id = list(matches.keys()) + layer = matches[layers_id[2]] input_name = layer.inputs["input"] - layer = subgraph_global_layers[subgraph_global_layers_id[5]] + layer = matches[layers_id[5]] output_name = layer.outputs[0] - layer = subgraph_global_layers[subgraph_global_layers_id[6]] + layer = matches[layers_id[6]] weight_name = layer.attrs["value"][8:-2] - layer = subgraph_global_layers[subgraph_global_layers_id[8]] + layer = matches[layers_id[8]] bias_name = layer.attrs["value"][8:-2] attrs = {} - attrs["input_dim"] = graph.parameters[weight_name].shape[1] - attrs["output_dim"] = graph.parameters[weight_name].shape[0] + attrs["input_dim"] = parameters[weight_name].shape[1] + attrs["output_dim"] = parameters[weight_name].shape[0] linear_name = "linear{}".format(self.linear_index) self.linear_index += 1 - graph.parameters["{}.weight".format(linear_name)] = graph.parameters[ + parameters["{}.weight".format(linear_name)] = parameters[ weight_name].transpose((1, 0)) - graph.parameters["{}.bias".format(linear_name)] = np.squeeze( - graph.parameters[bias_name]) - graph.parameters.pop(weight_name) - graph.parameters.pop(bias_name) - for i, layer_id in enumerate(subgraph_global_layers): - if layer_id in graph.layers: - layer = graph.layers[layer_id] - if i == 0: - new_layer = PaddleLayer( - layer_id, - "fluid.dygraph.Linear", - inputs={"input": input_name}, - outputs=[linear_name, output_name], - **attrs) - graph.layers[layer_id] = new_layer - else: - graph.layers.pop(layer_id) - graph.build() - return graph + parameters["{}.bias".format(linear_name)] = np.squeeze(parameters[ + bias_name]) + new_layer = PaddleLayer( + layers_id[0], + "fluid.dygraph.Linear", + inputs={"input": input_name}, + outputs=[linear_name, output_name], + **attrs) + return new_layer diff --git a/x2paddle/optimizer/optimizer.py b/x2paddle/optimizer/optimizer.py index 5b06e54fe185627a3408d5b325d5d1488f7b939d..c8d43ad4bf1c6b06b2026df0dca61f2ab6f4524f 100644 --- a/x2paddle/optimizer/optimizer.py +++ b/x2paddle/optimizer/optimizer.py @@ -12,38 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from x2paddle.optimizer.linear_pass import LinearPass, LinearMatcher +from x2paddle.optimizer.fusion import * +from x2paddle.optimizer.pass_manager import PassManager class GraphOptimizer(object): def __init__(self): - linear_pass = LinearPass() - linear_matcher = LinearMatcher() - self.passes = {linear_pass: linear_matcher} - - def run(self, graph): - is_update_graph = False - while True: - for i, (layer_id, layer) in enumerate(graph.layers.items()): - is_match = self.current_matcher.match_pattern( - self.current_pass.pattern, graph, i) - if is_match: - is_update_graph = True - graph = self.current_matcher.replace_layer(graph, is_match) - break - for j, block in enumerate(layer.blocks): - if len(block.layers) > 0: - layer.blocks[j], is_update_block = self.run(block) - if is_update_block: - break - if i + 1 == len(graph.layers): - return graph, is_update_graph + self.passes = ["fc_fuse_pass"] def optimize(self, graph): - # 开始优化 - for _pass, matcher in self.passes.items(): - self.current_pass = _pass - self.current_matcher = matcher - graph, _ = self.run(graph) - print("{} done!".format(_pass.__class__.__name__)) + for pass_name in self.passes: + pass_ = PassManager.lookup(pass_name)() + pass_.apply(graph) + print("{} done!".format(pass_name)) return graph diff --git a/x2paddle/optimizer/pass_.py b/x2paddle/optimizer/pass_.py new file mode 100644 index 0000000000000000000000000000000000000000..da74986f18f6c426e6a9b753eb62e04f22fe309f --- /dev/null +++ b/x2paddle/optimizer/pass_.py @@ -0,0 +1,44 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class Kind(Enum): + Program = 1 + Code = 2 + + +class Pass(object): + name = "pass" + + def __init__(self, kind): + self.kind = kind + + def apply(self, graph): + raise NotImplementedError("The apply function must be implemented!") + + @classmethod + def get_name(cls): + return cls.name + + +class ProgramPass(Pass): + def __init__(self): + super(ProgramPass, self).__init__(Kind.Program) + + +class CodePass(Pass): + def __init__(self): + super(CodePass, self).__init__(Kind.Code) diff --git a/x2paddle/optimizer/pass_manager.py b/x2paddle/optimizer/pass_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..8653f62b3a415c1a4db4a95d0a185a28028d75c9 --- /dev/null +++ b/x2paddle/optimizer/pass_manager.py @@ -0,0 +1,42 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class PassManager(object): + """ pass管理器。 + """ + # pass_map存储name与其对应的pass + pass_map = dict() + + def __init__(self): + pass + + @staticmethod + def add_new_pass(name, pass_): + if name not in PassManager.pass_map: + PassManager.pass_map[name] = pass_ + + @staticmethod + def clear(): + PassManager.passes = list() + + @staticmethod + def lookup(name): + return PassManager.pass_map[name] + + +def pass_register(cls): + name = cls.get_name() + PassManager.add_new_pass(name, cls) + return cls diff --git a/x2paddle/optimizer/passes.py b/x2paddle/optimizer/passes.py deleted file mode 100644 index 2987f7934e05bf443c2531a1ac70d51dc697a98e..0000000000000000000000000000000000000000 --- a/x2paddle/optimizer/passes.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License" -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from x2paddle.core.program import PaddleGraph - - -class Pass(object): - def __init__(self): - self.pattern = PaddleGraph() - self.build_pattern() - - -class Matcher(object): - def __init__(self): - pass - - -class PyTorchMatcher(Matcher): - def __init__(self): - super(PyTorchMatcher, self).__init__() - - def match_pattern(self, pattern, graph, start_index): - pattern_index = 0 - pattern_global_layers = pattern.get_global_layers() - subgraph_global_layers = dict() - graph_layers = dict(list(graph.layers.items())[start_index:]) - for layer_id, layer in graph_layers.items(): - pattern_layer = pattern.layers[list(pattern.layers.keys())[ - pattern_index]] - if layer.kernel == pattern_layer.kernel: - subgraph_global_layers[layer_id] = layer - pattern_layer_id = pattern_layer.id - if layer.kernel == "prim.constant": - if layer.attrs["value"] != pattern_layer.attrs["value"]: - return False - elif layer.kernel == "fluid.layers.addmm": - if layer.attrs["beta"] != pattern_layer.attrs["beta"]: - return False - if layer.attrs["alpha"] != pattern_layer.attrs["alpha"]: - return False - - if layer_id in graph.edges_in: - if pattern_layer_id not in pattern.edges_in: - return False - else: - if len(graph.edges_in[layer_id]) != len( - pattern.edges_in[pattern_layer_id]): - return False - layer_in = graph.edges_in[layer_id] - pattern_layer_in = pattern.edges_in[pattern_layer_id] - for i in range(len(layer_in)): - layer_id_in = layer_in[i] - pattern_layer_id_in = pattern_layer_in[i] - if pattern_layer_id_in != -1: - pattern_global_layers_id = list( - pattern_global_layers.keys()) - subgraph_global_layers_id = list( - subgraph_global_layers.keys()) - if pattern_global_layers_id.index(pattern_layer_id_in) == \ - subgraph_global_layers_id.index(layer_id_in): - # 判断pattern输入在pattern_global_layers_id的索引 - # 和graph输入在subgraph_global_layers_id的索引一致 - continue - return False - - if layer_id in graph.edges_out: - if pattern_layer_id not in pattern.edges_out: - if not set(pattern_layer.outputs).issubset( - pattern.outputs): - # 若pattern当前layer的输出是pattern的输出,则是正确的 - return False - else: - if len(graph.edges_out[layer_id]) != len( - pattern.edges_out[pattern_layer_id]): - # 如果在每个节点edges_in相同的情况下,edges_out数目相同则说明无节点在subgraph外被用到 - if not set(pattern_layer.outputs).issubset( - pattern.outputs): - # 若pattern当前layer的输出是pattern的输出,则是正确的 - return False - - if layer.kernel == "prim.if": - res = self.match_pattern(pattern_layer.blocks[0], - layer.blocks[0], 0) - if res: - subgraph_global_layers.update(res) - else: - return False - res = self.match_pattern(pattern_layer.blocks[1], - layer.blocks[1], 0) - if res: - subgraph_global_layers.update(res) - else: - return False - pattern_index += 1 - if pattern_index == len(pattern.layers): - return subgraph_global_layers - else: - return False diff --git a/x2paddle/optimizer/pattern_matcher.py b/x2paddle/optimizer/pattern_matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..a07b3980030b60a0f05f4fff0c55087b07137a61 --- /dev/null +++ b/x2paddle/optimizer/pattern_matcher.py @@ -0,0 +1,154 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from x2paddle.core.program import PaddleGraph + + +class PatternMatcher(object): + def __init__(self, pattern): + self.pattern = pattern + self.subgraphs = list() + + def operate(self, graph): + self.detect_patterns(graph) + self.remove_overlapped_match() + return self.subgraphs + + def detect_patterns(self, graph): + """ 找到与模式匹配的子图, + 并将子图的id以拓扑排序存放到subgraph_id2layers。 + """ + + def get_subgraph(pattern, graph, start_index): + pattern_index = 0 + pattern_id2layers = pattern.get_global_layers() + pattern_ids = list(pattern_id2layers.keys()) + subgraph_id2layers = dict() + graph_layers = dict(list(graph.layers.items())[start_index:]) + for layer_id, layer in graph_layers.items(): + pattern_layer = pattern.layers[list(pattern.layers.keys())[ + pattern_index]] + if layer.kernel == pattern_layer.kernel: + subgraph_id2layers[layer_id] = layer + pattern_layer_id = pattern_layer.id + # 判断输入连接是否一致 + if layer_id in graph.edges_in: + if pattern_layer_id not in pattern.edges_in: + return False + else: + if len(graph.edges_in[layer_id]) != len( + pattern.edges_in[pattern_layer_id]): + return False + layer_in = graph.edges_in[layer_id] + pattern_layer_in = pattern.edges_in[pattern_layer_id] + for i in range(len(layer_in)): + layer_id_in = layer_in[i] + pattern_layer_id_in = pattern_layer_in[i] + if pattern_layer_id_in != -1: + subgraph_ids = list(subgraph_id2layers.keys()) + if pattern_ids.index(pattern_layer_id_in) == \ + subgraph_ids.index(layer_id_in): + # 判断pattern输入在pattern_ids的索引 + # 和graph输入在subgraph_ids的索引一致 + continue + return False + # 判断subgraph中的节点是否被外部图使用到(如若被使用到则无效) + if layer_id in graph.edges_out: + if pattern_layer_id not in pattern.edges_out: + if not set(pattern_layer.outputs).issubset( + pattern.outputs): + # 若pattern当前layer的输出是pattern的输出,则是正确的 + return False + else: + if len(graph.edges_out[layer_id]) != len( + pattern.edges_out[pattern_layer_id]): + # 如果在每个节点edges_in相同的情况下,edges_out数目相同则说明无节点在subgraph外被用到 + if not set(pattern_layer.outputs).issubset( + pattern.outputs): + # 若pattern当前layer的输出是pattern的输出,则是正确的 + return False + # 当为控制流时的处理 + if layer.kernel == "prim.if": + match_info = get_subgraph(pattern_layer.blocks[0], + layer.blocks[0], 0) + if match_info: + subgraph_id2layers.update(match_info) + else: + return False + match_info = get_subgraph(pattern_layer.blocks[1], + layer.blocks[1], 0) + if match_info: + subgraph_id2layers.update(match_info) + else: + return False + pattern_index += 1 + if pattern_index == len(pattern.layers): + return subgraph_id2layers + else: + return False + + for i, (layer_id, layer) in enumerate(graph.layers.items()): + match_info = get_subgraph(self.pattern, graph, i) + if match_info: + self.subgraphs.append(match_info) + for j, block in enumerate(layer.blocks): + if len(block.layers) > 0: + self.detect_patterns(layer.blocks[j]) + + def remove_overlapped_match(self): + """ 如果2个子图有重叠,只取前一个子图。 + """ + match_ids = [] + for i, subgraph in enumerate(self.subgraphs): + is_overlapped = False + for id in subgraph.keys(): + if id in match_ids: + self.subgraphs.pop(i) + is_overlapped = True + break + if not is_overlapped: + match_ids.extend(list(subgraph.keys())) + + +class FuseBase(object): + def __init__(self): + self.pattern = PaddleGraph() + + def operate(self, graph): + self.build_pattern() + self.perform_pattern_matcher(graph) + for subgraph in self.subgraphs: + self.insert_new_layer(graph, subgraph) + self.delete_inter_layer(graph) + graph.build() + + def perform_pattern_matcher(self, graph): + """ 执行模式匹配,找到匹配的子图。 + """ + pattern_matcher = PatternMatcher(self.pattern) + self.subgraphs = pattern_matcher.operate(graph) + + def delete_inter_layer(self, graph): + """ 删除不需要的中间layer及其对应参数。 + """ + for subgraph in self.subgraphs: + for layer_id, layer in subgraph.items(): + if layer.kernel == "fluid.dygraph.base.to_variable" and \ + layer.attrs["value"].startswith("params["): + param_name = layer.attrs["value"][8:-2] + if param_name in graph.parameters: + graph.parameters.pop(param_name) + if layer_id in graph.layers: + # layer_id可能是属于子图的,此时删除父layer,即删除整个子图 + graph.layers.pop(layer_id)