tf_emitter.py 9.7 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
J
jiangjiajun 已提交
14

J
jiangjiajun 已提交
15 16
from x2paddle.parser.tf_parser import TFGraph
from x2paddle.core.emitter import Emitter
17
from x2paddle.core.fluid_code import FluidCode
J
jiangjiajun 已提交
18
from x2paddle.core.util import *
J
jiangjiajun 已提交
19
import numpy
20

J
jiangjiajun 已提交
21

J
jiangjiajun 已提交
22
class TFEmitter(Emitter):
23
    def __init__(self, parser):
J
jiangjiajun 已提交
24
        super(TFEmitter, self).__init__()
25 26
        self.parser = parser
        self.graph = parser.tf_graph
J
jiangjiajun 已提交
27 28 29
        # attr_node is used to record nodes that
        # only for define attribute of op
        self.attr_node = list()
J
jiangjiajun 已提交
30
        self.omit_nodes = list()
J
jiangjiajun 已提交
31
        self.weights = dict()
32 33 34 35 36 37 38 39 40 41

    def run(self):
        print("Total nodes: {}".format(len(self.graph.topo_sort)))
        for node_name in self.graph.topo_sort:
            node = self.graph.get_node(node_name)
            op = node.layer_type
            if hasattr(self, op):
                emit_func = getattr(self, op)
                emit_func(node)

J
jiangjiajun 已提交
42 43
        for i in range(len(self.graph.topo_sort)):
            node_name = self.graph.topo_sort[i]
J
jiangjiajun 已提交
44 45
            if node_name in self.omit_nodes:
                continue
J
jiangjiajun 已提交
46 47 48 49
            node = self.graph.get_node(node_name)
            for layer in node.fluid_code.layers:
                print(layer.get_code())

J
jiangjiajun 已提交
50 51 52 53 54
        for name, param in self.weights.items():
            node = self.graph.get_node(name)
            export_paddle_param(param, node.layer_name.replace('/', '_'),
                                "params1")

55 56 57 58
    def Placeholder(self, node):
        shape = node.out_shapes[0]
        dtype = node.dtype
        attr = {
J
jiangjiajun 已提交
59
            'dtype': string(dtype),
60
            'shape': shape,
J
jiangjiajun 已提交
61
            'name': string(node.layer_name)
62
        }
J
jiangjiajun 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
        node.fluid_code.add_layer("data",
                                  inputs=None,
                                  output=node,
                                  param_attr=attr)

    def Const(self, node):
        shape = node.out_shapes[0]
        dtype = node.dtype
        value = node.value
        initializer = "Constant(0.0)"
        if len(shape) == 0:
            assert value.size == 1, "Unexpected situation happend"
            shape = [1]
            initializer = "Constant({})".format(value)

        attr = {
            'dtype': string(dtype),
            'shape': shape,
            'name': string(node.layer_name),
            'default_initializer': initializer
        }
        node.fluid_code.add_layer("create_parameter",
                                  inputs=None,
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
88
        self.weights[node.layer_name] = node.value
J
jiangjiajun 已提交
89 90

    def Transpose(self, node):
J
jiangjiajun 已提交
91 92
        input = self.graph.get_node(node.layer.input[0], copy=True)
        perm = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
93 94
        assert perm.layer_type == "Const", "Perm of transpose OP should be Const"
        del self.weights[perm.layer_name]
J
jiangjiajun 已提交
95 96 97 98 99 100
        perm.fluid_code.clear()
        perm = perm.value.tolist()

        attr = {'perm': perm}
        node.fluid_code.add_layer("transpose",
                                  inputs=input,
101 102
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
103 104

    def RealDiv(self, node):
J
jiangjiajun 已提交
105 106
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
107 108 109 110 111 112
        inputs = {'x': x, 'y': y}
        node.fluid_code.add_layer("elementwise_div",
                                  inputs=inputs,
                                  output=node,
                                  param_attr=None)

J
jiangjiajun 已提交
113 114 115 116 117 118 119 120 121 122
    def Relu(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        node.fluid_code.add_layer("relu",
                                  inputs=input,
                                  output=node,
                                  param_attr=None)

    def Squeeze(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        squeeze_dims = node.get_attr('squeeze_dims')
J
jiangjiajun 已提交
123
        attr = {'axes': squeeze_dims}
J
jiangjiajun 已提交
124 125 126 127 128 129
        node.fluid_code.add_layer("squeeze",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)

    def BiasAdd(self, node):
J
jiangjiajun 已提交
130 131 132
        input = self.graph.get_node(node.layer.input[0], copy=True)
        bias = self.graph.get_node(node.layer.input[1], copy=True)
        inputs = {'x': input, 'y': bias}
J
jiangjiajun 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
        node.fluid_code.add_layer("elementwise_add",
                                  inputs=inputs,
                                  output=node,
                                  param_attr=None)

    def Identity(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        node.fluid_code.add_layer("assign",
                                  inputs=input,
                                  output=node,
                                  param_attr=None)

    def MaxPool(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        in_shape = input.out_shapes[0]
        k_size = node.get_attr("ksize")
        strides = node.get_attr("strides")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
J
jiangjiajun 已提交
152
        channel_first = data_format == "NCHW"
J
jiangjiajun 已提交
153

J
jiangjiajun 已提交
154
        if not channel_first:
J
jiangjiajun 已提交
155 156 157 158 159 160 161 162 163
            attr = {"perm": [0, 3, 1, 2]}
            node.fluid_code.add_layer("transpose",
                                      inputs=input,
                                      output=node,
                                      param_attr=attr)
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]

        if pad_mode == "SAME":
J
jiangjiajun 已提交
164 165
            pad_h = get_same_padding(in_shape[2], k_size[0], strides[2])
            pad_w = get_same_padding(in_shape[3], k_size[1], strides[3])
J
jiangjiajun 已提交
166 167 168
            pad_h = pad_h[0] + pad_h[1]
            pad_w = pad_w[0] + pad_w[1]
            attr = {"paddings": [0, pad_h, 0, pad_w], "pad_value": -10000.0}
J
jiangjiajun 已提交
169 170 171 172 173 174
            if pad_h + pad_w != 0:
                node.fluid_code.add_layer(
                    "pad2d",
                    inputs=input if channel_first else node,
                    output=node,
                    param_attr=attr)
J
jiangjiajun 已提交
175 176 177
        attr = {
            "pool_size": k_size[1:3],
            "pool_type": string("max"),
J
jiangjiajun 已提交
178
            "pool_stride": strides[2:4]
J
jiangjiajun 已提交
179 180
        }
        node.fluid_code.add_layer("pool2d",
J
jiangjiajun 已提交
181
                                  inputs=input if channel_first else node,
J
jiangjiajun 已提交
182 183 184
                                  output=node,
                                  param_attr=attr)

J
jiangjiajun 已提交
185
        if not channel_first:
J
jiangjiajun 已提交
186
            attr = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
            node.fluid_code.add_layer("transpose",
                                      inputs=node,
                                      output=node,
                                      param_attr=attr)

    def Conv2D(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        kernel = self.graph.get_node(node.layer.input[1], copy=True)
        assert kernel.layer_type == "Const", "Kernel of Conv2D should be Const"
        self.omit_nodes.append(kernel.layer_name)

        in_shape = input.out_shapes[0]
        k_size = kernel.out_shapes[0]
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
        channel_first = data_format == "NCHW"

        if not channel_first:
            self.weights[kernel.layer_name] = numpy.transpose(
                kernel.value, (3, 2, 0, 1))
            attr = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
210 211 212 213
            node.fluid_code.add_layer("transpose",
                                      inputs=input,
                                      output=node,
                                      param_attr=attr)
J
jiangjiajun 已提交
214 215 216
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
217

J
jiangjiajun 已提交
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
        if pad_mode == "SAME":
            pad_h = get_same_padding(in_shape[2], k_size[0], strides[2])
            pad_w = get_same_padding(in_shape[3], k_size[1], strides[3])
            attr = {"paddings": pad_h + pad_w, "pad_value": 0.0}
            if pad_h[0] + pad_h[1] + pad_w[0] + pad_w[1] != 0:
                node.fluid_code.add_layer(
                    "pad2d",
                    inputs=input if channel_first else node,
                    output=node,
                    param_attr=attr)
        attr = {
            "bias_attr": False,
            "param_attr": string(kernel.layer_name),
            "num_filters": k_size[3],
            "filter_size": k_size[0:2],
            "stride": strides[2:4],
            "dilation": dilations[2:4]
        }
        node.fluid_code.add_layer("conv2d",
                                  inputs=input if channel_first else node,
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
240

J
jiangjiajun 已提交
241 242 243 244 245 246
        if not channel_first:
            attr = {"perm": [0, 2, 3, 1]}
            node.fluid_code.add_layer("transpose",
                                      inputs=node,
                                      output=node,
                                      param_attr=attr)