onnx_op_mapper.py 3.9 KB
Newer Older
W
WJJ1995 已提交
1
# Copyright (c) 2022  PaddlePaddle Authors. All Rights Reserved.
S
SunAhong1993 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

S
SunAhong1993 已提交
15
import sys
W
WJJ1995 已提交
16 17 18 19 20 21 22 23 24
from .opset7 import OpSet7
from .opset8 import OpSet8
from .opset9 import OpSet9
from .opset10 import OpSet10
from .opset11 import OpSet11
from .opset12 import OpSet12
from .opset13 import OpSet13
from .opset14 import OpSet14
from .opset15 import OpSet15
S
SunAhong1993 已提交
25 26 27 28
from x2paddle.decoder.onnx_decoder import ONNXGraphNode
from x2paddle.core.program import PaddleGraph


S
SunAhong1993 已提交
29
class ONNXOpMapper():
S
SunAhong1993 已提交
30
    def __init__(self, decoder):
W
WJJ1995 已提交
31
        self.support_op_sets = [7, 8, 9, 10, 11, 12, 13, 14, 15]
S
SunAhong1993 已提交
32 33
        self.default_op_set = 9
        self.graph = decoder.graph
S
SunAhong1993 已提交
34
        self.paddle_graph = PaddleGraph(parent_layer=None, source_type="onnx")
S
SunAhong1993 已提交
35
        self.paddle_graph.outputs = self.graph.output_nodes
S
SunAhong1993 已提交
36 37
        self.opset = self.create_opset(decoder)
        if not self.op_checker():
S
SunAhong1993 已提交
38
            raise Exception("Model is not supported yet.")
S
SunAhong1993 已提交
39

S
SunAhong1993 已提交
40 41 42 43 44 45
        print("Total nodes: {}".format(
            sum([
                isinstance(node, ONNXGraphNode)
                for name, node in self.graph.node_map.items()
            ])))
        print("Nodes converting ...")
S
SunAhong1993 已提交
46 47
        for i, node_name in enumerate(self.graph.topo_sort):
            sys.stderr.write("\rConverting node {} ...     ".format(i + 1))
S
SunAhong1993 已提交
48 49 50 51 52
            node = self.graph.get_node(node_name)
            op = node.layer_type
            if hasattr(self.opset, op):
                func = getattr(self.opset, op)
                func(node)
S
SunAhong1993 已提交
53
            elif op in self.opset.directly_map_ops:
S
SunAhong1993 已提交
54 55 56
                self.opset.directly_map(node)
            elif op in self.opset.elementwise_ops:
                self.opset.elementwise_map(node)
S
SunAhong1993 已提交
57
        print("\nNodes converted.")
S
SunAhong1993 已提交
58
        self.paddle_graph.set_name(self.graph.graph_name)
S
SunAhong1993 已提交
59 60
        self.paddle_graph.set_parameters(self.opset.weights)
        self.paddle_graph.set_inputs_info(self.opset.inputs_info)
S
SunAhong1993 已提交
61 62 63 64 65 66 67

    def op_checker(self):
        unsupported_ops = set()
        for node_name in self.graph.topo_sort:
            node = self.graph.get_node(node_name)
            op = node.layer_type
            if not hasattr(self.opset, op) and \
S
SunAhong1993 已提交
68
                op not in self.opset.directly_map_ops and \
S
SunAhong1993 已提交
69 70 71 72 73
                op not in self.opset.elementwise_ops:
                unsupported_ops.add(op)
        if len(unsupported_ops) == 0:
            return True
        else:
S
SunAhong1993 已提交
74
            if len(unsupported_ops) > 0:
S
SunAhong1993 已提交
75 76
                print("\n========= {} OPs are not supported yet ===========".
                      format(len(unsupported_ops)))
S
SunAhong1993 已提交
77
            for op in unsupported_ops:
S
SunAhong1993 已提交
78
                print("========== {} ============".format(op))
S
SunAhong1993 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
            return False

    def create_opset(self, decoder):
        run_op_set = self.default_op_set
        opset = ''
        if decoder.op_set in self.support_op_sets:
            opset = 'OpSet' + str(decoder.op_set)
        elif decoder.op_set < self.default_op_set:
            opset = 'OpSet' + str(self.default_op_set)
        else:
            for op_set in self.support_op_sets:
                if decoder.op_set > op_set:
                    run_op_set = op_set
                else:
                    break
            opset = 'OpSet' + str(run_op_set)
W
WJJ1995 已提交
95 96 97
        print('Now, onnx2paddle support convert onnx model opset_verison {}, '
              'opset_verison of your onnx model is {}.'
              .format(self.support_op_sets, decoder.op_set))
S
SunAhong1993 已提交
98
        return eval(opset)(decoder, self.paddle_graph)