tf_decoder.py 7.1 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from x2paddle.core.graph import GraphNode, Graph
J
jiangjiajun 已提交
16 17
from x2paddle.core.fluid_code import FluidCode
from tensorflow.python.framework import tensor_util
J
jiangjiajun 已提交
18
from tensorflow.python.platform import gfile
J
jiangjiajun 已提交
19
from tensorflow.core.framework import attr_value_pb2
J
jiangjiajun 已提交
20
import tensorflow as tf
J
jiangjiajun 已提交
21 22
import copy as cp
import sys
J
jiangjiajun 已提交
23

24

J
jiangjiajun 已提交
25 26
class TFGraphNode(GraphNode):
    def __init__(self, layer, layer_name=None):
J
jiangjiajun 已提交
27
        if layer_name is None:
28 29
            super(TFGraphNode, self).__init__(layer,
                                              layer.name.replace('/', '_'))
J
jiangjiajun 已提交
30
        else:
31 32
            super(TFGraphNode, self).__init__(layer,
                                              layer_name.replace('/', '_'))
J
jiangjiajun 已提交
33

J
jiangjiajun 已提交
34
        self.layer_type = layer.op
J
jiangjiajun 已提交
35
        self.fluid_code = FluidCode()
J
jiangjiajun 已提交
36

37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
        self.dtype_map = {1: "float32", 3: "int32", 9: "int64"}

    @property
    def out_shapes(self):
        values = self.layer.attr["_output_shapes"].list.shape
        out_shapes = list()
        for value in values:
            shape = [dim.size for dim in value.dim]
            out_shapes.append(shape)
        return out_shapes

    @property
    def dtype(self):
        dtype = self.layer.attr["dtype"].type
        if dtype not in self.dtype_map:
            raise Exception("Dtype[{}] not in dtype_map".format(dtype))
        return self.dtype_map[dtype]

J
jiangjiajun 已提交
55 56 57 58 59 60 61 62
    @property
    def value(self):
        assert self.layer_type == "Const", "Only Const node has value."

        attr = self.layer.attr['value']
        field = getattr(attr, attr.WhichOneof('value'))
        return tensor_util.MakeNdarray(field)

J
jiangjiajun 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
    def get_attr(self, name):
        if name not in self.layer.attr:
            return None
        attr = self.layer.attr[name]
        field = attr.WhichOneof('value')
        value = getattr(attr, field) if field else None

        if isinstance(value, attr_value_pb2.AttrValue.ListValue):
            result = list(value.ListFields()[0][1])
            for i in range(len(result)):
                if isinstance(result[i], int):
                    result[i] = int(result[i])
                try:
                    if isinstance(result[i], long):
                        result[i] = int(result[i])
                except:
                    pass
            return result
        else:
            return value

J
jiangjiajun 已提交
84 85 86 87

class TFGraph(Graph):
    def __init__(self, model):
        super(TFGraph, self).__init__(model)
J
jiangjiajun 已提交
88
        self.identity_map = dict()
J
jiangjiajun 已提交
89 90 91

    def build(self):
        for layer in self.model.node:
92
            self.node_map[layer.name.replace('/', '_')] = TFGraphNode(layer)
J
jiangjiajun 已提交
93

J
jiangjiajun 已提交
94 95
        for layer_name, node in self.node_map.items():
            for in_node in node.layer.input:
96
                in_node = in_node.replace('/', '_')
J
jiangjiajun 已提交
97 98
                if in_node not in self.node_map:
                    if in_node.strip().split(':')[0] in self.node_map:
J
jiangjiajun 已提交
99
                        self.connect(in_node.strip().split(':')[0], layer_name)
J
jiangjiajun 已提交
100
                    else:
101 102 103
                        raise Exception(
                            'input[{}] of node[{}] does not exist in node_map'.
                            format(in_node, layer_name))
J
jiangjiajun 已提交
104 105 106
                else:
                    self.connect(in_node, layer_name)

107
        super(TFGraph, self).build()
J
jiangjiajun 已提交
108

J
jiangjiajun 已提交
109 110 111 112 113 114
        # tensorflow graph optimize
        self._remove_isolated_node()
        self._remove_identity_node()

    def get_node(self, node_name, copy=False):
        items = node_name.strip().split(':')
115
        items[0] = items[0].replace('/', '_')
J
jiangjiajun 已提交
116 117 118 119 120 121 122 123 124
        if items[0] in self.identity_map:
            items[0] = self.identity_map[items[0]]
        new_node_name = ":".join(items)
        return super(TFGraph, self).get_node(new_node_name, copy)

    def _remove_isolated_node(self):
        # delete isolated nodes
        isolated_nodes = list()
        for node_name in self.node_map.keys():
J
jiangjiajun 已提交
125
            if len(self.get_node(node_name).inputs) == 0 and len(
J
jiangjiajun 已提交
126 127 128
                    self.get_node(node_name).outputs) == 0:
                isolated_nodes.append(node_name)

J
jiangjiajun 已提交
129 130
        for node_name in isolated_nodes:
            self.remove_node(node_name)
J
jiangjiajun 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156

    def _remove_identity_node(self):
        identity_node = list()
        for node_name, node in self.node_map.items():
            if node.layer_type == "Identity":
                identity_node.append(node_name)

        for node_name in identity_node:
            node = self.get_node(node_name)
            # Remind: Only 1 input for Identity node
            input_node = self.get_node(node.inputs[0])

            # remove identity node from graph
            self.identity_map[node_name] = input_node.layer_name
            idx = input_node.outputs.index(node_name)
            del input_node.outputs[idx]

            output_names = node.outputs
            for output_name in output_names:
                output_node = self.get_node(output_name)
                idx = output_node.inputs.index(node_name)
                output_node.inputs[idx] = input_node.layer_name

            idx = self.topo_sort.index(node_name)
            del self.topo_sort[idx]

J
jiangjiajun 已提交
157

J
jiangjiajun 已提交
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
def check_input_shape(graph_def):
    graph_def = cp.deepcopy(graph_def)
    input_map = dict()
    for layer in graph_def.node:
        if layer.op != "Placeholder":
            continue
        graph_node = TFGraphNode(layer)
        dtype = graph_node.dtype
        #       print("shape:", graph_node.out_shapes)
        if not graph_node.get_attr("shape"):
            sys.stderr.write("Unknown shape for input tensor[{}]\n".format(
                layer.name))
            shape = input("Please define shape of input here: ")
            shape = [
                None if dim == "None" else int(dim)
                for dim in shape.strip().split(',')
            ]
            x2paddle_input = tf.placeholder(dtype=dtype,
                                            shape=shape,
                                            name="x2paddle_{}".format(
                                                layer.name))
            input_map["{}:0".format(layer.name)] = x2paddle_input
    return input_map


J
jiangjiajun 已提交
183 184
class TFDecoder(object):
    def __init__(self, pb_model):
J
jiangjiajun 已提交
185 186 187 188
        sess = tf.Session()
        with gfile.FastGFile(pb_model, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
J
jiangjiajun 已提交
189
            input_map = check_input_shape(graph_def)
J
jiangjiajun 已提交
190
            sess.graph.as_default()
J
jiangjiajun 已提交
191
            tf.import_graph_def(graph_def, name='', input_map=input_map)
192

J
jiangjiajun 已提交
193 194 195 196
        sess.run(tf.global_variables_initializer())

        self.tf_graph = TFGraph(sess.graph._as_graph_def(add_shapes=True)[0])
        self.tf_graph.build()