paddle_emitter.py 42.8 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import tensor_util
J
jiangjiajun 已提交
17 18
from utils import *
from functools import *
J
jiangjiajun 已提交
19 20
from six import string_types as _string_types
import framework_pb2 as framework
J
jiangjiajun 已提交
21
import logging
J
jiangjiajun 已提交
22 23 24
import math
import struct
import numpy
J
jiangjiajun 已提交
25
logging.basicConfig(level=logging.DEBUG)
J
jiangjiajun 已提交
26 27


J
jiangjiajun 已提交
28
class PaddleEmitter(object):
J
jiangjiajun 已提交
29 30 31
    def __init__(self, parser, save_dir):
        self.graph = parser.tf_graph
        self.weights = parser.weights
J
jiangjiajun 已提交
32 33
        self.infer = parser.infer
        self.inputs_sample_data = dict()
J
jiangjiajun 已提交
34 35 36 37
        self.save_dir = save_dir
        self.body_code = ""
        self.tab = " " * 4

J
jiangjiajun 已提交
38 39 40 41 42 43 44 45 46 47 48 49
        self.outputs = parser.outputs
        self.inputs = parser.inputs
        outputs = list()
        for output in self.outputs:
            while True:
                if output in self.graph.identity_relation:
                    output = self.graph.identity_relation[output]
                else:
                    break
            outputs.append(output)
        self.outputs = outputs

J
jiangjiajun 已提交
50 51 52 53
    @staticmethod
    def compute_padding_size(in_size, filter_size, stride):
        new_size = int(math.ceil(in_size * 1.0 / stride))
        pad_size = (new_size - 1) * stride + filter_size - in_size
J
jiangjiajun 已提交
54 55 56
        pad_0 = int(pad_size / 2)
        pad_1 = pad_size - pad_0
        return [pad_0, pad_1]
J
jiangjiajun 已提交
57

J
modify  
jiangjiajun 已提交
58 59 60 61 62 63 64 65 66
    def check_op(self, node_name_list):
        uncovered_ops = set()
        for name in node_name_list:
            node = self.graph.get_node(name)
            if len(node.inputs) == 0 and len(node.outputs) == 0:
                continue
            if not hasattr(self, "emit_" + node.layer_type):
                uncovered_ops.add(node.layer_type)
        if len(uncovered_ops) > 0:
J
jiangjiajun 已提交
67
            logging.error("{} OP are not supported".format(len(uncovered_ops)))
J
modify  
jiangjiajun 已提交
68
            for op in uncovered_ops:
J
jiangjiajun 已提交
69 70 71
                logging.error("Unsupported OP: {}".format(op))
            return False
        return True
J
modify  
jiangjiajun 已提交
72

J
jiangjiajun 已提交
73
    # trick method to solve NHWC problem
J
jiangjiajun 已提交
74
    def get_axis(self, node1, node2):
J
jiangjiajun 已提交
75 76 77
        shape1 = node1.shape_dim_size
        shape2 = node2.shape_dim_size
        if shape1 == 4 and shape2 == 1 and node1.data_format == NHWC:
J
jiangjiajun 已提交
78
            axis = 1
J
jiangjiajun 已提交
79
        elif shape2 == 4 and shape1 == 1 and node2.data_format == NHWC:
J
jiangjiajun 已提交
80 81 82 83 84
            axis = 1
        else:
            axis = -1
        return axis

J
jiangjiajun 已提交
85 86 87 88 89 90 91 92 93 94 95
    def elementwise(self, node, op):
        data1 = node.inputs[0]
        data2 = node.inputs[1]
        axis = self.get_axis(data1, data2)
        shape1 = self.infer.get_tensor_shape(data1.layer)
        shape2 = self.infer.get_tensor_shape(data2.layer)

        op = "elementwise_" + op
        if shape2.shape[0] == shape1.shape[0]:
            if (shape1 == shape2).all():
                param_attr = {
J
jiangjiajun 已提交
96 97
                    'x': data1.ref_name,
                    'y': data2.ref_name,
J
jiangjiajun 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
                }
                node.code.add_layer(op, None, node.output_name, param_attr)
                return

            index1_not_one = list(numpy.argwhere(shape1 != 1).flatten())
            index1_one = list(numpy.argwhere(shape1 == 1).flatten())
            perm1 = range(shape1.shape[0])
            perm2 = range(shape1.shape[0])
            if len(index1_one) != 0:
                perm1 = index1_not_one + index1_one

            index2_not_one = list(numpy.argwhere(shape2 != 1).flatten())
            index2_one = list(numpy.argwhere(shape2 == 1).flatten())
            if len(index2_one) != 0:
                perm2 = index2_not_one + index2_one

            perm = list(numpy.array(perm1)[numpy.array(perm2)])
            if perm != range(shape1.shape[0]):
J
jiangjiajun 已提交
116 117 118 119 120
                param_attr = {"perm": perm}
                node.code.add_layer("transpose", data1.ref_name, "temp1",
                                    param_attr)
                node.code.add_layer("transpose", data2.ref_name, "temp2",
                                    param_attr)
J
jiangjiajun 已提交
121
                if len(index2_one) > len(index1_one):
J
jiangjiajun 已提交
122
                    param_attr = {"x": "temp1", "y": "temp2"}
J
jiangjiajun 已提交
123
                else:
J
jiangjiajun 已提交
124
                    param_attr = {"x": "temp2", "y": "temp1"}
J
jiangjiajun 已提交
125
                node.code.add_layer(op, None, node.output_name, param_attr)
J
jiangjiajun 已提交
126 127 128 129
                perm = sorted(range(len(perm)), key=lambda k: perm[k])
                param_attr = {"perm": perm}
                node.code.add_layer("transpose", node.output_name,
                                    node.output_name, param_attr)
J
jiangjiajun 已提交
130 131
            else:
                if len(index2_one) > len(index1_one):
J
jiangjiajun 已提交
132
                    param_attr = {"x": data1.ref_name, "y": data2.ref_name}
J
jiangjiajun 已提交
133
                else:
J
jiangjiajun 已提交
134
                    param_attr = {"x": data2.ref_name, "y": data1.ref_name}
J
jiangjiajun 已提交
135 136
                node.code.add_layer(op, None, node.output_name, param_attr)
        else:
J
jiangjiajun 已提交
137 138 139 140 141
            param_attr = {
                "x": data1.ref_name,
                "y": data2.ref_name,
                "axis": axis
            }
J
jiangjiajun 已提交
142 143
            if shape2.shape[0] > shape1.shape[0]:
                param_attr = {
J
jiangjiajun 已提交
144 145 146
                    "x": data2.ref_name,
                    "y": data1.ref_name,
                    "axis": axis
J
jiangjiajun 已提交
147 148 149
                }
            node.code.add_layer(op, None, node.output_name, param_attr)

J
jiangjiajun 已提交
150
    def export_weights(self, weight, paddle_var_name, dir):
J
jiangjiajun 已提交
151
        self.save_var_set.add(paddle_var_name)
J
jiangjiajun 已提交
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
        numpy_dtype_map = {
            "int16": framework.VarType.INT16,
            "int32": framework.VarType.INT32,
            "int64": framework.VarType.INT64,
            "float16": framework.VarType.FP16,
            "float32": framework.VarType.FP32,
            "float64": framework.VarType.FP64
        }
        struct_write_format = {
            "int16": "h",
            "int32": "i",
            "int64": "q",
            "float16": "e",
            "float32": "f",
            "float64": "d"
        }
        shape = weight.shape
        filew = open(dir + "/" + paddle_var_name, "wb")
        filew.write(struct.pack('i', 0))
        filew.write(struct.pack('L', 0))
        filew.write(struct.pack('i', 0))
        tensor_desc = framework.VarType.TensorDesc()
        if str(weight.dtype) in numpy_dtype_map:
            tensor_desc.data_type = numpy_dtype_map[str(weight.dtype)]
        else:
            raise Exception("Unexpected array dtype [{}]".format(weight.dtype))
        tensor_desc.dims.extend(shape)
        desc_size = tensor_desc.ByteSize()
        filew.write(struct.pack('i', desc_size))
        filew.write(tensor_desc.SerializeToString())
J
jiangjiajun 已提交
182 183 184 185 186 187 188 189
        if len(shape) == 0:
            if weight.size == 1:
                tensor_size = 1
                weight = numpy.array([weight])
            else:
                tensor_size = 0
        else:
            tensor_size = reduce(lambda x, y: x * y, shape)
J
jiangjiajun 已提交
190 191
        weight = weight.flatten()
        for i in range(0, tensor_size):
J
jiangjiajun 已提交
192 193
            filew.write(
                struct.pack(struct_write_format[str(weight.dtype)], weight[i]))
J
jiangjiajun 已提交
194 195 196 197 198 199 200
        filew.close()

    @property
    def header_code(self):
        code = list()
        code.append("import paddle.fluid.layers as layers")
        code.append("import paddle.fluid as fluid")
J
jiangjiajun 已提交
201
        code.append("import numpy")
J
jiangjiajun 已提交
202
        code.append("")
J
jiangjiajun 已提交
203 204
        code.append("class Model(object):")
        code.append("    def build(self):")
J
jiangjiajun 已提交
205 206 207 208 209
        return code

    def add_codes(self, indent, codes):
        if isinstance(codes, _string_types):
            codes = codes.strip().split("\n")
J
jiangjiajun 已提交
210 211
        if not isinstance(codes, list):
            raise Exception("Unexpected error!")
J
jiangjiajun 已提交
212 213 214 215 216 217 218
        for code in codes:
            self.body_code += (self.tab * indent) + code + "\n"

    def run(self):
        node = self.graph.tf_graph.node[0]
        self.add_codes(0, self.header_code)

J
jiangjiajun 已提交
219 220
        self.save_var_set = set()

J
jiangjiajun 已提交
221 222 223 224 225
        # filter branch nodes, like 'split:1'
        translate_nodes = []
        for node in self.graph.topological_sort:
            if node.count(':') == 0:
                translate_nodes.append(node)
J
jiangjiajun 已提交
226

J
jiangjiajun 已提交
227 228 229 230
        # check if exists unsupported OPs in model
        if not self.check_op(translate_nodes):
            return

J
jiangjiajun 已提交
231
        # ref_name.txt record relationship between
J
jiangjiajun 已提交
232
        # paddle value name and tensorflow value name
J
jiangjiajun 已提交
233
        ref_name_recorder = open(self.save_dir + "/ref_name.info", 'w')
J
jiangjiajun 已提交
234 235

        total_nodes_num = len(translate_nodes)
J
jiangjiajun 已提交
236
        translated_nodes_count = 1
J
jiangjiajun 已提交
237 238 239
        for node in translate_nodes:
            logging.info("TotalNum:{},TraslatedNum:{},CurrentNode:{}".format(
                total_nodes_num, translated_nodes_count, node))
J
jiangjiajun 已提交
240 241 242
            current_node = self.graph.get_node(node)
            ref_name_recorder.write("{}\t{}\n".format(
                current_node.layer_name, current_node.output_name))
J
jiangjiajun 已提交
243
            translated_nodes_count += 1
J
jiangjiajun 已提交
244 245 246 247 248 249 250 251 252

            # skip isolated nodes
            if len(current_node.inputs) == 0 and len(
                    current_node.outputs) == 0:
                continue

            op = current_node.layer_type
            if hasattr(self, "emit_" + op):
                func = getattr(self, "emit_" + op)
J
jiangjiajun 已提交
253
                func(current_node)
J
jiangjiajun 已提交
254 255 256 257
            else:
                raise Exception("Unknow node op: {}".format(op))
        ref_name_recorder.close()

J
jiangjiajun 已提交
258 259 260
        # merge all the generated python codes
        for node in translate_nodes:
            codes = self.graph.get_node(node).code.gen_codes()
J
jiangjiajun 已提交
261
            self.add_codes(2, codes)
J
jiangjiajun 已提交
262

J
jiangjiajun 已提交
263
        # add return value codes
J
jiangjiajun 已提交
264
        outs = []
J
jiangjiajun 已提交
265
        for node in self.outputs:
J
jiangjiajun 已提交
266
            outs.append(self.graph.get_node(node).output_name)
J
jiangjiajun 已提交
267 268 269 270 271 272 273 274
            self.add_codes(
                2, "# {} : {}".format(
                    self.graph.get_node(node).output_name,
                    self.graph.get_node(node).layer_name))
        input_code = "self.inputs = {}".format([str(s) for s in self.inputs])
        output_code = "self.outputs = [{}]".format(", ".join(outs))
        self.add_codes(2, input_code)
        self.add_codes(2, output_code)
J
jiangjiajun 已提交
275

J
jiangjiajun 已提交
276
        # write python code to file "my_model.py"
J
jiangjiajun 已提交
277 278 279
        filew = open(self.save_dir + "/mymodel.py", 'w')
        filew.write(self.body_code)
        filew.close()
J
jiangjiajun 已提交
280

J
jiangjiajun 已提交
281
        # file "save_var.list" records name of dumped variables
J
jiangjiajun 已提交
282 283 284 285
        filew = open(self.save_dir + "/save_var.list", 'w')
        for var in self.save_var_set:
            filew.write(var + '\n')
        filew.close()
J
jiangjiajun 已提交
286

J
jiangjiajun 已提交
287
        logging.info("Model translated!")
J
jiangjiajun 已提交
288 289 290
        return self.body_code

    def emit_placeholder(self, node):
J
jiangjiajun 已提交
291 292 293 294 295 296 297 298 299 300 301 302
        shape = list(self.infer.get_tensor_shape(node.layer))

        self.inputs_sample_data[node.layer_name] = []
        if shape[0] < 0 or shape[0] is None:
            self.batch_node = node
            for i in range(1, 4):
                sample_data = numpy.random.random_sample([i] + shape[1:])
                self.inputs_sample_data[node.layer_name].append(sample_data)
        else:
            for i in range(1, 4):
                sample_data = numpy.random.random_sample(shape)
                self.inputs_sample_data[node.layer_name].append(sample_data)
J
jiangjiajun 已提交
303

J
jiangjiajun 已提交
304
        if node.data_format == NHWC and len(shape) == 4:
J
jiangjiajun 已提交
305 306
            shape = [shape[0], shape[3], shape[1], shape[2]]

J
jiangjiajun 已提交
307
        param_attr = {
J
jiangjiajun 已提交
308 309 310 311
            "name": "\'{}\'".format(node.ref_name),
            "shape": shape,
            "dtype": "\'{}\'".format(node.dtype),
            "append_batch_size": False
J
jiangjiajun 已提交
312 313
        }
        node.code.add_layer("data", None, node.output_name, param_attr)
J
jiangjiajun 已提交
314 315

    def emit_const(self, node):
J
jiangjiajun 已提交
316 317
        value = self.infer.get_const_tensor_value(node.layer)
        shape = list(value.shape)
J
jiangjiajun 已提交
318

J
jiangjiajun 已提交
319 320
        try:
            dtype = node.dtype
J
jiangjiajun 已提交
321 322 323
        except:
            return []

J
jiangjiajun 已提交
324 325
        node.code.add_str("#{} {} {}".format(node.layer_name, node.ref_name,
                                             value.shape))
J
jiangjiajun 已提交
326
        if value.size == 1:
J
jiangjiajun 已提交
327
            param_attr = {
J
jiangjiajun 已提交
328
                "shape": [1],
J
jiangjiajun 已提交
329
                "value": value.flatten()[0],
J
jiangjiajun 已提交
330
                "dtype": "\'{}\'".format(dtype),
J
jiangjiajun 已提交
331
            }
J
jiangjiajun 已提交
332 333
            node.code.add_layer("fill_constant", None, node.output_name,
                                param_attr)
J
jiangjiajun 已提交
334 335
        else:
            param_attr = {
J
jiangjiajun 已提交
336 337 338
                "shape": shape,
                "name": "\'{}\'".format(node.ref_name),
                "dtype": "\'{}\'".format(dtype)
J
jiangjiajun 已提交
339 340 341 342
            }
            if node.dtype.startswith('int'):
                param_attr["default_initializer"] = \
                "fluid.initializer.Constant(0)"
J
jiangjiajun 已提交
343 344
            node.code.add_layer("create_parameter", None, node.output_name,
                                param_attr)
J
jiangjiajun 已提交
345 346
            self.export_weights(value, node.ref_name, self.save_dir)

J
jiangjiajun 已提交
347 348 349
    def emit_conv2d(self, node):
        data = node.inputs[0]
        kernel = node.inputs[1]
J
jiangjiajun 已提交
350

J
jiangjiajun 已提交
351
        if len(kernel.outputs) == 1:
J
jiangjiajun 已提交
352
            kernel.code.clear()
J
jiangjiajun 已提交
353 354

        padding_mode = node.get_attr("padding")
J
jiangjiajun 已提交
355 356 357
        strides = node.get_attr("strides")[2:4]
        k_shape = list(self.infer.get_tensor_shape(kernel.layer))
        input_shape = list(self.infer.get_tensor_shape(data.layer))
J
jiangjiajun 已提交
358
        input_h, input_w = input_shape[2:4]
J
jiangjiajun 已提交
359 360
        k_h, k_w, channel, kernel_num = k_shape
        if node.data_format == NHWC:
J
jiangjiajun 已提交
361
            input_h, input_w = input_shape[1:3]
J
jiangjiajun 已提交
362
            strides = node.get_attr("strides")[1:3]
J
jiangjiajun 已提交
363 364

        if kernel.layer_name in self.weights:
J
jiangjiajun 已提交
365 366 367
            weight = self.weights[kernel.layer_name]
            self.weights[kernel.layer_name] = numpy.transpose(
                weight, (3, 2, 0, 1))
J
jiangjiajun 已提交
368 369 370
            self.export_weights(self.weights[kernel.layer_name],
                                kernel.ref_name, self.save_dir)

J
jiangjiajun 已提交
371
        conv2d_param = {
J
jiangjiajun 已提交
372 373 374 375 376
            "num_filters": kernel_num,
            "filter_size": [k_h, k_w],
            "stride": strides,
            "param_attr": "\'{}\'".format(kernel.ref_name),
            "bias_attr": False
J
jiangjiajun 已提交
377 378 379 380 381 382 383
        }

        if padding_mode == SAME:
            pad_h = self.compute_padding_size(input_h, k_h, strides[0])
            pad_w = self.compute_padding_size(input_w, k_w, strides[1])
            if len(set(pad_h)) == 1 and len(set(pad_w)) == 1:
                conv2d_param["padding"] = [pad_h[0], pad_w[0]]
J
jiangjiajun 已提交
384 385
                node.code.add_layer("conv2d", data.ref_name, node.output_name,
                                    conv2d_param)
J
jiangjiajun 已提交
386
            else:
J
jiangjiajun 已提交
387 388 389 390 391
                pad_param = {"paddings": pad_h + pad_w}
                node.code.add_layer("pad2d", data.ref_name, node.output_name,
                                    pad_param)
                node.code.add_layer("conv2d", node.output_name,
                                    node.output_name, conv2d_param)
J
jiangjiajun 已提交
392
        else:
J
jiangjiajun 已提交
393 394
            node.code.add_layer("conv2d", data.ref_name, node.output_name,
                                conv2d_param)
J
jiangjiajun 已提交
395 396

    def emit_variablev2(self, node):
J
jiangjiajun 已提交
397
        shape = list(self.infer.get_tensor_shape(node.layer))
J
jiangjiajun 已提交
398

J
jiangjiajun 已提交
399 400
        node.code.add_str("# variable[{}]:\t{}".format(node.output_name,
                                                       node.layer_name))
J
jiangjiajun 已提交
401 402

        if node.layer_name in self.weights:
J
jiangjiajun 已提交
403 404
            self.export_weights(self.weights[node.layer_name], node.ref_name,
                                self.save_dir)
J
jiangjiajun 已提交
405 406

        param_attr = {
J
jiangjiajun 已提交
407 408 409
            "name": "\'{}\'".format(node.ref_name),
            "shape": shape,
            "dtype": "\'{}\'".format(node.dtype)
J
jiangjiajun 已提交
410 411 412
        }
        if node.dtype.startswith('int'):
            param_attr["default_initializer"] = "fluid.initializer.Constant(0)"
J
jiangjiajun 已提交
413 414
        node.code.add_layer("create_parameter", None, node.output_name,
                            param_attr)
J
jiangjiajun 已提交
415 416 417 418 419 420

    def emit_biasadd(self, node):
        data = node.inputs[0]
        bias = node.inputs[1]

        if bias.layer_name in self.weights:
J
jiangjiajun 已提交
421 422
            self.export_weights(self.weights[bias.layer_name], bias.ref_name,
                                self.save_dir)
J
jiangjiajun 已提交
423 424

        self.emit_variablev2(bias)
J
jiangjiajun 已提交
425 426 427
        param_attr = {"x": data.ref_name, "y": bias.ref_name, "axis": 1}
        node.code.add_layer("elementwise_add", None, node.output_name,
                            param_attr)
J
jiangjiajun 已提交
428 429 430

    def emit_relu(self, node):
        data = node.inputs[0]
J
jiangjiajun 已提交
431
        node.code.add_layer("relu", data.ref_name, node.output_name)
J
jiangjiajun 已提交
432 433 434 435

    def emit_maxpool(self, node):
        data = node.inputs[0]
        padding_mode = node.get_attr("padding")
J
jiangjiajun 已提交
436
        input_shape = list(self.infer.get_tensor_shape(data.layer))
J
jiangjiajun 已提交
437 438 439 440 441 442 443 444 445
        input_h, input_w = input_shape[2:4]
        strides = node.get_attr("strides")[2:4]
        pool_size = node.get_attr("ksize")[2:4]
        if node.data_format == NHWC:
            input_h, input_w = input_shape[1:3]
            strides = node.get_attr("strides")[1:3]
            pool_size = node.get_attr("ksize")[1:3]

        pool_param = {
J
jiangjiajun 已提交
446 447
            "pool_size": pool_size,
            "pool_type": "\'max\'",
J
jiangjiajun 已提交
448
            "pool_stride": strides,
J
jiangjiajun 已提交
449 450 451
        }

        if padding_mode == SAME:
J
jiangjiajun 已提交
452 453 454 455
            pad_h = self.compute_padding_size(input_h, pool_size[0],
                                              strides[0])
            pad_w = self.compute_padding_size(input_w, pool_size[1],
                                              strides[1])
J
jiangjiajun 已提交
456 457 458 459 460 461 462 463 464 465 466
#            pad_right = pad_w[0] + pad_w[1]
#            pad_bottom = pad_h[0] + pad_h[1]
            if (pad_h[0] + pad_h[1]) % 2 != 0:
                pad_h[1] += pad_h[0]
                pad_h[0] = 0
            if (pad_w[0] + pad_w[1]) % 2 != 0:
                pad_w[1] += pad_w[0]
                pad_w[0] = 0
            #padding = [0, pad_bottom, 0, pad_right]
            padding = pad_h + pad_w
            pad_param = {"paddings": padding, "pad_value":-1000000.0}
J
jiangjiajun 已提交
467 468 469 470
            node.code.add_layer("pad2d", data.ref_name, node.output_name,
                                pad_param)
            node.code.add_layer("pool2d", node.output_name, node.output_name,
                                pool_param)
J
jiangjiajun 已提交
471
        else:
J
jiangjiajun 已提交
472 473
            node.code.add_layer("pool2d", data.ref_name, node.output_name,
                                pool_param)
J
jiangjiajun 已提交
474 475 476 477

    def emit_squeeze(self, node):
        data = node.inputs[0]
        axis = node.get_attr("squeeze_dims")
J
jiangjiajun 已提交
478 479
        input_shape_len = data.shape_dim_size
        if node.data_format == NHWC and input_shape_len == 4:
J
jiangjiajun 已提交
480
            for i in range(0, len(axis)):
J
jiangjiajun 已提交
481 482
                if axis[i] > 0:
                    axis[i] = (axis[i] + 1) % 4 + int((axis[i] + 1) / 4)
J
jiangjiajun 已提交
483 484 485
        param_attr = {"axes": axis}
        node.code.add_layer("squeeze", data.ref_name, node.output_name,
                            param_attr)
J
jiangjiajun 已提交
486 487

    def emit_add(self, node):
J
jiangjiajun 已提交
488
        return self.elementwise(node, "add")
J
jiangjiajun 已提交
489 490 491 492

    def emit_mean(self, node):
        data = node.inputs[0]
        reduce_idx = node.inputs[1]
J
jiangjiajun 已提交
493
        reduce_idx.code.clear()
J
jiangjiajun 已提交
494 495
        idxs = list(
            self.infer.get_const_tensor_value(reduce_idx.layer).flatten())
J
jiangjiajun 已提交
496
        data_shape_len = data.shape_dim_size
J
jiangjiajun 已提交
497
        keep_dims = node.layer.attr['keep_dims'].b
J
jiangjiajun 已提交
498 499 500 501
        if node.data_format == NHWC and data_shape_len == 4:
            for i in range(len(idxs)):
                if idxs[i] > 0:
                    idxs[i] = (idxs[i] + 1) % 4 + int((idxs[i] + 1) / 4)
J
jiangjiajun 已提交
502 503 504
        param_attr = {"dim": list(idxs), "keep_dim": keep_dims}
        node.code.add_layer("reduce_mean", data.ref_name, node.output_name,
                            param_attr)
J
jiangjiajun 已提交
505 506 507 508 509 510 511 512

    def emit_fusedbatchnorm(self, node):
        data = node.inputs[0]
        gamma = node.inputs[1]
        beta = node.inputs[2]
        moving_mean = node.inputs[3]
        moving_variance = node.inputs[4]
        if len(gamma.outputs) == 1:
J
jiangjiajun 已提交
513
            gamma.code.clear()
J
jiangjiajun 已提交
514
        if len(beta.outputs) == 1:
J
jiangjiajun 已提交
515
            beta.code.clear()
J
jiangjiajun 已提交
516
        if len(moving_mean.outputs) == 1:
J
jiangjiajun 已提交
517
            moving_mean.code.clear()
J
jiangjiajun 已提交
518
        if len(moving_variance.outputs) == 1:
J
jiangjiajun 已提交
519
            moving_variance.code.clear()
J
jiangjiajun 已提交
520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536

        epsilon = round(node.get_attr('epsilon'), 6)
        is_training = node.get_attr('is_training')

        if gamma.layer_name in self.weights:
            self.export_weights(self.weights[gamma.layer_name], gamma.ref_name,
                                self.save_dir)
        if beta.layer_name in self.weights:
            self.export_weights(self.weights[beta.layer_name], beta.ref_name,
                                self.save_dir)
        if moving_mean.layer_name in self.weights:
            self.export_weights(self.weights[moving_mean.layer_name],
                                moving_mean.ref_name, self.save_dir)
        if moving_variance.layer_name in self.weights:
            self.export_weights(self.weights[moving_variance.layer_name],
                                moving_variance.ref_name, self.save_dir)

J
jiangjiajun 已提交
537
        param_attr = {
J
jiangjiajun 已提交
538 539 540 541 542
            "epsilon": epsilon,
            "param_attr": "\'{}\'".format(gamma.ref_name),
            "bias_attr": "\'{}\'".format(beta.ref_name),
            "moving_mean_name": "\'{}\'".format(moving_mean.ref_name),
            "moving_variance_name": "\'{}\'".format(moving_variance.ref_name),
J
jiangjiajun 已提交
543 544
            "is_test": not is_training
        }
J
jiangjiajun 已提交
545 546
        node.code.add_layer("batch_norm", data.ref_name, node.output_name,
                            param_attr)
J
jiangjiajun 已提交
547 548

    def emit_concatv2(self, node):
J
jiangjiajun 已提交
549
        input_shape_len = node.inputs[0].shape_dim_size
J
jiangjiajun 已提交
550
        axis = node.inputs[-1]
J
jiangjiajun 已提交
551 552 553 554 555
        axis.code.clear()
        axis = self.infer.get_const_tensor_value(axis.layer)
        if axis < 0:
            axis = input_shape_len + axis
        if node.data_format == NHWC and input_shape_len == 4:
J
jiangjiajun 已提交
556
            if axis > 0:
J
jiangjiajun 已提交
557
                axis = (axis + 1) % 4 + int((axis + 1) / 4)
J
jiangjiajun 已提交
558
        num_tensor = len(node.inputs) - 1
J
jiangjiajun 已提交
559 560
        input_list = [input.ref_name for input in node.inputs[:num_tensor]]
        input_list = "[{}]".format(", ".join(input_list))
J
jiangjiajun 已提交
561
        param_attr = {"axis": axis}
J
jiangjiajun 已提交
562
        node.code.add_layer("concat", input_list, node.output_name, param_attr)
J
jiangjiajun 已提交
563 564 565 566

    def emit_avgpool(self, node):
        data = node.inputs[0]
        padding_mode = node.get_attr("padding")
J
jiangjiajun 已提交
567 568 569 570 571 572 573 574 575 576 577
        input_shape = list(self.infer.get_tensor_shape(data.layer))
        strides = node.get_attr("strides")[2:4]
        pool_size = node.get_attr("ksize")[2:4]
        input_h, input_w = input_shape[2:4]

        if node.data_format == NHWC:
            strides = node.get_attr("strides")[1:3]
            pool_size = node.get_attr("ksize")[1:3]
            input_h, input_w = input_shape[1:3]

        param_attr = {
J
jiangjiajun 已提交
578 579 580
            "pool_size": pool_size,
            "pool_stride": strides,
            "pool_type": "\'avg\'"
J
jiangjiajun 已提交
581 582 583
        }

        if padding_mode == SAME:
J
jiangjiajun 已提交
584 585 586 587
            pad_h = self.compute_padding_size(input_h, pool_size[0],
                                              strides[0])
            pad_w = self.compute_padding_size(input_w, pool_size[1],
                                              strides[0])
J
jiangjiajun 已提交
588 589 590 591
            if len(set(pad_h)) == 1 and len(set(pad_w)) == 1:
                padding = [pad_h[0], pad_w[0]]
                param_attr["pool_padding"] = padding
            else:
J
jiangjiajun 已提交
592 593 594 595 596 597 598
                pad_param = {"paddings": pad_h + pad_w}
                node.code.add_layer("pad2d", data.ref_name, node.output_name,
                                    pad_param)
                node.code.add_layer("pool2d", node.output_name,
                                    node.output_name, param_attr)
        node.code.add_layer("pool2d", data.ref_name, node.output_name,
                            param_attr)
J
jiangjiajun 已提交
599 600 601

    def emit_rsqrt(self, node):
        data = node.inputs[0]
J
jiangjiajun 已提交
602
        pow_param = {"factor": -1.0}
J
jiangjiajun 已提交
603
        node.code.add_layer("sqrt", data.ref_name, node.output_name)
J
jiangjiajun 已提交
604 605
        node.code.add_layer("pow", node.output_name, node.output_name,
                            pow_param)
J
jiangjiajun 已提交
606 607

    def emit_mul(self, node):
J
jiangjiajun 已提交
608
        return self.elementwise(node, "mul")
J
jiangjiajun 已提交
609 610 611 612 613

    def emit_sub(self, node):
        data1 = node.inputs[0]
        data2 = node.inputs[1]
        axis = self.get_axis(data1, data2)
J
jiangjiajun 已提交
614 615
        data1_shape = list(self.infer.get_tensor_shape(data1.layer))
        data2_shape = list(self.infer.get_tensor_shape(data2.layer))
J
jiangjiajun 已提交
616
        param_attr = {"x": data1.ref_name, "y": data2.ref_name, "axis": axis}
J
jiangjiajun 已提交
617 618 619
        if len(data1_shape) == 4 and len(data2_shape) == 4 \
            and node.data_format == NHWC:
            if data1_shape[-1] != data2_shape[-1]:
J
jiangjiajun 已提交
620 621 622 623 624 625 626 627 628
                node.code.add_layer("transpose", data1.ref_name, "temp1",
                                    {"perm": [0, 2, 3, 1]})
                node.code.add_layer("transpose", data2.ref_name, "temp2",
                                    {"perm": [0, 2, 3, 1]})
                param_attr = {"x": "temp1", "y": "temp2", "axis": -1}
                node.code.add_layer("elementwise_sub", None, node.output_name,
                                    param_attr)
                node.code.add_layer("transpose", node.output_name,
                                    node.output_name, {"perm": [0, 3, 1, 2]})
J
jiangjiajun 已提交
629
        else:
J
jiangjiajun 已提交
630 631
            node.code.add_layer("elementwise_sub", None, node.output_name,
                                param_attr)
J
jiangjiajun 已提交
632 633 634

    def emit_shape(self, node):
        data = node.inputs[0]
J
jiangjiajun 已提交
635 636
        input_shape_len = data.shape_dim_size
        if input_shape_len == 4 and node.data_format == NHWC:
J
jiangjiajun 已提交
637 638 639
            param = {"perm": [0, 2, 3, 1]}
            node.code.add_layer("transpose", data.ref_name, node.output_name,
                                param)
J
jiangjiajun 已提交
640 641 642
            node.code.add_layer("shape", node.output_name, node.output_name)
        else:
            node.code.add_layer("shape", data.ref_name, node.output_name)
J
jiangjiajun 已提交
643
        param = {"dtype": "\'int32\'"}
J
jiangjiajun 已提交
644
        node.code.add_layer("cast", node.output_name, node.output_name, param)
J
jiangjiajun 已提交
645 646 647 648

    def emit_pad(self, node):
        data = node.inputs[0]
        padding = node.inputs[1]
J
jiangjiajun 已提交
649
        padding.code.clear()
J
jiangjiajun 已提交
650
        padding = padding.layer.attr['value'].tensor
J
jiangjiajun 已提交
651 652
        padding = tensor_util.MakeNdarray(padding).astype('int32')
        if node.data_format == NHWC and padding.shape[0] == 4:
J
jiangjiajun 已提交
653
            padding = padding[[0, 3, 1, 2]]
J
jiangjiajun 已提交
654
        param_attr = {"paddings": list(padding.flatten())}
J
jiangjiajun 已提交
655
        node.code.add_layer("pad", data.ref_name, node.output_name, param_attr)
J
jiangjiajun 已提交
656 657 658 659 660 661

    def emit_stridedslice(self, node):
        data = node.inputs[0]
        begin = node.inputs[1]
        end = node.inputs[2]
        strides = node.inputs[3]
J
jiangjiajun 已提交
662 663 664 665 666 667
        begin.code.clear()
        end.code.clear()
        strides.code.clear()

        begin = list(self.infer.get_const_tensor_value(begin.layer).flatten())
        end = list(self.infer.get_const_tensor_value(end.layer).flatten())
J
jiangjiajun 已提交
668 669
        strides = list(
            self.infer.get_const_tensor_value(strides.layer).flatten())
J
jiangjiajun 已提交
670 671 672

        for i in range(len(strides)):
            assert strides[i] == 1
J
jiangjiajun 已提交
673 674 675
        param_attr = {"axes": range(len(begin)), "starts": begin, "ends": end}
        node.code.add_layer("slice", data.ref_name, node.output_name,
                            param_attr)
J
jiangjiajun 已提交
676 677 678

    def emit_resizenearestneighbor(self, node):
        data = node.inputs[0]
J
jiangjiajun 已提交
679 680
        resize_shape = node.inputs[1]
        resize_shape.code.clear()
J
jiangjiajun 已提交
681 682
        align_corners = node.get_attr('align_corners')

J
jiangjiajun 已提交
683 684
        resize_shape = list(self.infer.get_shape_tensor(resize_shape.layer))
        param_attr = {
J
jiangjiajun 已提交
685 686 687 688 689
            "align_corners": align_corners,
            "out_shape": resize_shape
        }
        node.code.add_layer("resize_nearest", data.ref_name, node.output_name,
                            param_attr)
J
jiangjiajun 已提交
690 691

    def emit_maximum(self, node):
J
jiangjiajun 已提交
692
        return self.elementwise(node, "max")
J
jiangjiajun 已提交
693 694

    def emit_minimum(self, node):
J
jiangjiajun 已提交
695
        return self.elementwise(node, "min")
J
jiangjiajun 已提交
696 697 698

    def emit_sigmoid(self, node):
        data = node.inputs[0]
J
jiangjiajun 已提交
699
        node.code.add_layer("sigmoid", data.ref_name, node.output_name)
J
jiangjiajun 已提交
700 701

    def emit_pack(self, node):
J
jiangjiajun 已提交
702 703 704
        inputs = [input.ref_name for input in node.inputs]
        inputs = "[{}]".format(", ".join(inputs))
        node.code.add_layer("stack", inputs, node.output_name)
J
jiangjiajun 已提交
705 706 707 708

    def emit_reshape(self, node):
        data = node.inputs[0]
        shape = node.inputs[1]
J
jiangjiajun 已提交
709 710 711 712 713
        input_shape_len = data.shape_dim_size
        output_shape = list(self.infer.get_tensor_shape(node.layer))

        shape = self.infer.get_shape_tensor(shape.layer, output_shape)

J
jiangjiajun 已提交
714
        reshape_param = {"shape": list(shape)}
J
jiangjiajun 已提交
715
        if node.data_format == NHWC and input_shape_len == 4:
J
jiangjiajun 已提交
716 717 718 719 720
            param_attr = {"perm": [0, 2, 3, 1]}
            node.code.add_layer("transpose", data.ref_name, node.output_name,
                                param_attr)
            node.code.add_layer("reshape", node.output_name, node.output_name,
                                reshape_param)
J
jiangjiajun 已提交
721
            if len(shape) == 4:
J
jiangjiajun 已提交
722 723 724
                param_attr = {"perm": [0, 3, 1, 2]}
                node.code.add_layer("transpose", node.output_name,
                                    node.output_name, param_attr)
J
jiangjiajun 已提交
725
        else:
J
jiangjiajun 已提交
726 727
            node.code.add_layer("reshape", data.ref_name, node.output_name,
                                reshape_param)
J
jiangjiajun 已提交
728 729 730 731 732 733

    def emit_conv2dbackpropinput(self, node):
        output_shape = node.inputs[0]
        kernel = node.inputs[1]
        data = node.inputs[2]
        if len(kernel.outputs) == 1:
J
jiangjiajun 已提交
734 735
            kernel.code.clear()
        output_shape.code.clear()
J
jiangjiajun 已提交
736
        padding_mode = node.get_attr("padding")
J
jiangjiajun 已提交
737 738 739 740 741
        strides = node.get_attr("strides")[2:4]
        k_shape = self.infer.get_tensor_shape(kernel.layer)
        k_h, k_w, k_num, channel = k_shape
        if node.data_format == NHWC:
            strides = node.get_attr("strides")[1:3]
J
jiangjiajun 已提交
742 743

        padding = [0, 0]
J
jiangjiajun 已提交
744 745 746
        if padding_mode == SAME:
            padding = [int(val) for val in [(k_h - strides[0]) / 2, \
                (k_w - strides[1]) / 2]]
J
jiangjiajun 已提交
747 748

        if kernel.layer_name in self.weights:
J
jiangjiajun 已提交
749 750 751 752 753
            weight = self.weights[kernel.layer_name]
            self.weights[kernel.layer_name] = numpy.transpose(
                weight, (3, 2, 0, 1))
            self.export_weights(self.weights[kernel.layer_name],
                                kernel.ref_name, self.save_dir)
J
jiangjiajun 已提交
754

J
jiangjiajun 已提交
755 756
        output_shape = list(self.infer.get_shape_tensor(output_shape.layer))
        if node.data_format == NHWC and len(output_shape) == 4:
J
jiangjiajun 已提交
757 758 759 760
            output_shape = [
                output_shape[0], output_shape[3], output_shape[1],
                output_shape[2]
            ]
J
jiangjiajun 已提交
761 762

        param_attr = {
J
jiangjiajun 已提交
763 764 765 766 767 768
            "num_filters": k_num,
            "filter_size": [k_h, k_w],
            "padding": padding,
            "stride": strides,
            "param_attr": "\'{}\'".format(kernel.ref_name),
            "bias_attr": False
J
jiangjiajun 已提交
769
        }
J
jiangjiajun 已提交
770 771
        node.code.add_layer("conv2d_transpose", data.ref_name,
                            node.output_name, param_attr)
J
jiangjiajun 已提交
772
        if padding_mode == SAME:
J
jiangjiajun 已提交
773 774 775
            param_attr = {"shape": list(output_shape)}
            node.code.add_layer("crop", node.output_name, node.output_name,
                                param_attr)
J
jiangjiajun 已提交
776 777 778 779 780

    def emit_depthwiseconv2dnative(self, node):
        data = node.inputs[0]
        kernel = node.inputs[1]
        if len(kernel.outputs) == 1:
J
jiangjiajun 已提交
781
            kernel.code.clear()
J
jiangjiajun 已提交
782 783

        padding_mode = node.get_attr("padding")
J
jiangjiajun 已提交
784 785 786
        strides = node.get_attr("strides")[2:4]
        k_shape = self.infer.get_tensor_shape(kernel.layer)
        input_shape = self.infer.get_tensor_shape(data.layer)
J
jiangjiajun 已提交
787
        input_h, input_w = input_shape[2:4]
J
jiangjiajun 已提交
788 789 790
        k_h, k_w, in_channels, channel_multiplier = k_shape
        if node.data_format == NHWC:
            strides = node.get_attr("strides")[1:3]
J
jiangjiajun 已提交
791 792 793 794
            input_h, input_w = input_shape[1:3]
        groups = channel_multiplier * in_channels

        if kernel.layer_name in self.weights:
J
jiangjiajun 已提交
795 796 797
            weight = self.weights[kernel.layer_name]
            self.weights[kernel.layer_name] = numpy.transpose(
                weight, (2, 3, 0, 1))
J
jiangjiajun 已提交
798 799
            self.export_weights(self.weights[kernel.layer_name],
                                kernel.ref_name, self.save_dir)
J
jiangjiajun 已提交
800
        conv_param = {
J
jiangjiajun 已提交
801 802 803 804 805 806
            "num_filters": in_channels,
            "filter_size": [k_h, k_w],
            "stride": strides,
            "groups": groups,
            "param_attr": "\'{}\'".format(kernel.ref_name),
            "bias_attr": False
J
jiangjiajun 已提交
807 808 809 810 811 812 813
        }
        if padding_mode == SAME:
            pad_h = self.compute_padding_size(input_h, k_h, strides[0])
            pad_w = self.compute_padding_size(input_w, k_w, strides[1])
            if len(set(pad_h)) == 1 and len(set(pad_w)) == 1:
                padding = [pad_h[0], pad_w[0]]
                conv_param["padding"] = padding
J
jiangjiajun 已提交
814 815
                node.code.add_layer("conv2d", data.ref_name, node.output_name,
                                    conv_param)
J
jiangjiajun 已提交
816
            else:
J
jiangjiajun 已提交
817 818 819 820 821
                pad_param = {"paddings": pad_h + pad_w}
                node.code.add_layer("pad2d", data.ref_name, node.output_name,
                                    pad_param)
                node.code.add_layer("conv2d", node.output_name,
                                    node.output_name, conv_param)
J
jiangjiajun 已提交
822
        else:
J
jiangjiajun 已提交
823 824
            node.code.add_layer("conv2d", data.ref_name, node.output_name,
                                conv_param)
J
modify  
jiangjiajun 已提交
825 826 827

    def emit_softmax(self, node):
        data = node.inputs[0]
J
jiangjiajun 已提交
828 829 830 831 832 833 834 835
        node.code.add_layer("softmax", data.ref_name, node.output_name)

    def emit_matmul(self, node):
        data0 = node.inputs[0]
        data1 = node.inputs[1]
        transpose_a = node.get_attr('transpose_a')
        transpose_b = node.get_attr('transpose_b')
        param_attr = {
J
jiangjiajun 已提交
836 837 838 839
            "x": data0.ref_name,
            "y": data1.ref_name,
            "transpose_x": transpose_a,
            "transpose_y": transpose_b
J
jiangjiajun 已提交
840 841 842 843 844 845 846 847 848 849 850
        }
        node.code.add_layer("matmul", None, node.output_name, param_attr)

    def emit_transpose(self, node):
        data = node.inputs[0]
        perm = node.inputs[1]
        perm.code.clear()
        perm = list(self.infer.get_shape_tensor(perm.layer))
        if node.data_format == NHWC and len(perm) == 4:
            if perm == [0, 3, 1, 2]:
                self.graph.set_data_format(node, NCHW)
J
jiangjiajun 已提交
851 852
                node.code.add_str("{} = {}".format(node.output_name,
                                                   data.ref_name))
J
jiangjiajun 已提交
853 854 855 856 857
            else:
                raise Exception("Unexpected situation in OP transpose")
        elif node.data_format == NCHW and len(perm) == 4:
            if perm == [0, 2, 3, 1]:
                self.graph.set_data_format(node, NHWC)
J
jiangjiajun 已提交
858 859
                node.code.add_str("{} = {}".format(node.output_name,
                                                   data.ref_name))
J
jiangjiajun 已提交
860 861 862
            else:
                raise Exception("Unexpected situation in OP transpose")
        else:
J
jiangjiajun 已提交
863 864 865
            param_attr = {"perm": perm}
            node.code.add_layer("transpose", data.ref_name, node.output_name,
                                param_attr)
J
jiangjiajun 已提交
866 867 868 869 870 871 872 873 874

    def emit_randomuniform(self, node):
        shape = node.inputs[0]
        shape = self.infer.get_shape_tensor(shape.layer)
        if node.data_format == NHWC and len(shape) == 4:
            shape = shape[[0, 3, 1, 2]]
        batch_index = list(numpy.argwhere(shape < 0).flatten())
        shape = list(shape)
        param_attr = {
J
jiangjiajun 已提交
875 876 877 878
            "shape": shape,
            "dtype": "\'float32\'",
            "min": 0.00001,
            "max": 0.99999
J
jiangjiajun 已提交
879 880 881 882
        }
        if len(batch_index) > 1:
            raise Exception("More than one dimension value less than zero")
        if len(batch_index) == 0:
J
jiangjiajun 已提交
883 884
            node.code.add_layer("uniform_random", None, node.output_name,
                                param_attr)
J
jiangjiajun 已提交
885 886
        else:
            param_attr["input_dim_idx"] = batch_index[0]
J
jiangjiajun 已提交
887 888 889
            node.code.add_layer("uniform_random_batch_size_like",
                                self.batch_node.ref_name, node.output_name,
                                param_attr)
J
jiangjiajun 已提交
890 891 892 893 894 895 896 897 898 899 900

    def emit_floor(self, node):
        data = node.inputs[0]
        node.code.add_layer("floor", data.ref_name, node.output_name)

    def emit_exp(self, node):
        data = node.inputs[0]
        node.code.add_layer("exp", data.ref_name, node.output_name)

    def emit_floordiv(self, node):
        self.emit_div(node)
J
jiangjiajun 已提交
901
        param = {"dtype": "\'float32\'"}
J
jiangjiajun 已提交
902 903 904 905 906 907 908 909 910
        node.code.add_layer("cast", node.output_name, node.output_name, param)
        node.code.add_layer("floor", node.output_name, node.output_name)

    def emit_div(self, node):
        data1 = node.inputs[0]
        data2 = node.inputs[1]
        axis = self.get_axis(data1, data2)
        data1_shape = self.infer.get_tensor_shape(data1.layer)
        data2_shape = self.infer.get_tensor_shape(data2.layer)
J
jiangjiajun 已提交
911
        div_param = {"x": data1.ref_name, "y": data2.ref_name, "axis": axis}
J
jiangjiajun 已提交
912 913 914
        if len(data1_shape) == 4 and len(data2_shape) == 4 \
            and node.data_format == NHWC:
            if data1_shape[-1] != data2_shape[-1]:
J
jiangjiajun 已提交
915
                perm = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
916 917 918 919 920
                node.code.add_layer("transpose", data1.ref_name, "temp1", perm)
                node.code.add_layer("transpose", data2.ref_name, "temp2", perm)
                div_param["x"] = "temp1"
                div_param["y"] = "temp2"
                div_param["axis"] = -1
J
jiangjiajun 已提交
921 922
        node.code.add_layer("elementwise_div", None, node.output_name,
                            div_param)
J
jiangjiajun 已提交
923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944

    def emit_realdiv(self, node):
        return self.emit_div(node)

    def emit_slice(self, node):
        data = node.inputs[0]
        begin = node.inputs[1]
        size = node.inputs[2]
        begin.code.clear()
        size.code.clear()
        begin = list(self.infer.get_shape_tensor(begin.layer))
        size = list(self.infer.get_shape_tensor(size.layer))

        input_shape = self.infer.get_tensor_shape(data.layer)
        if len(numpy.argwhere(input_shape < 0).flatten()) > 1:
            input_shape = list(self.infer.get_tensor_shape(data.layer))

        assert len(begin) == len(input_shape) and len(size) == len(input_shape)

        if node.data_format == NHWC and len(input_shape) == 4:
            begin = [begin[0], begin[3], begin[1], begin[2]]
            size = [size[0], size[3], size[1], size[2]]
J
jiangjiajun 已提交
945 946 947
            input_shape = [
                input_shape[0], input_shape[3], input_shape[1], input_shape[2]
            ]
J
jiangjiajun 已提交
948 949 950 951

        for i in range(len(size)):
            if size[i] < 0:
                size[i] = input_shape[i] - begin[i]
J
jiangjiajun 已提交
952 953 954
        param_attr = {"shape": size, "offsets": begin}
        node.code.add_layer("crop", data.ref_name, node.output_name,
                            param_attr)
J
jiangjiajun 已提交
955 956 957 958 959

    def emit_sum(self, node):
        data = node.inputs[0]
        reduce_idx = node.inputs[1]
        reduce_idx.code.clear()
J
jiangjiajun 已提交
960 961
        idxs = tensor_util.MakeNdarray(
            reduce_idx.layer.attr['value'].tensor).astype('int32').flatten()
J
jiangjiajun 已提交
962 963 964 965 966 967
        data_shape_len = data.shape_dim_size
        keep_dims = node.layer.attr['keep_dims'].b
        if node.data_format == NHWC and data_shape_len == 4:
            for i in range(idxs.shape[0]):
                if idxs[i] > 0:
                    idxs[i] = (idxs[i] + 1) % 4 + int((idxs[i] + 1) / 4)
J
jiangjiajun 已提交
968 969 970
        param = {"dim": list(idxs), "keep_dim": keep_dims}
        node.code.add_layer("reduce_sum", data.ref_name, node.output_name,
                            param)
J
jiangjiajun 已提交
971 972 973 974 975

    def emit_max(self, node):
        data = node.inputs[0]
        reduce_idx = node.inputs[1]
        reduce_idx.code.clear()
J
jiangjiajun 已提交
976 977
        idxs = tensor_util.MakeNdarray(
            reduce_idx.layer.attr['value'].tensor).astype('int32').flatten()
J
jiangjiajun 已提交
978 979 980 981 982 983
        data_shape_len = data.shape_dim_size
        keep_dims = node.layer.attr['keep_dims'].b
        if node.data_format == NHWC and data_shape_len == 4:
            for i in range(idxs.shape[0]):
                if idxs[i] > 0:
                    idxs[i] = (idxs[i] + 1) % 4 + int((idxs[i] + 1) / 4)
J
jiangjiajun 已提交
984 985 986
        param = {"dim": list(idxs), "keep_dim": keep_dims}
        node.code.add_layer("reduce_max", data.ref_name, node.output_name,
                            param)
J
jiangjiajun 已提交
987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002

    def emit_fill(self, node):
        shape = node.inputs[0]
        shape.code.clear()
        value = node.inputs[1]
        value.code.clear()

        shape = list(self.infer.get_shape_tensor(shape.layer))
        value = list(self.infer.get_const_tensor_value(value.layer).flatten())
        assert len(value) == 1
        value = value[0]

        if node.data_format == NHWC and len(shape) == 4:
            shape = [shape[0], shape[3], shape[1], shape[2]]

        param = {
J
jiangjiajun 已提交
1003 1004 1005
            "shape": shape,
            "dtype": "\'{}\'".format(value.dtype),
            "value": value
J
jiangjiajun 已提交
1006 1007
        }
        if shape[0] < 0:
J
jiangjiajun 已提交
1008 1009 1010
            node.code.add_layer("fill_constant_batch_size_like",
                                self.batch_node.ref_name, node.output_name,
                                param)
J
jiangjiajun 已提交
1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033
        else:
            node.code.add_layer("fill_constant", None, node.output_name, param)

    def emit_range(self, node):
        start = node.inputs[0]
        end = node.inputs[1]
        delta = node.inputs[2]
        start.code.clear()
        end.code.clear()
        delta.code.clear()

        start = self.infer.get_const_tensor_value(start.layer)
        end = self.infer.get_const_tensor_value(end.layer)
        delta = self.infer.get_const_tensor_value(delta.layer)
        np_code = "np_array = numpy.arange({}, {}, {}).astype(\'{}\')".format(
            start, end, delta, delta.dtype)
        node.code.add_str(np_code)
        node.code.add_layer("assign", "np_array", node.output_name)

    def emit_tile(self, node):
        data = node.inputs[0]
        expand_times = node.inputs[1]
        expand_times.code.clear()
J
jiangjiajun 已提交
1034 1035 1036
        expand_times = list(
            self.infer.get_const_tensor_value(expand_times.layer))
        param = {"expand_times": expand_times}
J
jiangjiajun 已提交
1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048
        node.code.add_layer("expand", data.ref_name, node.output_name, param)

    def emit_splitv(self, node):
        data = node.inputs[0]
        num_sections = node.inputs[1]
        num_sections.code.clear()
        split_dim = node.inputs[2]
        split_dim.code.clear()
        num_sections = self.infer.get_const_tensor_value(num_sections.layer)
        split_dim = self.infer.get_const_tensor_value(split_dim.layer)
        input_shape = self.infer.get_tensor_shape(data.layer)
        if split_dim < 0:
J
jiangjiajun 已提交
1049 1050
            split_dim += len(input_shape)

J
jiangjiajun 已提交
1051 1052 1053 1054
        index = numpy.argwhere(num_sections < 0).flatten()
        if index.shape[0] > 1:
            raise Exception("More than one dimension less than 0")
        if index.shape[0] == 1:
J
jiangjiajun 已提交
1055 1056 1057
            num_sections[index[0]] = input_shape[split_dim] - numpy.sum(
                num_sections) + num_sections[index[0]]
        param = {"num_or_sections": list(num_sections), "dim": split_dim}
J
jiangjiajun 已提交
1058
        node.code.add_layer("split", data.ref_name, node.output_name, param)
J
jiangjiajun 已提交
1059 1060 1061 1062 1063 1064 1065 1066

    def emit_expanddims(self, node):
        data = node.inputs[0]
        dim = node.inputs[1]
        dim.code.clear()
        dim = self.infer.get_const_tensor_value(dim.layer)
        param = {"axes":[dim]}
        node.code.add_layer("unsqueeze", data.ref_name, node.output_name, param)
J
jiangjiajun 已提交
1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077

    def emit_cast(self, node):
        data = node.inputs[0]
        dtype_map = {1: "float32", 3: "int32", 9: "int64"}
        dtype = node.get_attr("DstT")
        if dtype in dtype_map:
            dtype = dtype_map[dtype]
        else:
            raise Exception("Unknow dtype: {}".format(dtype))
        param = {"dtype":"\'{}\'".format(dtype)}
        node.code.add_layer("cast", data.ref_name, node.output_name, param)