tf_decoder.py 14.4 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from x2paddle.core.graph import GraphNode, Graph
J
jiangjiajun 已提交
16 17
from x2paddle.core.fluid_code import FluidCode
from tensorflow.python.framework import tensor_util
J
jiangjiajun 已提交
18
from tensorflow.python.platform import gfile
J
jiangjiajun 已提交
19
from tensorflow.core.framework import attr_value_pb2
J
jiangjiajun 已提交
20
import tensorflow as tf
J
jiangjiajun 已提交
21
import copy as cp
J
jiangjiajun 已提交
22
import numpy
J
jiangjiajun 已提交
23
import sys
J
jiangjiajun 已提交
24

25

J
jiangjiajun 已提交
26
class TFGraphNode(GraphNode):
J
jiangjiajun 已提交
27
    def __init__(self, layer, layer_name=None, data_format="NHWC"):
J
jiangjiajun 已提交
28
        if layer_name is None:
J
jiangjiajun 已提交
29 30 31
            super(TFGraphNode,
                  self).__init__(layer,
                                 layer.name.replace('/', '_').replace('-', '_'))
J
jiangjiajun 已提交
32
        else:
J
jiangjiajun 已提交
33 34 35
            super(TFGraphNode,
                  self).__init__(layer,
                                 layer_name.replace('/', '_').replace('-', '_'))
J
jiangjiajun 已提交
36

J
jiangjiajun 已提交
37
        self.layer_type = layer.op
J
jiangjiajun 已提交
38 39
        self.tf_data_format = data_format
        self.pd_data_format = "NCHW"
J
jiangjiajun 已提交
40
        self.fluid_code = FluidCode()
J
jiangjiajun 已提交
41

42
        self.dtype_map = {1: "float32", 3: "int32", 4: "uint8", 9: "int64"}
43 44 45 46 47 48 49 50 51 52 53 54

    @property
    def out_shapes(self):
        values = self.layer.attr["_output_shapes"].list.shape
        out_shapes = list()
        for value in values:
            shape = [dim.size for dim in value.dim]
            out_shapes.append(shape)
        return out_shapes

    @property
    def dtype(self):
55 56 57 58 59
        keys = ['dtype', 'Tidx', 'T']
        for k in keys:
            dtype = self.layer.attr[k].type
            if dtype > 0:
                break
60 61 62 63
        if dtype not in self.dtype_map:
            raise Exception("Dtype[{}] not in dtype_map".format(dtype))
        return self.dtype_map[dtype]

J
jiangjiajun 已提交
64 65 66 67 68 69 70 71
    @property
    def value(self):
        assert self.layer_type == "Const", "Only Const node has value."

        attr = self.layer.attr['value']
        field = getattr(attr, attr.WhichOneof('value'))
        return tensor_util.MakeNdarray(field)

J
jiangjiajun 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
    def get_attr(self, name):
        if name not in self.layer.attr:
            return None
        attr = self.layer.attr[name]
        field = attr.WhichOneof('value')
        value = getattr(attr, field) if field else None

        if isinstance(value, attr_value_pb2.AttrValue.ListValue):
            result = list(value.ListFields()[0][1])
            for i in range(len(result)):
                if isinstance(result[i], int):
                    result[i] = int(result[i])
                try:
                    if isinstance(result[i], long):
                        result[i] = int(result[i])
                except:
                    pass
            return result
        else:
            return value

J
jiangjiajun 已提交
93 94

class TFGraph(Graph):
J
jiangjiajun 已提交
95
    def __init__(self, model, data_format="NHWC"):
J
jiangjiajun 已提交
96
        super(TFGraph, self).__init__(model)
J
jiangjiajun 已提交
97
        self.identity_map = dict()
J
jiangjiajun 已提交
98
        self.multi_out_ops = ['Split', 'SplitV']
J
jiangjiajun 已提交
99
        self.tf_data_format = data_format
J
jiangjiajun 已提交
100 101 102

    def build(self):
        for layer in self.model.node:
J
jiangjiajun 已提交
103
            self.node_map[layer.name.replace('/', '_').replace(
J
jiangjiajun 已提交
104
                '-', '_')] = TFGraphNode(layer, data_format=self.tf_data_format)
J
jiangjiajun 已提交
105

J
jiangjiajun 已提交
106 107
        for layer_name, node in self.node_map.items():
            for in_node in node.layer.input:
J
jiangjiajun 已提交
108
                in_node = in_node.replace('/', '_').replace('-', '_')
J
jiangjiajun 已提交
109 110
                if in_node not in self.node_map:
                    if in_node.strip().split(':')[0] in self.node_map:
J
jiangjiajun 已提交
111
                        self.connect(in_node.strip().split(':')[0], layer_name)
J
jiangjiajun 已提交
112
                    else:
113 114 115
                        raise Exception(
                            'input[{}] of node[{}] does not exist in node_map'.
                            format(in_node, layer_name))
J
jiangjiajun 已提交
116 117 118
                else:
                    self.connect(in_node, layer_name)

119
        super(TFGraph, self).build()
J
jiangjiajun 已提交
120

J
jiangjiajun 已提交
121 122 123 124 125 126
        # tensorflow graph optimize
        self._remove_isolated_node()
        self._remove_identity_node()

    def get_node(self, node_name, copy=False):
        items = node_name.strip().split(':')
J
jiangjiajun 已提交
127
        items[0] = items[0].replace('/', '_').replace('-', '_')
J
jiangjiajun 已提交
128 129 130
        if items[0] in self.identity_map:
            items[0] = self.identity_map[items[0]]
        new_node_name = ":".join(items)
J
jiangjiajun 已提交
131 132 133 134
        node = super(TFGraph, self).get_node(new_node_name, copy)
        if len(items) == 1 and node.layer_type in self.multi_out_ops:
            node.index = 0
        return node
J
jiangjiajun 已提交
135

J
jiangjiajun 已提交
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
    def remove_node(self, node_name):
        if node_name not in self.node_map:
            raise Exception("Node[{}] not in graph".format(node_name))
        inputs = self.node_map[node_name].inputs
        outputs = self.node_map[node_name].outputs
        assert len(inputs) == 1
        input_node = self.node_map[inputs[0]]
        idx = input_node.outputs.index(node_name)
        del input_node.outputs[idx]
        for output in outputs:
            node = self.node_map[output]
            idx = node.inputs.index(node_name)
            node.inputs[idx] = inputs[0]
            input_node.outputs.append(output)

        del self.node_map[node_name]

        idx = self.topo_sort.index(node_name)
        del self.topo_sort[idx]

J
jiangjiajun 已提交
156 157 158 159
    def _remove_isolated_node(self):
        # delete isolated nodes
        isolated_nodes = list()
        for node_name in self.node_map.keys():
J
jiangjiajun 已提交
160
            if len(self.get_node(node_name).inputs) == 0 and len(
J
jiangjiajun 已提交
161 162 163
                    self.get_node(node_name).outputs) == 0:
                isolated_nodes.append(node_name)

J
jiangjiajun 已提交
164
        for node_name in isolated_nodes:
J
jiangjiajun 已提交
165 166 167 168 169 170 171 172 173
            del self.node_map[node_name]
            if node_name in self.input_nodes:
                idx = self.input_nodes.index(node_name)
                del self.input_nodes[idx]
            if node_name in self.output_nodes:
                idx = self.output_nodes.index(node_name)
                del self.output_nodes[idx]
            idx = self.topo_sort.index(node_name)
            del self.topo_sort[idx]
J
jiangjiajun 已提交
174 175 176 177 178 179 180 181 182 183

    def _remove_identity_node(self):
        identity_node = list()
        for node_name, node in self.node_map.items():
            if node.layer_type == "Identity":
                identity_node.append(node_name)

        for node_name in identity_node:
            node = self.get_node(node_name)
            input_node = self.get_node(node.inputs[0])
J
jiangjiajun 已提交
184
            self.remove_node(node_name)
J
jiangjiajun 已提交
185 186 187

            self.identity_map[node_name] = input_node.layer_name

J
jiangjiajun 已提交
188 189 190 191
            if node_name in self.output_nodes:
                idx = self.output_nodes.index(node_name)
                self.output_nodes[idx] = input_node.layer_name

J
jiangjiajun 已提交
192 193 194 195 196 197 198 199 200 201 202
    def data_format_propagation(self, node):
        current_node = self.node_map[node.layer_name]
        current_node = node.tf_data_format
        outputs = current_node.outputs
        if len(outputs) == 0:
            return
        for out in outputs:
            next_node = self.node_map[out]
            next_node.tf_data_format = node.tf_data_format
            self.data_format_propagation(next_node)

J
jiangjiajun 已提交
203

J
jiangjiajun 已提交
204
class TFDecoder(object):
205
    def __init__(self, pb_model, data_format="NHWC", define_input_shape=False):
J
jiangjiajun 已提交
206 207
        self.sess = tf.Session()
        self.input_info = dict()
208
        self.define_input_shape = define_input_shape
J
jiangjiajun 已提交
209 210 211
        with gfile.FastGFile(pb_model, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
J
jiangjiajun 已提交
212
            input_map = self._check_input_shape(graph_def)
J
jiangjiajun 已提交
213
            self._fix_output_shape(graph_def)
J
jiangjiajun 已提交
214
            self.sess.graph.as_default()
J
jiangjiajun 已提交
215
            tf.import_graph_def(graph_def, name='', input_map=input_map)
216

J
jiangjiajun 已提交
217
        self.sess.run(tf.global_variables_initializer())
J
jiangjiajun 已提交
218

J
jiangjiajun 已提交
219
        self.tf_graph = TFGraph(
J
jiangjiajun 已提交
220
            self.sess.graph._as_graph_def(add_shapes=True)[0], data_format)
J
jiangjiajun 已提交
221
        self.tf_graph.build()
J
jiangjiajun 已提交
222 223 224 225 226 227

    def _fix_output_shape(self, graph):
        for i in range(len(graph.node)):
            node = graph.node[i]
            if node.op == "swish_f32":
                graph.node[i].attr['_disable_call_shape_inference'].b = False
J
jiangjiajun 已提交
228 229

    def _check_input_shape(self, graph_def):
J
jiangjiajun 已提交
230
        numpy.random.seed(13)
J
jiangjiajun 已提交
231 232 233 234 235 236
        graph_def = cp.deepcopy(graph_def)
        input_map = dict()
        for layer in graph_def.node:
            if layer.op != "Placeholder":
                continue
            graph_node = TFGraphNode(layer)
237 238
            dtype = graph_node.layer.attr['dtype'].type
            print("========dtype", dtype)
J
jiangjiajun 已提交
239 240

            need_define_shape = 0
241 242 243 244 245
            if self.define_input_shape:
                need_define_shape = 3
            elif graph_node.layer.attr[
                    'shape'].shape.unknown_rank or not graph_node.get_attr(
                        "shape"):
J
jiangjiajun 已提交
246 247 248 249 250 251 252 253
                need_define_shape = 1
            else:
                value = graph_node.layer.attr["shape"].shape
                shape = [dim.size for dim in value.dim]
                if shape.count(-1) > 1:
                    need_define_shape = 2

            if need_define_shape > 0:
254 255 256 257
                shape = None
                if graph_node.get_attr("shape"):
                    value = value = graph_node.layer.attr["shape"].shape
                    shape = [dim.size for dim in value.dim]
J
jiangjiajun 已提交
258
                if need_define_shape == 1:
J
jiangjiajun 已提交
259 260
                    print("Unknown shape for input tensor[tensor name: \"{}\"]".
                          format(layer.name))
261
                elif need_define_shape == 2:
J
jiangjiajun 已提交
262
                    print(
J
jiangjiajun 已提交
263 264
                        "\nShape[now is {}] for input tensor[tensor name: \"{}\"] not support yet"
                        .format(shape, layer.name))
265 266 267 268
                else:
                    print(
                        "Define shape[now is {}] for input tensor[tensor name: \"{}\']"
                        .format(shape, layer.name))
J
jiangjiajun 已提交
269
                print(
J
jiangjiajun 已提交
270 271 272 273 274 275
                    "Use your keyboard type the shape of input tensor below :)")

                right_shape_been_input = False
                while not right_shape_been_input:
                    shape = input("Shape of Input(e.g. None,224,224,3): ")
                    if shape.count("None") > 1:
J
jiangjiajun 已提交
276
                        print("Only 1 dimension can be None, type again:)")
J
jiangjiajun 已提交
277 278 279
                    else:
                        right_shape_been_input = True

J
jiangjiajun 已提交
280 281 282 283
                shape = [
                    None if dim == "None" else int(dim)
                    for dim in shape.strip().split(',')
                ]
J
jiangjiajun 已提交
284
                assert shape.count(None) <= 1, "Only one dimension can be None"
285
                print("]]]]]]]]]dtype", dtype)
J
jiangjiajun 已提交
286 287 288 289 290
                x2paddle_input = tf.placeholder(dtype=dtype,
                                                shape=shape,
                                                name="x2paddle_{}".format(
                                                    layer.name))
                input_map["{}:0".format(layer.name)] = x2paddle_input
291 292
                if shape.count(None) > 0:
                    shape[shape.index(None)] = -1
J
jiangjiajun 已提交
293 294 295 296 297 298 299
                self.input_info["x2paddle_{}".format(layer.name)] = (shape,
                                                                     dtype)
            else:
                value = graph_node.layer.attr["shape"].shape
                shape = [dim.size for dim in value.dim]
                self.input_info[graph_node.layer_name] = (shape, dtype)

J
jiangjiajun 已提交
300
        return input_map
J
jiangjiajun 已提交
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363

    # trick method
    # should be removed after PaddlePaddle V1.6 been released
    def infer_tensor(self, graph_node):
        print("========== Use infer_tensor for tensor: ", graph_node.layer.name)
        if hasattr(graph_node, "index"):
            tensor_name = graph_node.layer.name + ":{}".format(graph_node.index)
        else:
            tensor_name = graph_node.layer.name + ":0"
        feed = dict()
        for input_name, info in self.input_info.items():
            (shape, dtype) = cp.deepcopy(info)
            input_tensor = self.sess.graph.get_tensor_by_name(input_name + ":0")
            if shape.count(-1) > 0:
                shape[shape.index(-1)] = 2
            feed[input_tensor] = numpy.random.random_sample(shape)
        output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
        return self.sess.run([output_tensor], feed)[0]

    def infer_shape_tensor(self, graph_node, out_shape=None):
        print("========== Use infer_shape_tensor for tensor: ",
              graph_node.layer.name)
        if hasattr(graph_node, "index"):
            tensor_name = graph_node.layer.name + ":{}".format(graph_node.index)
        else:
            tensor_name = graph_node.layer.name + ":0"
        feed = dict()
        batch_size = [2, 3, 5]
        results = list()
        for b in batch_size:
            for input_name, info in self.input_info.items():
                (shape, dtype) = cp.deepcopy(info)
                input_tensor = self.sess.graph.get_tensor_by_name(input_name +
                                                                  ":0")
                if shape.count(-1) > 0:
                    shape[shape.index(-1)] = b
                feed[input_tensor] = numpy.random.random_sample(shape)
            output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
            results.append(self.sess.run([output_tensor], feed)[0].flatten())

        compare01 = (results[0] == results[1])
        compare12 = (results[1] == results[2])

        if compare01.all() and compare12.all():
            return results[0].tolist()

        if (compare01 == compare12).all():
            index = numpy.argwhere(compare01 == False).flatten()
            if index.shape[0] != 1:
                raise Exception("There's not only one unstable dimension")
            results[0][index[0]] = -1

            index = numpy.argwhere(results[0] < 0).flatten()
            if index.shape[0] > 2:
                print("Warning: More than two dimension less than zero")
            if index.shape[0] == 2 and out_shape is not None:
                if out_shape[index[1]] > 0:
                    results[0][index[1]] = out_shape[index[1]]
                else:
                    results[0][index[0]] = out_shape[index[0]]
            return results[0].tolist()
        else:
            raise Exception("Couldn't infer a stable shape shape tensor value")