提交 34d61263 编写于 作者: S SunAhong1993

add pytorch

上级 5b6f11fd
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import torch
class PyTorchDecoder(object):
def __init__(self, script_path):
self.script = torch.jit.load(script_path)
self.graph = self._optimize_graph(self.script.inlined_graph)
def _optimize_graph(self, graph):
torch._C._jit_pass_constant_propagation(graph)
torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_peephole(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph)
graph = torch._C._jit_pass_canonicalize(graph)
torch._C._jit_pass_lint(graph)
return graph
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from x2paddle.core.util import *
def prim_Constant(mapper, graph, node):
""" 构造constant的PaddleLayer,该节点实现常量赋值。
PyTorch Script 示例:
%2 : int = prim::Constant[value=-1]()
参数含义:
%2 (常量类型由赋值类型定义,该示例中为int型): 常量赋值结果输出。
"""
output_name = mapper._get_outputs_name(node)[0]
node_outputs = [output_name]
output = list(node.outputs())[0]
value = output.toIValue()
mapper.attrs[output_name] = value
if isinstance(value, str):
value = string(value)
graph.add_layer(
"prim.constant", inputs={}, outputs=[output_name], value=value)
return [], node_outputs
def prim_GetAttr(mapper, graph, node):
""" 获取attribute信息。
PyTorch Script 示例:
%27 : Tensor? = prim::GetAttr[name="bias"](%7)
参数含义:
%7 (Tensor): 输入Tensor。
%27 (Tensor): 输入Tensor。
"""
output_name = mapper._get_outputs_name(node)[0]
node_outputs = [output_name]
field_name_list = [node.s('name')]
while True:
input_node = list(node.inputs())[0].node()
try:
field_name_list.insert(0, input_node.s('name'))
node = input_node
except Exception:
break
part_script = mapper.script
for field_name in field_name_list:
if hasattr(part_script, field_name):
param = getattr(part_script, field_name)
if isinstance(param, torch.Tensor):
param = param.detach().numpy()
mapper.pytorch_params[output_name] = param
part_script = param
return [], node_outputs
def prim_ListConstruct(mapper, graph, node):
""" 构造list的PaddleLayer。
PyTorch Script 示例:
%86 : int[] = prim::ListConstruct(%84, %85)
参数含义:
%84 (int/其他): list第一个元素信息。
%85 (int/其他): list第二个元素信息。
%86 (list): list节点输出。
"""
output_name = mapper._get_outputs_name(node)[0]
node_outputs = [output_name]
inputs = {}
for i, input_ivalue in enumerate(node.inputs()):
input_node = input_ivalue.node()
script_input_unique_id = input_ivalue.unique()
input_node_name = mapper.outputs_info[script_input_unique_id]
inputs['input{}'.format(i)] = input_node_name
graph.add_layer("prim.list", inputs=inputs, outputs=[output_name])
return list(inputs.values()), node_outputs
def prim_RaiseException(mapper, graph, node):
""" 构造抛出异常的PaddleLayer。
PyTorch Script 示例:
= prim::RaiseException(%76)
参数含义:
%76 (str): 异常信息。
"""
output_name = mapper._get_outputs_name(node)[0]
node_outputs = [output_name]
input_node = list(node.inputs())[0].node()
script_input_unique_id = list(node.inputs())[0].unique()
input_node_name = mapper.outputs_info[script_input_unique_id]
mapper._check_input(graph, input_node, input_node_name, node_outputs)
graph.add_layer(
"prim.exception",
inputs={'input': input_node_name},
outputs=[output_name])
return [input_node_name], node_outputs
def prim_Loop(mapper, graph, node):
""" 构造loop循环的PaddleLayer。
PyTorch Script 示例:
%x : Tensor = prim::Loop(%4, %3, %x.3)
block0(%i : int, %x.12 : Tensor):
%72 : int[] = prim::Constant[value=[6, 6]]()
...
%x.5 : Tensor = aten::adaptive_avg_pool2d(%x.12, %_output_size.1)
-> (%3, %x.5)
参数含义:
%4 (int): 循环次数。
%3 (bool): 是否进入退出。
%x.3 (Tensor): 循环中修改的Tensor。
%x (Tensor): loop循环的输出,与%x.5对应。
"""
output_name = mapper._get_outputs_name(node)[0]
node_outputs = [output_name]
loop_inputs = {}
block = list(node.blocks())[0]
loop_outputs = [output_name]
for i, block_input_ivalue in enumerate(block.inputs()):
block_input_node_name = 'x' + str(mapper.output_index)
unique_id = block_input_ivalue.unique()
if unique_id not in mapper.outputs_info:
mapper.outputs_info[unique_id] = block_input_node_name
mapper.output_index += 1
if i == 0:
loop_input_node = list(node.inputs())[0].node()
script_loop_input_unique_id = list(node.inputs())[0].unique()
loop_input_node_name = mapper.outputs_info[
script_loop_input_unique_id]
mapper._check_input(graph, loop_input_node, loop_input_node_name,
node_outputs)
loop_inputs['input'] = loop_input_node_name
loop_outputs.append(block_input_node_name)
node_outputs.append(block_input_node_name)
else:
loop_input_node = list(node.inputs())[i + 1].node()
script_loop_input_unique_id = list(node.inputs())[i + 1].unique()
loop_input_node_name = mapper.outputs_info[
script_loop_input_unique_id]
mapper._check_input(graph, loop_input_node, loop_input_node_name,
node_outputs)
graph.add_layer(
"prim.equal",
inputs={'input': loop_input_node_name},
outputs=[block_input_node_name])
node_outputs.append(block_input_node_name)
graph.add_layer("prim.loop", inputs=loop_inputs, outputs=loop_outputs)
current_layer = list(graph.layers.values())[-1]
block_graph, graph_inputs = mapper.traverse(block, node, current_layer)
for i, input_name in enumerate(graph_inputs):
if input_name == loop_outputs[1]:
continue
current_layer.inputs['input-{}'.format(i)] = input_name
current_layer.add_block(block_graph)
return list(current_layer.inputs.values()), node_outputs
def prim_If(mapper, graph, node):
""" 构造if控制流的PaddleLayer。
PyTorch Script 示例:
%input.5 : Tensor = prim::If(%107)
block0():
%109 : Tensor = aten::t(%102)
%ret.2 : Tensor = aten::addmm(%103, %101, %109, %104, %104)
-> (%ret.2)
block1():
%111 : Tensor = aten::t(%102)
...
-> (%output.4)
参数含义:
%107 (bool): if判断条件。
%input.5 (Tensor): if控制流的输出,与%output.4对应。
"""
output_name = mapper._get_outputs_name(node)[0]
node_outputs = [output_name]
input_node = list(node.inputs())[0].node()
script_input_unique_id = list(node.inputs())[0].unique()
input_node_name = mapper.outputs_info[script_input_unique_id]
mapper._check_input(graph, input_node, input_node_name, node_outputs)
graph.add_layer("prim.if", {'input': input_node_name}, [output_name])
current_layer = list(graph.layers.values())[-1]
block0 = list(node.blocks())[0]
block0_graph, graph_inputs0 = mapper.traverse(block0, node, current_layer)
len0 = 0
for i, input_name in enumerate(graph_inputs0):
current_layer.inputs['input-{}'.format(i)] = input_name
len0 = i
current_layer.add_block(block0_graph)
block1 = list(node.blocks())[1]
block1_graph, graph_inputs1 = mapper.traverse(block1, node, current_layer)
for i, input_name in enumerate(graph_inputs1):
current_layer.inputs['input-{}'.format(len0 + 1 + i)] = input_name
current_layer.add_block(block1_graph)
return list(current_layer.inputs.values()), node_outputs
def prim_min(mapper, graph, node):
""" 构造min的PaddleLayer。
PyTorch Script 示例:
%87 : int = prim::min(%86)
参数含义:
%86 (list): 输入。
%87 (int): 输出。
"""
output_name = mapper._get_outputs_name(node)[0]
node_outputs = [output_name]
input_node = list(node.inputs())[0].node()
script_input_unique_id = list(node.inputs())[0].unique()
input_node_name = mapper.outputs_info[script_input_unique_id]
mapper._check_input(graph, input_node, input_node_name, node_outputs)
graph.add_layer(
"prim.min", inputs={'input': input_node_name}, outputs=[output_name])
return [input_node_name], node_outputs
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
from x2paddle.core.op_mapper import OpMapper
from x2paddle.core.util import *
from x2paddle.core.paddle_graph import PaddleGraph
from x2paddle.op_mapper.pytorch2paddle import prim
from x2paddle.op_mapper.pytorch2paddle import aten
class PyTorchOpMapper(OpMapper):
def __init__(self, decoder):
super(PyTorchOpMapper, self).__init__()
self.script = decoder.script
self.paddle_params = dict()
self.outputs_info = {} # key为output unique id,value为当前节点的输出名字
self.pytorch_params = {} # key为节点名,value为参数
self.attrs = {} # key为节点名,value为属性值
self.output_index = 0
self.dygraph_name_id = {} # 动态图__init__输出名字中的id,key为kernel类型,value为id
# 转换
self.graph, _ = self.traverse(decoder.graph)
def traverse(self, script_graph, control_node=None, father_layer=None):
# 用于获取graph的输入
def _update_graph_inputs(inputs, outputs):
current_node_outputs.extend(outputs)
for name in inputs:
if name not in current_node_outputs:
graph_inputs.append(name)
# 初始化
graph = PaddleGraph(father_layer)
current_node_outputs = []
graph_inputs = []
# 转换输入节点
if isinstance(script_graph, torch._C.Graph):
for i, ivalue in enumerate(script_graph.inputs()):
node = ivalue.node()
if str(ivalue.type()) != "Tensor":
graph.set_name(str(ivalue.type()).split(".")[-1])
inputs, outputs = self.data(graph, node, ivalue.unique())
# 转换中间节点
for node in script_graph.nodes():
kind = node.kind()
func_name = kind.replace('::', '_')
if hasattr(prim, func_name):
func = getattr(prim, func_name)
inputs, outputs = func(self, graph, node)
_update_graph_inputs(inputs, outputs)
elif hasattr(aten, func_name):
func = getattr(aten, func_name)
inputs, outputs = func(self, graph, node)
_update_graph_inputs(inputs, outputs)
else:
raise Exception("The kind {} in model is not supported yet.".
format(node.kind()))
# 转换输出节点
if hasattr(script_graph, 'returnNode'):
for i, ivalue in enumerate(script_graph.returnNode().inputs()):
if control_node.kind() == "prim::Loop" and i == 0:
continue
node = ivalue.node()
script_unique_id = ivalue.unique()
inputs, outputs = self.equal(
graph,
node,
uid=script_unique_id,
control_node=control_node,
index=i)
_update_graph_inputs(inputs, outputs)
# 设置graph的参数
if isinstance(script_graph, torch._C.Graph):
graph.set_parameters(self.paddle_params)
return graph, graph_inputs
def _get_outputs_name(self, node):
outputs_name = []
for output_ivalue in node.outputs():
output_name = 'x' + str(self.output_index)
script_unique_id = output_ivalue.unique()
if script_unique_id in self.outputs_info:
output_name = self.outputs_info[script_unique_id]
self.outputs_info[script_unique_id] = output_name
self.output_index += 1
outputs_name.append(output_name)
# if节点没有输出的情况
if len(list(node.outputs())) == 0:
output_name = 'x' + str(self.output_index)
self.output_index += 1
outputs_name.append(output_name)
return outputs_name
def _check_input(self,
graph,
node,
output_name,
node_outputs,
add_dim=False):
if node.kind() == "prim::GetAttr":
param = self.pytorch_params[output_name]
if isinstance(param, np.ndarray):
if add_dim:
param = param[np.newaxis, :]
self.paddle_params[output_name] = param
graph.add_layer(
"fluid.dygraph.base.to_variable",
inputs={},
outputs=[output_name],
value="params[{}]".format(string(output_name)))
else:
graph.add_layer(
"prim.constant",
inputs={},
outputs=[output_name],
value=string(param) if isinstance(param, str) else param)
node_outputs.append(output_name)
def data(self, graph, node, uid):
for output_ivalue in node.outputs():
script_unique_id = output_ivalue.unique()
if script_unique_id in self.outputs_info or script_unique_id != uid:
continue
node_name = 'x' + str(self.output_index)
self.outputs_info[script_unique_id] = node_name
self.output_index += 1
output_name = self.outputs_info[uid]
graph.add_layer(
"fluid.dygraph.base.to_variable",
inputs={},
outputs=[node_name],
value=output_name)
return [], [output_name]
def equal(self, graph, node, uid=None, control_node=None, index=None):
if control_node is not None and index is not None:
kind = control_node.kind()
# block的输出
input_node_name = self.outputs_info[uid]
control_output_id = index
if kind == "prim::Loop":
control_output_id = index - 1
output_ivalue = list(control_node.outputs())[
control_output_id].unique()
output_node_name = self.outputs_info[output_ivalue]
graph.add_layer(
"prim.equal",
inputs={'input': input_node_name},
outputs=[output_node_name])
return [input_node_name], [output_node_name]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from x2paddle.core.util import *
from x2paddle.core.paddle_graph import PaddleLayer, PaddleGraph
from x2paddle.optimizer.passes import Pass, Matcher, PyTorchMatcher
class LinearPass(Pass):
def __init__(self):
self.linear_index = 0
super(LinearPass, self).__init__()
def build_pattern(self):
""" 构造fc层的模式。
fc层模式python实现代码示例:
x149 = 2
x151 = x146.shape
x151 = len(x151)
x152 = x151 == x149
if x152 :
x147 = self.x147
x154 = fluid.layers.transpose(x=x147, perm=[1, 0])
x148 = self.x148
x155 = fluid.layers.addmm(input=x148, x=x146, y=x154, beta=1, alpha=1)
x153 = x155
else:
x147 = self.x147
x157 = fluid.layers.transpose(x=x147, perm=[1, 0])
x158 = fluid.layers.matmul(x=x146, y=x157)
x159 = True
if x159 :
x148 = self.x148
x161 = x158 + 1 * x148
x160 = x161
else:
x160 = x158
x153 = x160
"""
def gen_name(id):
return "x" + str(id)
self.pattern.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(0)], value=2)
self.pattern.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(1)], value=1)
self.pattern.add_layer(
"prim.shape", inputs={'input': "fc-input-0"},
outputs=[gen_name(2)])
self.pattern.add_layer(
"prim.len", inputs={'input': gen_name(2)}, outputs=[gen_name(2)])
self.pattern.add_layer(
"prim.eq",
inputs={"eq0": gen_name(2),
"eq1": gen_name(0)},
outputs=[gen_name(3)])
self.pattern.add_layer("prim.if", {'input': gen_name(3)}, [gen_name(4)])
self.pattern.outputs.append(gen_name(4))
if_layer_a = self.pattern.layers[list(self.pattern.layers.keys())[-1]]
pattern_block0 = PaddleGraph(if_layer_a)
pattern_block0.add_layer(
"fluid.dygraph.base.to_variable",
inputs={},
outputs=[gen_name(5)],
value="params[{}]".format(string(gen_name(5))))
pattern_block0.add_layer(
"fluid.layers.transpose",
inputs={"x": gen_name(5)},
outputs=[gen_name(6)],
perm=[1, 0])
pattern_block0.add_layer(
"fluid.dygraph.base.to_variable",
inputs={},
outputs=[gen_name(7)],
value="params[{}]".format(string(gen_name(7))))
pattern_block0.add_layer(
"fluid.layers.addmm",
inputs={"input": gen_name(7),
"x": "fc-input-0",
"y": gen_name(6)},
outputs=[gen_name(8)],
beta=1,
alpha=1)
if_layer_a.inputs["input-0"] = "fc-input-0"
self.pattern.inputs.append("fc-input-0")
pattern_block0.add_layer(
"prim.equal", inputs={'input': gen_name(8)}, outputs=[gen_name(4)])
if_layer_a.add_block(pattern_block0)
pattern_block1 = PaddleGraph(if_layer_a)
pattern_block1.add_layer(
"fluid.dygraph.base.to_variable",
inputs={},
outputs=[gen_name(5)],
value="params[{}]".format(string(gen_name(5))))
pattern_block1.add_layer(
"fluid.layers.transpose",
inputs={"x": gen_name(5)},
outputs=[gen_name(6)],
perm=[1, 0])
pattern_block1.add_layer(
"fluid.layers.matmul",
inputs={"x": "fc-input-0",
"y": gen_name(6)},
outputs=[gen_name(9)])
if_layer_a.inputs["input-1"] = "fc-input-0"
pattern_block1.add_layer(
"prim.constant", inputs={}, outputs=[gen_name(10)], value=True)
pattern_block1.add_layer("prim.if", {'input': gen_name(10)},
[gen_name(11)])
if_layer_b = pattern_block1.layers[list(pattern_block1.layers.keys())[
-1]]
pattern_block1_block0 = PaddleGraph(if_layer_b)
pattern_block1_block0.add_layer(
"fluid.dygraph.base.to_variable",
inputs={},
outputs=[gen_name(12)],
value="params[{}]".format(string(gen_name(12))))
pattern_block1_block0.add_layer(
"prim.add",
inputs={"x": gen_name(9),
"y": gen_name(12)},
outputs=[gen_name(13)],
alpha=1)
if_layer_b.inputs["input-0"] = gen_name(9)
pattern_block1_block0.add_layer(
"prim.equal",
inputs={'input': gen_name(13)},
outputs=[gen_name(11)])
if_layer_b.add_block(pattern_block1_block0)
pattern_block1_block1 = PaddleGraph(if_layer_b)
pattern_block1_block1.add_layer(
"prim.equal", inputs={'input': gen_name(9)},
outputs=[gen_name(11)])
if_layer_b.inputs["input-1"] = gen_name(9)
pattern_block1.add_layer(
"prim.equal", inputs={'input': gen_name(11)},
outputs=[gen_name(4)])
if_layer_b.add_block(pattern_block1_block1)
if_layer_a.add_block(pattern_block1)
self.pattern.build(
inputs={"input-0": "fc-input-0",
"input-1": "fc-input-0"})
class LinearMatcher(PyTorchMatcher):
def __init__(self):
self.linear_index = 0
super(LinearMatcher, self).__init__()
def replace_layer(self, graph, subgraph_global_layers):
subgraph_global_layers_id = list(subgraph_global_layers.keys())
layer = subgraph_global_layers[subgraph_global_layers_id[2]]
input_name = layer.inputs["input"]
layer = subgraph_global_layers[subgraph_global_layers_id[5]]
output_name = layer.outputs[0]
layer = subgraph_global_layers[subgraph_global_layers_id[6]]
weight_name = layer.attrs["value"][8:-2]
layer = subgraph_global_layers[subgraph_global_layers_id[8]]
bias_name = layer.attrs["value"][8:-2]
attrs = {}
attrs["input_dim"] = graph.parameters[weight_name].shape[1]
attrs["output_dim"] = graph.parameters[weight_name].shape[0]
linear_name = "linear{}".format(self.linear_index)
self.linear_index += 1
graph.parameters["{}.weight".format(linear_name)] = graph.parameters[
weight_name].transpose((1, 0))
graph.parameters["{}.bias".format(linear_name)] = np.squeeze(
graph.parameters[bias_name])
graph.parameters.pop(weight_name)
graph.parameters.pop(bias_name)
for i, layer_id in enumerate(subgraph_global_layers):
if layer_id in graph.layers:
layer = graph.layers[layer_id]
if i == 0:
new_layer = PaddleLayer(
layer_id,
"fluid.dygraph.Linear",
inputs={"input": input_name},
outputs=[linear_name, output_name],
**attrs)
graph.layers[layer_id] = new_layer
else:
graph.layers.pop(layer_id)
graph.build()
return graph
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.linear_pass import LinearPass, LinearMatcher
class GraphOptimizer(object):
def __init__(self):
linear_pass = LinearPass()
linear_matcher = LinearMatcher()
self.passes = {linear_pass: linear_matcher}
def run(self, graph):
is_update_graph = False
while True:
for i, (layer_id, layer) in enumerate(graph.layers.items()):
is_match = self.current_matcher.match_pattern(
self.current_pass.pattern, graph, i)
if is_match:
is_update_graph = True
graph = self.current_matcher.replace_layer(graph, is_match)
break
for j, block in enumerate(layer.blocks):
if len(block.layers) > 0:
layer.blocks[j], is_update_block = self.run(block)
if is_update_block:
break
if i + 1 == len(graph.layers):
return graph, is_update_graph
def optimize(self, graph):
# 开始优化
for _pass, matcher in self.passes.items():
self.current_pass = _pass
self.current_matcher = matcher
graph, _ = self.run(graph)
print("{} done!".format(pa.__class__.__name__))
return graph
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.core.paddle_graph import PaddleGraph
class Pass(object):
def __init__(self):
self.pattern = PaddleGraph()
self.build_pattern()
class Matcher(object):
def __init__(self):
self.unique_id_layer = dict()
class PyTorchMatcher(Matcher):
def __init__(self):
super(PyTorchMatcher, self).__init__()
def match_pattern(self, pattern, graph, start_id):
pattern_index = 0
pattern_global_layers = pattern.get_global_layers()
subgraph_global_layers = dict()
graph_layers = dict(list(graph.layers.items())[start_id:])
for layer_id, layer in graph_layers.items():
pattern_layer = pattern.layers[list(pattern.layers.keys())[
pattern_index]]
if layer.kernel == pattern_layer.kernel:
subgraph_global_layers[layer_id] = layer
pattern_layer_id = pattern_layer.id
if layer.kernel == "prim.constant":
if layer.attrs["value"] != pattern_layer.attrs["value"]:
return False
elif layer.kernel == "fluid.layers.addmm":
if layer.attrs["beta"] != pattern_layer.attrs["beta"]:
return False
if layer.attrs["alpha"] != pattern_layer.attrs["alpha"]:
return False
if layer_id in graph.edges_in:
if pattern_layer_id not in pattern.edges_in:
return False
else:
if len(graph.edges_in[layer_id]) != len(
pattern.edges_in[pattern_layer_id]):
return False
layer_in = graph.edges_in[layer_id]
pattern_layer_in = pattern.edges_in[pattern_layer_id]
for i in range(len(layer_in)):
layer_id_in = layer_in[i]
pattern_layer_id_in = pattern_layer_in[i]
if pattern_layer_id_in != -1:
pattern_global_layers_id = list(
pattern_global_layers.keys())
subgraph_global_layers_id = list(
subgraph_global_layers.keys())
if pattern_global_layers_id.index(pattern_layer_id_in) == \
subgraph_global_layers_id.index(layer_id_in):
# 判断pattern输入在pattern_global_layers_id的索引
# 和graph输入在subgraph_global_layers_id的索引一致
continue
return False
if layer_id in graph.edges_out:
if pattern_layer_id not in pattern.edges_out:
if not set(pattern_layer.outputs).issubset(
pattern.outputs):
# 若pattern当前layer的输出是pattern的输出,则是正确的
return False
else:
if len(graph.edges_out[layer_id]) != len(
pattern.edges_out[pattern_layer_id]):
# 如果在每个节点edges_in相同的情况下,edges_out数目相同则说明无节点在subgraph外被用到
if not set(pattern_layer.outputs).issubset(
pattern.outputs):
# 若pattern当前layer的输出是pattern的输出,则是正确的
return False
if layer.kernel == "prim.if":
res = self.match_pattern(pattern_layer.blocks[0],
layer.blocks[0], 0)
if res:
subgraph_global_layers.update(res)
else:
return False
res = self.match_pattern(pattern_layer.blocks[1],
layer.blocks[1], 0)
if res:
subgraph_global_layers.update(res)
else:
return False
pattern_index += 1
if pattern_index == len(pattern.layers):
return subgraph_global_layers
else:
return False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册