tf_decoder.py 12.5 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from x2paddle.core.graph import GraphNode, Graph
J
jiangjiajun 已提交
16
from x2paddle.core.fluid_code import FluidCode
J
jiangjiajun 已提交
17
from x2paddle.core.util import *
J
jiangjiajun 已提交
18
from tensorflow.python.framework import tensor_util
J
jiangjiajun 已提交
19
from tensorflow.python.platform import gfile
J
jiangjiajun 已提交
20
from tensorflow.core.framework import attr_value_pb2
J
jiangjiajun 已提交
21
import tensorflow as tf
J
jiangjiajun 已提交
22
import copy as cp
J
jiangjiajun 已提交
23
import numpy
J
jiangjiajun 已提交
24
import sys
J
jiangjiajun 已提交
25

26

J
jiangjiajun 已提交
27 28
class TFGraphNode(GraphNode):
    def __init__(self, layer, layer_name=None):
J
jiangjiajun 已提交
29
        if layer_name is None:
J
jiangjiajun 已提交
30 31 32
            super(TFGraphNode,
                  self).__init__(layer,
                                 layer.name.replace('/', '_').replace('-', '_'))
J
jiangjiajun 已提交
33
        else:
J
jiangjiajun 已提交
34 35 36
            super(TFGraphNode,
                  self).__init__(layer,
                                 layer_name.replace('/', '_').replace('-', '_'))
J
jiangjiajun 已提交
37

J
jiangjiajun 已提交
38
        self.layer_type = layer.op
J
jiangjiajun 已提交
39
        self.fluid_code = FluidCode()
J
jiangjiajun 已提交
40

J
jiangjiajun 已提交
41
        self.dtype_map = {1: "float32", 3: "int32", 4: "int8", 9: "int64"}
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

    @property
    def out_shapes(self):
        values = self.layer.attr["_output_shapes"].list.shape
        out_shapes = list()
        for value in values:
            shape = [dim.size for dim in value.dim]
            out_shapes.append(shape)
        return out_shapes

    @property
    def dtype(self):
        dtype = self.layer.attr["dtype"].type
        if dtype not in self.dtype_map:
            raise Exception("Dtype[{}] not in dtype_map".format(dtype))
        return self.dtype_map[dtype]

J
jiangjiajun 已提交
59 60 61 62 63 64 65 66
    @property
    def value(self):
        assert self.layer_type == "Const", "Only Const node has value."

        attr = self.layer.attr['value']
        field = getattr(attr, attr.WhichOneof('value'))
        return tensor_util.MakeNdarray(field)

J
jiangjiajun 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
    def get_attr(self, name):
        if name not in self.layer.attr:
            return None
        attr = self.layer.attr[name]
        field = attr.WhichOneof('value')
        value = getattr(attr, field) if field else None

        if isinstance(value, attr_value_pb2.AttrValue.ListValue):
            result = list(value.ListFields()[0][1])
            for i in range(len(result)):
                if isinstance(result[i], int):
                    result[i] = int(result[i])
                try:
                    if isinstance(result[i], long):
                        result[i] = int(result[i])
                except:
                    pass
            return result
        else:
            return value

J
jiangjiajun 已提交
88 89 90 91

class TFGraph(Graph):
    def __init__(self, model):
        super(TFGraph, self).__init__(model)
J
jiangjiajun 已提交
92
        self.identity_map = dict()
J
jiangjiajun 已提交
93
        self.multi_out_ops = ['Split', 'SplitV']
J
jiangjiajun 已提交
94 95 96

    def build(self):
        for layer in self.model.node:
J
jiangjiajun 已提交
97 98
            self.node_map[layer.name.replace('/', '_').replace(
                '-', '_')] = TFGraphNode(layer)
J
jiangjiajun 已提交
99

J
jiangjiajun 已提交
100 101
        for layer_name, node in self.node_map.items():
            for in_node in node.layer.input:
J
jiangjiajun 已提交
102
                in_node = in_node.replace('/', '_').replace('-', '_')
J
jiangjiajun 已提交
103 104
                if in_node not in self.node_map:
                    if in_node.strip().split(':')[0] in self.node_map:
J
jiangjiajun 已提交
105
                        self.connect(in_node.strip().split(':')[0], layer_name)
J
jiangjiajun 已提交
106
                    else:
107 108 109
                        raise Exception(
                            'input[{}] of node[{}] does not exist in node_map'.
                            format(in_node, layer_name))
J
jiangjiajun 已提交
110 111 112
                else:
                    self.connect(in_node, layer_name)

113
        super(TFGraph, self).build()
J
jiangjiajun 已提交
114

J
jiangjiajun 已提交
115 116 117 118 119 120
        # tensorflow graph optimize
        self._remove_isolated_node()
        self._remove_identity_node()

    def get_node(self, node_name, copy=False):
        items = node_name.strip().split(':')
J
jiangjiajun 已提交
121
        items[0] = items[0].replace('/', '_').replace('-', '_')
J
jiangjiajun 已提交
122 123 124
        if items[0] in self.identity_map:
            items[0] = self.identity_map[items[0]]
        new_node_name = ":".join(items)
J
jiangjiajun 已提交
125 126 127 128
        node = super(TFGraph, self).get_node(new_node_name, copy)
        if len(items) == 1 and node.layer_type in self.multi_out_ops:
            node.index = 0
        return node
J
jiangjiajun 已提交
129 130 131 132 133

    def _remove_isolated_node(self):
        # delete isolated nodes
        isolated_nodes = list()
        for node_name in self.node_map.keys():
J
jiangjiajun 已提交
134
            if len(self.get_node(node_name).inputs) == 0 and len(
J
jiangjiajun 已提交
135 136 137
                    self.get_node(node_name).outputs) == 0:
                isolated_nodes.append(node_name)

J
jiangjiajun 已提交
138 139
        for node_name in isolated_nodes:
            self.remove_node(node_name)
J
jiangjiajun 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165

    def _remove_identity_node(self):
        identity_node = list()
        for node_name, node in self.node_map.items():
            if node.layer_type == "Identity":
                identity_node.append(node_name)

        for node_name in identity_node:
            node = self.get_node(node_name)
            # Remind: Only 1 input for Identity node
            input_node = self.get_node(node.inputs[0])

            # remove identity node from graph
            self.identity_map[node_name] = input_node.layer_name
            idx = input_node.outputs.index(node_name)
            del input_node.outputs[idx]

            output_names = node.outputs
            for output_name in output_names:
                output_node = self.get_node(output_name)
                idx = output_node.inputs.index(node_name)
                output_node.inputs[idx] = input_node.layer_name

            idx = self.topo_sort.index(node_name)
            del self.topo_sort[idx]

J
jiangjiajun 已提交
166

J
jiangjiajun 已提交
167 168
class TFDecoder(object):
    def __init__(self, pb_model):
J
jiangjiajun 已提交
169 170
        self.sess = tf.Session()
        self.input_info = dict()
J
jiangjiajun 已提交
171 172 173
        with gfile.FastGFile(pb_model, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
J
jiangjiajun 已提交
174
            input_map = self._check_input_shape(graph_def)
J
jiangjiajun 已提交
175
            self._fix_output_shape(graph_def)
J
jiangjiajun 已提交
176
            self.sess.graph.as_default()
J
jiangjiajun 已提交
177
            tf.import_graph_def(graph_def, name='', input_map=input_map)
178

J
jiangjiajun 已提交
179 180 181

#        for node in graph_def.node:
#            print(node.name, node.op, node.input)
182

J
jiangjiajun 已提交
183
        self.sess.run(tf.global_variables_initializer())
J
jiangjiajun 已提交
184

J
jiangjiajun 已提交
185 186
        self.tf_graph = TFGraph(
            self.sess.graph._as_graph_def(add_shapes=True)[0])
J
jiangjiajun 已提交
187
        self.tf_graph.build()
J
jiangjiajun 已提交
188 189 190 191 192 193

    def _fix_output_shape(self, graph):
        for i in range(len(graph.node)):
            node = graph.node[i]
            if node.op == "swish_f32":
                graph.node[i].attr['_disable_call_shape_inference'].b = False
J
jiangjiajun 已提交
194 195

    def _check_input_shape(self, graph_def):
J
jiangjiajun 已提交
196
        numpy.random.seed(13)
J
jiangjiajun 已提交
197 198 199 200 201 202 203
        graph_def = cp.deepcopy(graph_def)
        input_map = dict()
        for layer in graph_def.node:
            if layer.op != "Placeholder":
                continue
            graph_node = TFGraphNode(layer)
            dtype = graph_node.dtype
J
jiangjiajun 已提交
204 205

            need_define_shape = 0
J
jiangjiajun 已提交
206
            if not graph_node.get_attr("shape"):
J
jiangjiajun 已提交
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
                need_define_shape = 1
            else:
                value = graph_node.layer.attr["shape"].shape
                shape = [dim.size for dim in value.dim]
                if shape.count(-1) > 1:
                    need_define_shape = 2

            if need_define_shape > 0:
                if need_define_shape == 1:
                    color_log(
                        "\nUnknown shape for input tensor[tensor name: \"{}\"]".
                        format(layer.name))
                else:
                    color_log(
                        "\nShape[now is {}] for input tensor[tensor name: \"{}\"] not support yet"
                        .format(shape, layer.name))
                color_log(
                    "Use your keyboard type the shape of input tensor below :)")

                right_shape_been_input = False
                while not right_shape_been_input:
                    shape = input("Shape of Input(e.g. None,224,224,3): ")
                    if shape.count("None") > 1:
                        color_log("Only 1 dimension can be None, type again:)")
                    else:
                        right_shape_been_input = True

J
jiangjiajun 已提交
234 235 236 237
                shape = [
                    None if dim == "None" else int(dim)
                    for dim in shape.strip().split(',')
                ]
J
jiangjiajun 已提交
238
                assert shape.count(None) <= 1, "Only one dimension can be None"
J
jiangjiajun 已提交
239 240 241 242 243
                x2paddle_input = tf.placeholder(dtype=dtype,
                                                shape=shape,
                                                name="x2paddle_{}".format(
                                                    layer.name))
                input_map["{}:0".format(layer.name)] = x2paddle_input
J
jiangjiajun 已提交
244 245 246 247 248 249 250 251 252 253
                shape[shape.index(None)] = -1
                #                self.input_example_data["x2paddle_{}".format(layer.name)] = numpy.random.random_sample(shape).astype(dtype)
                self.input_info["x2paddle_{}".format(layer.name)] = (shape,
                                                                     dtype)
            else:
                value = graph_node.layer.attr["shape"].shape
                shape = [dim.size for dim in value.dim]
                #                self.input_example_data[graph_node.layer_name] = numpy.random.random_sample(shape).astype(dtype)
                self.input_info[graph_node.layer_name] = (shape, dtype)

J
jiangjiajun 已提交
254
        return input_map
J
jiangjiajun 已提交
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317

    # trick method
    # should be removed after PaddlePaddle V1.6 been released
    def infer_tensor(self, graph_node):
        print("========== Use infer_tensor for tensor: ", graph_node.layer.name)
        if hasattr(graph_node, "index"):
            tensor_name = graph_node.layer.name + ":{}".format(graph_node.index)
        else:
            tensor_name = graph_node.layer.name + ":0"
        feed = dict()
        for input_name, info in self.input_info.items():
            (shape, dtype) = cp.deepcopy(info)
            input_tensor = self.sess.graph.get_tensor_by_name(input_name + ":0")
            if shape.count(-1) > 0:
                shape[shape.index(-1)] = 2
            feed[input_tensor] = numpy.random.random_sample(shape)
        output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
        return self.sess.run([output_tensor], feed)[0]

    def infer_shape_tensor(self, graph_node, out_shape=None):
        print("========== Use infer_shape_tensor for tensor: ",
              graph_node.layer.name)
        if hasattr(graph_node, "index"):
            tensor_name = graph_node.layer.name + ":{}".format(graph_node.index)
        else:
            tensor_name = graph_node.layer.name + ":0"
        feed = dict()
        batch_size = [2, 3, 5]
        results = list()
        for b in batch_size:
            for input_name, info in self.input_info.items():
                (shape, dtype) = cp.deepcopy(info)
                input_tensor = self.sess.graph.get_tensor_by_name(input_name +
                                                                  ":0")
                if shape.count(-1) > 0:
                    shape[shape.index(-1)] = b
                feed[input_tensor] = numpy.random.random_sample(shape)
            output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
            results.append(self.sess.run([output_tensor], feed)[0].flatten())

        compare01 = (results[0] == results[1])
        compare12 = (results[1] == results[2])

        if compare01.all() and compare12.all():
            return results[0].tolist()

        if (compare01 == compare12).all():
            index = numpy.argwhere(compare01 == False).flatten()
            if index.shape[0] != 1:
                raise Exception("There's not only one unstable dimension")
            results[0][index[0]] = -1

            index = numpy.argwhere(results[0] < 0).flatten()
            if index.shape[0] > 2:
                print("Warning: More than two dimension less than zero")
            if index.shape[0] == 2 and out_shape is not None:
                if out_shape[index[1]] > 0:
                    results[0][index[1]] = out_shape[index[1]]
                else:
                    results[0][index[0]] = out_shape[index[0]]
            return results[0].tolist()
        else:
            raise Exception("Couldn't infer a stable shape shape tensor value")