tf_decoder.py 19.0 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from x2paddle.core.graph import GraphNode, Graph
J
jiangjiajun 已提交
16 17
from x2paddle.core.fluid_code import FluidCode
from tensorflow.python.framework import tensor_util
J
jiangjiajun 已提交
18
from tensorflow.core.framework import attr_value_pb2
J
jiangjiajun 已提交
19
import tensorflow as tf
J
jiangjiajun 已提交
20
import copy as cp
J
jiangjiajun 已提交
21
import numpy
J
jiangjiajun 已提交
22
import sys
J
jiangjiajun 已提交
23

24

J
jiangjiajun 已提交
25
class TFGraphNode(GraphNode):
J
jiangjiajun 已提交
26
    def __init__(self, layer, layer_name=None, data_format="NHWC"):
J
jiangjiajun 已提交
27
        if layer_name is None:
J
jiangjiajun 已提交
28 29 30
            super(TFGraphNode, self).__init__(
                layer,
                layer.name.replace('/', '_').replace('-', '_').replace('^', ''))
J
jiangjiajun 已提交
31
        else:
J
jiangjiajun 已提交
32 33 34
            super(TFGraphNode, self).__init__(
                layer,
                layer_name.replace('/', '_').replace('-', '_').replace('^', ''))
J
jiangjiajun 已提交
35

J
jiangjiajun 已提交
36
        self.layer_type = layer.op
J
jiangjiajun 已提交
37 38
        self.tf_data_format = data_format
        self.pd_data_format = "NCHW"
J
jiangjiajun 已提交
39
        self.fluid_code = FluidCode()
J
jiangjiajun 已提交
40

J
jiangjiajun 已提交
41 42 43 44 45 46 47
        self.dtype_map = {
            1: "float32",
            3: "int32",
            4: "uint8",
            9: "int64",
            10: "bool"
        }
48 49 50 51 52 53 54 55 56 57 58 59

    @property
    def out_shapes(self):
        values = self.layer.attr["_output_shapes"].list.shape
        out_shapes = list()
        for value in values:
            shape = [dim.size for dim in value.dim]
            out_shapes.append(shape)
        return out_shapes

    @property
    def dtype(self):
J
jiangjiajun 已提交
60
        keys = ['dtype', 'Tidx', 'T', 'DstT']
61 62 63 64
        for k in keys:
            dtype = self.layer.attr[k].type
            if dtype > 0:
                break
65 66 67 68
        if dtype not in self.dtype_map:
            raise Exception("Dtype[{}] not in dtype_map".format(dtype))
        return self.dtype_map[dtype]

J
jiangjiajun 已提交
69 70 71 72 73 74 75 76 77
    @property
    def raw_dtype(self):
        keys = ['dtype', 'Tidx', 'T', 'DstT']
        for k in keys:
            dtype = self.layer.attr[k].type
            if dtype > 0:
                break
        return dtype

J
jiangjiajun 已提交
78 79 80 81 82 83 84 85
    @property
    def value(self):
        assert self.layer_type == "Const", "Only Const node has value."

        attr = self.layer.attr['value']
        field = getattr(attr, attr.WhichOneof('value'))
        return tensor_util.MakeNdarray(field)

J
jiangjiajun 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
    def get_attr(self, name):
        if name not in self.layer.attr:
            return None
        attr = self.layer.attr[name]
        field = attr.WhichOneof('value')
        value = getattr(attr, field) if field else None

        if isinstance(value, attr_value_pb2.AttrValue.ListValue):
            result = list(value.ListFields()[0][1])
            for i in range(len(result)):
                if isinstance(result[i], int):
                    result[i] = int(result[i])
                try:
                    if isinstance(result[i], long):
                        result[i] = int(result[i])
                except:
                    pass
            return result
        else:
            return value

J
jiangjiajun 已提交
107 108

class TFGraph(Graph):
J
jiangjiajun 已提交
109
    def __init__(self, model, data_format="NHWC"):
J
jiangjiajun 已提交
110
        super(TFGraph, self).__init__(model)
J
jiangjiajun 已提交
111
        self.identity_map = dict()
J
jiangjiajun 已提交
112
        self.multi_out_ops = ['Split', 'SplitV']
J
jiangjiajun 已提交
113
        self.tf_data_format = data_format
J
jiangjiajun 已提交
114 115 116

    def build(self):
        for layer in self.model.node:
J
jiangjiajun 已提交
117
            self.node_map[layer.name.replace('/', '_').replace(
J
jiangjiajun 已提交
118
                '-', '_')] = TFGraphNode(layer, data_format=self.tf_data_format)
J
jiangjiajun 已提交
119

J
jiangjiajun 已提交
120 121
        for layer_name, node in self.node_map.items():
            for in_node in node.layer.input:
J
jiangjiajun 已提交
122 123 124
                in_node = in_node.replace('/',
                                          '_').replace('-',
                                                       '_').replace('^', '')
J
jiangjiajun 已提交
125 126
                if in_node not in self.node_map:
                    if in_node.strip().split(':')[0] in self.node_map:
J
jiangjiajun 已提交
127
                        self.connect(in_node.strip().split(':')[0], layer_name)
J
jiangjiajun 已提交
128
                    else:
129 130 131
                        raise Exception(
                            'input[{}] of node[{}] does not exist in node_map'.
                            format(in_node, layer_name))
J
jiangjiajun 已提交
132 133 134
                else:
                    self.connect(in_node, layer_name)

135
        super(TFGraph, self).build()
J
jiangjiajun 已提交
136

J
jiangjiajun 已提交
137 138
        # tensorflow graph optimize
        self._remove_isolated_node()
J
jiangjiajun@baidu.com 已提交
139
        self._optimize_dialiation_conv()
J
jiangjiajun 已提交
140
        self._remove_identity_node()
J
jiangjiajun 已提交
141
        self._remove_cast_node()
J
jiangjiajun 已提交
142 143 144

    def get_node(self, node_name, copy=False):
        items = node_name.strip().split(':')
J
jiangjiajun 已提交
145
        items[0] = items[0].replace('/', '_').replace('-', '_')
J
jiangjiajun 已提交
146 147 148
        if items[0] in self.identity_map:
            items[0] = self.identity_map[items[0]]
        new_node_name = ":".join(items)
J
jiangjiajun 已提交
149
        node = super(TFGraph, self).get_node(new_node_name, copy)
J
jiangjiajun 已提交
150 151
        if node is None:
            return None
J
jiangjiajun 已提交
152 153 154
        if node.layer_type == "Switch":
            if hasattr(node, 'index'):
                del node.index
J
jiangjiajun 已提交
155 156 157
        if len(items) == 1 and node.layer_type in self.multi_out_ops:
            node.index = 0
        return node
J
jiangjiajun 已提交
158

J
jiangjiajun 已提交
159 160 161 162 163
    def remove_node(self, node_name):
        if node_name not in self.node_map:
            raise Exception("Node[{}] not in graph".format(node_name))
        inputs = self.node_map[node_name].inputs
        outputs = self.node_map[node_name].outputs
164
        #        assert len(inputs) == 1
J
jiangjiajun 已提交
165 166 167 168 169 170 171 172 173 174 175 176 177 178
        input_node = self.node_map[inputs[0]]
        idx = input_node.outputs.index(node_name)
        del input_node.outputs[idx]
        for output in outputs:
            node = self.node_map[output]
            idx = node.inputs.index(node_name)
            node.inputs[idx] = inputs[0]
            input_node.outputs.append(output)

        del self.node_map[node_name]

        idx = self.topo_sort.index(node_name)
        del self.topo_sort[idx]

J
jiangjiajun@baidu.com 已提交
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
    def _optimize_dialiation_conv(self):
        for name in list(self.node_map.keys()):
            node = self.node_map[name]
            if node.layer_type == "SpaceToBatchND":
                is_dilation = True
                out_node0 = self.node_map[node.outputs[0]]
                if out_node0.layer_type != 'ExpandDims':
                    is_dilation = False
                    continue
                out_node1 = self.node_map[out_node0.outputs[0]]
                if out_node1.layer_type != 'Conv2D':
                    is_dilation = False
                    continue
                out_node2 = self.node_map[out_node1.outputs[0]]
                if out_node2.layer_type != 'Squeeze':
                    is_dilation = False
                    continue
                out_node3 = self.node_map[out_node2.outputs[0]]
                if out_node3.layer_type != 'BatchToSpaceND':
                    is_dilation = False
                    continue

                if is_dilation:
                    node.skip = True
                    out_node3.skip = True
                    block_shape = self.node_map[node.inputs[1]]
                    out_node1.dilation = block_shape.value.tolist()

J
jiangjiajun 已提交
207 208 209 210
    def _remove_isolated_node(self):
        # delete isolated nodes
        isolated_nodes = list()
        for node_name in self.node_map.keys():
J
jiangjiajun 已提交
211
            if len(self.get_node(node_name).inputs) == 0 and len(
J
jiangjiajun 已提交
212 213 214
                    self.get_node(node_name).outputs) == 0:
                isolated_nodes.append(node_name)

J
jiangjiajun 已提交
215
        for node_name in isolated_nodes:
J
jiangjiajun 已提交
216 217 218 219 220 221 222 223 224
            del self.node_map[node_name]
            if node_name in self.input_nodes:
                idx = self.input_nodes.index(node_name)
                del self.input_nodes[idx]
            if node_name in self.output_nodes:
                idx = self.output_nodes.index(node_name)
                del self.output_nodes[idx]
            idx = self.topo_sort.index(node_name)
            del self.topo_sort[idx]
J
jiangjiajun 已提交
225 226

    def _remove_identity_node(self):
J
jiangjiajun 已提交
227 228 229 230
        identity_ops = [
            'Identity', 'StopGradient', 'Switch', 'Merge',
            'PlaceholderWithDefault'
        ]
J
jiangjiajun 已提交
231 232
        identity_node = list()
        for node_name, node in self.node_map.items():
J
jiangjiajun 已提交
233
            if node.layer_type in identity_ops:
J
jiangjiajun 已提交
234 235 236 237 238
                identity_node.append(node_name)

        for node_name in identity_node:
            node = self.get_node(node_name)
            input_node = self.get_node(node.inputs[0])
J
jiangjiajun 已提交
239
            self.remove_node(node_name)
J
jiangjiajun 已提交
240 241 242

            self.identity_map[node_name] = input_node.layer_name

J
jiangjiajun 已提交
243 244 245 246
            if node_name in self.output_nodes:
                idx = self.output_nodes.index(node_name)
                self.output_nodes[idx] = input_node.layer_name

J
jiangjiajun 已提交
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
    def _remove_cast_node(self):
        cast_node = list()
        for node_name, node in self.node_map.items():
            if node.layer_type == "Cast":
                input = self.get_node(node.inputs[0])
                if input.layer_type != "Placeholder" or len(input.outputs) != 1:
                    continue
                cast_node.append(node_name)

        for node_name in cast_node:
            node = self.get_node(node_name)
            input_node = self.get_node(node.inputs[0])
            input_node.layer.attr["dtype"].type = node.raw_dtype
            self.remove_node(node_name)

            self.identity_map[node_name] = input_node.layer_name

            if node_name in self.output_nodes:
                idx = self.output_nodes.index(node_name)
                self.output_nodes[idx] = input_node.layer_name

J
jiangjiajun 已提交
268 269 270 271 272 273 274 275 276 277 278
    def data_format_propagation(self, node):
        current_node = self.node_map[node.layer_name]
        current_node = node.tf_data_format
        outputs = current_node.outputs
        if len(outputs) == 0:
            return
        for out in outputs:
            next_node = self.node_map[out]
            next_node.tf_data_format = node.tf_data_format
            self.data_format_propagation(next_node)

J
jiangjiajun 已提交
279

J
jiangjiajun 已提交
280
class TFDecoder(object):
281
    def __init__(self, pb_model, data_format="NHWC", define_input_shape=False):
282 283 284 285
        try:
            self.sess = tf.compat.v1.Session()
        except:
            self.sess = tf.Session()
J
jiangjiajun 已提交
286
        self.input_info = dict()
287
        self.define_input_shape = define_input_shape
288 289 290 291 292
        with open(pb_model, 'rb') as f:
            try:
                graph_def = tf.compat.v1.GraphDef()
            except:
                graph_def = tf.GraphDef()
J
jiangjiajun 已提交
293
            graph_def.ParseFromString(f.read())
J
jiangjiajun 已提交
294
            input_map = self._check_input_shape(graph_def)
J
jiangjiajun 已提交
295
            self._fix_output_shape(graph_def)
J
jiangjiajun 已提交
296
            self.sess.graph.as_default()
J
jiangjiajun 已提交
297
            tf.import_graph_def(graph_def, name='', input_map=input_map)
298

299 300 301 302 303
        try:
            initializer = tf.compat.v1.global_variables_initializer()
        except:
            initializer = tf.global_variables_initializer()
        self.sess.run(initializer)
J
jiangjiajun 已提交
304

J
jiangjiajun 已提交
305
        self.tf_graph = TFGraph(
J
jiangjiajun 已提交
306
            self.sess.graph._as_graph_def(add_shapes=True)[0], data_format)
J
jiangjiajun 已提交
307
        self.tf_graph.build()
J
jiangjiajun 已提交
308 309 310 311 312 313

    def _fix_output_shape(self, graph):
        for i in range(len(graph.node)):
            node = graph.node[i]
            if node.op == "swish_f32":
                graph.node[i].attr['_disable_call_shape_inference'].b = False
J
jiangjiajun 已提交
314 315

    def _check_input_shape(self, graph_def):
J
jiangjiajun 已提交
316
        numpy.random.seed(13)
J
jiangjiajun 已提交
317 318 319 320 321 322
        graph_def = cp.deepcopy(graph_def)
        input_map = dict()
        for layer in graph_def.node:
            if layer.op != "Placeholder":
                continue
            graph_node = TFGraphNode(layer)
323
            dtype = graph_node.layer.attr['dtype'].type
J
jiangjiajun 已提交
324 325

            need_define_shape = 0
326 327 328 329 330
            if self.define_input_shape:
                need_define_shape = 3
            elif graph_node.layer.attr[
                    'shape'].shape.unknown_rank or not graph_node.get_attr(
                        "shape"):
J
jiangjiajun 已提交
331 332 333 334 335 336 337 338
                need_define_shape = 1
            else:
                value = graph_node.layer.attr["shape"].shape
                shape = [dim.size for dim in value.dim]
                if shape.count(-1) > 1:
                    need_define_shape = 2

            if need_define_shape > 0:
339 340 341 342
                shape = None
                if graph_node.get_attr("shape"):
                    value = value = graph_node.layer.attr["shape"].shape
                    shape = [dim.size for dim in value.dim]
J
jiangjiajun 已提交
343
                if need_define_shape == 1:
J
jiangjiajun 已提交
344 345
                    print("Unknown shape for input tensor[tensor name: \"{}\"]".
                          format(layer.name))
346
                elif need_define_shape == 2:
J
jiangjiajun 已提交
347
                    print(
J
jiangjiajun 已提交
348 349
                        "\nShape[now is {}] for input tensor[tensor name: \"{}\"] not support yet"
                        .format(shape, layer.name))
350 351 352 353
                else:
                    print(
                        "Define shape[now is {}] for input tensor[tensor name: \"{}\']"
                        .format(shape, layer.name))
J
jiangjiajun 已提交
354
                print(
J
jiangjiajun 已提交
355 356 357 358
                    "Use your keyboard type the shape of input tensor below :)")

                right_shape_been_input = False
                while not right_shape_been_input:
M
mamingjie-China 已提交
359 360 361 362 363
                    try:
                        shape = raw_input(
                            "Shape of Input(e.g. None,224,224,3): ")
                    except:
                        shape = input("Shape of Input(e.g. None,224,224,3): ")
J
jiangjiajun 已提交
364
                    if shape.count("None") > 1:
J
jiangjiajun 已提交
365
                        print("Only 1 dimension can be None, type again:)")
J
jiangjiajun 已提交
366 367 368
                    else:
                        right_shape_been_input = True

J
jiangjiajun 已提交
369 370 371 372
                shape = [
                    None if dim == "None" else int(dim)
                    for dim in shape.strip().split(',')
                ]
J
jiangjiajun 已提交
373
                assert shape.count(None) <= 1, "Only one dimension can be None"
374 375 376 377 378 379 380 381 382 383 384
                try:
                    x2paddle_input = tf.compat.v1.placeholder(
                        dtype=dtype,
                        shape=shape,
                        name="x2paddle_{}".format(layer.name))
                except:
                    x2paddle_input = tf.placeholder(dtype=dtype,
                                                    shape=shape,
                                                    name="x2paddle_{}".format(
                                                        layer.name))

J
jiangjiajun 已提交
385
                input_map["{}:0".format(layer.name)] = x2paddle_input
386 387
                if shape.count(None) > 0:
                    shape[shape.index(None)] = -1
J
jiangjiajun 已提交
388 389 390 391 392 393 394
                self.input_info["x2paddle_{}".format(layer.name)] = (shape,
                                                                     dtype)
            else:
                value = graph_node.layer.attr["shape"].shape
                shape = [dim.size for dim in value.dim]
                self.input_info[graph_node.layer_name] = (shape, dtype)

J
jiangjiajun 已提交
395
        return input_map
J
jiangjiajun 已提交
396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455

    # trick method
    # should be removed after PaddlePaddle V1.6 been released
    def infer_tensor(self, graph_node):
        if hasattr(graph_node, "index"):
            tensor_name = graph_node.layer.name + ":{}".format(graph_node.index)
        else:
            tensor_name = graph_node.layer.name + ":0"
        feed = dict()
        for input_name, info in self.input_info.items():
            (shape, dtype) = cp.deepcopy(info)
            input_tensor = self.sess.graph.get_tensor_by_name(input_name + ":0")
            if shape.count(-1) > 0:
                shape[shape.index(-1)] = 2
            feed[input_tensor] = numpy.random.random_sample(shape)
        output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
        return self.sess.run([output_tensor], feed)[0]

    def infer_shape_tensor(self, graph_node, out_shape=None):
        if hasattr(graph_node, "index"):
            tensor_name = graph_node.layer.name + ":{}".format(graph_node.index)
        else:
            tensor_name = graph_node.layer.name + ":0"
        feed = dict()
        batch_size = [2, 3, 5]
        results = list()
        for b in batch_size:
            for input_name, info in self.input_info.items():
                (shape, dtype) = cp.deepcopy(info)
                input_tensor = self.sess.graph.get_tensor_by_name(input_name +
                                                                  ":0")
                if shape.count(-1) > 0:
                    shape[shape.index(-1)] = b
                feed[input_tensor] = numpy.random.random_sample(shape)
            output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
            results.append(self.sess.run([output_tensor], feed)[0].flatten())

        compare01 = (results[0] == results[1])
        compare12 = (results[1] == results[2])

        if compare01.all() and compare12.all():
            return results[0].tolist()

        if (compare01 == compare12).all():
            index = numpy.argwhere(compare01 == False).flatten()
            if index.shape[0] != 1:
                raise Exception("There's not only one unstable dimension")
            results[0][index[0]] = -1

            index = numpy.argwhere(results[0] < 0).flatten()
            if index.shape[0] > 2:
                print("Warning: More than two dimension less than zero")
            if index.shape[0] == 2 and out_shape is not None:
                if out_shape[index[1]] > 0:
                    results[0][index[1]] = out_shape[index[1]]
                else:
                    results[0][index[0]] = out_shape[index[0]]
            return results[0].tolist()
        else:
            raise Exception("Couldn't infer a stable shape shape tensor value")
J
jiangjiajun 已提交
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490

    def infer_tensor_shape(self, graph_node):
        if hasattr(graph_node, "index"):
            tensor_name = graph_node.layer.name + ":{}".format(graph_node.index)
        else:
            tensor_name = graph_node.layer.name + ":0"
        feed = dict()
        batch_size = [2, 3, 5]
        shapes = list()
        for b in batch_size:
            for input_name, info in self.input_info.items():
                (shape, dtype) = cp.deepcopy(info)
                input_tensor = self.sess.graph.get_tensor_by_name(input_name +
                                                                  ":0")
                if shape.count(-1) > 0:
                    shape[shape.index(-1)] = b
                feed[input_tensor] = numpy.random.random_sample(shape)
            output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
            shape = self.sess.run([output_tensor], feed)[0].shape
            shapes.append(numpy.array(shape))

        compare01 = (shapes[0] == shapes[1])
        compare12 = (shapes[1] == shapes[2])

        if compare01.all() and compare12.all():
            return shape[0].tolist()

        if (compare01 == compare12).all():
            index = numpy.argwhere(compare01 == False).flatten()
            if index.shape[0] != 1:
                raise Exception("There's not only one unstable dimension")
            if index[0] != 0:
                raise Exception("Batch size not in the first dimension")
            shapes[0][0] = -1
            return shapes[0].tolist()