From 34d612630f047b7672061e5d25b4119fc8afb64b Mon Sep 17 00:00:00 2001 From: SunAhong1993 Date: Tue, 11 Aug 2020 17:15:51 +0800 Subject: [PATCH] add pytorch --- x2paddle/decoder/pytorch_decoder.py | 35 ++ x2paddle/op_mapper/pytorch2paddle/__init__.py | 0 x2paddle/op_mapper/pytorch2paddle/aten.py | 520 ++++++++++++++++++ x2paddle/op_mapper/pytorch2paddle/prim.py | 230 ++++++++ .../pytorch2paddle/pytorch_op_mapper.py | 163 ++++++ x2paddle/optimizer/linear_pass.py | 198 +++++++ x2paddle/optimizer/optimizer.py | 49 ++ x2paddle/optimizer/passes.py | 109 ++++ 8 files changed, 1304 insertions(+) create mode 100644 x2paddle/decoder/pytorch_decoder.py create mode 100644 x2paddle/op_mapper/pytorch2paddle/__init__.py create mode 100644 x2paddle/op_mapper/pytorch2paddle/aten.py create mode 100644 x2paddle/op_mapper/pytorch2paddle/prim.py create mode 100644 x2paddle/op_mapper/pytorch2paddle/pytorch_op_mapper.py create mode 100644 x2paddle/optimizer/linear_pass.py create mode 100644 x2paddle/optimizer/optimizer.py create mode 100644 x2paddle/optimizer/passes.py diff --git a/x2paddle/decoder/pytorch_decoder.py b/x2paddle/decoder/pytorch_decoder.py new file mode 100644 index 0000000..60e4677 --- /dev/null +++ b/x2paddle/decoder/pytorch_decoder.py @@ -0,0 +1,35 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import torch + + +class PyTorchDecoder(object): + def __init__(self, script_path): + self.script = torch.jit.load(script_path) + self.graph = self._optimize_graph(self.script.inlined_graph) + + def _optimize_graph(self, graph): + torch._C._jit_pass_constant_propagation(graph) + torch._C._jit_pass_dce(graph) + torch._C._jit_pass_lint(graph) + torch._C._jit_pass_peephole(graph) + torch._C._jit_pass_lint(graph) + torch._C._jit_pass_dce(graph) + torch._C._jit_pass_lint(graph) + graph = torch._C._jit_pass_canonicalize(graph) + torch._C._jit_pass_lint(graph) + return graph diff --git a/x2paddle/op_mapper/pytorch2paddle/__init__.py b/x2paddle/op_mapper/pytorch2paddle/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/x2paddle/op_mapper/pytorch2paddle/aten.py b/x2paddle/op_mapper/pytorch2paddle/aten.py new file mode 100644 index 0000000..15f5657 --- /dev/null +++ b/x2paddle/op_mapper/pytorch2paddle/aten.py @@ -0,0 +1,520 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from x2paddle.core.util import * + + +def aten_adaptive_avg_pool2d(mapper, graph, node): + """ 构造average adaptive pool2d的PaddleLayer。 + + PyTorch Script 示例: + %x.5 : Tensor = aten::adaptive_avg_pool2d(%x.3, %_output_size.1) + 参数含义: + %x.5 (Tensor): 池化后结果Tensor。 + %x.3 (Tensor): 输入Tensor。 + %_output_size.1 (list): 自适应池化后的Tensor的宽、高大小。 + """ + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + adapoo2d_inputs = [] + input_node = list(node.inputs())[0].node() + script_input_unique_id = list(node.inputs())[0].unique() + input_node_name = mapper.outputs_info[script_input_unique_id] + mapper._check_input(graph, input_node, input_node_name, node_outputs) + adapoo2d_inputs.append(input_node_name) + attr_node = list(node.inputs())[1].node() + attr_unique_id = list(node.inputs())[1].unique() + attr_node_name = mapper.outputs_info[attr_unique_id] + attrs = {} + attrs["pool_size"] = mapper.attrs[ + attr_node_name] if attr_node_name in mapper.attrs else attr_node_name + if attr_node_name not in mapper.attrs: + adapoo2d_inputs.append(attr_node_name) + attrs["pool_type"] = string("avg") + graph.add_layer( + "fluid.layers.adaptive_pool2d", + inputs={"input": input_node_name}, + outputs=[output_name], + **attrs) + return [input_node_name], node_outputs + + +def aten_addmm(mapper, graph, node): + """ 构造addmm的PaddleLayer,该节点实现out = alpha ∗ x ∗ y + beta ∗ input。 + + PyTorch Script 示例: + %ret.2 : Tensor = aten::addmm(%150, %input.3, %156, %151, %152) + 参数含义: + %ret.2 (Tensor): addmm结果Tensor。 + %150 (Tensor): 输入Tensor input。 + %input.3 (Tensor): 输入Tensor x。 + %156 (Tensor): 输入Tensor y。 + %151 (int/float): 输入alpha。 + %152 (int/float): 输入beta。 + """ + output_name = mapper._get_outputs_name(node)[0] + inputs = {} + attrs = {} + addmm_inputs = [] + node_outputs = [output_name] + input_node = list(node.inputs())[0].node() + script_input_unique_id = list(node.inputs())[0].unique() + input_node_name = mapper.outputs_info[script_input_unique_id] + mapper._check_input( + graph, input_node, input_node_name, node_outputs, add_dim=True) + inputs['input'] = input_node_name + addmm_inputs.append(input_node_name) + x_node = list(node.inputs())[1].node() + x_unique_id = list(node.inputs())[1].unique() + x_node_name = mapper.outputs_info[x_unique_id] + mapper._check_input(graph, x_node, x_node_name, node_outputs) + inputs['x'] = x_node_name + addmm_inputs.append(x_node_name) + y_node = list(node.inputs())[2].node() + y_unique_id = list(node.inputs())[2].unique() + y_node_name = mapper.outputs_info[y_unique_id] + mapper._check_input(graph, y_node, y_node_name, node_outputs) + inputs['y'] = y_node_name + addmm_inputs.append(y_node_name) + beta_node = list(node.inputs())[3].node() + beta_unique_id = list(node.inputs())[3].unique() + beta_node_name = mapper.outputs_info[beta_unique_id] + attrs['beta'] = mapper.attrs[ + beta_node_name] if beta_node_name in mapper.attrs else beta_node_name + if beta_node_name not in mapper.attrs: + addmm_inputs.append(beta_node_name) + alpha_node = list(node.inputs())[4].node() + alpha_unique_id = list(node.inputs())[4].unique() + alpha_node_name = mapper.outputs_info[alpha_unique_id] + attrs['alpha'] = mapper.attrs[ + alpha_node_name] if alpha_node_name in mapper.attrs else alpha_node_name + if alpha_node_name not in mapper.attrs: + addmm_inputs.append(alpha_node_name) + graph.add_layer( + "fluid.layers.addmm", inputs=inputs, outputs=[output_name], **attrs) + return addmm_inputs, node_outputs + + +def aten_add_(mapper, graph, node): + """ 构造add的PaddleLayer,该节点实现out = x + alpha * y。 + + PyTorch Script 示例: + %output.5 : Tensor = aten::add_(%output.2, %150, %151) + 参数含义: + %output.5 (Tensor): add结果Tensor。 + %output.2 (Tensor): 输入Tensor x。 + %150 (Tensor): 输入Tensor y。 + %151 (int/float): 输入alpha。 + """ + output_name = mapper._get_outputs_name(node)[0] + inputs = {} + attrs = {} + add_inputs = [] + node_outputs = [output_name] + x_node = list(node.inputs())[0].node() + x_unique_id = list(node.inputs())[0].unique() + x_node_name = mapper.outputs_info[x_unique_id] + mapper._check_input(graph, x_node, x_node_name, node_outputs) + inputs['x'] = x_node_name + add_inputs.append(x_node_name) + y_node = list(node.inputs())[1].node() + y_unique_id = list(node.inputs())[1].unique() + y_node_name = mapper.outputs_info[y_unique_id] + mapper._check_input(graph, y_node, y_node_name, node_outputs, add_dim=True) + inputs['y'] = y_node_name + add_inputs.append(y_node_name) + alpha_node = list(node.inputs())[2].node() + alpha_unique_id = list(node.inputs())[2].unique() + alpha_node_name = mapper.outputs_info[alpha_unique_id] + attrs['alpha'] = mapper.attrs[ + alpha_node_name] if alpha_node_name in mapper.attrs else alpha_node_name + if alpha_node_name not in mapper.attrs: + add_inputs.append(alpha_node_name) + graph.add_layer("prim.add", inputs=inputs, outputs=[output_name], **attrs) + return add_inputs, node_outputs + + +def aten_append(mapper, graph, node): + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + inputs = {} + for i, input_ivalue in enumerate(node.inputs()): + input_node = input_ivalue.node() + input_unique_id = input_ivalue.unique() + input_node_name = mapper.outputs_info[input_unique_id] + mapper._check_input(graph, input_node, input_node_name, node_outputs) + if i == 0: + inputs['list'] = input_node_name + else: + inputs['element'] = input_node_name + graph.add_layer("prim.append", inputs=inputs, outputs=[output_name]) + return list(inputs.values()), node_outputs + + +def aten_conv2d(mapper, graph, node): + output_name = mapper._get_outputs_name(node)[0] + inputs = {} + attrs = {} + conv2d_inputs = [] + node_outputs = [output_name] + if "conv" in mapper.dygraph_name_id: + mapper.dygraph_name_id["conv"] += 1 + else: + mapper.dygraph_name_id["conv"] = 0 + conv2d_name = "conv" + str(mapper.dygraph_name_id["conv"]) + # 输入input + input_node = list(node.inputs())[0].node() + input_unique_id = list(node.inputs())[0].unique() + input_node_name = mapper.outputs_info[input_unique_id] + inputs['input'] = input_node_name + conv2d_inputs.append(input_node_name) + # 输入weight + weight_node = list(node.inputs())[1].node() + weight_unique_id = list(node.inputs())[1].unique() + weight_node_name = mapper.outputs_info[weight_unique_id] + weights = mapper.pytorch_params[weight_node_name] + mapper.paddle_params[conv2d_name + '.weight'] = weights + attrs['num_filters'] = weights.shape[0] + attrs['filter_size'] = weights.shape[2:] + # 输入bias + bias_node = list(node.inputs())[2].node() + bias_unique_id = list(node.inputs())[2].unique() + bias_node_name = mapper.outputs_info[bias_unique_id] + if bias_node_name in mapper.pytorch_params: + bias = mapper.pytorch_params[bias_node_name] + mapper.paddle_params[conv2d_name + '.bias'] = bias + else: + mapper.paddle_params[conv2d_name + '.bias'] = False + # 输入stride + stride_node = list(node.inputs())[3].node() + stride_unique_id = list(node.inputs())[3].unique() + stride_node_name = mapper.outputs_info[stride_unique_id] + attrs['stride'] = mapper.attrs[stride_node_name] + # 输入padding + padding_node = list(node.inputs())[4].node() + padding_unique_id = list(node.inputs())[4].unique() + padding_node_name = mapper.outputs_info[padding_unique_id] + attrs['padding'] = mapper.attrs[padding_node_name] + # 输入dilation + dilation_node = list(node.inputs())[5].node() + dilation_unique_id = list(node.inputs())[5].unique() + dilation_node_name = mapper.outputs_info[dilation_unique_id] + attrs['dilation'] = mapper.attrs[dilation_node_name] + # 输入group + groups_node = list(node.inputs())[6].node() + groups_unique_id = list(node.inputs())[6].unique() + groups_node_name = mapper.outputs_info[groups_unique_id] + attrs['groups'] = mapper.attrs[groups_node_name] + attrs['num_channels'] = weights.shape[1] * mapper.attrs[groups_node_name] + graph.add_layer( + "fluid.dygraph.Conv2D", + inputs=inputs, + outputs=[conv2d_name, output_name], + **attrs) + return conv2d_inputs, node_outputs + + +def aten_dim(mapper, graph, node): + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + input_node = list(node.inputs())[0].node() + input_unique_id = list(node.inputs())[0].unique() + input_node_name = mapper.outputs_info[input_unique_id] + mapper._check_input(graph, input_node, input_node_name, node_outputs) + graph.add_layer( + "prim.shape", inputs={'input': input_node_name}, outputs=[output_name]) + graph.add_layer( + "prim.len", inputs={'input': output_name}, outputs=[output_name]) + return [input_node_name], node_outputs + + +def aten_dropout(mapper, graph, node): + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + if "dropout" in mapper.dygraph_name_id: + mapper.dygraph_name_id["dropout"] += 1 + else: + mapper.dygraph_name_id["dropout"] = 0 + dropout_name = "dropout" + str(mapper.dygraph_name_id["dropout"]) + input_node = list(node.inputs())[0].node() + input_unique_id = list(node.inputs())[0].unique() + input_node_name = mapper.outputs_info[input_unique_id] + mapper._check_input(graph, input_node, input_node_name, node_outputs) + graph.add_layer( + "fluid.dygraph.Dropout", + inputs={"input": input_node_name}, + outputs=[dropout_name, output_name], + p=0.0) + return [input_node_name], node_outputs + + +def aten_eq(mapper, graph, node): + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + inputs = {} + eq_inputs = [] + for i, input_ivalue in enumerate(node.inputs()): + input_node = input_ivalue.node() + input_unique_id = input_ivalue.unique() + input_node_name = mapper.outputs_info[input_unique_id] + mapper._check_input(graph, input_node, input_node_name, node_outputs) + inputs['eq{}'.format(i)] = input_node_name + eq_inputs.append(input_node_name) + graph.add_layer("prim.eq", inputs=inputs, outputs=[output_name]) + return list(inputs.values()), node_outputs + + +def aten_flatten(mapper, graph, node): + # 目前只支持第一维的flatten + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + flatten_inputs = [] + for i, input_ivalue in enumerate(node.inputs()): + if i == 0: + continue + input_node = input_ivalue.node() + input_unique_id = input_ivalue.unique() + input_node_name = mapper.outputs_info[input_unique_id] + mapper._check_input(graph, input_node, input_node_name, node_outputs) + graph.add_layer( + "prim.assert", + inputs={}, + outputs=[output_name + '_assert'], + type='eq', + key=mapper.attrs[input_node_name], + value=1 if i == 1 else -1) + flatten_inputs.append(input_node_name) + input_node = list(node.inputs())[0].node() + input_unique_id = list(node.inputs())[0].unique() + input_node_name = mapper.outputs_info[input_unique_id] + mapper._check_input(graph, input_node, input_node_name, node_outputs) + graph.add_layer( + "fluid.layers.flatten", + inputs={'x': input_node_name}, + outputs=[output_name], + axis=1) + flatten_inputs.append(input_node_name) + return flatten_inputs, node_outputs + + +def aten___getitem__(mapper, graph, node): + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + inputs = {} + for i, input_ivalue in enumerate(node.inputs()): + input_node = input_ivalue.node() + input_unique_id = input_ivalue.unique() + input_node_name = mapper.outputs_info[input_unique_id] + mapper._check_input(graph, input_node, input_node_name, node_outputs) + if i == 0: + inputs['list'] = input_node_name + else: + inputs['index'] = input_node_name + graph.add_layer("prim.getitem", inputs=inputs, outputs=[output_name]) + return list(inputs.values()), node_outputs + + +def aten_le(mapper, graph, node): + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + inputs = {} + for i, input_ivalue in enumerate(node.inputs()): + input_node = input_ivalue.node() + input_unique_id = input_ivalue.unique() + input_node_name = mapper.outputs_info[input_unique_id] + mapper._check_input(graph, input_node, input_node_name, node_outputs) + inputs['input{}'.format(i)] = input_node_name + graph.add_layer("prim.le", inputs=inputs, outputs=[output_name]) + return list(inputs.values()), node_outputs + + +def aten_len(mapper, graph, node): + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + input_node = list(node.inputs())[0].node() + input_unique_id = list(node.inputs())[0].unique() + input_node_name = mapper.outputs_info[input_unique_id] + mapper._check_input(graph, input_node, input_node_name, node_outputs) + graph.add_layer( + "prim.len", inputs={'input': input_node_name}, outputs=[output_name]) + return [input_node_name], node_outputs + + +def aten_max_pool2d(mapper, graph, node): + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + inputs = {} + attrs = {} + pool_inputs = [] + if "pool" in mapper.dygraph_name_id: + mapper.dygraph_name_id["pool"] += 1 + else: + mapper.dygraph_name_id["pool"] = 0 + pool_name = "pool" + str(mapper.dygraph_name_id["pool"]) + for i, input_ivalue in enumerate(node.inputs()): + input_node = input_ivalue.node() + input_unique_id = input_ivalue.unique() + input_node_name = mapper.outputs_info[input_unique_id] + if i == 0: + mapper._check_input(graph, input_node, input_node_name, + node_outputs) + inputs['input'] = input_node_name + pool_inputs.append(input_node_name) + elif i == 1: + attrs['pool_size'] = mapper.attrs[input_node_name] + elif i == 2: + attrs['pool_stride'] = mapper.attrs[input_node_name] + elif i == 3: + attrs['pool_padding'] = mapper.attrs[input_node_name] + elif i == 4: + graph.add_layer( + "prim.assert", + inputs={}, + outputs=[output_name + '_assert'], + type='eq', + key=mapper.attrs[input_node_name], + value=[1, [1, 1]]) + pool_inputs.append(input_node_name) + elif i == 5: + attrs['ceil_mode'] = mapper.attrs[ + input_node_name] if input_node_name in mapper.attrs else input_node_name + if input_node_name not in mapper.attrs: + pool_inputs.append(input_node_name) + attrs['pool_type'] = string('max') + graph.add_layer( + "fluid.dygraph.Pool2D", + inputs=inputs, + outputs=[pool_name, output_name], + **attrs) + return pool_inputs, node_outputs + + +def aten_matmul(mapper, graph, node): + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + inputs = {} + x_node = list(node.inputs())[0].node() + x_unique_id = list(node.inputs())[0].unique() + x_node_name = mapper.outputs_info[x_unique_id] + mapper._check_input(graph, x_node, x_node_name, node_outputs) + inputs['x'] = x_node_name + y_node = list(node.inputs())[1].node() + y_unique_id = list(node.inputs())[1].unique() + y_node_name = mapper.outputs_info[y_unique_id] + inputs['y'] = y_node_name + mapper._check_input(graph, y_node, y_node_name, node_outputs) + graph.add_layer("fluid.layers.matmul", inputs=inputs, outputs=[output_name]) + return list(inputs.values()), node_outputs + + +def aten_relu_(mapper, graph, node): + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + input_node = list(node.inputs())[0].node() + input_unique_id = list(node.inputs())[0].unique() + input_node_name = mapper.outputs_info[input_unique_id] + mapper._check_input(graph, input_node, input_node_name, node_outputs) + # inplace这个参数在paddle中未实现 + graph.add_layer( + "fluid.layers.relu", + inputs={"x": input_node_name}, + outputs=[output_name]) + return [input_node_name], node_outputs + + +def aten_relu6(mapper, graph, node): + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + input_node = list(node.inputs())[0].node() + input_unique_id = list(node.inputs())[0].unique() + input_node_name = mapper.outputs_info[input_unique_id] + mapper._check_input(graph, input_node, input_node_name, node_outputs) + # inplace这个参数在paddle中未实现 + graph.add_layer( + "fluid.layers.relu6", + inputs={"x": input_node_name}, + outputs=[output_name], + threshold=6.0) + return [input_node_name], node_outputs + + +def aten_size(mapper, graph, node): + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + input_node = list(node.inputs())[0].node() + input_unique_id = list(node.inputs())[0].unique() + input_node_name = mapper.outputs_info[input_unique_id] + mapper._check_input(graph, input_node, input_node_name, node_outputs) + graph.add_layer( + "prim.shape", inputs={'input': input_node_name}, outputs=[output_name]) + return [input_node_name], node_outputs + + +def aten_slice(mapper, graph, node): + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + attrs = {} + slice_inputs = [] + input_node = list(node.inputs())[0].node() + input_unique_id = list(node.inputs())[0].unique() + input_node_name = mapper.outputs_info[input_unique_id] + slice_inputs.append(input_node_name) + strat_node = list(node.inputs())[1].node() + start_unique_id = list(node.inputs())[1].unique() + start_node_name = mapper.outputs_info[start_unique_id] + slice_inputs.append(start_node_name) + attrs['start'] = mapper.attrs[ + start_node_name] if start_node_name in mapper.attrs else start_node_name + if start_node_name not in mapper.attrs: + mapper._check_input(graph, strat_node, start_node_name, node_outputs) + slice_inputs.append(input_node_name) + end_node = list(node.inputs())[2].node() + end_unique_id = list(node.inputs())[2].unique() + end_node_name = mapper.outputs_info[end_unique_id] + slice_inputs.append(end_node_name) + attrs['end'] = mapper.attrs[ + end_node_name] if end_node_name in mapper.attrs else end_node_name + if end_node_name not in mapper.attrs: + mapper._check_input(graph, end_node, end_node_name, node_outputs) + slice_inputs.append(end_node_name) + step_node = list(node.inputs())[3].node() + step_unique_id = list(node.inputs())[3].unique() + step_node_name = mapper.outputs_info[step_unique_id] + slice_inputs.append(step_node_name) + attrs['step'] = mapper.attrs[ + step_node_name] if step_node_name in mapper.attrs else step_node_name + if step_node_name not in mapper.attrs: + mapper._check_input(graph, step_node, step_node_name, node_outputs) + slice_inputs.append(step_node_name) + graph.add_layer( + "prim.slice", + inputs={'input': input_node_name}, + outputs=[output_name], + **attrs) + return [input_node_name], node_outputs + + +def aten_t(mapper, graph, node): + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + input_node = list(node.inputs())[0].node() + input_unique_id = list(node.inputs())[0].unique() + input_node_name = mapper.outputs_info[input_unique_id] + mapper._check_input(graph, input_node, input_node_name, node_outputs) + graph.add_layer( + "fluid.layers.transpose", + inputs={"x": input_node_name}, + outputs=[output_name], + perm=[1, 0]) + return [input_node_name], node_outputs diff --git a/x2paddle/op_mapper/pytorch2paddle/prim.py b/x2paddle/op_mapper/pytorch2paddle/prim.py new file mode 100644 index 0000000..744c540 --- /dev/null +++ b/x2paddle/op_mapper/pytorch2paddle/prim.py @@ -0,0 +1,230 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from x2paddle.core.util import * + + +def prim_Constant(mapper, graph, node): + """ 构造constant的PaddleLayer,该节点实现常量赋值。 + + PyTorch Script 示例: + %2 : int = prim::Constant[value=-1]() + 参数含义: + %2 (常量类型由赋值类型定义,该示例中为int型): 常量赋值结果输出。 + """ + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + output = list(node.outputs())[0] + value = output.toIValue() + mapper.attrs[output_name] = value + if isinstance(value, str): + value = string(value) + graph.add_layer( + "prim.constant", inputs={}, outputs=[output_name], value=value) + return [], node_outputs + + +def prim_GetAttr(mapper, graph, node): + """ 获取attribute信息。 + + PyTorch Script 示例: + %27 : Tensor? = prim::GetAttr[name="bias"](%7) + 参数含义: + %7 (Tensor): 输入Tensor。 + %27 (Tensor): 输入Tensor。 + """ + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + field_name_list = [node.s('name')] + while True: + input_node = list(node.inputs())[0].node() + try: + field_name_list.insert(0, input_node.s('name')) + node = input_node + except Exception: + break + part_script = mapper.script + for field_name in field_name_list: + if hasattr(part_script, field_name): + param = getattr(part_script, field_name) + if isinstance(param, torch.Tensor): + param = param.detach().numpy() + mapper.pytorch_params[output_name] = param + part_script = param + return [], node_outputs + + +def prim_ListConstruct(mapper, graph, node): + """ 构造list的PaddleLayer。 + + PyTorch Script 示例: + %86 : int[] = prim::ListConstruct(%84, %85) + 参数含义: + %84 (int/其他): list第一个元素信息。 + %85 (int/其他): list第二个元素信息。 + %86 (list): list节点输出。 + """ + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + inputs = {} + for i, input_ivalue in enumerate(node.inputs()): + input_node = input_ivalue.node() + script_input_unique_id = input_ivalue.unique() + input_node_name = mapper.outputs_info[script_input_unique_id] + inputs['input{}'.format(i)] = input_node_name + graph.add_layer("prim.list", inputs=inputs, outputs=[output_name]) + return list(inputs.values()), node_outputs + + +def prim_RaiseException(mapper, graph, node): + """ 构造抛出异常的PaddleLayer。 + + PyTorch Script 示例: + = prim::RaiseException(%76) + 参数含义: + %76 (str): 异常信息。 + """ + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + input_node = list(node.inputs())[0].node() + script_input_unique_id = list(node.inputs())[0].unique() + input_node_name = mapper.outputs_info[script_input_unique_id] + mapper._check_input(graph, input_node, input_node_name, node_outputs) + graph.add_layer( + "prim.exception", + inputs={'input': input_node_name}, + outputs=[output_name]) + return [input_node_name], node_outputs + + +def prim_Loop(mapper, graph, node): + """ 构造loop循环的PaddleLayer。 + + PyTorch Script 示例: + %x : Tensor = prim::Loop(%4, %3, %x.3) + block0(%i : int, %x.12 : Tensor): + %72 : int[] = prim::Constant[value=[6, 6]]() + ... + %x.5 : Tensor = aten::adaptive_avg_pool2d(%x.12, %_output_size.1) + -> (%3, %x.5) + 参数含义: + %4 (int): 循环次数。 + %3 (bool): 是否进入退出。 + %x.3 (Tensor): 循环中修改的Tensor。 + %x (Tensor): loop循环的输出,与%x.5对应。 + """ + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + loop_inputs = {} + block = list(node.blocks())[0] + loop_outputs = [output_name] + for i, block_input_ivalue in enumerate(block.inputs()): + block_input_node_name = 'x' + str(mapper.output_index) + unique_id = block_input_ivalue.unique() + if unique_id not in mapper.outputs_info: + mapper.outputs_info[unique_id] = block_input_node_name + mapper.output_index += 1 + if i == 0: + loop_input_node = list(node.inputs())[0].node() + script_loop_input_unique_id = list(node.inputs())[0].unique() + loop_input_node_name = mapper.outputs_info[ + script_loop_input_unique_id] + mapper._check_input(graph, loop_input_node, loop_input_node_name, + node_outputs) + loop_inputs['input'] = loop_input_node_name + loop_outputs.append(block_input_node_name) + node_outputs.append(block_input_node_name) + else: + loop_input_node = list(node.inputs())[i + 1].node() + script_loop_input_unique_id = list(node.inputs())[i + 1].unique() + loop_input_node_name = mapper.outputs_info[ + script_loop_input_unique_id] + mapper._check_input(graph, loop_input_node, loop_input_node_name, + node_outputs) + graph.add_layer( + "prim.equal", + inputs={'input': loop_input_node_name}, + outputs=[block_input_node_name]) + node_outputs.append(block_input_node_name) + + graph.add_layer("prim.loop", inputs=loop_inputs, outputs=loop_outputs) + current_layer = list(graph.layers.values())[-1] + block_graph, graph_inputs = mapper.traverse(block, node, current_layer) + for i, input_name in enumerate(graph_inputs): + if input_name == loop_outputs[1]: + continue + current_layer.inputs['input-{}'.format(i)] = input_name + current_layer.add_block(block_graph) + return list(current_layer.inputs.values()), node_outputs + + +def prim_If(mapper, graph, node): + """ 构造if控制流的PaddleLayer。 + + PyTorch Script 示例: + %input.5 : Tensor = prim::If(%107) + block0(): + %109 : Tensor = aten::t(%102) + %ret.2 : Tensor = aten::addmm(%103, %101, %109, %104, %104) + -> (%ret.2) + block1(): + %111 : Tensor = aten::t(%102) + ... + -> (%output.4) + 参数含义: + %107 (bool): if判断条件。 + %input.5 (Tensor): if控制流的输出,与%output.4对应。 + """ + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + input_node = list(node.inputs())[0].node() + script_input_unique_id = list(node.inputs())[0].unique() + input_node_name = mapper.outputs_info[script_input_unique_id] + mapper._check_input(graph, input_node, input_node_name, node_outputs) + graph.add_layer("prim.if", {'input': input_node_name}, [output_name]) + current_layer = list(graph.layers.values())[-1] + block0 = list(node.blocks())[0] + block0_graph, graph_inputs0 = mapper.traverse(block0, node, current_layer) + len0 = 0 + for i, input_name in enumerate(graph_inputs0): + current_layer.inputs['input-{}'.format(i)] = input_name + len0 = i + current_layer.add_block(block0_graph) + block1 = list(node.blocks())[1] + block1_graph, graph_inputs1 = mapper.traverse(block1, node, current_layer) + for i, input_name in enumerate(graph_inputs1): + current_layer.inputs['input-{}'.format(len0 + 1 + i)] = input_name + current_layer.add_block(block1_graph) + return list(current_layer.inputs.values()), node_outputs + + +def prim_min(mapper, graph, node): + """ 构造min的PaddleLayer。 + + PyTorch Script 示例: + %87 : int = prim::min(%86) + 参数含义: + %86 (list): 输入。 + %87 (int): 输出。 + """ + output_name = mapper._get_outputs_name(node)[0] + node_outputs = [output_name] + input_node = list(node.inputs())[0].node() + script_input_unique_id = list(node.inputs())[0].unique() + input_node_name = mapper.outputs_info[script_input_unique_id] + mapper._check_input(graph, input_node, input_node_name, node_outputs) + graph.add_layer( + "prim.min", inputs={'input': input_node_name}, outputs=[output_name]) + return [input_node_name], node_outputs diff --git a/x2paddle/op_mapper/pytorch2paddle/pytorch_op_mapper.py b/x2paddle/op_mapper/pytorch2paddle/pytorch_op_mapper.py new file mode 100644 index 0000000..6284e19 --- /dev/null +++ b/x2paddle/op_mapper/pytorch2paddle/pytorch_op_mapper.py @@ -0,0 +1,163 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np +from x2paddle.core.op_mapper import OpMapper +from x2paddle.core.util import * +from x2paddle.core.paddle_graph import PaddleGraph +from x2paddle.op_mapper.pytorch2paddle import prim +from x2paddle.op_mapper.pytorch2paddle import aten + + +class PyTorchOpMapper(OpMapper): + def __init__(self, decoder): + super(PyTorchOpMapper, self).__init__() + self.script = decoder.script + self.paddle_params = dict() + self.outputs_info = {} # key为output unique id,value为当前节点的输出名字 + self.pytorch_params = {} # key为节点名,value为参数 + self.attrs = {} # key为节点名,value为属性值 + self.output_index = 0 + self.dygraph_name_id = {} # 动态图__init__输出名字中的id,key为kernel类型,value为id + # 转换 + self.graph, _ = self.traverse(decoder.graph) + + def traverse(self, script_graph, control_node=None, father_layer=None): + # 用于获取graph的输入 + def _update_graph_inputs(inputs, outputs): + current_node_outputs.extend(outputs) + for name in inputs: + if name not in current_node_outputs: + graph_inputs.append(name) + + # 初始化 + graph = PaddleGraph(father_layer) + current_node_outputs = [] + graph_inputs = [] + # 转换输入节点 + if isinstance(script_graph, torch._C.Graph): + for i, ivalue in enumerate(script_graph.inputs()): + node = ivalue.node() + if str(ivalue.type()) != "Tensor": + graph.set_name(str(ivalue.type()).split(".")[-1]) + inputs, outputs = self.data(graph, node, ivalue.unique()) + # 转换中间节点 + for node in script_graph.nodes(): + kind = node.kind() + func_name = kind.replace('::', '_') + if hasattr(prim, func_name): + func = getattr(prim, func_name) + inputs, outputs = func(self, graph, node) + _update_graph_inputs(inputs, outputs) + elif hasattr(aten, func_name): + func = getattr(aten, func_name) + inputs, outputs = func(self, graph, node) + _update_graph_inputs(inputs, outputs) + else: + raise Exception("The kind {} in model is not supported yet.". + format(node.kind())) + # 转换输出节点 + if hasattr(script_graph, 'returnNode'): + for i, ivalue in enumerate(script_graph.returnNode().inputs()): + if control_node.kind() == "prim::Loop" and i == 0: + continue + node = ivalue.node() + script_unique_id = ivalue.unique() + inputs, outputs = self.equal( + graph, + node, + uid=script_unique_id, + control_node=control_node, + index=i) + _update_graph_inputs(inputs, outputs) + # 设置graph的参数 + if isinstance(script_graph, torch._C.Graph): + graph.set_parameters(self.paddle_params) + return graph, graph_inputs + + def _get_outputs_name(self, node): + outputs_name = [] + for output_ivalue in node.outputs(): + output_name = 'x' + str(self.output_index) + script_unique_id = output_ivalue.unique() + if script_unique_id in self.outputs_info: + output_name = self.outputs_info[script_unique_id] + self.outputs_info[script_unique_id] = output_name + self.output_index += 1 + outputs_name.append(output_name) + # if节点没有输出的情况 + if len(list(node.outputs())) == 0: + output_name = 'x' + str(self.output_index) + self.output_index += 1 + outputs_name.append(output_name) + return outputs_name + + def _check_input(self, + graph, + node, + output_name, + node_outputs, + add_dim=False): + if node.kind() == "prim::GetAttr": + param = self.pytorch_params[output_name] + if isinstance(param, np.ndarray): + if add_dim: + param = param[np.newaxis, :] + self.paddle_params[output_name] = param + graph.add_layer( + "fluid.dygraph.base.to_variable", + inputs={}, + outputs=[output_name], + value="params[{}]".format(string(output_name))) + else: + graph.add_layer( + "prim.constant", + inputs={}, + outputs=[output_name], + value=string(param) if isinstance(param, str) else param) + node_outputs.append(output_name) + + def data(self, graph, node, uid): + for output_ivalue in node.outputs(): + script_unique_id = output_ivalue.unique() + if script_unique_id in self.outputs_info or script_unique_id != uid: + continue + node_name = 'x' + str(self.output_index) + self.outputs_info[script_unique_id] = node_name + self.output_index += 1 + output_name = self.outputs_info[uid] + graph.add_layer( + "fluid.dygraph.base.to_variable", + inputs={}, + outputs=[node_name], + value=output_name) + return [], [output_name] + + def equal(self, graph, node, uid=None, control_node=None, index=None): + if control_node is not None and index is not None: + kind = control_node.kind() + # block的输出 + input_node_name = self.outputs_info[uid] + control_output_id = index + if kind == "prim::Loop": + control_output_id = index - 1 + output_ivalue = list(control_node.outputs())[ + control_output_id].unique() + output_node_name = self.outputs_info[output_ivalue] + graph.add_layer( + "prim.equal", + inputs={'input': input_node_name}, + outputs=[output_node_name]) + return [input_node_name], [output_node_name] diff --git a/x2paddle/optimizer/linear_pass.py b/x2paddle/optimizer/linear_pass.py new file mode 100644 index 0000000..6af4819 --- /dev/null +++ b/x2paddle/optimizer/linear_pass.py @@ -0,0 +1,198 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from x2paddle.core.util import * +from x2paddle.core.paddle_graph import PaddleLayer, PaddleGraph +from x2paddle.optimizer.passes import Pass, Matcher, PyTorchMatcher + + +class LinearPass(Pass): + def __init__(self): + self.linear_index = 0 + super(LinearPass, self).__init__() + + def build_pattern(self): + """ 构造fc层的模式。 + fc层模式python实现代码示例: + x149 = 2 + x151 = x146.shape + x151 = len(x151) + x152 = x151 == x149 + if x152 : + x147 = self.x147 + x154 = fluid.layers.transpose(x=x147, perm=[1, 0]) + x148 = self.x148 + x155 = fluid.layers.addmm(input=x148, x=x146, y=x154, beta=1, alpha=1) + x153 = x155 + else: + x147 = self.x147 + x157 = fluid.layers.transpose(x=x147, perm=[1, 0]) + x158 = fluid.layers.matmul(x=x146, y=x157) + x159 = True + if x159 : + x148 = self.x148 + x161 = x158 + 1 * x148 + x160 = x161 + else: + x160 = x158 + x153 = x160 + """ + + def gen_name(id): + return "x" + str(id) + + self.pattern.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(0)], value=2) + self.pattern.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(1)], value=1) + self.pattern.add_layer( + "prim.shape", inputs={'input': "fc-input-0"}, + outputs=[gen_name(2)]) + self.pattern.add_layer( + "prim.len", inputs={'input': gen_name(2)}, outputs=[gen_name(2)]) + self.pattern.add_layer( + "prim.eq", + inputs={"eq0": gen_name(2), + "eq1": gen_name(0)}, + outputs=[gen_name(3)]) + self.pattern.add_layer("prim.if", {'input': gen_name(3)}, [gen_name(4)]) + self.pattern.outputs.append(gen_name(4)) + if_layer_a = self.pattern.layers[list(self.pattern.layers.keys())[-1]] + pattern_block0 = PaddleGraph(if_layer_a) + pattern_block0.add_layer( + "fluid.dygraph.base.to_variable", + inputs={}, + outputs=[gen_name(5)], + value="params[{}]".format(string(gen_name(5)))) + pattern_block0.add_layer( + "fluid.layers.transpose", + inputs={"x": gen_name(5)}, + outputs=[gen_name(6)], + perm=[1, 0]) + pattern_block0.add_layer( + "fluid.dygraph.base.to_variable", + inputs={}, + outputs=[gen_name(7)], + value="params[{}]".format(string(gen_name(7)))) + pattern_block0.add_layer( + "fluid.layers.addmm", + inputs={"input": gen_name(7), + "x": "fc-input-0", + "y": gen_name(6)}, + outputs=[gen_name(8)], + beta=1, + alpha=1) + if_layer_a.inputs["input-0"] = "fc-input-0" + self.pattern.inputs.append("fc-input-0") + pattern_block0.add_layer( + "prim.equal", inputs={'input': gen_name(8)}, outputs=[gen_name(4)]) + if_layer_a.add_block(pattern_block0) + pattern_block1 = PaddleGraph(if_layer_a) + pattern_block1.add_layer( + "fluid.dygraph.base.to_variable", + inputs={}, + outputs=[gen_name(5)], + value="params[{}]".format(string(gen_name(5)))) + pattern_block1.add_layer( + "fluid.layers.transpose", + inputs={"x": gen_name(5)}, + outputs=[gen_name(6)], + perm=[1, 0]) + pattern_block1.add_layer( + "fluid.layers.matmul", + inputs={"x": "fc-input-0", + "y": gen_name(6)}, + outputs=[gen_name(9)]) + if_layer_a.inputs["input-1"] = "fc-input-0" + pattern_block1.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(10)], value=True) + pattern_block1.add_layer("prim.if", {'input': gen_name(10)}, + [gen_name(11)]) + if_layer_b = pattern_block1.layers[list(pattern_block1.layers.keys())[ + -1]] + pattern_block1_block0 = PaddleGraph(if_layer_b) + pattern_block1_block0.add_layer( + "fluid.dygraph.base.to_variable", + inputs={}, + outputs=[gen_name(12)], + value="params[{}]".format(string(gen_name(12)))) + pattern_block1_block0.add_layer( + "prim.add", + inputs={"x": gen_name(9), + "y": gen_name(12)}, + outputs=[gen_name(13)], + alpha=1) + if_layer_b.inputs["input-0"] = gen_name(9) + pattern_block1_block0.add_layer( + "prim.equal", + inputs={'input': gen_name(13)}, + outputs=[gen_name(11)]) + if_layer_b.add_block(pattern_block1_block0) + pattern_block1_block1 = PaddleGraph(if_layer_b) + pattern_block1_block1.add_layer( + "prim.equal", inputs={'input': gen_name(9)}, + outputs=[gen_name(11)]) + if_layer_b.inputs["input-1"] = gen_name(9) + pattern_block1.add_layer( + "prim.equal", inputs={'input': gen_name(11)}, + outputs=[gen_name(4)]) + if_layer_b.add_block(pattern_block1_block1) + if_layer_a.add_block(pattern_block1) + self.pattern.build( + inputs={"input-0": "fc-input-0", + "input-1": "fc-input-0"}) + + +class LinearMatcher(PyTorchMatcher): + def __init__(self): + self.linear_index = 0 + super(LinearMatcher, self).__init__() + + def replace_layer(self, graph, subgraph_global_layers): + subgraph_global_layers_id = list(subgraph_global_layers.keys()) + layer = subgraph_global_layers[subgraph_global_layers_id[2]] + input_name = layer.inputs["input"] + layer = subgraph_global_layers[subgraph_global_layers_id[5]] + output_name = layer.outputs[0] + layer = subgraph_global_layers[subgraph_global_layers_id[6]] + weight_name = layer.attrs["value"][8:-2] + layer = subgraph_global_layers[subgraph_global_layers_id[8]] + bias_name = layer.attrs["value"][8:-2] + attrs = {} + attrs["input_dim"] = graph.parameters[weight_name].shape[1] + attrs["output_dim"] = graph.parameters[weight_name].shape[0] + linear_name = "linear{}".format(self.linear_index) + self.linear_index += 1 + graph.parameters["{}.weight".format(linear_name)] = graph.parameters[ + weight_name].transpose((1, 0)) + graph.parameters["{}.bias".format(linear_name)] = np.squeeze( + graph.parameters[bias_name]) + graph.parameters.pop(weight_name) + graph.parameters.pop(bias_name) + for i, layer_id in enumerate(subgraph_global_layers): + if layer_id in graph.layers: + layer = graph.layers[layer_id] + if i == 0: + new_layer = PaddleLayer( + layer_id, + "fluid.dygraph.Linear", + inputs={"input": input_name}, + outputs=[linear_name, output_name], + **attrs) + graph.layers[layer_id] = new_layer + else: + graph.layers.pop(layer_id) + graph.build() + return graph diff --git a/x2paddle/optimizer/optimizer.py b/x2paddle/optimizer/optimizer.py new file mode 100644 index 0000000..c03b9e9 --- /dev/null +++ b/x2paddle/optimizer/optimizer.py @@ -0,0 +1,49 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from x2paddle.optimizer.linear_pass import LinearPass, LinearMatcher + + +class GraphOptimizer(object): + def __init__(self): + linear_pass = LinearPass() + linear_matcher = LinearMatcher() + self.passes = {linear_pass: linear_matcher} + + def run(self, graph): + is_update_graph = False + while True: + for i, (layer_id, layer) in enumerate(graph.layers.items()): + is_match = self.current_matcher.match_pattern( + self.current_pass.pattern, graph, i) + if is_match: + is_update_graph = True + graph = self.current_matcher.replace_layer(graph, is_match) + break + for j, block in enumerate(layer.blocks): + if len(block.layers) > 0: + layer.blocks[j], is_update_block = self.run(block) + if is_update_block: + break + if i + 1 == len(graph.layers): + return graph, is_update_graph + + def optimize(self, graph): + # 开始优化 + for _pass, matcher in self.passes.items(): + self.current_pass = _pass + self.current_matcher = matcher + graph, _ = self.run(graph) + print("{} done!".format(pa.__class__.__name__)) + return graph diff --git a/x2paddle/optimizer/passes.py b/x2paddle/optimizer/passes.py new file mode 100644 index 0000000..028a211 --- /dev/null +++ b/x2paddle/optimizer/passes.py @@ -0,0 +1,109 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from x2paddle.core.paddle_graph import PaddleGraph + + +class Pass(object): + def __init__(self): + self.pattern = PaddleGraph() + self.build_pattern() + + +class Matcher(object): + def __init__(self): + self.unique_id_layer = dict() + + +class PyTorchMatcher(Matcher): + def __init__(self): + super(PyTorchMatcher, self).__init__() + + def match_pattern(self, pattern, graph, start_id): + pattern_index = 0 + pattern_global_layers = pattern.get_global_layers() + subgraph_global_layers = dict() + graph_layers = dict(list(graph.layers.items())[start_id:]) + for layer_id, layer in graph_layers.items(): + pattern_layer = pattern.layers[list(pattern.layers.keys())[ + pattern_index]] + if layer.kernel == pattern_layer.kernel: + subgraph_global_layers[layer_id] = layer + pattern_layer_id = pattern_layer.id + if layer.kernel == "prim.constant": + if layer.attrs["value"] != pattern_layer.attrs["value"]: + return False + elif layer.kernel == "fluid.layers.addmm": + if layer.attrs["beta"] != pattern_layer.attrs["beta"]: + return False + if layer.attrs["alpha"] != pattern_layer.attrs["alpha"]: + return False + + if layer_id in graph.edges_in: + if pattern_layer_id not in pattern.edges_in: + return False + else: + if len(graph.edges_in[layer_id]) != len( + pattern.edges_in[pattern_layer_id]): + return False + layer_in = graph.edges_in[layer_id] + pattern_layer_in = pattern.edges_in[pattern_layer_id] + for i in range(len(layer_in)): + layer_id_in = layer_in[i] + pattern_layer_id_in = pattern_layer_in[i] + if pattern_layer_id_in != -1: + pattern_global_layers_id = list( + pattern_global_layers.keys()) + subgraph_global_layers_id = list( + subgraph_global_layers.keys()) + if pattern_global_layers_id.index(pattern_layer_id_in) == \ + subgraph_global_layers_id.index(layer_id_in): + # 判断pattern输入在pattern_global_layers_id的索引 + # 和graph输入在subgraph_global_layers_id的索引一致 + continue + return False + + if layer_id in graph.edges_out: + if pattern_layer_id not in pattern.edges_out: + if not set(pattern_layer.outputs).issubset( + pattern.outputs): + # 若pattern当前layer的输出是pattern的输出,则是正确的 + return False + else: + if len(graph.edges_out[layer_id]) != len( + pattern.edges_out[pattern_layer_id]): + # 如果在每个节点edges_in相同的情况下,edges_out数目相同则说明无节点在subgraph外被用到 + if not set(pattern_layer.outputs).issubset( + pattern.outputs): + # 若pattern当前layer的输出是pattern的输出,则是正确的 + return False + + if layer.kernel == "prim.if": + res = self.match_pattern(pattern_layer.blocks[0], + layer.blocks[0], 0) + if res: + subgraph_global_layers.update(res) + else: + return False + res = self.match_pattern(pattern_layer.blocks[1], + layer.blocks[1], 0) + if res: + subgraph_global_layers.update(res) + else: + return False + pattern_index += 1 + if pattern_index == len(pattern.layers): + return subgraph_global_layers + else: + return False -- GitLab