tensorflow_converter.py 40.3 KB
Newer Older
L
Liangliang He 已提交
1
# Copyright 2018 The MACE Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
liutuo 已提交
15 16 17 18
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

19 20
import math
import numpy as np
21
import six
22
import tensorflow as tf
23
from enum import Enum
24

L
liutuo 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38
from python.py_proto import mace_pb2
from . import base_converter
from .base_converter import PoolingType
from .base_converter import PaddingMode
from .base_converter import ActivationType
from .base_converter import EltwiseType
from .base_converter import PadType
from .base_converter import FrameworkType
from .base_converter import ReduceType
from .base_converter import DataFormat
from .base_converter import MaceOp
from .base_converter import MaceKeyword
from .base_converter import ConverterUtil
from python.utils.util import mace_check
39 40

from tensorflow.core.framework import tensor_shape_pb2
41
from tensorflow.tools.graph_transforms import TransformGraph
42 43 44 45 46 47 48

tf_padding_str = 'padding'
tf_strides_str = 'strides'
tf_dilations_str = 'dilations'
tf_data_format_str = 'data_format'
tf_kernel_str = 'ksize'
tf_epsilon_str = 'epsilon'
49
tf_alpha_str = 'alpha'
50
tf_is_training_str = 'is_training'
51 52
tf_align_corners = 'align_corners'
tf_block_size = 'block_size'
53 54
tf_squeeze_dims = 'squeeze_dims'
tf_axis = 'axis'
55

56 57 58 59 60 61 62 63 64 65
TFSupportedOps = [
    'Conv2D',
    'DepthwiseConv2dNative',
    'Conv2DBackpropInput',
    'BiasAdd',
    'Add',
    'Sub',
    'Mul',
    'Div',
    'Min',
L
liutuo 已提交
66
    'Minimum',
67
    'Max',
L
liutuo 已提交
68
    'Maximum',
69 70
    'Neg',
    'Abs',
L
liutuo 已提交
71
    'Pow',
72
    'RealDiv',
73
    'Square',
74
    'SquaredDifference',
75
    'Rsqrt',
L
liutuo 已提交
76
    'Sum',
李寅 已提交
77
    'Equal',
78
    'Relu',
79
    'LeakyRelu',
80 81 82
    'Relu6',
    'Tanh',
    'Sigmoid',
Y
yejianwu 已提交
83
    'Fill',
84 85 86
    'FusedBatchNorm',
    'AvgPool',
    'MaxPool',
87
    'ExpandDims',
88 89
    'Squeeze',
    'MatMul',
90
    'BatchMatMul',
91 92 93 94 95
    'Identity',
    'Reshape',
    'Shape',
    'Transpose',
    'Softmax',
赵奇可 已提交
96
    'ResizeBicubic',
97
    'ResizeBilinear',
L
lichao18 已提交
98
    'ResizeNearestNeighbor',
99 100 101 102 103 104 105 106
    'Placeholder',
    'SpaceToBatchND',
    'BatchToSpaceND',
    'DepthToSpace',
    'SpaceToDepth',
    'Pad',
    'ConcatV2',
    'Mean',
107
    'Prod',
108
    'Const',
109
    'Gather',
L
liuqi 已提交
110
    'GatherV2',
111 112
    'StridedSlice',
    'Slice',
113
    'ReverseV2',
114 115
    'Stack',
    'Pack',
Y
yejianwu 已提交
116 117
    'Unstack',
    'Unpack',
李寅 已提交
118
    'Cast',
李寅 已提交
119
    'ArgMax',
L
liuqi 已提交
120
    'Split',
李寅 已提交
121
    'FakeQuantWithMinMaxVars',
122
    'FakeQuantWithMinMaxArgs',
W
w-adamski 已提交
123
    'FloorDiv',
124
    'Sqrt',
125
    'MirrorPad',
126
    'Cumsum',
W
Wiktor Adamski 已提交
127
    'OneHot',
128 129 130 131
]

TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str)

L
liuqi 已提交
132 133
TFSupportedOps = [six.b(op) for op in TFSupportedOps]

134 135 136 137 138 139 140 141 142 143
TFTransformGraphOptions = [
    'strip_unused_nodes',
    'remove_nodes(op=Identity, op=CheckNumerics)',
    'fold_constants(ignore_errors=true)',
    'fold_batch_norms',
    'fold_old_batch_norms',
    'remove_control_dependencies',
    'strip_unused_nodes',
    'sort_by_execution_order'
]
144

145 146 147 148 149 150 151 152 153 154 155

class TensorflowConverter(base_converter.ConverterInterface):
    """A class for convert tensorflow frozen model to mace model.
    We use tensorflow engine to infer op output shapes, since they are of
    too many types."""

    padding_mode = {
        'VALID': PaddingMode.VALID,
        'SAME': PaddingMode.SAME,
        'FULL': PaddingMode.FULL
    }
156 157
    padding_mode = {six.b(k): v for k, v in six.iteritems(padding_mode)}

158
    pooling_type_mode = {
159 160
        TFOpType.AvgPool.name: PoolingType.AVG,
        TFOpType.MaxPool.name: PoolingType.MAX
161
    }
162

163
    eltwise_type = {
164 165 166 167
        TFOpType.Add.name: EltwiseType.SUM,
        TFOpType.Sub.name: EltwiseType.SUB,
        TFOpType.Mul.name: EltwiseType.PROD,
        TFOpType.Div.name: EltwiseType.DIV,
L
liutuo 已提交
168
        TFOpType.Minimum.name: EltwiseType.MIN,
L
liutuo 已提交
169
        TFOpType.Maximum.name: EltwiseType.MAX,
170 171
        TFOpType.Neg.name: EltwiseType.NEG,
        TFOpType.Abs.name: EltwiseType.ABS,
L
liutuo 已提交
172
        TFOpType.Pow.name: EltwiseType.POW,
173
        TFOpType.RealDiv.name: EltwiseType.DIV,
W
w-adamski 已提交
174
        TFOpType.FloorDiv.name: EltwiseType.FLOOR_DIV,
175
        TFOpType.SquaredDifference.name: EltwiseType.SQR_DIFF,
176
        TFOpType.Square.name: EltwiseType.POW,
李寅 已提交
177
        TFOpType.Rsqrt.name: EltwiseType.POW,
178
        TFOpType.Sqrt.name: EltwiseType.POW,
李寅 已提交
179
        TFOpType.Equal.name: EltwiseType.EQUAL,
180
    }
181

182
    activation_type = {
183 184 185
        TFOpType.Relu.name: ActivationType.RELU,
        TFOpType.Relu6.name: ActivationType.RELUX,
        TFOpType.Tanh.name: ActivationType.TANH,
186 187
        TFOpType.Sigmoid.name: ActivationType.SIGMOID,
        TFOpType.LeakyRelu.name: ActivationType.LEAKYRELU,
188 189
    }

190 191 192 193 194
    reduce_math_type = {
        TFOpType.Min.name: ReduceType.MIN,
        TFOpType.Max.name: ReduceType.MAX,
        TFOpType.Mean.name: ReduceType.MEAN,
        TFOpType.Prod.name: ReduceType.PROD,
L
liutuo 已提交
195
        TFOpType.Sum.name: ReduceType.SUM,
196 197
    }

198
    pad_type = {
W
Wiktor Adamski 已提交
199 200
        'CONSTANT':  PadType.CONSTANT,
        'REFLECT':   PadType.REFLECT,
201 202 203
        'SYMMETRIC': PadType.SYMMETRIC
    }

204 205
    def __init__(self, option, src_model_file):
        self._op_converters = {
206 207 208 209 210 211 212 213
            TFOpType.Conv2D.name: self.convert_conv2d,
            TFOpType.DepthwiseConv2dNative.name: self.convert_conv2d,
            TFOpType.Conv2DBackpropInput.name: self.convert_conv2d,
            TFOpType.BiasAdd.name: self.convert_biasadd,
            TFOpType.Add.name: self.convert_add,
            TFOpType.Sub.name: self.convert_elementwise,
            TFOpType.Mul.name: self.convert_elementwise,
            TFOpType.Div.name: self.convert_elementwise,
L
liutuo 已提交
214
            TFOpType.Minimum.name: self.convert_elementwise,
L
liutuo 已提交
215
            TFOpType.Maximum.name: self.convert_elementwise,
216 217
            TFOpType.Neg.name: self.convert_elementwise,
            TFOpType.Abs.name: self.convert_elementwise,
L
liutuo 已提交
218
            TFOpType.Pow.name: self.convert_elementwise,
219 220
            TFOpType.RealDiv.name: self.convert_elementwise,
            TFOpType.SquaredDifference.name: self.convert_elementwise,
221 222
            TFOpType.Square.name: self.convert_elementwise,
            TFOpType.Rsqrt.name: self.convert_elementwise,
李寅 已提交
223
            TFOpType.Equal.name: self.convert_elementwise,
224 225 226 227
            TFOpType.Min.name: self.convert_reduce,
            TFOpType.Max.name: self.convert_reduce,
            TFOpType.Mean.name: self.convert_reduce,
            TFOpType.Prod.name: self.convert_reduce,
228
            TFOpType.Relu.name: self.convert_activation,
229
            TFOpType.LeakyRelu.name: self.convert_activation,
230 231 232
            TFOpType.Relu6.name: self.convert_activation,
            TFOpType.Tanh.name: self.convert_activation,
            TFOpType.Sigmoid.name: self.convert_activation,
Y
yejianwu 已提交
233
            TFOpType.Fill.name: self.convert_fill,
234 235 236 237
            TFOpType.FusedBatchNorm.name: self.convert_fused_batchnorm,
            TFOpType.AvgPool.name: self.convert_pooling,
            TFOpType.MaxPool.name: self.convert_pooling,
            TFOpType.MatMul.name: self.convert_matmul,
238
            TFOpType.BatchMatMul.name: self.convert_matmul,
239 240
            TFOpType.Identity.name: self.convert_identity,
            TFOpType.Reshape.name: self.convert_reshape,
241
            TFOpType.Shape.name: self.convert_shape,
242
            TFOpType.ExpandDims.name: self.convert_expand_dims,
243
            TFOpType.Squeeze.name: self.convert_squeeze,
244 245
            TFOpType.Transpose.name: self.convert_transpose,
            TFOpType.Softmax.name: self.convert_softmax,
赵奇可 已提交
246
            TFOpType.ResizeBicubic.name: self.convert_resize_bicubic,
247
            TFOpType.ResizeBilinear.name: self.convert_resize_bilinear,
L
lichao18 已提交
248
            TFOpType.ResizeNearestNeighbor.name: self.convert_resize_nearest_neighbor,  # noqa
249 250 251 252 253 254 255 256
            TFOpType.Placeholder.name: self.convert_nop,
            TFOpType.SpaceToBatchND.name: self.convert_space_batch,
            TFOpType.BatchToSpaceND.name: self.convert_space_batch,
            TFOpType.DepthToSpace.name: self.convert_space_depth,
            TFOpType.SpaceToDepth.name: self.convert_space_depth,
            TFOpType.Pad.name: self.convert_pad,
            TFOpType.ConcatV2.name: self.convert_concat,
            TFOpType.Const.name: self.convert_nop,
257
            TFOpType.Gather.name: self.convert_gather,
L
liuqi 已提交
258
            TFOpType.GatherV2.name: self.convert_gather,
259 260
            TFOpType.StridedSlice.name: self.convert_stridedslice,
            TFOpType.Slice.name: self.convert_slice,
261
            TFOpType.ReverseV2.name: self.convert_reverse,
262
            TFOpType.Pack.name: self.convert_stack,
李寅 已提交
263
            TFOpType.Stack.name: self.convert_stack,
Y
yejianwu 已提交
264 265
            TFOpType.Unpack.name: self.convert_unstack,
            TFOpType.Unstack.name: self.convert_unstack,
李寅 已提交
266 267
            TFOpType.Cast.name: self.convert_cast,
            TFOpType.ArgMax.name: self.convert_argmax,
L
liuqi 已提交
268
            TFOpType.Split.name: self.convert_split,
李寅 已提交
269
            TFOpType.FakeQuantWithMinMaxVars.name: self.convert_fake_quantize,
270
            TFOpType.FakeQuantWithMinMaxArgs.name: self.convert_fake_quantize,
W
w-adamski 已提交
271
            TFOpType.FloorDiv.name: self.convert_elementwise,
272
            TFOpType.Sqrt.name: self.convert_elementwise,
273
            TFOpType.MirrorPad.name: self.convert_pad,
274
            TFOpType.Cumsum.name: self.convert_cumsum,
W
Wiktor Adamski 已提交
275
            TFOpType.OneHot.name: self.convert_one_hot,
L
liutuo 已提交
276
            TFOpType.Sum.name: self.convert_reduce,
277 278 279
        }
        self._option = option
        self._mace_net_def = mace_pb2.NetDef()
280
        ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.HWIO)
281
        ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NHWC)
282 283

        # import tensorflow graph
284 285 286
        tf_graph_def = tf.GraphDef()
        with tf.gfile.Open(src_model_file, 'rb') as f:
            tf_graph_def.ParseFromString(f.read())
287 288

        self._placeholders = {}
L
liyin 已提交
289 290
        self._skip_tensor = set()
        self._output_shape = {}
291

292
        print("Run transform_graph: %s" % TFTransformGraphOptions)
李寅 已提交
293
        try:
294
            print("output keys: ", option.output_nodes.keys())
李寅 已提交
295 296 297
            transformed_graph_def = TransformGraph(tf_graph_def,
                                                   option.input_nodes.keys(),
                                                   option.output_nodes.keys(),
298
                                                   TFTransformGraphOptions)
李寅 已提交
299 300 301
        except Exception as ex:
            print("Failed to transform graph using tf tool: %s" % ex)
            transformed_graph_def = tf_graph_def
302

303 304 305 306 307 308 309 310 311 312
        # To check optimized model, uncomment following code.
        # tf.io.write_graph(
        #     transformed_graph_def,
        #     ".",
        #     os.path.basename(src_model_file)[:-3] + "_opt.pb",
        #     as_text=False
        # )

        self.add_shape_info(transformed_graph_def)

313 314
        with tf.Session() as session:
            with session.graph.as_default() as graph:
315
                tf.import_graph_def(transformed_graph_def, name='')
316
                self._tf_graph = graph
L
liyin 已提交
317
                self.update_output_shapes(session)
318

L
liyin 已提交
319 320 321 322 323 324 325 326
        # we have polluted graph with 'shape' ops, so reset it and reload it
        # again
        tf.reset_default_graph()

        with tf.Session() as session:
            with session.graph.as_default() as graph:
                tf.import_graph_def(transformed_graph_def, name='')
                self._tf_graph = graph
327

328 329
    def run(self):
        with tf.Session() as session:
330
            self.convert_ops(session)
331 332 333 334 335 336

        self.replace_input_output_tensor_name()
        return self._mace_net_def

    def replace_input_output_tensor_name(self):
        for op in self._mace_net_def.op:
337
            for i in six.moves.range(len(op.input)):
338 339
                if op.input[i][-2:] == ':0':
                    op_name = op.input[i][:-2]
李寅 已提交
340 341
                    if op_name in self._option.input_nodes \
                            or op_name in self._option.output_nodes:
342
                        op.input[i] = op_name
343
            for i in six.moves.range(len(op.output)):
344 345 346 347 348 349 350
                if op.output[i][-2:] == ':0':
                    op_name = op.output[i][:-2]
                    if op_name in self._option.output_nodes:
                        op.output[i] = op_name

    def add_shape_info(self, tf_graph_def):
        for node in tf_graph_def.node:
李寅 已提交
351 352 353
            for input_node in self._option.input_nodes.values():
                if node.name == input_node.name \
                        or node.name + ':0' == input_node.name:
354 355 356 357 358 359
                    input_shape = input_node.shape
                    if input_node.data_format == DataFormat.OIHW \
                            and len(input_shape) == 4:
                        # OIHW -> HWIO
                        input_shape = [input_shape[2], input_shape[3],
                                       input_shape[1], input_shape[0]]
李寅 已提交
360 361 362
                    del node.attr['shape'].shape.dim[:]
                    node.attr['shape'].shape.dim.extend([
                        tensor_shape_pb2.TensorShapeProto.Dim(size=i) for i in
363
                        input_shape
李寅 已提交
364
                    ])
365
                    self._placeholders[node.name + ':0'] = \
366
                        np.zeros(shape=input_shape, dtype=float)
367 368 369 370 371 372 373 374 375

    @staticmethod
    def get_scope(tensor_name):
        idx = tensor_name.rfind('/')
        if idx == -1:
            return tensor_name
        else:
            return tensor_name[:idx]

376
    def update_output_shapes(self, sess):
L
liyin 已提交
377 378 379 380 381 382 383 384
        tensors = []
        shape_tensors = []
        for tf_op in self._tf_graph.get_operations():
            for output in tf_op.outputs:
                tensors.append(output.name)
                shape_tensors.append(tf.shape(output))

        tensor_shapes = sess.run(shape_tensors,
385
                                 feed_dict=self._placeholders)
L
liyin 已提交
386 387
        for i in range(len(tensors)):
            self._output_shape[tensors[i]] = tensor_shapes[i]
388 389

    def convert_ops(self, sess):
390 391 392 393 394
        for tf_op in self._tf_graph.get_operations():
            mace_check(tf_op.type in self._op_converters,
                       "Mace does not support tensorflow op type %s yet"
                       % tf_op.type)
            self._op_converters[tf_op.type](tf_op)
L
liyin 已提交
395

396
        self.convert_tensors()
397

398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418
    def convert_tensors(self):
        for tf_op in self._tf_graph.get_operations():
            if tf_op.type != TFOpType.Const.name:
                continue
            output_name = tf_op.outputs[0].name
            if output_name not in self._skip_tensor:
                tensor = self._mace_net_def.tensors.add()
                tensor.name = tf_op.outputs[0].name
                tf_tensor = tf_op.outputs[0].eval()
                tensor.dims.extend(list(tf_tensor.shape))

                tf_dt = tf_op.get_attr('dtype')
                if tf_dt == tf.float32:
                    tensor.data_type = mace_pb2.DT_FLOAT
                    tensor.float_data.extend(tf_tensor.astype(np.float32).flat)
                elif tf_dt == tf.int32:
                    tensor.data_type = mace_pb2.DT_INT32
                    tensor.int32_data.extend(tf_tensor.astype(np.int32).flat)
                else:
                    mace_check(False,
                               "Not supported tensor type: %s" % tf_dt.name)
419 420 421 422 423 424 425 426

    def add_tensor(self, name, shape, data_type, value):
        tensor = self._mace_net_def.tensors.add()
        tensor.name = name
        tensor.dims.extend(list(shape))
        tensor.data_type = data_type
        tensor.float_data.extend(value.flat)

427 428
    # this function tries to infer tensor shape, but some dimension shape
    # may be undefined due to variance of input length
L
liyin 已提交
429 430 431 432
    def infer_tensor_shape(self, tensor, output_shape=None):
        shape = None
        if tensor.name in self._output_shape:
            shape = self._output_shape[tensor.name]
433
        else:
L
liyin 已提交
434 435 436 437 438 439
            shape = tensor.shape.as_list()

        if output_shape:
            output_shape.dims.extend(shape)

        return shape
440

441 442 443 444 445 446 447 448 449 450 451
    def convert_nop(self, tf_op):
        pass

    def convert_general_op(self, tf_op):
        op = self._mace_net_def.op.add()
        op.name = tf_op.name
        op.type = tf_op.type
        op.input.extend([tf_input.name for tf_input in tf_op.inputs])
        op.output.extend([tf_output.name for tf_output in tf_op.outputs])
        for tf_output in tf_op.outputs:
            output_shape = op.output_shape.add()
L
liyin 已提交
452
            self.infer_tensor_shape(tf_output, output_shape)
453

李寅 已提交
454 455 456 457 458 459 460 461 462 463 464
        data_type_arg = op.arg.add()
        data_type_arg.name = 'T'
        try:
            dtype = tf_op.get_attr('T')
            if dtype == tf.int32:
                data_type_arg.i = mace_pb2.DT_INT32
            elif dtype == tf.float32:
                data_type_arg.i = self._option.data_type
            else:
                mace_check(False, "data type %s not supported" % dtype)
        except ValueError:
李滨 已提交
465 466 467 468 469 470 471 472 473 474
            try:
                dtype = tf_op.get_attr('SrcT')
                if dtype == tf.int32 or dtype == tf.bool:
                    data_type_arg.i = mace_pb2.DT_INT32
                elif dtype == tf.float32:
                    data_type_arg.i = self._option.data_type
                else:
                    mace_check(False, "data type %s not supported" % dtype)
            except ValueError:
                data_type_arg.i = self._option.data_type
李寅 已提交
475

L
liutuo 已提交
476 477 478 479
        framework_type_arg = op.arg.add()
        framework_type_arg.name = MaceKeyword.mace_framework_type_str
        framework_type_arg.i = FrameworkType.TENSORFLOW.value

480 481 482 483 484 485 486 487 488 489
        ConverterUtil.add_data_format_arg(op, DataFormat.NHWC)

        return op

    def convert_identity(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = 'Identity'

    def convert_conv2d(self, tf_op):
        op = self.convert_general_op(tf_op)
490
        if tf_op.type == TFOpType.DepthwiseConv2dNative.name:
491
            op.type = MaceOp.DepthwiseConv2d.name
492
        elif tf_op.type == TFOpType.Conv2DBackpropInput.name:
493 494 495 496 497 498 499 500 501 502 503 504 505
            op.type = MaceOp.Deconv2D.name
        else:
            op.type = MaceOp.Conv2D.name

        padding_arg = op.arg.add()
        padding_arg.name = MaceKeyword.mace_padding_str
        padding_arg.i = self.padding_mode[tf_op.get_attr(tf_padding_str)].value
        strides_arg = op.arg.add()
        strides_arg.name = MaceKeyword.mace_strides_str
        strides_arg.ints.extend(tf_op.get_attr(tf_strides_str)[1:3])
        if op.type != MaceOp.Deconv2D.name:
            dilation_arg = op.arg.add()
            dilation_arg.name = MaceKeyword.mace_dilations_str
李寅 已提交
506 507 508 509 510
            try:
                dilation_val = tf_op.get_attr(tf_dilations_str)[1:3]
            except ValueError:
                dilation_val = [1, 1]
            dilation_arg.ints.extend(dilation_val)
L
liutuo 已提交
511
        else:
L
liutuo 已提交
512 513 514 515 516 517
            try:
                dilation_val = tf_op.get_attr(tf_dilations_str)[1:3]
            except ValueError:
                dilation_val = [1, 1]
            mace_check(dilation_val[0] == 1 and dilation_val[1] == 1,
                       "Mace only supports dilation == 1 conv2d_transpose.")
L
liutuo 已提交
518 519 520 521 522 523
            mace_check(len(tf_op.inputs) >= 3,
                       "deconv should have (>=) 3 inputs.")
            del op.input[:]
            op.input.extend([tf_op.inputs[2].name,
                             tf_op.inputs[1].name,
                             tf_op.inputs[0].name])
524 525 526 527 528 529 530 531 532

    def convert_elementwise(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Eltwise.name

        type_arg = op.arg.add()
        type_arg.name = MaceKeyword.mace_element_type_str
        type_arg.i = self.eltwise_type[tf_op.type].value

533 534
        def check_is_scalar(tf_op):
            if len(tf_op.inputs) == 1:
L
liyin 已提交
535
                return len(self.infer_tensor_shape(tf_op.inputs[0])) == 0
536
            elif len(tf_op.inputs) == 2:
L
liyin 已提交
537 538
                return len(self.infer_tensor_shape(tf_op.inputs[0])) == 0 and \
                       len(self.infer_tensor_shape(tf_op.inputs[1])) == 0
539 540 541 542 543

        if check_is_scalar(tf_op):
            op.type = MaceOp.ScalarMath.name
        else:
            op.type = MaceOp.Eltwise.name
544
        if tf_op.type == TFOpType.Square:
L
liutuo 已提交
545
            value_arg = op.arg.add()
546
            value_arg.name = MaceKeyword.mace_scalar_input_str
547 548
            value_arg.f = 2.0
        elif tf_op.type == TFOpType.Rsqrt:
L
liutuo 已提交
549
            value_arg = op.arg.add()
550
            value_arg.name = MaceKeyword.mace_scalar_input_str
551
            value_arg.f = -0.5
552 553 554 555
        elif tf_op.type == TFOpType.Sqrt:
            value_arg = op.arg.add()
            value_arg.name = MaceKeyword.mace_scalar_input_str
            value_arg.f = 0.5
556 557 558

        if type_arg.i != EltwiseType.NEG.value \
                and type_arg.i != EltwiseType.ABS.value:
李寅 已提交
559 560 561 562 563 564
            try:
                def is_commutative(eltwise_type):
                    return EltwiseType(eltwise_type) in [
                        EltwiseType.SUM, EltwiseType.PROD,
                        EltwiseType.MAX, EltwiseType.MIN]

L
liyin 已提交
565 566 567
                if (len(tf_op.inputs) > 1 and
                        len(self.infer_tensor_shape(tf_op.inputs[1])) == 0 and
                        tf_op.inputs[1].op.type == TFOpType.Const.name):
李寅 已提交
568 569
                    scalar = tf_op.inputs[1].eval().astype(np.float32)
                    value_arg = op.arg.add()
570
                    value_arg.name = MaceKeyword.mace_scalar_input_str
李寅 已提交
571 572
                    value_arg.f = scalar
                    self._skip_tensor.add(tf_op.inputs[1].name)
573
                    value_index_arg = op.arg.add()
574
                    value_index_arg.name = \
575 576 577
                        MaceKeyword.mace_scalar_input_index_str
                    value_index_arg.i = 1
                    self._skip_tensor.add(tf_op.inputs[1].name)
李寅 已提交
578
                    del op.input[1]
L
liyin 已提交
579
                elif len(self.infer_tensor_shape(tf_op.inputs[0])) == 0 and \
580
                        tf_op.inputs[0].op.type == TFOpType.Const.name and \
李寅 已提交
581 582 583
                        is_commutative(type_arg.i):
                    scalar = tf_op.inputs[0].eval().astype(np.float32)
                    value_arg = op.arg.add()
584
                    value_arg.name = MaceKeyword.mace_scalar_input_str
李寅 已提交
585
                    value_arg.f = scalar
586
                    value_index_arg = op.arg.add()
587
                    value_index_arg.name = \
588 589
                        MaceKeyword.mace_scalar_input_index_str
                    value_index_arg.i = 0
李寅 已提交
590 591 592 593
                    self._skip_tensor.add(tf_op.inputs[0].name)
                    del op.input[0]
            except tf.errors.InvalidArgumentError:
                pass
L
liutuo 已提交
594

595 596 597 598
    def convert_biasadd(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.BiasAdd.name

W
Wiktor Adamski 已提交
599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621
    def convert_one_hot(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.OneHot.name

        depth_arg = op.arg.add()
        depth_arg.name = 'depth'
        depth_arg.i = tf_op.inputs[1].eval().astype(np.int32)

        on_value_arg = op.arg.add()
        on_value_arg.name = 'on_value'
        on_value_arg.f = tf_op.inputs[2].eval().astype(np.float32)

        off_value_arg = op.arg.add()
        off_value_arg.name = 'off_value'
        off_value_arg.f = tf_op.inputs[3].eval().astype(np.float32)

        axis_arg = op.arg.add()
        axis_arg.name = tf_axis
        axis_arg.i = tf_op.get_attr(tf_axis)

        self._skip_tensor.update([inp.name for inp in tf_op.inputs][1:])
        del op.input[1:]

622 623 624 625 626 627 628 629 630 631 632 633 634
    def convert_add(self, tf_op):
        if len(tf_op.inputs) == 2:
            self.convert_elementwise(tf_op)
        else:
            op = self.convert_general_op(tf_op)
            op.type = MaceOp.AddN.name

    def convert_activation(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Activation.name

        type_arg = op.arg.add()
        type_arg.name = MaceKeyword.mace_activation_type_str
635
        type_arg.s = six.b(self.activation_type[tf_op.type].name)
636

637
        if tf_op.type == TFOpType.Relu6.name:
638 639 640
            limit_arg = op.arg.add()
            limit_arg.name = MaceKeyword.mace_activation_max_limit_str
            limit_arg.f = 6.0
641 642 643 644 645
        elif tf_op.type == TFOpType.LeakyRelu.name:
            alpha_arg = op.arg.add()
            alpha_arg.name = \
                MaceKeyword.mace_activation_leakyrelu_coefficient_str
            alpha_arg.f = tf_op.get_attr(tf_alpha_str)
646

Y
yejianwu 已提交
647 648 649 650
    def convert_fill(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Fill.name

651 652
    def convert_fused_batchnorm(self, tf_op):
        op = self.convert_general_op(tf_op)
653
        op.type = MaceOp.BatchNorm.name
654

655 656
        is_training = tf_op.get_attr(tf_is_training_str)
        assert is_training is False, 'Only support batch normalization ' \
657 658
                                     'with is_training False, but got %s' % \
                                     is_training
659

660 661 662 663 664 665 666 667 668
        gamma_value = tf_op.inputs[1].eval().astype(np.float32)
        beta_value = tf_op.inputs[2].eval().astype(np.float32)
        mean_value = tf_op.inputs[3].eval().astype(np.float32)
        var_value = tf_op.inputs[4].eval().astype(np.float32)
        epsilon_value = tf_op.get_attr(tf_epsilon_str)

        scale_name = self.get_scope(tf_op.name) + '/scale:0'
        offset_name = self.get_scope(tf_op.name) + '/offset:0'
        scale_value = (
669 670
                (1.0 / np.vectorize(math.sqrt)(
                    var_value + epsilon_value)) * gamma_value)
671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702
        offset_value = (-mean_value * scale_value) + beta_value
        self.add_tensor(scale_name, scale_value.shape, mace_pb2.DT_FLOAT,
                        scale_value)
        self.add_tensor(offset_name, offset_value.shape, mace_pb2.DT_FLOAT,
                        offset_value)
        self._skip_tensor.update([inp.name for inp in tf_op.inputs][1:])

        del op.input[1:]
        op.input.extend([scale_name, offset_name])
        del op.output[1:]
        del op.output_shape[1:]

    def convert_pooling(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Pooling.name
        pooling_type_arg = op.arg.add()
        pooling_type_arg.name = MaceKeyword.mace_pooling_type_str
        pooling_type_arg.i = self.pooling_type_mode[tf_op.type].value
        padding_arg = op.arg.add()
        padding_arg.name = MaceKeyword.mace_padding_str
        padding_arg.i = self.padding_mode[tf_op.get_attr(tf_padding_str)].value
        strides_arg = op.arg.add()
        strides_arg.name = MaceKeyword.mace_strides_str
        strides_arg.ints.extend(tf_op.get_attr(tf_strides_str)[1:3])
        kernels_arg = op.arg.add()
        kernels_arg.name = MaceKeyword.mace_kernel_str
        kernels_arg.ints.extend(tf_op.get_attr(tf_kernel_str)[1:3])

    def convert_softmax(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Softmax.name

赵奇可 已提交
703 704 705 706 707 708 709 710 711 712 713 714 715 716
    def convert_resize_bicubic(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.ResizeBicubic.name
        del op.input[1:]

        size_arg = op.arg.add()
        size_arg.name = MaceKeyword.mace_resize_size_str
        size_value = tf_op.inputs[1].eval().astype(np.int32)
        size_arg.ints.extend(size_value)
        self._skip_tensor.add(tf_op.inputs[1].name)
        align_corners_arg = op.arg.add()
        align_corners_arg.name = MaceKeyword.mace_align_corners_str
        align_corners_arg.i = tf_op.get_attr(tf_align_corners)

717 718 719 720 721 722 723 724 725
    def convert_resize_bilinear(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.ResizeBilinear.name
        del op.input[1:]

        size_arg = op.arg.add()
        size_arg.name = MaceKeyword.mace_resize_size_str
        size_value = tf_op.inputs[1].eval().astype(np.int32)
        size_arg.ints.extend(size_value)
726
        self._skip_tensor.add(tf_op.inputs[1].name)
727 728 729 730
        align_corners_arg = op.arg.add()
        align_corners_arg.name = MaceKeyword.mace_align_corners_str
        align_corners_arg.i = tf_op.get_attr(tf_align_corners)

L
lichao18 已提交
731 732 733
    def convert_resize_nearest_neighbor(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.ResizeNearestNeighbor.name
734

L
lichao18 已提交
735 736 737 738 739
        align_corners_arg = op.arg.add()
        align_corners_arg.name = MaceKeyword.mace_align_corners_str
        align_corners_arg.i = tf_op.get_attr(tf_align_corners)

    def convert_space_batch(self, tf_op):
740 741 742 743 744 745 746 747 748
        op = self.convert_general_op(tf_op)
        del op.input[1:]

        size_arg = op.arg.add()
        size_arg.name = MaceKeyword.mace_space_batch_block_shape_str
        size_value = tf_op.inputs[1].eval().astype(np.int32)
        size_arg.ints.extend(size_value)

        crops_or_paddings_arg = op.arg.add()
749
        if op.type == TFOpType.BatchToSpaceND.name:
750 751 752 753 754 755 756 757 758
            op.type = MaceOp.BatchToSpaceND.name
            crops_or_paddings_arg.name = \
                MaceKeyword.mace_batch_to_space_crops_str
        else:
            op.type = MaceOp.SpaceToBatchND.name
            crops_or_paddings_arg.name = MaceKeyword.mace_paddings_str
        crops_or_paddings_value = tf_op.inputs[2].eval().astype(np.int32).flat
        crops_or_paddings_arg.ints.extend(crops_or_paddings_value)

759 760
        self._skip_tensor.add(tf_op.inputs[1].name)
        self._skip_tensor.add(tf_op.inputs[2].name)
761 762 763

    def convert_space_depth(self, tf_op):
        op = self.convert_general_op(tf_op)
764
        if op.type == TFOpType.SpaceToDepth.name:
765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781
            op.type = MaceOp.SpaceToDepth.name
        else:
            op.type = MaceOp.DepthToSpace.name

        size_arg = op.arg.add()
        size_arg.name = MaceKeyword.mace_space_depth_block_size_str
        size_arg.i = tf_op.get_attr(tf_block_size)

    def convert_pad(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Pad.name
        del op.input[1:]

        paddings_arg = op.arg.add()
        paddings_arg.name = MaceKeyword.mace_paddings_str
        paddings_value = tf_op.inputs[1].eval().astype(np.int32).flat
        paddings_arg.ints.extend(paddings_value)
782
        self._skip_tensor.add(tf_op.inputs[1].name)
783

784 785 786 787 788 789 790
        pad_type_arg = op.arg.add()
        pad_type_arg.name = MaceKeyword.mace_pad_type_str

        if tf_op.type == TFOpType.Pad:
            if len(tf_op.inputs) == 3:
                constant_value_arg = op.arg.add()
                constant_value_arg.name = MaceKeyword.mace_constant_value_str
W
Wiktor Adamski 已提交
791 792
                constant_value = tf_op.inputs[2].eval().astype(np.int32) \
                    .flat[0]
793 794 795 796 797 798 799
                constant_value_arg.i = constant_value
                self._skip_tensor.add(tf_op.inputs[2].name)

            pad_type_arg.i = PadType.CONSTANT.value

        elif tf_op.type == TFOpType.MirrorPad:
            pad_type_arg.i = self.pad_type[tf_op.get_attr('mode')].value
800 801 802 803 804 805 806 807 808

    def convert_concat(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Concat.name
        del op.input[-1]

        axis_arg = op.arg.add()
        axis_arg.name = MaceKeyword.mace_axis_str
        axis = tf_op.inputs[-1].eval().astype(np.int32)
L
liuqi 已提交
809
        axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis
810 811
        axis_arg.i = axis

812
        self._skip_tensor.add(tf_op.inputs[-1].name)
813

李寅 已提交
814 815 816 817
    def convert_matmul(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.MatMul.name

818 819 820 821 822 823
        try:
            adj_x = tf_op.get_attr('adj_x')
            transpose_a_arg = op.arg.add()
            transpose_a_arg.name = MaceKeyword.mace_transpose_a_str
            transpose_a_arg.i = int(adj_x)
        except ValueError:
李寅 已提交
824 825 826 827 828 829 830
            try:
                transpose_a = tf_op.get_attr('transpose_a')
                transpose_a_arg = op.arg.add()
                transpose_a_arg.name = MaceKeyword.mace_transpose_a_str
                transpose_a_arg.i = int(transpose_a)
            except ValueError:
                pass
831 832 833 834 835 836 837

        try:
            adj_y = tf_op.get_attr('adj_y')
            transpose_b_arg = op.arg.add()
            transpose_b_arg.name = MaceKeyword.mace_transpose_b_str
            transpose_b_arg.i = int(adj_y)
        except ValueError:
李寅 已提交
838 839 840 841 842 843 844
            try:
                transpose_b = tf_op.get_attr('transpose_b')
                transpose_b_arg = op.arg.add()
                transpose_b_arg.name = MaceKeyword.mace_transpose_b_str
                transpose_b_arg.i = int(transpose_b)
            except ValueError:
                pass
845

846 847 848 849 850
    def convert_shape(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Shape.name
        op.output_type.extend([mace_pb2.DT_INT32])

851 852 853 854
    def convert_reshape(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Reshape.name

855 856 857 858
    def convert_expand_dims(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.ExpandDims.name

L
liuqi 已提交
859
        axis_value = tf_op.inputs[1].eval().astype(np.int32)
860 861 862
        axis_arg = op.arg.add()
        axis_arg.name = MaceKeyword.mace_axis_str
        axis_arg.i = axis_value
L
liuqi 已提交
863
        del op.input[1]
864

865 866 867 868 869 870 871 872 873 874 875 876 877 878
    def convert_squeeze(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Squeeze.name

        axis_arg = op.arg.add()
        axis_arg.name = MaceKeyword.mace_axis_str
        try:
            axis_value = tf_op.get_attr('squeeze_dims')
        except ValueError:
            try:
                axis_value = tf_op.get_attr('axis')
            except ValueError:
                axis_value = []
        axis_arg.ints.extend(axis_value)
879

李寅 已提交
880
    def convert_transpose(self, tf_op):
881 882 883
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Transpose.name

李寅 已提交
884 885 886
        perm = tf_op.inputs[1].eval().astype(np.int32)
        ordered_perm = np.sort(perm)

887 888 889 890 891 892 893 894
        if np.array_equal(perm, ordered_perm):
            op.type = MaceOp.Identity.name
            del op.input[1:]
            self._skip_tensor.add(tf_op.inputs[1].name)
        else:
            dims_arg = op.arg.add()
            dims_arg.name = MaceKeyword.mace_dims_str
            dims_arg.ints.extend(perm)
李寅 已提交
895

896
    def convert_reduce(self, tf_op):
897 898 899
        op = self.convert_general_op(tf_op)
        del op.input[1:]

L
liutuo 已提交
900 901 902 903
        op.type = MaceOp.Reduce.name

        reduce_type_arg = op.arg.add()
        reduce_type_arg.name = MaceKeyword.mace_reduce_type_str
904
        reduce_type_arg.i = self.reduce_math_type[tf_op.type].value
L
liutuo 已提交
905

L
liutuo 已提交
906 907
        axis_arg = op.arg.add()
        axis_arg.name = MaceKeyword.mace_axis_str
L
liutuo 已提交
908 909 910 911 912 913 914 915 916 917
        if len(tf_op.inputs) > 1:
            reduce_dims = tf_op.inputs[1].eval()
        else:
            try:
                reduce_dims = tf_op.get_attr('axis')
            except ValueError:
                try:
                    reduce_dims = tf_op.get_attr('reduction_indices')
                except ValueError:
                    reduce_dims = []
918
        if isinstance(reduce_dims, (np.ndarray, list)):
L
liutuo 已提交
919 920 921
            axis_arg.ints.extend(reduce_dims)
        else:
            axis_arg.ints.append(reduce_dims)
L
liutuo 已提交
922 923
        keep_dims_arg = op.arg.add()
        keep_dims_arg.name = MaceKeyword.mace_keepdims_str
924
        try:
L
liutuo 已提交
925
            keep_dims = tf_op.get_attr('keepdims')
926
        except ValueError:
L
liutuo 已提交
927 928 929 930 931
            try:
                keep_dims = tf_op.get_attr('keep_dims')
            except ValueError:
                keep_dims = 0
        keep_dims_arg.i = keep_dims
932 933

        self._skip_tensor.add(tf_op.inputs[1].name)
934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973

    def convert_gather(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Gather.name

        if len(tf_op.inputs) >= 3:
            axis_arg = op.arg.add()
            axis_arg.name = MaceKeyword.mace_axis_str
            axis_arg.i = tf_op.inputs[2].eval()

    def convert_stridedslice(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.StridedSlice.name

        begin_mask_arg = op.arg.add()
        begin_mask_arg.name = MaceKeyword.mace_begin_mask_str
        begin_mask_arg.i = tf_op.get_attr(MaceKeyword.mace_begin_mask_str)

        end_mask_arg = op.arg.add()
        end_mask_arg.name = MaceKeyword.mace_end_mask_str
        end_mask_arg.i = tf_op.get_attr(MaceKeyword.mace_end_mask_str)

        ellipsis_mask_arg = op.arg.add()
        ellipsis_mask_arg.name = MaceKeyword.mace_ellipsis_mask_str
        ellipsis_mask_arg.i = tf_op.get_attr(
            MaceKeyword.mace_ellipsis_mask_str)

        new_axis_mask_arg = op.arg.add()
        new_axis_mask_arg.name = MaceKeyword.mace_new_axis_mask_str
        new_axis_mask_arg.i = tf_op.get_attr(
            MaceKeyword.mace_new_axis_mask_str)

        shrink_axis_mask_arg = op.arg.add()
        shrink_axis_mask_arg.name = MaceKeyword.mace_shrink_axis_mask_str
        shrink_axis_mask_arg.i = tf_op.get_attr(
            MaceKeyword.mace_shrink_axis_mask_str)

    def convert_slice(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.StridedSlice.name
李寅 已提交
974 975 976
        arg = op.arg.add()
        arg.name = 'slice'
        arg.i = 1
977

978 979 980 981
    def convert_reverse(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Reverse.name

982 983 984 985 986 987 988 989 990 991
    def convert_stack(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Stack.name

        axis_arg = op.arg.add()
        axis_arg.name = MaceKeyword.mace_axis_str
        try:
            axis_arg.i = tf_op.get_attr(MaceKeyword.mace_axis_str)
        except ValueError:
            axis_arg.i = 0
李寅 已提交
992

Y
yejianwu 已提交
993 994 995 996 997 998 999 1000 1001 1002 1003
    def convert_unstack(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Unstack.name

        axis_arg = op.arg.add()
        axis_arg.name = MaceKeyword.mace_axis_str
        try:
            axis_arg.i = tf_op.get_attr(MaceKeyword.mace_axis_str)
        except ValueError:
            axis_arg.i = 0

李寅 已提交
1004 1005 1006 1007 1008 1009 1010
    def convert_cast(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Cast.name

        try:
            dtype = tf_op.get_attr('DstT')
            if dtype == tf.int32:
李寅 已提交
1011
                op.output_type.extend([mace_pb2.DT_INT32])
李寅 已提交
1012
            elif dtype == tf.float32:
L
liyin 已提交
1013
                op.output_type.extend([mace_pb2.DT_FLOAT])
李寅 已提交
1014 1015 1016
            else:
                mace_check(False, "data type %s not supported" % dtype)
        except ValueError:
L
liyin 已提交
1017
            op.output_type.extend([mace_pb2.DT_FLOAT])
李寅 已提交
1018 1019 1020 1021 1022

    def convert_argmax(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.ArgMax.name
        op.output_type.extend([mace_pb2.DT_INT32])
L
liuqi 已提交
1023 1024 1025

    def convert_split(self, tf_op):
        op = self.convert_general_op(tf_op)
L
liutuo 已提交
1026 1027 1028 1029 1030 1031 1032
        num_or_size_splits = tf_op.get_attr('num_split')
        if num_or_size_splits == 1:
            op.type = MaceOp.Identity.name
        else:
            op.type = MaceOp.Split.name
            axis = tf_op.inputs[0].eval().astype(np.int32)
            axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis
L
liuqi 已提交
1033

L
liutuo 已提交
1034 1035 1036
            axis_arg = op.arg.add()
            axis_arg.name = MaceKeyword.mace_axis_str
            axis_arg.i = axis
Y
yejianwu 已提交
1037

L
liutuo 已提交
1038 1039 1040 1041
            num_split_arg = op.arg.add()
            num_split_arg.name = MaceKeyword.mace_num_split_str
            num_split_arg.i = num_or_size_splits
        del op.input[0]
L
liuqi 已提交
1042
        self._skip_tensor.add(tf_op.inputs[0].name)
李寅 已提交
1043 1044 1045 1046 1047 1048 1049

    def convert_fake_quantize(self, tf_op):
        op = self.convert_general_op(tf_op)
        min_arg = op.arg.add()
        min_arg.name = 'min'
        max_arg = op.arg.add()
        max_arg.name = 'max'
1050 1051 1052 1053 1054 1055
        if tf_op.type == TFOpType.FakeQuantWithMinMaxVars.name:
            min_arg.f = tf_op.inputs[1].eval()
            max_arg.f = tf_op.inputs[2].eval()
        elif tf_op.type == TFOpType.FakeQuantWithMinMaxArgs.name:
            min_arg.f = float(tf_op.get_attr('min'))
            max_arg.f = float(tf_op.get_attr('max'))
李寅 已提交
1056 1057 1058 1059 1060 1061 1062
        narrow_range_arg = op.arg.add()
        narrow_range_arg.name = 'narrow_range'
        narrow_range_arg.i = int(tf_op.get_attr('narrow_range'))
        num_bits_arg = op.arg.add()
        num_bits_arg.name = 'num_bits'
        num_bits_arg.i = int(tf_op.get_attr('num_bits'))

1063 1064 1065
        if tf_op.type == TFOpType.FakeQuantWithMinMaxVars.name:
            self._skip_tensor.add(tf_op.inputs[1].name)
            self._skip_tensor.add(tf_op.inputs[2].name)
1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085

    def convert_cumsum(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Cumsum.name

        axis = tf_op.inputs[1].eval().astype(np.int32)
        axis_arg = op.arg.add()
        axis_arg.name = MaceKeyword.mace_axis_str
        axis_arg.i = axis
        del op.input[1]

        exclusive = tf_op.get_attr('exclusive')
        exclusive_arg = op.arg.add()
        exclusive_arg.name = MaceKeyword.mace_exclusive_str
        exclusive_arg.i = int(exclusive)

        reverse = tf_op.get_attr('reverse')
        reverse_arg = op.arg.add()
        reverse_arg.name = MaceKeyword.mace_reverse_str
        reverse_arg.i = int(reverse)