提交 5f0f476f 编写于 作者: S SunAhong1993

modify optimizer

上级 cbc3efdb
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .fc_fuser import FcFuser
from .fc_fuse_pass import FcFusePass
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pass_ import ProgramPass
from x2paddle.optimizer.fusion import FcFuser
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class FcFusePass(ProgramPass):
name = "fc_fuse_pass"
def __init__(self):
ProgramPass.__init__(self)
def apply(self, graph):
fuser = FcFuser()
fuser.operate(graph)
# 用于注册
fc_fuse_pass = FcFusePass()
......@@ -13,17 +13,18 @@
# limitations under the License.
import numpy as np
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *
from x2paddle.core.program import PaddleLayer, PaddleGraph
from x2paddle.optimizer.passes import Pass, Matcher, PyTorchMatcher
class LinearPass(Pass):
class FcFuser(FuseBase):
def __init__(self):
super(LinearPass, self).__init__()
self.linear_index = 0
super(FcFuser, self).__init__()
def build_pattern(self):
""" 构造fc层的模式
""" 描述需要替换的fc图结构
fc层模式python实现代码示例:
x149 = 2
x151 = x146.shape
......@@ -68,8 +69,8 @@ class LinearPass(Pass):
outputs=[gen_name(3)])
self.pattern.add_layer("prim.if", {'input': gen_name(3)}, [gen_name(4)])
self.pattern.outputs.append(gen_name(4))
if_layer_a = self.pattern.layers[list(self.pattern.layers.keys())[-1]]
pattern_block0 = PaddleGraph(if_layer_a)
if_layer1 = self.pattern.layers[list(self.pattern.layers.keys())[-1]]
pattern_block0 = PaddleGraph(if_layer1)
pattern_block0.add_layer(
"fluid.dygraph.base.to_variable",
inputs={},
......@@ -93,12 +94,12 @@ class LinearPass(Pass):
outputs=[gen_name(8)],
beta=1,
alpha=1)
if_layer_a.inputs["input-0"] = "fc-input-0"
if_layer1.inputs["input-0"] = "fc-input-0"
self.pattern.inputs.append("fc-input-0")
pattern_block0.add_layer(
"prim.equal", inputs={'input': gen_name(8)}, outputs=[gen_name(4)])
if_layer_a.add_block(pattern_block0)
pattern_block1 = PaddleGraph(if_layer_a)
if_layer1.add_block(pattern_block0)
pattern_block1 = PaddleGraph(if_layer1)
pattern_block1.add_layer(
"fluid.dygraph.base.to_variable",
inputs={},
......@@ -114,84 +115,75 @@ class LinearPass(Pass):
inputs={"x": "fc-input-0",
"y": gen_name(6)},
outputs=[gen_name(9)])
if_layer_a.inputs["input-1"] = "fc-input-0"
if_layer1.inputs["input-1"] = "fc-input-0"
pattern_block1.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(10)], value=True)
pattern_block1.add_layer("prim.if", {'input': gen_name(10)},
[gen_name(11)])
if_layer_b = pattern_block1.layers[list(pattern_block1.layers.keys())[
if_layer2 = pattern_block1.layers[list(pattern_block1.layers.keys())[
-1]]
pattern_block1_block0 = PaddleGraph(if_layer_b)
pattern_block1_block0 = PaddleGraph(if_layer2)
pattern_block1_block0.add_layer(
"fluid.dygraph.base.to_variable",
inputs={},
outputs=[gen_name(12)],
value="params[{}]".format(string(gen_name(12))))
pattern_block1_block0.add_layer(
"prim.add",
"prim.add_",
inputs={"x": gen_name(9),
"y": gen_name(12)},
outputs=[gen_name(13)],
alpha=1)
if_layer_b.inputs["input-0"] = gen_name(9)
if_layer2.inputs["input-0"] = gen_name(9)
pattern_block1_block0.add_layer(
"prim.equal",
inputs={'input': gen_name(13)},
outputs=[gen_name(11)])
if_layer_b.add_block(pattern_block1_block0)
pattern_block1_block1 = PaddleGraph(if_layer_b)
if_layer2.add_block(pattern_block1_block0)
pattern_block1_block1 = PaddleGraph(if_layer2)
pattern_block1_block1.add_layer(
"prim.equal", inputs={'input': gen_name(9)},
outputs=[gen_name(11)])
if_layer_b.inputs["input-1"] = gen_name(9)
if_layer2.inputs["input-1"] = gen_name(9)
pattern_block1.add_layer(
"prim.equal", inputs={'input': gen_name(11)},
outputs=[gen_name(4)])
if_layer_b.add_block(pattern_block1_block1)
if_layer_a.add_block(pattern_block1)
if_layer2.add_block(pattern_block1_block1)
if_layer1.add_block(pattern_block1)
self.pattern.build(
inputs={"input-0": "fc-input-0",
"input-1": "fc-input-0"})
def insert_new_layer(self, graph, matches):
parameters = graph.parameters
new_layer = self.gen_new_layer(parameters, matches)
new_layer_id = list(matches.keys())[0]
graph.layers[new_layer_id] = new_layer
matches.pop(new_layer_id)
class LinearMatcher(PyTorchMatcher):
def __init__(self):
self.linear_index = 0
super(LinearMatcher, self).__init__()
def replace_layer(self, graph, subgraph_global_layers):
subgraph_global_layers_id = list(subgraph_global_layers.keys())
layer = subgraph_global_layers[subgraph_global_layers_id[2]]
def gen_new_layer(self, parameters, matches):
layers_id = list(matches.keys())
layer = matches[layers_id[2]]
input_name = layer.inputs["input"]
layer = subgraph_global_layers[subgraph_global_layers_id[5]]
layer = matches[layers_id[5]]
output_name = layer.outputs[0]
layer = subgraph_global_layers[subgraph_global_layers_id[6]]
layer = matches[layers_id[6]]
weight_name = layer.attrs["value"][8:-2]
layer = subgraph_global_layers[subgraph_global_layers_id[8]]
layer = matches[layers_id[8]]
bias_name = layer.attrs["value"][8:-2]
attrs = {}
attrs["input_dim"] = graph.parameters[weight_name].shape[1]
attrs["output_dim"] = graph.parameters[weight_name].shape[0]
attrs["input_dim"] = parameters[weight_name].shape[1]
attrs["output_dim"] = parameters[weight_name].shape[0]
linear_name = "linear{}".format(self.linear_index)
self.linear_index += 1
graph.parameters["{}.weight".format(linear_name)] = graph.parameters[
parameters["{}.weight".format(linear_name)] = parameters[
weight_name].transpose((1, 0))
graph.parameters["{}.bias".format(linear_name)] = np.squeeze(
graph.parameters[bias_name])
graph.parameters.pop(weight_name)
graph.parameters.pop(bias_name)
for i, layer_id in enumerate(subgraph_global_layers):
if layer_id in graph.layers:
layer = graph.layers[layer_id]
if i == 0:
parameters["{}.bias".format(linear_name)] = np.squeeze(parameters[
bias_name])
new_layer = PaddleLayer(
layer_id,
layers_id[0],
"fluid.dygraph.Linear",
inputs={"input": input_name},
outputs=[linear_name, output_name],
**attrs)
graph.layers[layer_id] = new_layer
else:
graph.layers.pop(layer_id)
graph.build()
return graph
return new_layer
......@@ -12,38 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.linear_pass import LinearPass, LinearMatcher
from x2paddle.optimizer.fusion import *
from x2paddle.optimizer.pass_manager import PassManager
class GraphOptimizer(object):
def __init__(self):
linear_pass = LinearPass()
linear_matcher = LinearMatcher()
self.passes = {linear_pass: linear_matcher}
def run(self, graph):
is_update_graph = False
while True:
for i, (layer_id, layer) in enumerate(graph.layers.items()):
is_match = self.current_matcher.match_pattern(
self.current_pass.pattern, graph, i)
if is_match:
is_update_graph = True
graph = self.current_matcher.replace_layer(graph, is_match)
break
for j, block in enumerate(layer.blocks):
if len(block.layers) > 0:
layer.blocks[j], is_update_block = self.run(block)
if is_update_block:
break
if i + 1 == len(graph.layers):
return graph, is_update_graph
self.passes = ["fc_fuse_pass"]
def optimize(self, graph):
# 开始优化
for _pass, matcher in self.passes.items():
self.current_pass = _pass
self.current_matcher = matcher
graph, _ = self.run(graph)
print("{} done!".format(_pass.__class__.__name__))
for pass_name in self.passes:
pass_ = PassManager.lookup(pass_name)()
pass_.apply(graph)
print("{} done!".format(pass_name))
return graph
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
class Kind(Enum):
Program = 1
Code = 2
class Pass(object):
name = "pass"
def __init__(self, kind):
self.kind = kind
def apply(self, graph):
raise NotImplementedError("The apply function must be implemented!")
@classmethod
def get_name(cls):
return cls.name
class ProgramPass(Pass):
def __init__(self):
super(ProgramPass, self).__init__(Kind.Program)
class CodePass(Pass):
def __init__(self):
super(CodePass, self).__init__(Kind.Code)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class PassManager(object):
""" pass管理器。
"""
# pass_map存储name与其对应的pass
pass_map = dict()
def __init__(self):
pass
@staticmethod
def add_new_pass(name, pass_):
if name not in PassManager.pass_map:
PassManager.pass_map[name] = pass_
@staticmethod
def clear():
PassManager.passes = list()
@staticmethod
def lookup(name):
return PassManager.pass_map[name]
def pass_register(cls):
name = cls.get_name()
PassManager.add_new_pass(name, cls)
return cls
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.core.program import PaddleGraph
class Pass(object):
def __init__(self):
self.pattern = PaddleGraph()
self.build_pattern()
class Matcher(object):
def __init__(self):
pass
class PyTorchMatcher(Matcher):
def __init__(self):
super(PyTorchMatcher, self).__init__()
def match_pattern(self, pattern, graph, start_index):
pattern_index = 0
pattern_global_layers = pattern.get_global_layers()
subgraph_global_layers = dict()
graph_layers = dict(list(graph.layers.items())[start_index:])
for layer_id, layer in graph_layers.items():
pattern_layer = pattern.layers[list(pattern.layers.keys())[
pattern_index]]
if layer.kernel == pattern_layer.kernel:
subgraph_global_layers[layer_id] = layer
pattern_layer_id = pattern_layer.id
if layer.kernel == "prim.constant":
if layer.attrs["value"] != pattern_layer.attrs["value"]:
return False
elif layer.kernel == "fluid.layers.addmm":
if layer.attrs["beta"] != pattern_layer.attrs["beta"]:
return False
if layer.attrs["alpha"] != pattern_layer.attrs["alpha"]:
return False
if layer_id in graph.edges_in:
if pattern_layer_id not in pattern.edges_in:
return False
else:
if len(graph.edges_in[layer_id]) != len(
pattern.edges_in[pattern_layer_id]):
return False
layer_in = graph.edges_in[layer_id]
pattern_layer_in = pattern.edges_in[pattern_layer_id]
for i in range(len(layer_in)):
layer_id_in = layer_in[i]
pattern_layer_id_in = pattern_layer_in[i]
if pattern_layer_id_in != -1:
pattern_global_layers_id = list(
pattern_global_layers.keys())
subgraph_global_layers_id = list(
subgraph_global_layers.keys())
if pattern_global_layers_id.index(pattern_layer_id_in) == \
subgraph_global_layers_id.index(layer_id_in):
# 判断pattern输入在pattern_global_layers_id的索引
# 和graph输入在subgraph_global_layers_id的索引一致
continue
return False
if layer_id in graph.edges_out:
if pattern_layer_id not in pattern.edges_out:
if not set(pattern_layer.outputs).issubset(
pattern.outputs):
# 若pattern当前layer的输出是pattern的输出,则是正确的
return False
else:
if len(graph.edges_out[layer_id]) != len(
pattern.edges_out[pattern_layer_id]):
# 如果在每个节点edges_in相同的情况下,edges_out数目相同则说明无节点在subgraph外被用到
if not set(pattern_layer.outputs).issubset(
pattern.outputs):
# 若pattern当前layer的输出是pattern的输出,则是正确的
return False
if layer.kernel == "prim.if":
res = self.match_pattern(pattern_layer.blocks[0],
layer.blocks[0], 0)
if res:
subgraph_global_layers.update(res)
else:
return False
res = self.match_pattern(pattern_layer.blocks[1],
layer.blocks[1], 0)
if res:
subgraph_global_layers.update(res)
else:
return False
pattern_index += 1
if pattern_index == len(pattern.layers):
return subgraph_global_layers
else:
return False
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.core.program import PaddleGraph
class PatternMatcher(object):
def __init__(self, pattern):
self.pattern = pattern
self.subgraphs = list()
def operate(self, graph):
self.detect_patterns(graph)
self.remove_overlapped_match()
return self.subgraphs
def detect_patterns(self, graph):
""" 找到与模式匹配的子图,
并将子图的id以拓扑排序存放到subgraph_id2layers。
"""
def get_subgraph(pattern, graph, start_index):
pattern_index = 0
pattern_id2layers = pattern.get_global_layers()
pattern_ids = list(pattern_id2layers.keys())
subgraph_id2layers = dict()
graph_layers = dict(list(graph.layers.items())[start_index:])
for layer_id, layer in graph_layers.items():
pattern_layer = pattern.layers[list(pattern.layers.keys())[
pattern_index]]
if layer.kernel == pattern_layer.kernel:
subgraph_id2layers[layer_id] = layer
pattern_layer_id = pattern_layer.id
# 判断输入连接是否一致
if layer_id in graph.edges_in:
if pattern_layer_id not in pattern.edges_in:
return False
else:
if len(graph.edges_in[layer_id]) != len(
pattern.edges_in[pattern_layer_id]):
return False
layer_in = graph.edges_in[layer_id]
pattern_layer_in = pattern.edges_in[pattern_layer_id]
for i in range(len(layer_in)):
layer_id_in = layer_in[i]
pattern_layer_id_in = pattern_layer_in[i]
if pattern_layer_id_in != -1:
subgraph_ids = list(subgraph_id2layers.keys())
if pattern_ids.index(pattern_layer_id_in) == \
subgraph_ids.index(layer_id_in):
# 判断pattern输入在pattern_ids的索引
# 和graph输入在subgraph_ids的索引一致
continue
return False
# 判断subgraph中的节点是否被外部图使用到(如若被使用到则无效)
if layer_id in graph.edges_out:
if pattern_layer_id not in pattern.edges_out:
if not set(pattern_layer.outputs).issubset(
pattern.outputs):
# 若pattern当前layer的输出是pattern的输出,则是正确的
return False
else:
if len(graph.edges_out[layer_id]) != len(
pattern.edges_out[pattern_layer_id]):
# 如果在每个节点edges_in相同的情况下,edges_out数目相同则说明无节点在subgraph外被用到
if not set(pattern_layer.outputs).issubset(
pattern.outputs):
# 若pattern当前layer的输出是pattern的输出,则是正确的
return False
# 当为控制流时的处理
if layer.kernel == "prim.if":
match_info = get_subgraph(pattern_layer.blocks[0],
layer.blocks[0], 0)
if match_info:
subgraph_id2layers.update(match_info)
else:
return False
match_info = get_subgraph(pattern_layer.blocks[1],
layer.blocks[1], 0)
if match_info:
subgraph_id2layers.update(match_info)
else:
return False
pattern_index += 1
if pattern_index == len(pattern.layers):
return subgraph_id2layers
else:
return False
for i, (layer_id, layer) in enumerate(graph.layers.items()):
match_info = get_subgraph(self.pattern, graph, i)
if match_info:
self.subgraphs.append(match_info)
for j, block in enumerate(layer.blocks):
if len(block.layers) > 0:
self.detect_patterns(layer.blocks[j])
def remove_overlapped_match(self):
""" 如果2个子图有重叠,只取前一个子图。
"""
match_ids = []
for i, subgraph in enumerate(self.subgraphs):
is_overlapped = False
for id in subgraph.keys():
if id in match_ids:
self.subgraphs.pop(i)
is_overlapped = True
break
if not is_overlapped:
match_ids.extend(list(subgraph.keys()))
class FuseBase(object):
def __init__(self):
self.pattern = PaddleGraph()
def operate(self, graph):
self.build_pattern()
self.perform_pattern_matcher(graph)
for subgraph in self.subgraphs:
self.insert_new_layer(graph, subgraph)
self.delete_inter_layer(graph)
graph.build()
def perform_pattern_matcher(self, graph):
""" 执行模式匹配,找到匹配的子图。
"""
pattern_matcher = PatternMatcher(self.pattern)
self.subgraphs = pattern_matcher.operate(graph)
def delete_inter_layer(self, graph):
""" 删除不需要的中间layer及其对应参数。
"""
for subgraph in self.subgraphs:
for layer_id, layer in subgraph.items():
if layer.kernel == "fluid.dygraph.base.to_variable" and \
layer.attrs["value"].startswith("params["):
param_name = layer.attrs["value"][8:-2]
if param_name in graph.parameters:
graph.parameters.pop(param_name)
if layer_id in graph.layers:
# layer_id可能是属于子图的,此时删除父layer,即删除整个子图
graph.layers.pop(layer_id)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册