paddle_emitter.py 43.1 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import tensor_util
J
jiangjiajun 已提交
17 18
from utils import *
from functools import *
J
jiangjiajun 已提交
19 20
from six import string_types as _string_types
import framework_pb2 as framework
J
jiangjiajun 已提交
21
import logging
J
jiangjiajun 已提交
22 23 24
import math
import struct
import numpy
J
jiangjiajun 已提交
25
logging.basicConfig(level=logging.DEBUG)
J
jiangjiajun 已提交
26 27


J
jiangjiajun 已提交
28
class PaddleEmitter(object):
J
jiangjiajun 已提交
29 30 31
    def __init__(self, parser, save_dir):
        self.graph = parser.tf_graph
        self.weights = parser.weights
J
jiangjiajun 已提交
32 33
        self.infer = parser.infer
        self.inputs_sample_data = dict()
J
jiangjiajun 已提交
34 35 36 37
        self.save_dir = save_dir
        self.body_code = ""
        self.tab = " " * 4

J
jiangjiajun 已提交
38 39 40 41 42 43 44 45 46 47 48 49
        self.outputs = parser.outputs
        self.inputs = parser.inputs
        outputs = list()
        for output in self.outputs:
            while True:
                if output in self.graph.identity_relation:
                    output = self.graph.identity_relation[output]
                else:
                    break
            outputs.append(output)
        self.outputs = outputs

J
jiangjiajun 已提交
50 51 52 53
    @staticmethod
    def compute_padding_size(in_size, filter_size, stride):
        new_size = int(math.ceil(in_size * 1.0 / stride))
        pad_size = (new_size - 1) * stride + filter_size - in_size
J
jiangjiajun 已提交
54 55 56
        pad_0 = int(pad_size / 2)
        pad_1 = pad_size - pad_0
        return [pad_0, pad_1]
J
jiangjiajun 已提交
57

J
modify  
jiangjiajun 已提交
58 59 60 61 62 63 64 65 66
    def check_op(self, node_name_list):
        uncovered_ops = set()
        for name in node_name_list:
            node = self.graph.get_node(name)
            if len(node.inputs) == 0 and len(node.outputs) == 0:
                continue
            if not hasattr(self, "emit_" + node.layer_type):
                uncovered_ops.add(node.layer_type)
        if len(uncovered_ops) > 0:
J
jiangjiajun 已提交
67
            logging.error("{} OP are not supported".format(len(uncovered_ops)))
J
modify  
jiangjiajun 已提交
68
            for op in uncovered_ops:
J
jiangjiajun 已提交
69 70 71
                logging.error("Unsupported OP: {}".format(op))
            return False
        return True
J
modify  
jiangjiajun 已提交
72

J
jiangjiajun 已提交
73
    # trick method to solve NHWC problem
J
jiangjiajun 已提交
74
    def get_axis(self, node1, node2):
J
jiangjiajun 已提交
75 76 77
        shape1 = node1.shape_dim_size
        shape2 = node2.shape_dim_size
        if shape1 == 4 and shape2 == 1 and node1.data_format == NHWC:
J
jiangjiajun 已提交
78
            axis = 1
J
jiangjiajun 已提交
79
        elif shape2 == 4 and shape1 == 1 and node2.data_format == NHWC:
J
jiangjiajun 已提交
80 81 82 83 84
            axis = 1
        else:
            axis = -1
        return axis

J
jiangjiajun 已提交
85 86 87 88 89 90 91 92 93 94 95
    def elementwise(self, node, op):
        data1 = node.inputs[0]
        data2 = node.inputs[1]
        axis = self.get_axis(data1, data2)
        shape1 = self.infer.get_tensor_shape(data1.layer)
        shape2 = self.infer.get_tensor_shape(data2.layer)

        op = "elementwise_" + op
        if shape2.shape[0] == shape1.shape[0]:
            if (shape1 == shape2).all():
                param_attr = {
J
jiangjiajun 已提交
96 97
                    'x': data1.ref_name,
                    'y': data2.ref_name,
J
jiangjiajun 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
                }
                node.code.add_layer(op, None, node.output_name, param_attr)
                return

            index1_not_one = list(numpy.argwhere(shape1 != 1).flatten())
            index1_one = list(numpy.argwhere(shape1 == 1).flatten())
            perm1 = range(shape1.shape[0])
            perm2 = range(shape1.shape[0])
            if len(index1_one) != 0:
                perm1 = index1_not_one + index1_one

            index2_not_one = list(numpy.argwhere(shape2 != 1).flatten())
            index2_one = list(numpy.argwhere(shape2 == 1).flatten())
            if len(index2_one) != 0:
                perm2 = index2_not_one + index2_one

            perm = list(numpy.array(perm1)[numpy.array(perm2)])
            if perm != range(shape1.shape[0]):
J
jiangjiajun 已提交
116 117 118 119 120
                param_attr = {"perm": perm}
                node.code.add_layer("transpose", data1.ref_name, "temp1",
                                    param_attr)
                node.code.add_layer("transpose", data2.ref_name, "temp2",
                                    param_attr)
J
jiangjiajun 已提交
121
                if len(index2_one) > len(index1_one):
J
jiangjiajun 已提交
122
                    param_attr = {"x": "temp1", "y": "temp2"}
J
jiangjiajun 已提交
123
                else:
J
jiangjiajun 已提交
124
                    param_attr = {"x": "temp2", "y": "temp1"}
J
jiangjiajun 已提交
125
                node.code.add_layer(op, None, node.output_name, param_attr)
J
jiangjiajun 已提交
126 127 128 129
                perm = sorted(range(len(perm)), key=lambda k: perm[k])
                param_attr = {"perm": perm}
                node.code.add_layer("transpose", node.output_name,
                                    node.output_name, param_attr)
J
jiangjiajun 已提交
130 131
            else:
                if len(index2_one) > len(index1_one):
J
jiangjiajun 已提交
132
                    param_attr = {"x": data1.ref_name, "y": data2.ref_name}
J
jiangjiajun 已提交
133
                else:
J
jiangjiajun 已提交
134
                    param_attr = {"x": data2.ref_name, "y": data1.ref_name}
J
jiangjiajun 已提交
135 136
                node.code.add_layer(op, None, node.output_name, param_attr)
        else:
J
jiangjiajun 已提交
137 138 139 140 141
            param_attr = {
                "x": data1.ref_name,
                "y": data2.ref_name,
                "axis": axis
            }
J
jiangjiajun 已提交
142 143
            if shape2.shape[0] > shape1.shape[0]:
                param_attr = {
J
jiangjiajun 已提交
144 145 146
                    "x": data2.ref_name,
                    "y": data1.ref_name,
                    "axis": axis
J
jiangjiajun 已提交
147 148 149
                }
            node.code.add_layer(op, None, node.output_name, param_attr)

J
jiangjiajun 已提交
150
    def export_weights(self, weight, paddle_var_name, dir):
J
jiangjiajun 已提交
151
        self.save_var_set.add(paddle_var_name)
J
jiangjiajun 已提交
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
        numpy_dtype_map = {
            "int16": framework.VarType.INT16,
            "int32": framework.VarType.INT32,
            "int64": framework.VarType.INT64,
            "float16": framework.VarType.FP16,
            "float32": framework.VarType.FP32,
            "float64": framework.VarType.FP64
        }
        struct_write_format = {
            "int16": "h",
            "int32": "i",
            "int64": "q",
            "float16": "e",
            "float32": "f",
            "float64": "d"
        }
        shape = weight.shape
        filew = open(dir + "/" + paddle_var_name, "wb")
        filew.write(struct.pack('i', 0))
        filew.write(struct.pack('L', 0))
        filew.write(struct.pack('i', 0))
        tensor_desc = framework.VarType.TensorDesc()
        if str(weight.dtype) in numpy_dtype_map:
            tensor_desc.data_type = numpy_dtype_map[str(weight.dtype)]
        else:
            raise Exception("Unexpected array dtype [{}]".format(weight.dtype))
        tensor_desc.dims.extend(shape)
        desc_size = tensor_desc.ByteSize()
        filew.write(struct.pack('i', desc_size))
        filew.write(tensor_desc.SerializeToString())
J
jiangjiajun 已提交
182
        weight.tofile(filew)
J
jiangjiajun 已提交
183 184 185 186 187 188 189
        filew.close()

    @property
    def header_code(self):
        code = list()
        code.append("import paddle.fluid.layers as layers")
        code.append("import paddle.fluid as fluid")
J
jiangjiajun 已提交
190
        code.append("import numpy")
J
jiangjiajun 已提交
191
        code.append("")
J
jiangjiajun 已提交
192 193
        code.append("class Model(object):")
        code.append("    def build(self):")
J
jiangjiajun 已提交
194 195 196 197 198
        return code

    def add_codes(self, indent, codes):
        if isinstance(codes, _string_types):
            codes = codes.strip().split("\n")
J
jiangjiajun 已提交
199 200
        if not isinstance(codes, list):
            raise Exception("Unexpected error!")
J
jiangjiajun 已提交
201 202 203 204 205 206 207
        for code in codes:
            self.body_code += (self.tab * indent) + code + "\n"

    def run(self):
        node = self.graph.tf_graph.node[0]
        self.add_codes(0, self.header_code)

J
jiangjiajun 已提交
208 209
        self.save_var_set = set()

J
jiangjiajun 已提交
210 211 212 213 214
        # filter branch nodes, like 'split:1'
        translate_nodes = []
        for node in self.graph.topological_sort:
            if node.count(':') == 0:
                translate_nodes.append(node)
J
jiangjiajun 已提交
215

J
jiangjiajun 已提交
216 217 218 219
        # check if exists unsupported OPs in model
        if not self.check_op(translate_nodes):
            return

J
jiangjiajun 已提交
220
        # ref_name.txt record relationship between
J
jiangjiajun 已提交
221
        # paddle value name and tensorflow value name
J
jiangjiajun 已提交
222
        ref_name_recorder = open(self.save_dir + "/ref_name.info", 'w')
J
jiangjiajun 已提交
223 224

        total_nodes_num = len(translate_nodes)
J
jiangjiajun 已提交
225
        translated_nodes_count = 1
J
jiangjiajun 已提交
226 227 228
        for node in translate_nodes:
            logging.info("TotalNum:{},TraslatedNum:{},CurrentNode:{}".format(
                total_nodes_num, translated_nodes_count, node))
J
jiangjiajun 已提交
229 230 231
            current_node = self.graph.get_node(node)
            ref_name_recorder.write("{}\t{}\n".format(
                current_node.layer_name, current_node.output_name))
J
jiangjiajun 已提交
232
            translated_nodes_count += 1
J
jiangjiajun 已提交
233 234 235 236 237 238 239 240 241

            # skip isolated nodes
            if len(current_node.inputs) == 0 and len(
                    current_node.outputs) == 0:
                continue

            op = current_node.layer_type
            if hasattr(self, "emit_" + op):
                func = getattr(self, "emit_" + op)
J
jiangjiajun 已提交
242
                func(current_node)
J
jiangjiajun 已提交
243 244 245 246
            else:
                raise Exception("Unknow node op: {}".format(op))
        ref_name_recorder.close()

J
jiangjiajun 已提交
247 248 249
        # merge all the generated python codes
        for node in translate_nodes:
            codes = self.graph.get_node(node).code.gen_codes()
J
jiangjiajun 已提交
250
            self.add_codes(2, codes)
J
jiangjiajun 已提交
251

J
jiangjiajun 已提交
252
        # add return value codes
J
jiangjiajun 已提交
253
        outs = []
J
jiangjiajun 已提交
254
        for node in self.outputs:
J
jiangjiajun 已提交
255
            outs.append(self.graph.get_node(node).output_name)
J
jiangjiajun 已提交
256 257 258 259 260 261 262 263
            self.add_codes(
                2, "# {} : {}".format(
                    self.graph.get_node(node).output_name,
                    self.graph.get_node(node).layer_name))
        input_code = "self.inputs = {}".format([str(s) for s in self.inputs])
        output_code = "self.outputs = [{}]".format(", ".join(outs))
        self.add_codes(2, input_code)
        self.add_codes(2, output_code)
J
jiangjiajun 已提交
264

J
jiangjiajun 已提交
265
        # write python code to file "my_model.py"
J
jiangjiajun 已提交
266 267 268
        filew = open(self.save_dir + "/mymodel.py", 'w')
        filew.write(self.body_code)
        filew.close()
J
jiangjiajun 已提交
269

J
jiangjiajun 已提交
270
        # file "save_var.list" records name of dumped variables
J
jiangjiajun 已提交
271 272 273 274
        filew = open(self.save_dir + "/save_var.list", 'w')
        for var in self.save_var_set:
            filew.write(var + '\n')
        filew.close()
J
jiangjiajun 已提交
275

J
jiangjiajun 已提交
276
        logging.info("Model translated!")
J
jiangjiajun 已提交
277 278 279
        return self.body_code

    def emit_placeholder(self, node):
J
jiangjiajun 已提交
280 281 282 283 284 285 286 287 288 289 290 291
        shape = list(self.infer.get_tensor_shape(node.layer))

        self.inputs_sample_data[node.layer_name] = []
        if shape[0] < 0 or shape[0] is None:
            self.batch_node = node
            for i in range(1, 4):
                sample_data = numpy.random.random_sample([i] + shape[1:])
                self.inputs_sample_data[node.layer_name].append(sample_data)
        else:
            for i in range(1, 4):
                sample_data = numpy.random.random_sample(shape)
                self.inputs_sample_data[node.layer_name].append(sample_data)
J
jiangjiajun 已提交
292

J
jiangjiajun 已提交
293
        if node.data_format == NHWC and len(shape) == 4:
J
jiangjiajun 已提交
294 295
            shape = [shape[0], shape[3], shape[1], shape[2]]

J
jiangjiajun 已提交
296
        param_attr = {
J
jiangjiajun 已提交
297 298 299 300
            "name": "\'{}\'".format(node.ref_name),
            "shape": shape,
            "dtype": "\'{}\'".format(node.dtype),
            "append_batch_size": False
J
jiangjiajun 已提交
301 302
        }
        node.code.add_layer("data", None, node.output_name, param_attr)
J
jiangjiajun 已提交
303 304

    def emit_const(self, node):
J
jiangjiajun 已提交
305 306
        value = self.infer.get_const_tensor_value(node.layer)
        shape = list(value.shape)
J
jiangjiajun 已提交
307

J
jiangjiajun 已提交
308 309
        try:
            dtype = node.dtype
J
jiangjiajun 已提交
310 311 312
        except:
            return []

J
jiangjiajun 已提交
313 314
        node.code.add_str("#{} {} {}".format(node.layer_name, node.ref_name,
                                             value.shape))
J
jiangjiajun 已提交
315
        if value.size == 1:
J
jiangjiajun 已提交
316
            param_attr = {
J
jiangjiajun 已提交
317
                "shape": [1],
J
jiangjiajun 已提交
318
                "value": value.flatten()[0],
J
jiangjiajun 已提交
319
                "dtype": "\'{}\'".format(dtype),
J
jiangjiajun 已提交
320
            }
J
jiangjiajun 已提交
321 322
            node.code.add_layer("fill_constant", None, node.output_name,
                                param_attr)
J
jiangjiajun 已提交
323 324
        else:
            param_attr = {
J
jiangjiajun 已提交
325 326 327
                "shape": shape,
                "name": "\'{}\'".format(node.ref_name),
                "dtype": "\'{}\'".format(dtype)
J
jiangjiajun 已提交
328 329 330 331
            }
            if node.dtype.startswith('int'):
                param_attr["default_initializer"] = \
                "fluid.initializer.Constant(0)"
J
jiangjiajun 已提交
332 333
            node.code.add_layer("create_parameter", None, node.output_name,
                                param_attr)
J
jiangjiajun 已提交
334 335
            self.export_weights(value, node.ref_name, self.save_dir)

J
jiangjiajun 已提交
336 337 338
    def emit_conv2d(self, node):
        data = node.inputs[0]
        kernel = node.inputs[1]
J
jiangjiajun 已提交
339

J
jiangjiajun 已提交
340
        if len(kernel.outputs) == 1:
J
jiangjiajun 已提交
341
            kernel.code.clear()
J
jiangjiajun 已提交
342 343

        padding_mode = node.get_attr("padding")
J
jiangjiajun 已提交
344 345 346
        strides = node.get_attr("strides")[2:4]
        k_shape = list(self.infer.get_tensor_shape(kernel.layer))
        input_shape = list(self.infer.get_tensor_shape(data.layer))
J
jiangjiajun 已提交
347
        input_h, input_w = input_shape[2:4]
J
jiangjiajun 已提交
348 349
        k_h, k_w, channel, kernel_num = k_shape
        if node.data_format == NHWC:
J
jiangjiajun 已提交
350
            input_h, input_w = input_shape[1:3]
J
jiangjiajun 已提交
351
            strides = node.get_attr("strides")[1:3]
J
jiangjiajun 已提交
352 353

        if kernel.layer_name in self.weights:
J
jiangjiajun 已提交
354 355 356
            weight = self.weights[kernel.layer_name]
            self.weights[kernel.layer_name] = numpy.transpose(
                weight, (3, 2, 0, 1))
J
jiangjiajun 已提交
357 358 359
            self.export_weights(self.weights[kernel.layer_name],
                                kernel.ref_name, self.save_dir)

J
jiangjiajun 已提交
360
        conv2d_param = {
J
jiangjiajun 已提交
361 362 363 364 365
            "num_filters": kernel_num,
            "filter_size": [k_h, k_w],
            "stride": strides,
            "param_attr": "\'{}\'".format(kernel.ref_name),
            "bias_attr": False
J
jiangjiajun 已提交
366 367 368 369 370 371 372
        }

        if padding_mode == SAME:
            pad_h = self.compute_padding_size(input_h, k_h, strides[0])
            pad_w = self.compute_padding_size(input_w, k_w, strides[1])
            if len(set(pad_h)) == 1 and len(set(pad_w)) == 1:
                conv2d_param["padding"] = [pad_h[0], pad_w[0]]
J
jiangjiajun 已提交
373 374
                node.code.add_layer("conv2d", data.ref_name, node.output_name,
                                    conv2d_param)
J
jiangjiajun 已提交
375
            else:
J
jiangjiajun 已提交
376 377 378 379 380
                pad_param = {"paddings": pad_h + pad_w}
                node.code.add_layer("pad2d", data.ref_name, node.output_name,
                                    pad_param)
                node.code.add_layer("conv2d", node.output_name,
                                    node.output_name, conv2d_param)
J
jiangjiajun 已提交
381
        else:
J
jiangjiajun 已提交
382 383
            node.code.add_layer("conv2d", data.ref_name, node.output_name,
                                conv2d_param)
J
jiangjiajun 已提交
384 385

    def emit_variablev2(self, node):
J
jiangjiajun 已提交
386
        shape = list(self.infer.get_tensor_shape(node.layer))
J
jiangjiajun 已提交
387

J
jiangjiajun 已提交
388 389
        node.code.add_str("# variable[{}]:\t{}".format(node.output_name,
                                                       node.layer_name))
J
jiangjiajun 已提交
390 391

        if node.layer_name in self.weights:
J
jiangjiajun 已提交
392 393
            self.export_weights(self.weights[node.layer_name], node.ref_name,
                                self.save_dir)
J
jiangjiajun 已提交
394 395

        param_attr = {
J
jiangjiajun 已提交
396 397 398
            "name": "\'{}\'".format(node.ref_name),
            "shape": shape,
            "dtype": "\'{}\'".format(node.dtype)
J
jiangjiajun 已提交
399 400 401
        }
        if node.dtype.startswith('int'):
            param_attr["default_initializer"] = "fluid.initializer.Constant(0)"
J
jiangjiajun 已提交
402 403
        node.code.add_layer("create_parameter", None, node.output_name,
                            param_attr)
J
jiangjiajun 已提交
404 405 406 407 408 409

    def emit_biasadd(self, node):
        data = node.inputs[0]
        bias = node.inputs[1]

        if bias.layer_name in self.weights:
J
jiangjiajun 已提交
410 411
            self.export_weights(self.weights[bias.layer_name], bias.ref_name,
                                self.save_dir)
J
jiangjiajun 已提交
412 413

        self.emit_variablev2(bias)
J
jiangjiajun 已提交
414 415 416
        param_attr = {"x": data.ref_name, "y": bias.ref_name, "axis": 1}
        node.code.add_layer("elementwise_add", None, node.output_name,
                            param_attr)
J
jiangjiajun 已提交
417 418 419

    def emit_relu(self, node):
        data = node.inputs[0]
J
jiangjiajun 已提交
420
        node.code.add_layer("relu", data.ref_name, node.output_name)
J
jiangjiajun 已提交
421 422 423 424

    def emit_maxpool(self, node):
        data = node.inputs[0]
        padding_mode = node.get_attr("padding")
J
jiangjiajun 已提交
425
        input_shape = list(self.infer.get_tensor_shape(data.layer))
J
jiangjiajun 已提交
426 427 428 429 430 431 432 433 434
        input_h, input_w = input_shape[2:4]
        strides = node.get_attr("strides")[2:4]
        pool_size = node.get_attr("ksize")[2:4]
        if node.data_format == NHWC:
            input_h, input_w = input_shape[1:3]
            strides = node.get_attr("strides")[1:3]
            pool_size = node.get_attr("ksize")[1:3]

        pool_param = {
J
jiangjiajun 已提交
435 436
            "pool_size": pool_size,
            "pool_type": "\'max\'",
J
jiangjiajun 已提交
437
            "pool_stride": strides,
J
jiangjiajun 已提交
438 439 440
        }

        if padding_mode == SAME:
J
jiangjiajun 已提交
441 442 443 444
            pad_h = self.compute_padding_size(input_h, pool_size[0],
                                              strides[0])
            pad_w = self.compute_padding_size(input_w, pool_size[1],
                                              strides[1])
J
jiangjiajun 已提交
445 446 447 448 449 450 451 452 453 454 455
#            pad_right = pad_w[0] + pad_w[1]
#            pad_bottom = pad_h[0] + pad_h[1]
            if (pad_h[0] + pad_h[1]) % 2 != 0:
                pad_h[1] += pad_h[0]
                pad_h[0] = 0
            if (pad_w[0] + pad_w[1]) % 2 != 0:
                pad_w[1] += pad_w[0]
                pad_w[0] = 0
            #padding = [0, pad_bottom, 0, pad_right]
            padding = pad_h + pad_w
            pad_param = {"paddings": padding, "pad_value":-1000000.0}
J
jiangjiajun 已提交
456 457 458 459
            node.code.add_layer("pad2d", data.ref_name, node.output_name,
                                pad_param)
            node.code.add_layer("pool2d", node.output_name, node.output_name,
                                pool_param)
J
jiangjiajun 已提交
460
        else:
J
jiangjiajun 已提交
461 462
            node.code.add_layer("pool2d", data.ref_name, node.output_name,
                                pool_param)
J
jiangjiajun 已提交
463 464 465 466

    def emit_squeeze(self, node):
        data = node.inputs[0]
        axis = node.get_attr("squeeze_dims")
J
jiangjiajun 已提交
467 468
        input_shape_len = data.shape_dim_size
        if node.data_format == NHWC and input_shape_len == 4:
J
jiangjiajun 已提交
469
            for i in range(0, len(axis)):
J
jiangjiajun 已提交
470 471
                if axis[i] > 0:
                    axis[i] = (axis[i] + 1) % 4 + int((axis[i] + 1) / 4)
J
jiangjiajun 已提交
472 473 474
        param_attr = {"axes": axis}
        node.code.add_layer("squeeze", data.ref_name, node.output_name,
                            param_attr)
J
jiangjiajun 已提交
475 476

    def emit_add(self, node):
J
jiangjiajun 已提交
477
        return self.elementwise(node, "add")
J
jiangjiajun 已提交
478 479 480 481

    def emit_mean(self, node):
        data = node.inputs[0]
        reduce_idx = node.inputs[1]
J
jiangjiajun 已提交
482
        reduce_idx.code.clear()
J
jiangjiajun 已提交
483 484
        idxs = list(
            self.infer.get_const_tensor_value(reduce_idx.layer).flatten())
J
jiangjiajun 已提交
485
        data_shape_len = data.shape_dim_size
J
jiangjiajun 已提交
486
        keep_dims = node.layer.attr['keep_dims'].b
J
jiangjiajun 已提交
487 488 489 490
        if node.data_format == NHWC and data_shape_len == 4:
            for i in range(len(idxs)):
                if idxs[i] > 0:
                    idxs[i] = (idxs[i] + 1) % 4 + int((idxs[i] + 1) / 4)
J
jiangjiajun 已提交
491 492 493
        param_attr = {"dim": list(idxs), "keep_dim": keep_dims}
        node.code.add_layer("reduce_mean", data.ref_name, node.output_name,
                            param_attr)
J
jiangjiajun 已提交
494 495 496 497 498 499 500 501

    def emit_fusedbatchnorm(self, node):
        data = node.inputs[0]
        gamma = node.inputs[1]
        beta = node.inputs[2]
        moving_mean = node.inputs[3]
        moving_variance = node.inputs[4]
        if len(gamma.outputs) == 1:
J
jiangjiajun 已提交
502
            gamma.code.clear()
J
jiangjiajun 已提交
503
        if len(beta.outputs) == 1:
J
jiangjiajun 已提交
504
            beta.code.clear()
J
jiangjiajun 已提交
505
        if len(moving_mean.outputs) == 1:
J
jiangjiajun 已提交
506
            moving_mean.code.clear()
J
jiangjiajun 已提交
507
        if len(moving_variance.outputs) == 1:
J
jiangjiajun 已提交
508
            moving_variance.code.clear()
J
jiangjiajun 已提交
509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525

        epsilon = round(node.get_attr('epsilon'), 6)
        is_training = node.get_attr('is_training')

        if gamma.layer_name in self.weights:
            self.export_weights(self.weights[gamma.layer_name], gamma.ref_name,
                                self.save_dir)
        if beta.layer_name in self.weights:
            self.export_weights(self.weights[beta.layer_name], beta.ref_name,
                                self.save_dir)
        if moving_mean.layer_name in self.weights:
            self.export_weights(self.weights[moving_mean.layer_name],
                                moving_mean.ref_name, self.save_dir)
        if moving_variance.layer_name in self.weights:
            self.export_weights(self.weights[moving_variance.layer_name],
                                moving_variance.ref_name, self.save_dir)

J
jiangjiajun 已提交
526
        param_attr = {
J
jiangjiajun 已提交
527 528 529 530 531
            "epsilon": epsilon,
            "param_attr": "\'{}\'".format(gamma.ref_name),
            "bias_attr": "\'{}\'".format(beta.ref_name),
            "moving_mean_name": "\'{}\'".format(moving_mean.ref_name),
            "moving_variance_name": "\'{}\'".format(moving_variance.ref_name),
J
jiangjiajun 已提交
532 533
            "is_test": not is_training
        }
J
jiangjiajun 已提交
534 535
        node.code.add_layer("batch_norm", data.ref_name, node.output_name,
                            param_attr)
J
jiangjiajun 已提交
536 537

    def emit_concatv2(self, node):
J
jiangjiajun 已提交
538
        input_shape_len = node.inputs[0].shape_dim_size
J
jiangjiajun 已提交
539
        axis = node.inputs[-1]
J
jiangjiajun 已提交
540 541 542 543 544
        axis.code.clear()
        axis = self.infer.get_const_tensor_value(axis.layer)
        if axis < 0:
            axis = input_shape_len + axis
        if node.data_format == NHWC and input_shape_len == 4:
J
jiangjiajun 已提交
545
            if axis > 0:
J
jiangjiajun 已提交
546
                axis = (axis + 1) % 4 + int((axis + 1) / 4)
J
jiangjiajun 已提交
547
        num_tensor = len(node.inputs) - 1
J
jiangjiajun 已提交
548 549
        input_list = [input.ref_name for input in node.inputs[:num_tensor]]
        input_list = "[{}]".format(", ".join(input_list))
J
jiangjiajun 已提交
550
        param_attr = {"axis": axis}
J
jiangjiajun 已提交
551
        node.code.add_layer("concat", input_list, node.output_name, param_attr)
J
jiangjiajun 已提交
552 553 554 555

    def emit_avgpool(self, node):
        data = node.inputs[0]
        padding_mode = node.get_attr("padding")
J
jiangjiajun 已提交
556 557 558 559 560 561 562 563 564 565 566
        input_shape = list(self.infer.get_tensor_shape(data.layer))
        strides = node.get_attr("strides")[2:4]
        pool_size = node.get_attr("ksize")[2:4]
        input_h, input_w = input_shape[2:4]

        if node.data_format == NHWC:
            strides = node.get_attr("strides")[1:3]
            pool_size = node.get_attr("ksize")[1:3]
            input_h, input_w = input_shape[1:3]

        param_attr = {
J
jiangjiajun 已提交
567 568 569
            "pool_size": pool_size,
            "pool_stride": strides,
            "pool_type": "\'avg\'"
J
jiangjiajun 已提交
570 571 572
        }

        if padding_mode == SAME:
J
jiangjiajun 已提交
573 574 575 576
            pad_h = self.compute_padding_size(input_h, pool_size[0],
                                              strides[0])
            pad_w = self.compute_padding_size(input_w, pool_size[1],
                                              strides[0])
J
jiangjiajun 已提交
577 578 579 580
            if len(set(pad_h)) == 1 and len(set(pad_w)) == 1:
                padding = [pad_h[0], pad_w[0]]
                param_attr["pool_padding"] = padding
            else:
J
jiangjiajun 已提交
581 582 583 584 585
                pad_param = {"paddings": pad_h + pad_w}
                node.code.add_layer("pad2d", data.ref_name, node.output_name,
                                    pad_param)
                node.code.add_layer("pool2d", node.output_name,
                                    node.output_name, param_attr)
J
jiangjiajun 已提交
586
                return
J
jiangjiajun 已提交
587
        node.code.add_layer("pool2d", data.ref_name, node.output_name,
J
jiangjiajun 已提交
588
                                param_attr)
J
jiangjiajun 已提交
589 590 591

    def emit_rsqrt(self, node):
        data = node.inputs[0]
J
jiangjiajun 已提交
592
        pow_param = {"factor": -1.0}
J
jiangjiajun 已提交
593
        node.code.add_layer("sqrt", data.ref_name, node.output_name)
J
jiangjiajun 已提交
594 595
        node.code.add_layer("pow", node.output_name, node.output_name,
                            pow_param)
J
jiangjiajun 已提交
596 597

    def emit_mul(self, node):
J
jiangjiajun 已提交
598
        return self.elementwise(node, "mul")
J
jiangjiajun 已提交
599 600 601 602 603

    def emit_sub(self, node):
        data1 = node.inputs[0]
        data2 = node.inputs[1]
        axis = self.get_axis(data1, data2)
J
jiangjiajun 已提交
604 605
        data1_shape = list(self.infer.get_tensor_shape(data1.layer))
        data2_shape = list(self.infer.get_tensor_shape(data2.layer))
J
jiangjiajun 已提交
606
        param_attr = {"x": data1.ref_name, "y": data2.ref_name, "axis": axis}
J
jiangjiajun 已提交
607 608 609
        if len(data1_shape) == 4 and len(data2_shape) == 4 \
            and node.data_format == NHWC:
            if data1_shape[-1] != data2_shape[-1]:
J
jiangjiajun 已提交
610 611 612 613 614 615 616 617 618
                node.code.add_layer("transpose", data1.ref_name, "temp1",
                                    {"perm": [0, 2, 3, 1]})
                node.code.add_layer("transpose", data2.ref_name, "temp2",
                                    {"perm": [0, 2, 3, 1]})
                param_attr = {"x": "temp1", "y": "temp2", "axis": -1}
                node.code.add_layer("elementwise_sub", None, node.output_name,
                                    param_attr)
                node.code.add_layer("transpose", node.output_name,
                                    node.output_name, {"perm": [0, 3, 1, 2]})
J
jiangjiajun 已提交
619
        else:
J
jiangjiajun 已提交
620 621
            node.code.add_layer("elementwise_sub", None, node.output_name,
                                param_attr)
J
jiangjiajun 已提交
622 623 624

    def emit_shape(self, node):
        data = node.inputs[0]
J
jiangjiajun 已提交
625 626
        input_shape_len = data.shape_dim_size
        if input_shape_len == 4 and node.data_format == NHWC:
J
jiangjiajun 已提交
627 628 629
            param = {"perm": [0, 2, 3, 1]}
            node.code.add_layer("transpose", data.ref_name, node.output_name,
                                param)
J
jiangjiajun 已提交
630 631 632
            node.code.add_layer("shape", node.output_name, node.output_name)
        else:
            node.code.add_layer("shape", data.ref_name, node.output_name)
J
jiangjiajun 已提交
633
        param = {"dtype": "\'int32\'"}
J
jiangjiajun 已提交
634
        node.code.add_layer("cast", node.output_name, node.output_name, param)
J
jiangjiajun 已提交
635 636 637 638

    def emit_pad(self, node):
        data = node.inputs[0]
        padding = node.inputs[1]
J
jiangjiajun 已提交
639
        padding.code.clear()
J
jiangjiajun 已提交
640
        padding = padding.layer.attr['value'].tensor
J
jiangjiajun 已提交
641 642
        padding = tensor_util.MakeNdarray(padding).astype('int32')
        if node.data_format == NHWC and padding.shape[0] == 4:
J
jiangjiajun 已提交
643
            padding = padding[[0, 3, 1, 2]]
J
jiangjiajun 已提交
644
        param_attr = {"paddings": list(padding.flatten())}
J
jiangjiajun 已提交
645
        node.code.add_layer("pad", data.ref_name, node.output_name, param_attr)
J
jiangjiajun 已提交
646 647 648 649 650 651

    def emit_stridedslice(self, node):
        data = node.inputs[0]
        begin = node.inputs[1]
        end = node.inputs[2]
        strides = node.inputs[3]
J
jiangjiajun 已提交
652 653 654 655 656 657
        begin.code.clear()
        end.code.clear()
        strides.code.clear()

        begin = list(self.infer.get_const_tensor_value(begin.layer).flatten())
        end = list(self.infer.get_const_tensor_value(end.layer).flatten())
J
jiangjiajun 已提交
658 659
        strides = list(
            self.infer.get_const_tensor_value(strides.layer).flatten())
J
jiangjiajun 已提交
660 661 662

        for i in range(len(strides)):
            assert strides[i] == 1
J
jiangjiajun 已提交
663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678

        if len(set(end)) == 1 and end[0] == 0:
            output_shape = list(self.infer.get_tensor_shape(node.layer))
            if node.data_format == NHWC and len(output_shape) == 4:
                output_shape = [output_shape[0],
                                output_shape[3],
                                output_shape[1],
                                output_shape[2]]
                begin = [begin[0], begin[3], begin[1], begin[2]]
            param = {"shape":output_shape, "offsets":begin}
            node.code.add_layer("crop", data.ref_name, 
                            node.output_name, param)
        else:
            param = {"axes": range(len(begin)), "starts": begin, "ends": end}
            node.code.add_layer("slice", data.ref_name, 
                            node.output_name, param)
J
jiangjiajun 已提交
679 680 681

    def emit_resizenearestneighbor(self, node):
        data = node.inputs[0]
J
jiangjiajun 已提交
682 683
        resize_shape = node.inputs[1]
        resize_shape.code.clear()
J
jiangjiajun 已提交
684 685
        align_corners = node.get_attr('align_corners')

J
jiangjiajun 已提交
686 687
        resize_shape = list(self.infer.get_shape_tensor(resize_shape.layer))
        param_attr = {
J
jiangjiajun 已提交
688 689 690 691 692
            "align_corners": align_corners,
            "out_shape": resize_shape
        }
        node.code.add_layer("resize_nearest", data.ref_name, node.output_name,
                            param_attr)
J
jiangjiajun 已提交
693 694

    def emit_maximum(self, node):
J
jiangjiajun 已提交
695
        return self.elementwise(node, "max")
J
jiangjiajun 已提交
696 697

    def emit_minimum(self, node):
J
jiangjiajun 已提交
698
        return self.elementwise(node, "min")
J
jiangjiajun 已提交
699 700 701

    def emit_sigmoid(self, node):
        data = node.inputs[0]
J
jiangjiajun 已提交
702
        node.code.add_layer("sigmoid", data.ref_name, node.output_name)
J
jiangjiajun 已提交
703 704

    def emit_pack(self, node):
J
jiangjiajun 已提交
705 706 707
        inputs = [input.ref_name for input in node.inputs]
        inputs = "[{}]".format(", ".join(inputs))
        node.code.add_layer("stack", inputs, node.output_name)
J
jiangjiajun 已提交
708 709 710 711

    def emit_reshape(self, node):
        data = node.inputs[0]
        shape = node.inputs[1]
J
jiangjiajun 已提交
712 713 714 715 716
        input_shape_len = data.shape_dim_size
        output_shape = list(self.infer.get_tensor_shape(node.layer))

        shape = self.infer.get_shape_tensor(shape.layer, output_shape)

J
jiangjiajun 已提交
717
        reshape_param = {"shape": list(shape)}
J
jiangjiajun 已提交
718
        if node.data_format == NHWC and input_shape_len == 4:
J
jiangjiajun 已提交
719 720 721 722 723
            param_attr = {"perm": [0, 2, 3, 1]}
            node.code.add_layer("transpose", data.ref_name, node.output_name,
                                param_attr)
            node.code.add_layer("reshape", node.output_name, node.output_name,
                                reshape_param)
J
jiangjiajun 已提交
724
            if len(shape) == 4:
J
jiangjiajun 已提交
725 726 727
                param_attr = {"perm": [0, 3, 1, 2]}
                node.code.add_layer("transpose", node.output_name,
                                    node.output_name, param_attr)
J
jiangjiajun 已提交
728
        else:
J
jiangjiajun 已提交
729 730
            node.code.add_layer("reshape", data.ref_name, node.output_name,
                                reshape_param)
J
jiangjiajun 已提交
731 732 733 734 735 736

    def emit_conv2dbackpropinput(self, node):
        output_shape = node.inputs[0]
        kernel = node.inputs[1]
        data = node.inputs[2]
        if len(kernel.outputs) == 1:
J
jiangjiajun 已提交
737 738
            kernel.code.clear()
        output_shape.code.clear()
J
jiangjiajun 已提交
739
        padding_mode = node.get_attr("padding")
J
jiangjiajun 已提交
740 741 742 743 744
        strides = node.get_attr("strides")[2:4]
        k_shape = self.infer.get_tensor_shape(kernel.layer)
        k_h, k_w, k_num, channel = k_shape
        if node.data_format == NHWC:
            strides = node.get_attr("strides")[1:3]
J
jiangjiajun 已提交
745 746

        padding = [0, 0]
J
jiangjiajun 已提交
747 748 749
        if padding_mode == SAME:
            padding = [int(val) for val in [(k_h - strides[0]) / 2, \
                (k_w - strides[1]) / 2]]
J
jiangjiajun 已提交
750 751

        if kernel.layer_name in self.weights:
J
jiangjiajun 已提交
752 753 754 755 756
            weight = self.weights[kernel.layer_name]
            self.weights[kernel.layer_name] = numpy.transpose(
                weight, (3, 2, 0, 1))
            self.export_weights(self.weights[kernel.layer_name],
                                kernel.ref_name, self.save_dir)
J
jiangjiajun 已提交
757

J
jiangjiajun 已提交
758 759
        output_shape = list(self.infer.get_shape_tensor(output_shape.layer))
        if node.data_format == NHWC and len(output_shape) == 4:
J
jiangjiajun 已提交
760 761 762 763
            output_shape = [
                output_shape[0], output_shape[3], output_shape[1],
                output_shape[2]
            ]
J
jiangjiajun 已提交
764 765

        param_attr = {
J
jiangjiajun 已提交
766 767 768 769 770 771
            "num_filters": k_num,
            "filter_size": [k_h, k_w],
            "padding": padding,
            "stride": strides,
            "param_attr": "\'{}\'".format(kernel.ref_name),
            "bias_attr": False
J
jiangjiajun 已提交
772
        }
J
jiangjiajun 已提交
773 774
        node.code.add_layer("conv2d_transpose", data.ref_name,
                            node.output_name, param_attr)
J
jiangjiajun 已提交
775
        if padding_mode == SAME:
J
jiangjiajun 已提交
776 777 778
            param_attr = {"shape": list(output_shape)}
            node.code.add_layer("crop", node.output_name, node.output_name,
                                param_attr)
J
jiangjiajun 已提交
779 780 781 782 783

    def emit_depthwiseconv2dnative(self, node):
        data = node.inputs[0]
        kernel = node.inputs[1]
        if len(kernel.outputs) == 1:
J
jiangjiajun 已提交
784
            kernel.code.clear()
J
jiangjiajun 已提交
785 786

        padding_mode = node.get_attr("padding")
J
jiangjiajun 已提交
787 788 789
        strides = node.get_attr("strides")[2:4]
        k_shape = self.infer.get_tensor_shape(kernel.layer)
        input_shape = self.infer.get_tensor_shape(data.layer)
J
jiangjiajun 已提交
790
        input_h, input_w = input_shape[2:4]
J
jiangjiajun 已提交
791 792 793
        k_h, k_w, in_channels, channel_multiplier = k_shape
        if node.data_format == NHWC:
            strides = node.get_attr("strides")[1:3]
J
jiangjiajun 已提交
794 795 796 797
            input_h, input_w = input_shape[1:3]
        groups = channel_multiplier * in_channels

        if kernel.layer_name in self.weights:
J
jiangjiajun 已提交
798 799 800
            weight = self.weights[kernel.layer_name]
            self.weights[kernel.layer_name] = numpy.transpose(
                weight, (2, 3, 0, 1))
J
jiangjiajun 已提交
801 802
            self.export_weights(self.weights[kernel.layer_name],
                                kernel.ref_name, self.save_dir)
J
jiangjiajun 已提交
803
        conv_param = {
J
jiangjiajun 已提交
804 805 806 807 808 809
            "num_filters": in_channels,
            "filter_size": [k_h, k_w],
            "stride": strides,
            "groups": groups,
            "param_attr": "\'{}\'".format(kernel.ref_name),
            "bias_attr": False
J
jiangjiajun 已提交
810 811 812 813 814 815 816
        }
        if padding_mode == SAME:
            pad_h = self.compute_padding_size(input_h, k_h, strides[0])
            pad_w = self.compute_padding_size(input_w, k_w, strides[1])
            if len(set(pad_h)) == 1 and len(set(pad_w)) == 1:
                padding = [pad_h[0], pad_w[0]]
                conv_param["padding"] = padding
J
jiangjiajun 已提交
817 818
                node.code.add_layer("conv2d", data.ref_name, node.output_name,
                                    conv_param)
J
jiangjiajun 已提交
819
            else:
J
jiangjiajun 已提交
820 821 822 823 824
                pad_param = {"paddings": pad_h + pad_w}
                node.code.add_layer("pad2d", data.ref_name, node.output_name,
                                    pad_param)
                node.code.add_layer("conv2d", node.output_name,
                                    node.output_name, conv_param)
J
jiangjiajun 已提交
825
        else:
J
jiangjiajun 已提交
826 827
            node.code.add_layer("conv2d", data.ref_name, node.output_name,
                                conv_param)
J
modify  
jiangjiajun 已提交
828 829 830

    def emit_softmax(self, node):
        data = node.inputs[0]
J
jiangjiajun 已提交
831 832 833 834 835 836 837 838
        node.code.add_layer("softmax", data.ref_name, node.output_name)

    def emit_matmul(self, node):
        data0 = node.inputs[0]
        data1 = node.inputs[1]
        transpose_a = node.get_attr('transpose_a')
        transpose_b = node.get_attr('transpose_b')
        param_attr = {
J
jiangjiajun 已提交
839 840 841 842
            "x": data0.ref_name,
            "y": data1.ref_name,
            "transpose_x": transpose_a,
            "transpose_y": transpose_b
J
jiangjiajun 已提交
843 844 845 846 847 848 849 850 851 852 853
        }
        node.code.add_layer("matmul", None, node.output_name, param_attr)

    def emit_transpose(self, node):
        data = node.inputs[0]
        perm = node.inputs[1]
        perm.code.clear()
        perm = list(self.infer.get_shape_tensor(perm.layer))
        if node.data_format == NHWC and len(perm) == 4:
            if perm == [0, 3, 1, 2]:
                self.graph.set_data_format(node, NCHW)
J
jiangjiajun 已提交
854 855
                node.code.add_str("{} = {}".format(node.output_name,
                                                   data.ref_name))
J
jiangjiajun 已提交
856 857 858 859 860
            else:
                raise Exception("Unexpected situation in OP transpose")
        elif node.data_format == NCHW and len(perm) == 4:
            if perm == [0, 2, 3, 1]:
                self.graph.set_data_format(node, NHWC)
J
jiangjiajun 已提交
861 862
                node.code.add_str("{} = {}".format(node.output_name,
                                                   data.ref_name))
J
jiangjiajun 已提交
863 864 865
            else:
                raise Exception("Unexpected situation in OP transpose")
        else:
J
jiangjiajun 已提交
866 867 868
            param_attr = {"perm": perm}
            node.code.add_layer("transpose", data.ref_name, node.output_name,
                                param_attr)
J
jiangjiajun 已提交
869 870 871 872 873 874 875 876 877

    def emit_randomuniform(self, node):
        shape = node.inputs[0]
        shape = self.infer.get_shape_tensor(shape.layer)
        if node.data_format == NHWC and len(shape) == 4:
            shape = shape[[0, 3, 1, 2]]
        batch_index = list(numpy.argwhere(shape < 0).flatten())
        shape = list(shape)
        param_attr = {
J
jiangjiajun 已提交
878 879 880 881
            "shape": shape,
            "dtype": "\'float32\'",
            "min": 0.00001,
            "max": 0.99999
J
jiangjiajun 已提交
882 883 884 885
        }
        if len(batch_index) > 1:
            raise Exception("More than one dimension value less than zero")
        if len(batch_index) == 0:
J
jiangjiajun 已提交
886 887
            node.code.add_layer("uniform_random", None, node.output_name,
                                param_attr)
J
jiangjiajun 已提交
888 889
        else:
            param_attr["input_dim_idx"] = batch_index[0]
J
jiangjiajun 已提交
890 891 892
            node.code.add_layer("uniform_random_batch_size_like",
                                self.batch_node.ref_name, node.output_name,
                                param_attr)
J
jiangjiajun 已提交
893 894 895 896 897 898 899 900 901 902 903

    def emit_floor(self, node):
        data = node.inputs[0]
        node.code.add_layer("floor", data.ref_name, node.output_name)

    def emit_exp(self, node):
        data = node.inputs[0]
        node.code.add_layer("exp", data.ref_name, node.output_name)

    def emit_floordiv(self, node):
        self.emit_div(node)
J
jiangjiajun 已提交
904
        param = {"dtype": "\'float32\'"}
J
jiangjiajun 已提交
905 906 907 908 909 910 911 912 913
        node.code.add_layer("cast", node.output_name, node.output_name, param)
        node.code.add_layer("floor", node.output_name, node.output_name)

    def emit_div(self, node):
        data1 = node.inputs[0]
        data2 = node.inputs[1]
        axis = self.get_axis(data1, data2)
        data1_shape = self.infer.get_tensor_shape(data1.layer)
        data2_shape = self.infer.get_tensor_shape(data2.layer)
J
jiangjiajun 已提交
914
        div_param = {"x": data1.ref_name, "y": data2.ref_name, "axis": axis}
J
jiangjiajun 已提交
915 916 917
        if len(data1_shape) == 4 and len(data2_shape) == 4 \
            and node.data_format == NHWC:
            if data1_shape[-1] != data2_shape[-1]:
J
jiangjiajun 已提交
918
                perm = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
919 920 921 922 923
                node.code.add_layer("transpose", data1.ref_name, "temp1", perm)
                node.code.add_layer("transpose", data2.ref_name, "temp2", perm)
                div_param["x"] = "temp1"
                div_param["y"] = "temp2"
                div_param["axis"] = -1
J
jiangjiajun 已提交
924 925
        node.code.add_layer("elementwise_div", None, node.output_name,
                            div_param)
J
jiangjiajun 已提交
926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947

    def emit_realdiv(self, node):
        return self.emit_div(node)

    def emit_slice(self, node):
        data = node.inputs[0]
        begin = node.inputs[1]
        size = node.inputs[2]
        begin.code.clear()
        size.code.clear()
        begin = list(self.infer.get_shape_tensor(begin.layer))
        size = list(self.infer.get_shape_tensor(size.layer))

        input_shape = self.infer.get_tensor_shape(data.layer)
        if len(numpy.argwhere(input_shape < 0).flatten()) > 1:
            input_shape = list(self.infer.get_tensor_shape(data.layer))

        assert len(begin) == len(input_shape) and len(size) == len(input_shape)

        if node.data_format == NHWC and len(input_shape) == 4:
            begin = [begin[0], begin[3], begin[1], begin[2]]
            size = [size[0], size[3], size[1], size[2]]
J
jiangjiajun 已提交
948 949 950
            input_shape = [
                input_shape[0], input_shape[3], input_shape[1], input_shape[2]
            ]
J
jiangjiajun 已提交
951 952 953 954

        for i in range(len(size)):
            if size[i] < 0:
                size[i] = input_shape[i] - begin[i]
J
jiangjiajun 已提交
955 956 957
        param_attr = {"shape": size, "offsets": begin}
        node.code.add_layer("crop", data.ref_name, node.output_name,
                            param_attr)
J
jiangjiajun 已提交
958 959 960 961 962

    def emit_sum(self, node):
        data = node.inputs[0]
        reduce_idx = node.inputs[1]
        reduce_idx.code.clear()
J
jiangjiajun 已提交
963 964
        idxs = tensor_util.MakeNdarray(
            reduce_idx.layer.attr['value'].tensor).astype('int32').flatten()
J
jiangjiajun 已提交
965 966 967 968 969 970
        data_shape_len = data.shape_dim_size
        keep_dims = node.layer.attr['keep_dims'].b
        if node.data_format == NHWC and data_shape_len == 4:
            for i in range(idxs.shape[0]):
                if idxs[i] > 0:
                    idxs[i] = (idxs[i] + 1) % 4 + int((idxs[i] + 1) / 4)
J
jiangjiajun 已提交
971 972 973
        param = {"dim": list(idxs), "keep_dim": keep_dims}
        node.code.add_layer("reduce_sum", data.ref_name, node.output_name,
                            param)
J
jiangjiajun 已提交
974 975 976 977 978

    def emit_max(self, node):
        data = node.inputs[0]
        reduce_idx = node.inputs[1]
        reduce_idx.code.clear()
J
jiangjiajun 已提交
979 980
        idxs = tensor_util.MakeNdarray(
            reduce_idx.layer.attr['value'].tensor).astype('int32').flatten()
J
jiangjiajun 已提交
981 982 983 984 985 986
        data_shape_len = data.shape_dim_size
        keep_dims = node.layer.attr['keep_dims'].b
        if node.data_format == NHWC and data_shape_len == 4:
            for i in range(idxs.shape[0]):
                if idxs[i] > 0:
                    idxs[i] = (idxs[i] + 1) % 4 + int((idxs[i] + 1) / 4)
J
jiangjiajun 已提交
987 988 989
        param = {"dim": list(idxs), "keep_dim": keep_dims}
        node.code.add_layer("reduce_max", data.ref_name, node.output_name,
                            param)
J
jiangjiajun 已提交
990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005

    def emit_fill(self, node):
        shape = node.inputs[0]
        shape.code.clear()
        value = node.inputs[1]
        value.code.clear()

        shape = list(self.infer.get_shape_tensor(shape.layer))
        value = list(self.infer.get_const_tensor_value(value.layer).flatten())
        assert len(value) == 1
        value = value[0]

        if node.data_format == NHWC and len(shape) == 4:
            shape = [shape[0], shape[3], shape[1], shape[2]]

        param = {
J
jiangjiajun 已提交
1006 1007 1008
            "shape": shape,
            "dtype": "\'{}\'".format(value.dtype),
            "value": value
J
jiangjiajun 已提交
1009 1010
        }
        if shape[0] < 0:
J
jiangjiajun 已提交
1011 1012 1013
            node.code.add_layer("fill_constant_batch_size_like",
                                self.batch_node.ref_name, node.output_name,
                                param)
J
jiangjiajun 已提交
1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036
        else:
            node.code.add_layer("fill_constant", None, node.output_name, param)

    def emit_range(self, node):
        start = node.inputs[0]
        end = node.inputs[1]
        delta = node.inputs[2]
        start.code.clear()
        end.code.clear()
        delta.code.clear()

        start = self.infer.get_const_tensor_value(start.layer)
        end = self.infer.get_const_tensor_value(end.layer)
        delta = self.infer.get_const_tensor_value(delta.layer)
        np_code = "np_array = numpy.arange({}, {}, {}).astype(\'{}\')".format(
            start, end, delta, delta.dtype)
        node.code.add_str(np_code)
        node.code.add_layer("assign", "np_array", node.output_name)

    def emit_tile(self, node):
        data = node.inputs[0]
        expand_times = node.inputs[1]
        expand_times.code.clear()
J
jiangjiajun 已提交
1037 1038 1039
        expand_times = list(
            self.infer.get_const_tensor_value(expand_times.layer))
        param = {"expand_times": expand_times}
J
jiangjiajun 已提交
1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051
        node.code.add_layer("expand", data.ref_name, node.output_name, param)

    def emit_splitv(self, node):
        data = node.inputs[0]
        num_sections = node.inputs[1]
        num_sections.code.clear()
        split_dim = node.inputs[2]
        split_dim.code.clear()
        num_sections = self.infer.get_const_tensor_value(num_sections.layer)
        split_dim = self.infer.get_const_tensor_value(split_dim.layer)
        input_shape = self.infer.get_tensor_shape(data.layer)
        if split_dim < 0:
J
jiangjiajun 已提交
1052 1053
            split_dim += len(input_shape)

J
jiangjiajun 已提交
1054 1055 1056 1057
        index = numpy.argwhere(num_sections < 0).flatten()
        if index.shape[0] > 1:
            raise Exception("More than one dimension less than 0")
        if index.shape[0] == 1:
J
jiangjiajun 已提交
1058 1059 1060
            num_sections[index[0]] = input_shape[split_dim] - numpy.sum(
                num_sections) + num_sections[index[0]]
        param = {"num_or_sections": list(num_sections), "dim": split_dim}
J
jiangjiajun 已提交
1061
        node.code.add_layer("split", data.ref_name, node.output_name, param)
J
jiangjiajun 已提交
1062 1063 1064 1065 1066 1067 1068 1069

    def emit_expanddims(self, node):
        data = node.inputs[0]
        dim = node.inputs[1]
        dim.code.clear()
        dim = self.infer.get_const_tensor_value(dim.layer)
        param = {"axes":[dim]}
        node.code.add_layer("unsqueeze", data.ref_name, node.output_name, param)
J
jiangjiajun 已提交
1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080

    def emit_cast(self, node):
        data = node.inputs[0]
        dtype_map = {1: "float32", 3: "int32", 9: "int64"}
        dtype = node.get_attr("DstT")
        if dtype in dtype_map:
            dtype = dtype_map[dtype]
        else:
            raise Exception("Unknow dtype: {}".format(dtype))
        param = {"dtype":"\'{}\'".format(dtype)}
        node.code.add_layer("cast", data.ref_name, node.output_name, param)