tensorflow_graph.py 6.4 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from graph import GraphNode, Graph
from tensorflow.core.framework import attr_value_pb2
J
jiangjiajun 已提交
17
from utils import *
J
jiangjiajun 已提交
18

J
jiangjiajun 已提交
19

J
jiangjiajun 已提交
20
class TensorflowGraphNode(GraphNode):
J
jiangjiajun 已提交
21 22 23 24
    dtype_map = {1: "float32", 3: "int32", 9: "int64"}

    def __init__(self, layer, input_format, layer_name=None):
        super(TensorflowGraphNode, self).__init__(layer, layer_name)
J
jiangjiajun 已提交
25
        self.codes = list()
J
jiangjiajun 已提交
26 27 28
        self.code = FluidCode()
        self.ref_as_const = 0
        self.data_format = input_format
J
jiangjiajun 已提交
29 30 31 32 33 34

    @property
    def layer_type(self):
        return self.layer.op.lower()

    @property
J
jiangjiajun 已提交
35 36 37 38 39 40 41 42 43 44 45 46
    def shape_dim_size(self):
        shape = self.layer.attr['_output_shapes']
        return len(shape.list.shape[0].dim)

    @property
    def dtype(self):
        dtype = self.get_attr("dtype")
        if dtype in self.dtype_map:
            dtype = self.dtype_map[dtype]
        else:
            raise Exception("Unknow dtype: {}".format(dtype))
        return dtype
J
jiangjiajun 已提交
47 48 49 50 51 52 53

    def get_attr(self, name, default_value=None):
        if name in self.layer.attr:
            attr = self.layer.attr[name]
            field = attr.WhichOneof('value')
            val = getattr(attr, field) if field else default_value
            if isinstance(val, attr_value_pb2.AttrValue.ListValue):
J
jiangjiajun 已提交
54 55 56 57 58 59 60 61 62 63
                result = list(val.ListFields()[0][1])
                for i in range(len(result)):
                    if isinstance(result[i], int):
                        result[i] = int(result[i])
                    try:
                        if isinstance(result[i], long):
                            result[i] = int(result[i])
                    except:
                        pass
                return result
J
jiangjiajun 已提交
64
            else:
J
jiangjiajun 已提交
65
                return val if isinstance(val, bytes) else val
J
jiangjiajun 已提交
66 67
        else:
            return default_value
J
jiangjiajun 已提交
68

J
jiangjiajun 已提交
69 70
    def clear_code(self):
        self.code.clear()
J
jiangjiajun 已提交
71 72 73


class TensorflowGraph(Graph):
J
jiangjiajun 已提交
74
    useless_type = ['identity', 'placeholderwithdefault', 'switch', 'merge']
J
jiangjiajun 已提交
75

J
jiangjiajun 已提交
76 77 78
    def __init__(self, tf_graph):
        super(TensorflowGraph, self).__init__(tf_graph)
        self.tf_graph = tf_graph
J
jiangjiajun 已提交
79
        self.identity_relation = dict()
J
jiangjiajun 已提交
80

J
jiangjiajun 已提交
81
    def build(self, input_format):
J
jiangjiajun 已提交
82 83
        skip_node = set(['const'])
        for i, layer in enumerate(self.tf_graph.node):
J
jiangjiajun 已提交
84 85
            self.node_map[layer.name] = TensorflowGraphNode(
                layer, input_format)
J
jiangjiajun 已提交
86

J
jiangjiajun 已提交
87 88 89 90 91 92
        for i, layer in enumerate(self.tf_graph.node):
            if layer.op.lower() in skip_node:
                continue
            for pred in layer.input:
                if pred not in self.node_map and pred.split(
                        ':')[0] in self.node_map:
J
jiangjiajun 已提交
93 94
                    pred_node = self.node_map[pred.split(':')[0]]
                    if pred_node.layer_type == "switch":
J
jiangjiajun 已提交
95 96
                        self._make_connection(pred_node,
                                              self.node_map[layer.name])
J
jiangjiajun 已提交
97 98 99 100
                    elif pred_node.layer_type == "split" or \
                        pred_node.layer_type == "splitv":
                        self.node_map[pred] = TensorflowGraphNode(
                            pred_node.layer, input_format, pred)
J
jiangjiajun 已提交
101 102
                        self._make_connection(self.node_map[pred],
                                              self.node_map[layer.name])
J
jiangjiajun 已提交
103
                        self._make_connection(pred_node, self.node_map[pred])
J
jiangjiajun 已提交
104
                    else:
J
Jason 已提交
105 106
                        raise Exception("\nUnsupported situation(name:[{}]," \
                            "OP[{}])".format(layer.name, layer.op))
J
jiangjiajun 已提交
107 108

                elif pred in self.node_map:
J
jiangjiajun 已提交
109 110
                    self._make_connection(self.node_map[pred],
                                          self.node_map[layer.name])
J
jiangjiajun 已提交
111 112 113

                else:
                    raise Exception("input: {} not in node_map".format(pred))
J
jiangjiajun 已提交
114
        super(TensorflowGraph, self).build(input_format)
J
jiangjiajun 已提交
115

J
jiangjiajun 已提交
116 117
        self._process_useless_nodes()
        self._check_dataformat(input_format)
J
jiangjiajun 已提交
118

J
jiangjiajun 已提交
119 120
    def _check_dataformat(self, input_format):
        for i in range(len(self.topological_sort)):
J
jiangjiajun 已提交
121
            current_node = self.node_map[self.topological_sort[i]]
J
jiangjiajun 已提交
122
            if 'data_format'.encode() in current_node.layer.attr:
J
jiangjiajun 已提交
123
                s = current_node.layer.attr['data_format'].s
J
jiangjiajun 已提交
124
                if s != NHWC and s != NCHW:
J
jiangjiajun 已提交
125
                    raise Exception('Unkown dataformat {}'.format(s))
J
jiangjiajun 已提交
126
                self.set_data_format(current_node, s)
J
jiangjiajun 已提交
127

J
jiangjiajun 已提交
128
    def _process_useless_nodes(self):
J
jiangjiajun 已提交
129
        remove_index = list()
J
jiangjiajun 已提交
130
        for i in range(len(self.topological_sort)):
J
jiangjiajun 已提交
131 132
            name = self.topological_sort[i]
            current_node = self.node_map[name]
J
jiangjiajun 已提交
133
            if current_node.layer_type in self.useless_type:
J
jiangjiajun 已提交
134
                input = current_node.inputs[0]
J
jiangjiajun 已提交
135
                self.identity_relation[current_node.layer.name] = input.layer.name
J
jiangjiajun 已提交
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
                for node in current_node.outputs:
                    for k in range(0, len(node.inputs)):
                        if node.inputs[k] == current_node:
                            node.inputs[k] = input
                            if node not in input.outputs:
                                input.outputs.append(node)
                input.outputs.remove(current_node)
                del self.node_map[name]
                if name in self.output_nodes:
                    self.output_nodes.remove(name)
                if name in self.input_nodes:
                    self.input_nodes.remove(name)
                remove_index.append(i)

        remove_index.sort(reverse=True)
J
jiangjiajun 已提交
151
        for i in range(len(remove_index)):
J
jiangjiajun 已提交
152
            del self.topological_sort[remove_index[i]]
J
jiangjiajun 已提交
153 154 155 156 157 158 159 160 161 162

    def set_data_format(self, node, data_format):
        assert data_format == 'NHWC'.encode() or data_format == 'NCHW'.encode()
        if node.data_format == data_format:
            return
        node.data_format = data_format
        if len(node.outputs) == 0:
            return
        for output in node.outputs:
            self.set_data_format(output, data_format)