提交 edd42662 编写于 作者: S SunAhong1993

modify the prim2code

上级 f76d0121
...@@ -357,7 +357,8 @@ class PaddleGraph(object): ...@@ -357,7 +357,8 @@ class PaddleGraph(object):
for layer_id, layer in self.layers.items(): for layer_id, layer in self.layers.items():
if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get( if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get(
layer_id, 0) == 0 and layer.kernel != "prim.assert": layer_id, 0) == 0 and layer.kernel != "prim.assert" \
and layer.kernel != "prim.exception":
continue continue
if "dygraph" in layer.kernel: if "dygraph" in layer.kernel:
line = "{}".format( line = "{}".format(
...@@ -396,9 +397,9 @@ class PaddleGraph(object): ...@@ -396,9 +397,9 @@ class PaddleGraph(object):
self.forward_func.extend(gen_codes([line], indent=indent)) self.forward_func.extend(gen_codes([line], indent=indent))
elif "prim" in layer.kernel: elif "prim" in layer.kernel:
func_name = layer.kernel.replace(".", "_") func_name = layer.kernel.replace(".", "_")
from . import convert_prim from x2paddle.op_mapper.pytorch2paddle import prim2code
if hasattr(convert_prim, func_name): if hasattr(prim2code, func_name):
func = getattr(convert_prim, func_name) func = getattr(prim2code, func_name)
func( func(
layer, layer,
indent=indent, indent=indent,
......
...@@ -532,6 +532,41 @@ def aten_dropout(mapper, graph, node): ...@@ -532,6 +532,41 @@ def aten_dropout(mapper, graph, node):
return current_inputs, current_outputs return current_inputs, current_outputs
def aten_dropout_(mapper, graph, node):
""" 构造Dropout的PaddleLayer。
TorchScript示例:
%119 : Tensor = aten::dropout_(%result.3, %117, %118)
参数含义:
%119 (Tensor): Dropout后的Tensor。
%result.3 (Tensor): 输入Tensor。
%118 (bool): 是否是训练阶段。
"""
if "dropout" in mapper.dygraph_name_id:
mapper.dygraph_name_id["dropout"] += 1
else:
mapper.dygraph_name_id["dropout"] = 0
dropout_name = "dropout" + str(mapper.dygraph_name_id["dropout"])
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [dropout_name, output_name]
layer_inputs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%119
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs)
layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values())
graph.add_layer(
"fluid.dygraph.Dropout",
inputs=layer_inputs,
outputs=layer_outputs,
p=0.0)
return current_inputs, current_outputs
def aten_eq(mapper, graph, node): def aten_eq(mapper, graph, node):
""" 构造判断数值是否相等的PaddleLayer。 """ 构造判断数值是否相等的PaddleLayer。
...@@ -994,6 +1029,34 @@ def aten___not__(mapper, graph, node): ...@@ -994,6 +1029,34 @@ def aten___not__(mapper, graph, node):
return current_inputs, current_outputs return current_inputs, current_outputs
def aten_relu(mapper, graph, node):
""" 构造ReLU激活的PaddleLayer。
TorchScript示例:
%result.3 : Tensor = aten::relu(%input.5)
参数含义:
%result.3 (Tensor): 输出,ReLU后的结果。
%result.5 (Tensor): 需要ReLU的Tensor。
注意: inplace这个参数在paddle中未实现
"""
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%result.5
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs)
layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values())
graph.add_layer(
"fluid.layers.relu", inputs=layer_inputs, outputs=layer_outputs)
return current_inputs, current_outputs
def aten_relu_(mapper, graph, node): def aten_relu_(mapper, graph, node):
""" 构造ReLU激活的PaddleLayer。 """ 构造ReLU激活的PaddleLayer。
......
...@@ -24,26 +24,39 @@ def gen_codes(code_list, indent=0): ...@@ -24,26 +24,39 @@ def gen_codes(code_list, indent=0):
return codes return codes
def get_value(layer, key):
""" 进行optimizer后可能把inputs的value直接用数值代替(ConstantFuser),
会把input换成attr,所以需要此处的操作。
"""
if key in layer.inputs:
return layer.inputs[key]
else:
return str(layer.attrs[key])
def prim_add(layer, indent=1, init_func=[], forward_func=[]): def prim_add(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} + {}".format(layer.outputs[0], layer.inputs["x"], line = "{} = {} + {}".format(layer.outputs[0],
layer.inputs["y"]) get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_add_(layer, indent=1, init_func=[], forward_func=[]): def prim_add_(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} + {} * {}".format(layer.outputs[0], layer.inputs["x"], line = "{} = {} + {} * {}".format(layer.outputs[0],
layer.attrs["alpha"], layer.inputs["y"]) get_value(layer, "x"),
layer.attrs["alpha"],
get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_and(layer, indent=1, init_func=[], forward_func=[]): def prim_and(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} and {}".format(layer.outputs[0], layer.inputs["x"], line = "{} = {} and {}".format(layer.outputs[0],
layer.inputs["y"]) get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_append(layer, indent=1, init_func=[], forward_func=[]): def prim_append(layer, indent=1, init_func=[], forward_func=[]):
line = "{}.append({})".format(layer.inputs["list"], layer.inputs["element"]) line = "{}.append({})".format(
get_value(layer, "list"), get_value(layer, "element"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
...@@ -72,23 +85,23 @@ def prim_constant(layer, indent=1, init_func=[], forward_func=[]): ...@@ -72,23 +85,23 @@ def prim_constant(layer, indent=1, init_func=[], forward_func=[]):
def prim_eq(layer, indent=1, init_func=[], forward_func=[]): def prim_eq(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} == {}".format(layer.outputs[0], layer.inputs["x"], line = "{} = {} == {}".format(layer.outputs[0],
layer.inputs["y"]) get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_equal(layer, indent=1, init_func=[], forward_func=[]): def prim_equal(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {}".format(layer.outputs[0], layer.inputs["input"]) line = "{} = {}".format(layer.outputs[0], get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_exception(layer, indent=1, init_func=[], forward_func=[]): def prim_exception(layer, indent=1, init_func=[], forward_func=[]):
line = "raise RaiseException({})".format(layer.inputs["input"]) line = "raise RaiseException({})".format(get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_if(layer, indent=1, init_func=[], forward_func=[]): def prim_if(layer, indent=1, init_func=[], forward_func=[]):
line = "if {} :".format(list(layer.inputs.values())[0]) line = "if {} :".format(get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
block = layer.blocks[0] block = layer.blocks[0]
b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1) b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1)
...@@ -105,45 +118,47 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[]): ...@@ -105,45 +118,47 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[]):
def prim_getitem(layer, indent=1, init_func=[], forward_func=[]): def prim_getitem(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {}[{}]".format(layer.outputs[0], layer.inputs["list"], line = "{} = {}[{}]".format(layer.outputs[0],
layer.inputs["index"]) get_value(layer, "list"),
get_value(layer, "index"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_gt(layer, indent=1, init_func=[], forward_func=[]): def prim_gt(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} > {}".format(layer.outputs[0], layer.inputs["x"], line = "{} = {} > {}".format(layer.outputs[0],
layer.inputs["y"]) get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_le(layer, indent=1, init_func=[], forward_func=[]): def prim_le(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} <= {}".format(layer.outputs[0], layer.inputs["x"], line = "{} = {} <= {}".format(layer.outputs[0],
layer.inputs["y"]) get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_len(layer, indent=1, init_func=[], forward_func=[]): def prim_len(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = len({})".format(layer.outputs[0], layer.inputs["input"]) line = "{} = len({})".format(layer.outputs[0], get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_lt(layer, indent=1, init_func=[], forward_func=[]): def prim_lt(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} < {}".format(layer.outputs[0], layer.inputs["x"], line = "{} = {} < {}".format(layer.outputs[0],
layer.inputs["y"]) get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_list(layer, indent=1, init_func=[], forward_func=[]): def prim_list(layer, indent=1, init_func=[], forward_func=[]):
inputs_list = list(layer.inputs.values()) input_len = len(layer.inputs) + len(layer.attrs)
inputs_list = list()
for i in range(input_len):
inputs_list.append(get_value(layer, "input{}".format(i)))
inputs_str = ', '.join(inputs_list) inputs_str = ', '.join(inputs_list)
line = "{} = [{}]".format(layer.outputs[0], inputs_str) line = "{} = [{}]".format(layer.outputs[0], inputs_str)
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_loop(layer, indent=1, init_func=[], forward_func=[]): def prim_loop(layer, indent=1, init_func=[], forward_func=[]):
loop_range = list(layer.inputs.values())[0] loop_range = get_value(layer, "input")
if list(layer.inputs.values())[0] is None:
loop_range = str(layer.attrs[list(layer.inputs.keys())[0]])
line = "for {} in range({}):".format(layer.outputs[1], loop_range) line = "for {} in range({}):".format(layer.outputs[1], loop_range)
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
block = layer.blocks[0] block = layer.blocks[0]
...@@ -153,66 +168,71 @@ def prim_loop(layer, indent=1, init_func=[], forward_func=[]): ...@@ -153,66 +168,71 @@ def prim_loop(layer, indent=1, init_func=[], forward_func=[]):
def prim_min(layer, indent=1, init_func=[], forward_func=[]): def prim_min(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = min({})".format(layer.outputs[0], layer.inputs["input"]) line = "{} = min({})".format(layer.outputs[0], get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_mul(layer, indent=1, init_func=[], forward_func=[]): def prim_mul(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} * {}".format(layer.outputs[0], layer.inputs["x"], line = "{} = {} * {}".format(layer.outputs[0],
layer.inputs["y"]) get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_ne(layer, indent=1, init_func=[], forward_func=[]): def prim_ne(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} < {}".format(layer.outputs[0], layer.inputs["x"], line = "{} = {} < {}".format(layer.outputs[0],
layer.inputs["y"]) get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_neg(layer, indent=1, init_func=[], forward_func=[]): def prim_neg(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = -{}".format(layer.outputs[0], layer.inputs["input"]) line = "{} = -{}".format(layer.outputs[0], get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_not(layer, indent=1, init_func=[], forward_func=[]): def prim_not(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = not {}".format(layer.outputs[0], layer.inputs["input"]) line = "{} = not {}".format(layer.outputs[0], get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_requires_grad(layer, indent=1, init_func=[], forward_func=[]): def prim_requires_grad(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = not {}.stop_gradient".format(layer.outputs[0], line = "{} = not {}.stop_gradient".format(layer.outputs[0],
layer.inputs["input"]) get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_select(layer, indent=1, init_func=[], forward_func=[]): def prim_select(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {}[".format(layer.outputs[0], layer.inputs["input"]) line = "{} = {}[".format(layer.outputs[0], get_value(layer, "input"))
for dim in range(layer.attrs["dim"]): for dim in range(layer.attrs["dim"]):
line += ":, " line += ":, "
line += (layer.inputs["index"] + "]") line += (get_value(layer, "index") + "]")
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_shape(layer, indent=1, init_func=[], forward_func=[]): def prim_shape(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {}.shape".format(layer.outputs[0], layer.inputs["input"]) line = "{} = {}.shape".format(layer.outputs[0], get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_slice(layer, indent=1, init_func=[], forward_func=[]): def prim_slice(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {}[{}: {}: {}]".format( line = "{} = {}[{}: {}: {}]".format(layer.outputs[0],
layer.outputs[0], layer.inputs["input"], layer.inputs["start"], get_value(layer, "input"),
layer.inputs["end"], layer.inputs["step"]) get_value(layer, "start"),
get_value(layer, "end"),
get_value(layer, "step"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_sub(layer, indent=1, init_func=[], forward_func=[]): def prim_sub(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} - {}".format(layer.outputs[0], layer.inputs["x"], line = "{} = {} - {}".format(layer.outputs[0],
layer.inputs["y"]) get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_tuple(layer, indent=1, init_func=[], forward_func=[]): def prim_tuple(layer, indent=1, init_func=[], forward_func=[]):
inputs_list = list(layer.inputs.values()) input_len = len(layer.inputs) + len(layer.attrs)
inputs_list = list()
for i in range(input_len):
inputs_list.append(get_value(layer, "input{}".format(i)))
inputs_str = ', '.join(inputs_list) inputs_str = ', '.join(inputs_list)
line = "{} = ({})".format(layer.outputs[0], inputs_str) line = "{} = ({})".format(layer.outputs[0], inputs_str)
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
...@@ -220,13 +240,13 @@ def prim_tuple(layer, indent=1, init_func=[], forward_func=[]): ...@@ -220,13 +240,13 @@ def prim_tuple(layer, indent=1, init_func=[], forward_func=[]):
def prim_tuple_unpack(layer, indent=1, init_func=[], forward_func=[]): def prim_tuple_unpack(layer, indent=1, init_func=[], forward_func=[]):
outputs_str = ', '.join(layer.outputs) outputs_str = ', '.join(layer.outputs)
line = "{} = {}".format(outputs_str, layer.inputs["input"]) line = "{} = {}".format(outputs_str, get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_warnings(layer, indent=1, init_func=[], forward_func=[]): def prim_warnings(layer, indent=1, init_func=[], forward_func=[]):
lines = ["import warnings"] lines = ["import warnings"]
line = "warnings.warn({}, stacklevel={})".format(layer.inputs["input"], line = "warnings.warn({}, stacklevel={})".format(
layer.attrs["stacklevel"]) get_value(layer, "input"), layer.attrs["stacklevel"])
lines.append(line) lines.append(line)
forward_func.extend(gen_codes(lines, indent=indent)) forward_func.extend(gen_codes(lines, indent=indent))
...@@ -14,3 +14,5 @@ ...@@ -14,3 +14,5 @@
from .fc_fuser import FcFuser from .fc_fuser import FcFuser
from .fc_fuse_pass import FcFusePass from .fc_fuse_pass import FcFusePass
from .constant_fuser import ConstantFuser
from .constant_fuse_pass import ConstantFusePass
...@@ -12,21 +12,21 @@ ...@@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from x2paddle.optimizer.pass_ import ProgramPass from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion import FcFuser from x2paddle.optimizer.fusion import FcFuser
from x2paddle.optimizer.pass_manager import pass_register from x2paddle.optimizer.pass_manager import pass_register
@pass_register @pass_register
class FcFusePass(ProgramPass): class FcFusePass(Pass):
name = "fc_fuse_pass" name = "fc_fuse_pass"
def __init__(self): def __init__(self):
ProgramPass.__init__(self) Pass.__init__(self)
def apply(self, graph): def apply(self, graph):
fuser = FcFuser() fuser = FcFuser()
fuser.operate(graph) fuser.operate(graph, match_kind="topo")
# 用于注册 # 用于注册
......
...@@ -18,7 +18,7 @@ from x2paddle.optimizer.pass_manager import PassManager ...@@ -18,7 +18,7 @@ from x2paddle.optimizer.pass_manager import PassManager
class GraphOptimizer(object): class GraphOptimizer(object):
def __init__(self): def __init__(self):
self.passes = ["fc_fuse_pass"] self.passes = ["fc_fuse_pass", "constant_fuse_pass"]
def optimize(self, graph): def optimize(self, graph):
for pass_name in self.passes: for pass_name in self.passes:
......
...@@ -12,19 +12,10 @@ ...@@ -12,19 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from enum import Enum
class Kind(Enum):
Program = 1
Code = 2
class Pass(object): class Pass(object):
name = "pass" def __init__(self):
pass
def __init__(self, kind):
self.kind = kind
def apply(self, graph): def apply(self, graph):
raise NotImplementedError("The apply function must be implemented!") raise NotImplementedError("The apply function must be implemented!")
...@@ -32,13 +23,3 @@ class Pass(object): ...@@ -32,13 +23,3 @@ class Pass(object):
@classmethod @classmethod
def get_name(cls): def get_name(cls):
return cls.name return cls.name
class ProgramPass(Pass):
def __init__(self):
super(ProgramPass, self).__init__(Kind.Program)
class CodePass(Pass):
def __init__(self):
super(CodePass, self).__init__(Kind.Code)
...@@ -18,14 +18,18 @@ from x2paddle.core.program import PaddleGraph ...@@ -18,14 +18,18 @@ from x2paddle.core.program import PaddleGraph
class PatternMatcher(object): class PatternMatcher(object):
def __init__(self, pattern): def __init__(self, pattern):
self.pattern = pattern self.pattern = pattern
self.subgraphs = list() # matches的每个match是按照拓扑排序组成layer的dict
self.matches = list()
def operate(self, graph): def operate(self, graph, match_kind="topo"):
self.detect_patterns(graph) if match_kind == "topo":
self.detect_patterns_by_topo(graph)
elif match_kind == "edge":
self.detect_patterns_by_edge(graph)
self.remove_overlapped_match() self.remove_overlapped_match()
return self.subgraphs return self.matches
def detect_patterns(self, graph): def detect_patterns_by_topo(self, graph):
""" 找到与模式匹配的子图, """ 找到与模式匹配的子图,
并将子图的id以拓扑排序存放到subgraph_id2layers。 并将子图的id以拓扑排序存放到subgraph_id2layers。
""" """
...@@ -101,49 +105,79 @@ class PatternMatcher(object): ...@@ -101,49 +105,79 @@ class PatternMatcher(object):
for i, (layer_id, layer) in enumerate(graph.layers.items()): for i, (layer_id, layer) in enumerate(graph.layers.items()):
match_info = get_subgraph(self.pattern, graph, i) match_info = get_subgraph(self.pattern, graph, i)
if match_info: if match_info:
self.subgraphs.append(match_info) self.matches.append(match_info)
for j, block in enumerate(layer.blocks): for j, block in enumerate(layer.blocks):
if len(block.layers) > 0: if len(block.layers) > 0:
self.detect_patterns(layer.blocks[j]) self.detect_patterns_by_topo(layer.blocks[j])
def detect_patterns_by_edge(self, graph):
"""当遇见顺序没有强制规定的pattern时使用该方式
"""
pass
def remove_overlapped_match(self): def remove_overlapped_match(self):
""" 如果2个子图有重叠,只取前一个子图。 """ 如果2个子图有重叠,只取前一个子图。
""" """
match_ids = [] match_ids = []
for i, subgraph in enumerate(self.subgraphs): for i, match in enumerate(self.matches):
is_overlapped = False is_overlapped = False
for id in subgraph.keys(): for id in match.keys():
if id in match_ids: if id in match_ids:
self.subgraphs.pop(i) self.matches.pop(i)
is_overlapped = True is_overlapped = True
break break
if not is_overlapped: if not is_overlapped:
match_ids.extend(list(subgraph.keys())) match_ids.extend(list(match.keys()))
def get_subgraph(prefix_layer_id, suffix_layer_id, graph):
""" 根据prefix_layer_id和suffix_layer_id获取需要子图。
Args:
prefix_layer_id (str): 起初为一个空字符串,之后为suffix_layer_id分割出来的前缀。
suffix_layer_id (str): 起初为以一个layer的id,之后将分割部分给prefix_layer_id;例如”57.0.1“;
graph (x2paddle.core.program.PaddleGraph): 需要进行pass的子图。
"""
id_part = suffix_layer_id.split(".")
if len(id_part) == 1:
return graph
if prefix_layer_id == "":
layer_id = id_part[0]
prefix_layer_id += ".".join(id_part[:2])
else:
layer_id = prefix_layer_id + "." + id_part[0]
prefix_layer_id += ("." + ".".join(id_part[:2]))
subgraph = graph.layers[layer_id].blocks[int(id_part[1])]
suffix_layer_id = ".".join(id_part[2:])
return get_subgraph(prefix_layer_id, suffix_layer_id, subgraph)
class FuseBase(object): class FuseBase(object):
def __init__(self): def __init__(self):
self.pattern = PaddleGraph() self.pattern = PaddleGraph()
def operate(self, graph): def operate(self, graph, match_kind="topo"):
self.build_pattern() self.build_pattern()
self.perform_pattern_matcher(graph) self.perform_pattern_matcher(graph, match_kind)
for subgraph in self.subgraphs: for match in self.matches:
self.insert_new_layer(graph, subgraph) first_layer_id = list(match.keys())[0]
subgraph = get_subgraph("", first_layer_id, graph)
self.insert_new_layer(subgraph, match)
self.delete_inter_layer(graph) self.delete_inter_layer(graph)
graph.build() graph.build()
def perform_pattern_matcher(self, graph): def perform_pattern_matcher(self, graph, match_kind="topo"):
""" 执行模式匹配,找到匹配的子图。 """ 执行模式匹配,找到匹配的子图。
""" """
pattern_matcher = PatternMatcher(self.pattern) pattern_matcher = PatternMatcher(self.pattern)
self.subgraphs = pattern_matcher.operate(graph) self.matches = pattern_matcher.operate(graph, match_kind)
def delete_inter_layer(self, graph): def delete_inter_layer(self, graph):
""" 删除不需要的中间layer及其对应参数。 """ 删除不需要的中间layer及其对应参数。
""" """
for subgraph in self.subgraphs: for match in self.matches:
for layer_id, layer in subgraph.items(): first_layer_id = list(match.keys())[0]
subgraph = get_subgraph("", first_layer_id, graph)
for layer_id, layer in match.items():
if layer.kernel == "fluid.dygraph.base.to_variable" and \ if layer.kernel == "fluid.dygraph.base.to_variable" and \
layer.attrs["value"].startswith("params["): layer.attrs["value"].startswith("params["):
param_name = layer.attrs["value"][8:-2] param_name = layer.attrs["value"][8:-2]
...@@ -151,4 +185,4 @@ class FuseBase(object): ...@@ -151,4 +185,4 @@ class FuseBase(object):
graph.parameters.pop(param_name) graph.parameters.pop(param_name)
if layer_id in graph.layers: if layer_id in graph.layers:
# layer_id可能是属于子图的,此时删除父layer,即删除整个子图 # layer_id可能是属于子图的,此时删除父layer,即删除整个子图
graph.layers.pop(layer_id) subgraph.layers.pop(layer_id)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册