tensorflow_converter.py 40.7 KB
Newer Older
L
Liangliang He 已提交
1
# Copyright 2018 The MACE Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
liutuo 已提交
15 16 17 18
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

19 20
import math
import numpy as np
21
import six
22
import tensorflow as tf
23
from enum import Enum
24

L
liutuo 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38
from python.py_proto import mace_pb2
from . import base_converter
from .base_converter import PoolingType
from .base_converter import PaddingMode
from .base_converter import ActivationType
from .base_converter import EltwiseType
from .base_converter import PadType
from .base_converter import FrameworkType
from .base_converter import ReduceType
from .base_converter import DataFormat
from .base_converter import MaceOp
from .base_converter import MaceKeyword
from .base_converter import ConverterUtil
from python.utils.util import mace_check
39 40

from tensorflow.core.framework import tensor_shape_pb2
41
from tensorflow.tools.graph_transforms import TransformGraph
42 43 44 45 46 47 48

tf_padding_str = 'padding'
tf_strides_str = 'strides'
tf_dilations_str = 'dilations'
tf_data_format_str = 'data_format'
tf_kernel_str = 'ksize'
tf_epsilon_str = 'epsilon'
49
tf_alpha_str = 'alpha'
50
tf_is_training_str = 'is_training'
51 52
tf_align_corners = 'align_corners'
tf_block_size = 'block_size'
53 54
tf_squeeze_dims = 'squeeze_dims'
tf_axis = 'axis'
55

56 57 58 59 60 61 62 63 64 65
TFSupportedOps = [
    'Conv2D',
    'DepthwiseConv2dNative',
    'Conv2DBackpropInput',
    'BiasAdd',
    'Add',
    'Sub',
    'Mul',
    'Div',
    'Min',
L
liutuo 已提交
66
    'Minimum',
67
    'Max',
L
liutuo 已提交
68
    'Maximum',
69 70
    'Neg',
    'Abs',
L
liutuo 已提交
71
    'Pow',
72
    'RealDiv',
73
    'Square',
74
    'SquaredDifference',
75
    'Rsqrt',
L
liutuo 已提交
76
    'Sum',
李寅 已提交
77
    'Equal',
78
    'Relu',
79
    'LeakyRelu',
80 81 82
    'Relu6',
    'Tanh',
    'Sigmoid',
Y
yejianwu 已提交
83
    'Fill',
84 85 86
    'FusedBatchNorm',
    'AvgPool',
    'MaxPool',
87
    'ExpandDims',
88 89
    'Squeeze',
    'MatMul',
90
    'BatchMatMul',
91 92 93 94 95
    'Identity',
    'Reshape',
    'Shape',
    'Transpose',
    'Softmax',
赵奇可 已提交
96
    'ResizeBicubic',
97
    'ResizeBilinear',
L
lichao18 已提交
98
    'ResizeNearestNeighbor',
99 100 101 102 103 104
    'Placeholder',
    'SpaceToBatchND',
    'BatchToSpaceND',
    'DepthToSpace',
    'SpaceToDepth',
    'Pad',
叶剑武 已提交
105
    'PadV2',
106 107
    'ConcatV2',
    'Mean',
108
    'Prod',
109
    'Const',
110
    'Gather',
L
liuqi 已提交
111
    'GatherV2',
112 113
    'StridedSlice',
    'Slice',
114
    'ReverseV2',
115 116
    'Stack',
    'Pack',
Y
yejianwu 已提交
117 118
    'Unstack',
    'Unpack',
李寅 已提交
119
    'Cast',
李寅 已提交
120
    'ArgMax',
L
liuqi 已提交
121
    'Split',
李寅 已提交
122
    'FakeQuantWithMinMaxVars',
123
    'FakeQuantWithMinMaxArgs',
W
w-adamski 已提交
124
    'FloorDiv',
125
    'Sqrt',
126
    'MirrorPad',
127
    'Cumsum',
W
Wiktor Adamski 已提交
128
    'OneHot',
129 130 131 132
]

TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str)

L
liuqi 已提交
133 134
TFSupportedOps = [six.b(op) for op in TFSupportedOps]

135 136 137 138 139 140 141 142 143 144
TFTransformGraphOptions = [
    'strip_unused_nodes',
    'remove_nodes(op=Identity, op=CheckNumerics)',
    'fold_constants(ignore_errors=true)',
    'fold_batch_norms',
    'fold_old_batch_norms',
    'remove_control_dependencies',
    'strip_unused_nodes',
    'sort_by_execution_order'
]
145

146 147 148 149 150 151 152 153 154 155 156

class TensorflowConverter(base_converter.ConverterInterface):
    """A class for convert tensorflow frozen model to mace model.
    We use tensorflow engine to infer op output shapes, since they are of
    too many types."""

    padding_mode = {
        'VALID': PaddingMode.VALID,
        'SAME': PaddingMode.SAME,
        'FULL': PaddingMode.FULL
    }
157 158
    padding_mode = {six.b(k): v for k, v in six.iteritems(padding_mode)}

159
    pooling_type_mode = {
160 161
        TFOpType.AvgPool.name: PoolingType.AVG,
        TFOpType.MaxPool.name: PoolingType.MAX
162
    }
163

164
    eltwise_type = {
165 166 167 168
        TFOpType.Add.name: EltwiseType.SUM,
        TFOpType.Sub.name: EltwiseType.SUB,
        TFOpType.Mul.name: EltwiseType.PROD,
        TFOpType.Div.name: EltwiseType.DIV,
L
liutuo 已提交
169
        TFOpType.Minimum.name: EltwiseType.MIN,
L
liutuo 已提交
170
        TFOpType.Maximum.name: EltwiseType.MAX,
171 172
        TFOpType.Neg.name: EltwiseType.NEG,
        TFOpType.Abs.name: EltwiseType.ABS,
L
liutuo 已提交
173
        TFOpType.Pow.name: EltwiseType.POW,
174
        TFOpType.RealDiv.name: EltwiseType.DIV,
W
w-adamski 已提交
175
        TFOpType.FloorDiv.name: EltwiseType.FLOOR_DIV,
176
        TFOpType.SquaredDifference.name: EltwiseType.SQR_DIFF,
177
        TFOpType.Square.name: EltwiseType.POW,
李寅 已提交
178
        TFOpType.Rsqrt.name: EltwiseType.POW,
179
        TFOpType.Sqrt.name: EltwiseType.POW,
李寅 已提交
180
        TFOpType.Equal.name: EltwiseType.EQUAL,
181
    }
182

183
    activation_type = {
184 185 186
        TFOpType.Relu.name: ActivationType.RELU,
        TFOpType.Relu6.name: ActivationType.RELUX,
        TFOpType.Tanh.name: ActivationType.TANH,
187 188
        TFOpType.Sigmoid.name: ActivationType.SIGMOID,
        TFOpType.LeakyRelu.name: ActivationType.LEAKYRELU,
189 190
    }

191 192 193 194 195
    reduce_math_type = {
        TFOpType.Min.name: ReduceType.MIN,
        TFOpType.Max.name: ReduceType.MAX,
        TFOpType.Mean.name: ReduceType.MEAN,
        TFOpType.Prod.name: ReduceType.PROD,
L
liutuo 已提交
196
        TFOpType.Sum.name: ReduceType.SUM,
197 198
    }

199
    pad_type = {
W
Wiktor Adamski 已提交
200 201
        'CONSTANT':  PadType.CONSTANT,
        'REFLECT':   PadType.REFLECT,
202 203 204
        'SYMMETRIC': PadType.SYMMETRIC
    }

205 206
    def __init__(self, option, src_model_file):
        self._op_converters = {
207 208 209 210 211 212 213 214
            TFOpType.Conv2D.name: self.convert_conv2d,
            TFOpType.DepthwiseConv2dNative.name: self.convert_conv2d,
            TFOpType.Conv2DBackpropInput.name: self.convert_conv2d,
            TFOpType.BiasAdd.name: self.convert_biasadd,
            TFOpType.Add.name: self.convert_add,
            TFOpType.Sub.name: self.convert_elementwise,
            TFOpType.Mul.name: self.convert_elementwise,
            TFOpType.Div.name: self.convert_elementwise,
L
liutuo 已提交
215
            TFOpType.Minimum.name: self.convert_elementwise,
L
liutuo 已提交
216
            TFOpType.Maximum.name: self.convert_elementwise,
217 218
            TFOpType.Neg.name: self.convert_elementwise,
            TFOpType.Abs.name: self.convert_elementwise,
L
liutuo 已提交
219
            TFOpType.Pow.name: self.convert_elementwise,
220 221
            TFOpType.RealDiv.name: self.convert_elementwise,
            TFOpType.SquaredDifference.name: self.convert_elementwise,
222 223
            TFOpType.Square.name: self.convert_elementwise,
            TFOpType.Rsqrt.name: self.convert_elementwise,
李寅 已提交
224
            TFOpType.Equal.name: self.convert_elementwise,
225 226 227 228
            TFOpType.Min.name: self.convert_reduce,
            TFOpType.Max.name: self.convert_reduce,
            TFOpType.Mean.name: self.convert_reduce,
            TFOpType.Prod.name: self.convert_reduce,
229
            TFOpType.Relu.name: self.convert_activation,
230
            TFOpType.LeakyRelu.name: self.convert_activation,
231 232 233
            TFOpType.Relu6.name: self.convert_activation,
            TFOpType.Tanh.name: self.convert_activation,
            TFOpType.Sigmoid.name: self.convert_activation,
Y
yejianwu 已提交
234
            TFOpType.Fill.name: self.convert_fill,
235 236 237 238
            TFOpType.FusedBatchNorm.name: self.convert_fused_batchnorm,
            TFOpType.AvgPool.name: self.convert_pooling,
            TFOpType.MaxPool.name: self.convert_pooling,
            TFOpType.MatMul.name: self.convert_matmul,
239
            TFOpType.BatchMatMul.name: self.convert_matmul,
240 241
            TFOpType.Identity.name: self.convert_identity,
            TFOpType.Reshape.name: self.convert_reshape,
242
            TFOpType.Shape.name: self.convert_shape,
243
            TFOpType.ExpandDims.name: self.convert_expand_dims,
244
            TFOpType.Squeeze.name: self.convert_squeeze,
245 246
            TFOpType.Transpose.name: self.convert_transpose,
            TFOpType.Softmax.name: self.convert_softmax,
赵奇可 已提交
247
            TFOpType.ResizeBicubic.name: self.convert_resize_bicubic,
248
            TFOpType.ResizeBilinear.name: self.convert_resize_bilinear,
L
lichao18 已提交
249
            TFOpType.ResizeNearestNeighbor.name: self.convert_resize_nearest_neighbor,  # noqa
250 251 252 253 254 255
            TFOpType.Placeholder.name: self.convert_nop,
            TFOpType.SpaceToBatchND.name: self.convert_space_batch,
            TFOpType.BatchToSpaceND.name: self.convert_space_batch,
            TFOpType.DepthToSpace.name: self.convert_space_depth,
            TFOpType.SpaceToDepth.name: self.convert_space_depth,
            TFOpType.Pad.name: self.convert_pad,
叶剑武 已提交
256
            TFOpType.PadV2.name: self.convert_pad,
257 258
            TFOpType.ConcatV2.name: self.convert_concat,
            TFOpType.Const.name: self.convert_nop,
259
            TFOpType.Gather.name: self.convert_gather,
L
liuqi 已提交
260
            TFOpType.GatherV2.name: self.convert_gather,
261 262
            TFOpType.StridedSlice.name: self.convert_stridedslice,
            TFOpType.Slice.name: self.convert_slice,
263
            TFOpType.ReverseV2.name: self.convert_reverse,
264
            TFOpType.Pack.name: self.convert_stack,
李寅 已提交
265
            TFOpType.Stack.name: self.convert_stack,
Y
yejianwu 已提交
266 267
            TFOpType.Unpack.name: self.convert_unstack,
            TFOpType.Unstack.name: self.convert_unstack,
李寅 已提交
268 269
            TFOpType.Cast.name: self.convert_cast,
            TFOpType.ArgMax.name: self.convert_argmax,
L
liuqi 已提交
270
            TFOpType.Split.name: self.convert_split,
李寅 已提交
271
            TFOpType.FakeQuantWithMinMaxVars.name: self.convert_fake_quantize,
272
            TFOpType.FakeQuantWithMinMaxArgs.name: self.convert_fake_quantize,
W
w-adamski 已提交
273
            TFOpType.FloorDiv.name: self.convert_elementwise,
274
            TFOpType.Sqrt.name: self.convert_elementwise,
275
            TFOpType.MirrorPad.name: self.convert_pad,
276
            TFOpType.Cumsum.name: self.convert_cumsum,
W
Wiktor Adamski 已提交
277
            TFOpType.OneHot.name: self.convert_one_hot,
L
liutuo 已提交
278
            TFOpType.Sum.name: self.convert_reduce,
279 280 281
        }
        self._option = option
        self._mace_net_def = mace_pb2.NetDef()
282
        ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.HWIO)
283
        ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NHWC)
284 285

        # import tensorflow graph
286 287 288
        tf_graph_def = tf.GraphDef()
        with tf.gfile.Open(src_model_file, 'rb') as f:
            tf_graph_def.ParseFromString(f.read())
289 290

        self._placeholders = {}
L
liyin 已提交
291 292
        self._skip_tensor = set()
        self._output_shape = {}
293

294
        print("Run transform_graph: %s" % TFTransformGraphOptions)
李寅 已提交
295
        try:
296
            print("output keys: ", option.output_nodes.keys())
李寅 已提交
297 298 299
            transformed_graph_def = TransformGraph(tf_graph_def,
                                                   option.input_nodes.keys(),
                                                   option.output_nodes.keys(),
300
                                                   TFTransformGraphOptions)
李寅 已提交
301 302 303
        except Exception as ex:
            print("Failed to transform graph using tf tool: %s" % ex)
            transformed_graph_def = tf_graph_def
304

305 306 307 308 309 310 311 312 313 314
        # To check optimized model, uncomment following code.
        # tf.io.write_graph(
        #     transformed_graph_def,
        #     ".",
        #     os.path.basename(src_model_file)[:-3] + "_opt.pb",
        #     as_text=False
        # )

        self.add_shape_info(transformed_graph_def)

B
Bin Li 已提交
315 316
        # reset default graph to clear earlier import
        tf.reset_default_graph()
317 318
        with tf.Session() as session:
            with session.graph.as_default() as graph:
319
                tf.import_graph_def(transformed_graph_def, name='')
320
                self._tf_graph = graph
L
liyin 已提交
321
                self.update_output_shapes(session)
322

L
liyin 已提交
323 324 325 326 327 328 329
        # we have polluted graph with 'shape' ops, so reset it and reload it
        # again
        tf.reset_default_graph()
        with tf.Session() as session:
            with session.graph.as_default() as graph:
                tf.import_graph_def(transformed_graph_def, name='')
                self._tf_graph = graph
330

331 332
    def run(self):
        with tf.Session() as session:
333
            self.convert_ops(session)
334 335 336 337 338 339

        self.replace_input_output_tensor_name()
        return self._mace_net_def

    def replace_input_output_tensor_name(self):
        for op in self._mace_net_def.op:
340
            for i in six.moves.range(len(op.input)):
341 342
                if op.input[i][-2:] == ':0':
                    op_name = op.input[i][:-2]
李寅 已提交
343 344
                    if op_name in self._option.input_nodes \
                            or op_name in self._option.output_nodes:
345
                        op.input[i] = op_name
346
            for i in six.moves.range(len(op.output)):
347 348 349 350 351 352 353
                if op.output[i][-2:] == ':0':
                    op_name = op.output[i][:-2]
                    if op_name in self._option.output_nodes:
                        op.output[i] = op_name

    def add_shape_info(self, tf_graph_def):
        for node in tf_graph_def.node:
李寅 已提交
354 355 356
            for input_node in self._option.input_nodes.values():
                if node.name == input_node.name \
                        or node.name + ':0' == input_node.name:
357 358 359 360 361 362
                    input_shape = input_node.shape
                    if input_node.data_format == DataFormat.OIHW \
                            and len(input_shape) == 4:
                        # OIHW -> HWIO
                        input_shape = [input_shape[2], input_shape[3],
                                       input_shape[1], input_shape[0]]
李寅 已提交
363 364 365
                    del node.attr['shape'].shape.dim[:]
                    node.attr['shape'].shape.dim.extend([
                        tensor_shape_pb2.TensorShapeProto.Dim(size=i) for i in
366
                        input_shape
李寅 已提交
367
                    ])
368
                    self._placeholders[node.name + ':0'] = \
369
                        np.zeros(shape=input_shape, dtype=float)
370 371 372 373 374 375 376 377 378

    @staticmethod
    def get_scope(tensor_name):
        idx = tensor_name.rfind('/')
        if idx == -1:
            return tensor_name
        else:
            return tensor_name[:idx]

379
    def update_output_shapes(self, sess):
L
liyin 已提交
380 381 382 383 384 385 386 387
        tensors = []
        shape_tensors = []
        for tf_op in self._tf_graph.get_operations():
            for output in tf_op.outputs:
                tensors.append(output.name)
                shape_tensors.append(tf.shape(output))

        tensor_shapes = sess.run(shape_tensors,
388
                                 feed_dict=self._placeholders)
L
liyin 已提交
389 390
        for i in range(len(tensors)):
            self._output_shape[tensors[i]] = tensor_shapes[i]
391 392

    def convert_ops(self, sess):
393 394 395 396 397
        for tf_op in self._tf_graph.get_operations():
            mace_check(tf_op.type in self._op_converters,
                       "Mace does not support tensorflow op type %s yet"
                       % tf_op.type)
            self._op_converters[tf_op.type](tf_op)
L
liyin 已提交
398

399
        self.convert_tensors()
400

401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421
    def convert_tensors(self):
        for tf_op in self._tf_graph.get_operations():
            if tf_op.type != TFOpType.Const.name:
                continue
            output_name = tf_op.outputs[0].name
            if output_name not in self._skip_tensor:
                tensor = self._mace_net_def.tensors.add()
                tensor.name = tf_op.outputs[0].name
                tf_tensor = tf_op.outputs[0].eval()
                tensor.dims.extend(list(tf_tensor.shape))

                tf_dt = tf_op.get_attr('dtype')
                if tf_dt == tf.float32:
                    tensor.data_type = mace_pb2.DT_FLOAT
                    tensor.float_data.extend(tf_tensor.astype(np.float32).flat)
                elif tf_dt == tf.int32:
                    tensor.data_type = mace_pb2.DT_INT32
                    tensor.int32_data.extend(tf_tensor.astype(np.int32).flat)
                else:
                    mace_check(False,
                               "Not supported tensor type: %s" % tf_dt.name)
422 423 424 425 426 427 428 429

    def add_tensor(self, name, shape, data_type, value):
        tensor = self._mace_net_def.tensors.add()
        tensor.name = name
        tensor.dims.extend(list(shape))
        tensor.data_type = data_type
        tensor.float_data.extend(value.flat)

430 431
    # this function tries to infer tensor shape, but some dimension shape
    # may be undefined due to variance of input length
L
liyin 已提交
432 433 434 435
    def infer_tensor_shape(self, tensor, output_shape=None):
        shape = None
        if tensor.name in self._output_shape:
            shape = self._output_shape[tensor.name]
436
        else:
L
liyin 已提交
437 438 439 440 441 442
            shape = tensor.shape.as_list()

        if output_shape:
            output_shape.dims.extend(shape)

        return shape
443

444 445 446 447 448 449 450 451 452 453 454
    def convert_nop(self, tf_op):
        pass

    def convert_general_op(self, tf_op):
        op = self._mace_net_def.op.add()
        op.name = tf_op.name
        op.type = tf_op.type
        op.input.extend([tf_input.name for tf_input in tf_op.inputs])
        op.output.extend([tf_output.name for tf_output in tf_op.outputs])
        for tf_output in tf_op.outputs:
            output_shape = op.output_shape.add()
L
liyin 已提交
455
            self.infer_tensor_shape(tf_output, output_shape)
456

李寅 已提交
457 458 459 460 461 462 463 464 465 466 467
        data_type_arg = op.arg.add()
        data_type_arg.name = 'T'
        try:
            dtype = tf_op.get_attr('T')
            if dtype == tf.int32:
                data_type_arg.i = mace_pb2.DT_INT32
            elif dtype == tf.float32:
                data_type_arg.i = self._option.data_type
            else:
                mace_check(False, "data type %s not supported" % dtype)
        except ValueError:
李滨 已提交
468 469 470 471 472 473 474 475 476 477
            try:
                dtype = tf_op.get_attr('SrcT')
                if dtype == tf.int32 or dtype == tf.bool:
                    data_type_arg.i = mace_pb2.DT_INT32
                elif dtype == tf.float32:
                    data_type_arg.i = self._option.data_type
                else:
                    mace_check(False, "data type %s not supported" % dtype)
            except ValueError:
                data_type_arg.i = self._option.data_type
李寅 已提交
478

L
liutuo 已提交
479 480 481 482
        framework_type_arg = op.arg.add()
        framework_type_arg.name = MaceKeyword.mace_framework_type_str
        framework_type_arg.i = FrameworkType.TENSORFLOW.value

483 484 485 486 487 488 489 490 491 492
        ConverterUtil.add_data_format_arg(op, DataFormat.NHWC)

        return op

    def convert_identity(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = 'Identity'

    def convert_conv2d(self, tf_op):
        op = self.convert_general_op(tf_op)
493
        if tf_op.type == TFOpType.DepthwiseConv2dNative.name:
494
            op.type = MaceOp.DepthwiseConv2d.name
495
        elif tf_op.type == TFOpType.Conv2DBackpropInput.name:
496 497 498 499 500 501 502 503 504 505 506 507 508
            op.type = MaceOp.Deconv2D.name
        else:
            op.type = MaceOp.Conv2D.name

        padding_arg = op.arg.add()
        padding_arg.name = MaceKeyword.mace_padding_str
        padding_arg.i = self.padding_mode[tf_op.get_attr(tf_padding_str)].value
        strides_arg = op.arg.add()
        strides_arg.name = MaceKeyword.mace_strides_str
        strides_arg.ints.extend(tf_op.get_attr(tf_strides_str)[1:3])
        if op.type != MaceOp.Deconv2D.name:
            dilation_arg = op.arg.add()
            dilation_arg.name = MaceKeyword.mace_dilations_str
李寅 已提交
509 510 511 512 513
            try:
                dilation_val = tf_op.get_attr(tf_dilations_str)[1:3]
            except ValueError:
                dilation_val = [1, 1]
            dilation_arg.ints.extend(dilation_val)
L
liutuo 已提交
514
        else:
L
liutuo 已提交
515 516 517 518 519 520
            try:
                dilation_val = tf_op.get_attr(tf_dilations_str)[1:3]
            except ValueError:
                dilation_val = [1, 1]
            mace_check(dilation_val[0] == 1 and dilation_val[1] == 1,
                       "Mace only supports dilation == 1 conv2d_transpose.")
L
liutuo 已提交
521 522 523 524 525 526
            mace_check(len(tf_op.inputs) >= 3,
                       "deconv should have (>=) 3 inputs.")
            del op.input[:]
            op.input.extend([tf_op.inputs[2].name,
                             tf_op.inputs[1].name,
                             tf_op.inputs[0].name])
527 528 529 530 531 532 533 534 535

    def convert_elementwise(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Eltwise.name

        type_arg = op.arg.add()
        type_arg.name = MaceKeyword.mace_element_type_str
        type_arg.i = self.eltwise_type[tf_op.type].value

536 537
        def check_is_scalar(tf_op):
            if len(tf_op.inputs) == 1:
L
liyin 已提交
538
                return len(self.infer_tensor_shape(tf_op.inputs[0])) == 0
539
            elif len(tf_op.inputs) == 2:
L
liyin 已提交
540 541
                return len(self.infer_tensor_shape(tf_op.inputs[0])) == 0 and \
                       len(self.infer_tensor_shape(tf_op.inputs[1])) == 0
542 543 544 545 546

        if check_is_scalar(tf_op):
            op.type = MaceOp.ScalarMath.name
        else:
            op.type = MaceOp.Eltwise.name
547
        if tf_op.type == TFOpType.Square:
L
liutuo 已提交
548
            value_arg = op.arg.add()
549
            value_arg.name = MaceKeyword.mace_scalar_input_str
550 551
            value_arg.f = 2.0
        elif tf_op.type == TFOpType.Rsqrt:
L
liutuo 已提交
552
            value_arg = op.arg.add()
553
            value_arg.name = MaceKeyword.mace_scalar_input_str
554
            value_arg.f = -0.5
555 556 557 558
        elif tf_op.type == TFOpType.Sqrt:
            value_arg = op.arg.add()
            value_arg.name = MaceKeyword.mace_scalar_input_str
            value_arg.f = 0.5
559 560 561

        if type_arg.i != EltwiseType.NEG.value \
                and type_arg.i != EltwiseType.ABS.value:
李寅 已提交
562 563 564 565 566 567
            try:
                def is_commutative(eltwise_type):
                    return EltwiseType(eltwise_type) in [
                        EltwiseType.SUM, EltwiseType.PROD,
                        EltwiseType.MAX, EltwiseType.MIN]

L
liyin 已提交
568 569 570
                if (len(tf_op.inputs) > 1 and
                        len(self.infer_tensor_shape(tf_op.inputs[1])) == 0 and
                        tf_op.inputs[1].op.type == TFOpType.Const.name):
李寅 已提交
571 572
                    scalar = tf_op.inputs[1].eval().astype(np.float32)
                    value_arg = op.arg.add()
573
                    value_arg.name = MaceKeyword.mace_scalar_input_str
李寅 已提交
574 575
                    value_arg.f = scalar
                    self._skip_tensor.add(tf_op.inputs[1].name)
576
                    value_index_arg = op.arg.add()
577
                    value_index_arg.name = \
578 579 580
                        MaceKeyword.mace_scalar_input_index_str
                    value_index_arg.i = 1
                    self._skip_tensor.add(tf_op.inputs[1].name)
李寅 已提交
581
                    del op.input[1]
L
liyin 已提交
582
                elif len(self.infer_tensor_shape(tf_op.inputs[0])) == 0 and \
583
                        tf_op.inputs[0].op.type == TFOpType.Const.name and \
李寅 已提交
584 585 586
                        is_commutative(type_arg.i):
                    scalar = tf_op.inputs[0].eval().astype(np.float32)
                    value_arg = op.arg.add()
587
                    value_arg.name = MaceKeyword.mace_scalar_input_str
李寅 已提交
588
                    value_arg.f = scalar
589
                    value_index_arg = op.arg.add()
590
                    value_index_arg.name = \
591 592
                        MaceKeyword.mace_scalar_input_index_str
                    value_index_arg.i = 0
李寅 已提交
593 594 595 596
                    self._skip_tensor.add(tf_op.inputs[0].name)
                    del op.input[0]
            except tf.errors.InvalidArgumentError:
                pass
L
liutuo 已提交
597

598 599 600 601
    def convert_biasadd(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.BiasAdd.name

W
Wiktor Adamski 已提交
602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624
    def convert_one_hot(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.OneHot.name

        depth_arg = op.arg.add()
        depth_arg.name = 'depth'
        depth_arg.i = tf_op.inputs[1].eval().astype(np.int32)

        on_value_arg = op.arg.add()
        on_value_arg.name = 'on_value'
        on_value_arg.f = tf_op.inputs[2].eval().astype(np.float32)

        off_value_arg = op.arg.add()
        off_value_arg.name = 'off_value'
        off_value_arg.f = tf_op.inputs[3].eval().astype(np.float32)

        axis_arg = op.arg.add()
        axis_arg.name = tf_axis
        axis_arg.i = tf_op.get_attr(tf_axis)

        self._skip_tensor.update([inp.name for inp in tf_op.inputs][1:])
        del op.input[1:]

625 626 627 628 629 630 631 632 633 634 635 636 637
    def convert_add(self, tf_op):
        if len(tf_op.inputs) == 2:
            self.convert_elementwise(tf_op)
        else:
            op = self.convert_general_op(tf_op)
            op.type = MaceOp.AddN.name

    def convert_activation(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Activation.name

        type_arg = op.arg.add()
        type_arg.name = MaceKeyword.mace_activation_type_str
638
        type_arg.s = six.b(self.activation_type[tf_op.type].name)
639

640
        if tf_op.type == TFOpType.Relu6.name:
641 642 643
            limit_arg = op.arg.add()
            limit_arg.name = MaceKeyword.mace_activation_max_limit_str
            limit_arg.f = 6.0
644 645 646 647 648
        elif tf_op.type == TFOpType.LeakyRelu.name:
            alpha_arg = op.arg.add()
            alpha_arg.name = \
                MaceKeyword.mace_activation_leakyrelu_coefficient_str
            alpha_arg.f = tf_op.get_attr(tf_alpha_str)
649

Y
yejianwu 已提交
650 651 652 653
    def convert_fill(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Fill.name

654 655
    def convert_fused_batchnorm(self, tf_op):
        op = self.convert_general_op(tf_op)
656
        op.type = MaceOp.BatchNorm.name
657

658 659
        is_training = tf_op.get_attr(tf_is_training_str)
        assert is_training is False, 'Only support batch normalization ' \
660 661
                                     'with is_training False, but got %s' % \
                                     is_training
662

663 664 665 666 667 668 669 670 671
        gamma_value = tf_op.inputs[1].eval().astype(np.float32)
        beta_value = tf_op.inputs[2].eval().astype(np.float32)
        mean_value = tf_op.inputs[3].eval().astype(np.float32)
        var_value = tf_op.inputs[4].eval().astype(np.float32)
        epsilon_value = tf_op.get_attr(tf_epsilon_str)

        scale_name = self.get_scope(tf_op.name) + '/scale:0'
        offset_name = self.get_scope(tf_op.name) + '/offset:0'
        scale_value = (
672 673
                (1.0 / np.vectorize(math.sqrt)(
                    var_value + epsilon_value)) * gamma_value)
674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705
        offset_value = (-mean_value * scale_value) + beta_value
        self.add_tensor(scale_name, scale_value.shape, mace_pb2.DT_FLOAT,
                        scale_value)
        self.add_tensor(offset_name, offset_value.shape, mace_pb2.DT_FLOAT,
                        offset_value)
        self._skip_tensor.update([inp.name for inp in tf_op.inputs][1:])

        del op.input[1:]
        op.input.extend([scale_name, offset_name])
        del op.output[1:]
        del op.output_shape[1:]

    def convert_pooling(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Pooling.name
        pooling_type_arg = op.arg.add()
        pooling_type_arg.name = MaceKeyword.mace_pooling_type_str
        pooling_type_arg.i = self.pooling_type_mode[tf_op.type].value
        padding_arg = op.arg.add()
        padding_arg.name = MaceKeyword.mace_padding_str
        padding_arg.i = self.padding_mode[tf_op.get_attr(tf_padding_str)].value
        strides_arg = op.arg.add()
        strides_arg.name = MaceKeyword.mace_strides_str
        strides_arg.ints.extend(tf_op.get_attr(tf_strides_str)[1:3])
        kernels_arg = op.arg.add()
        kernels_arg.name = MaceKeyword.mace_kernel_str
        kernels_arg.ints.extend(tf_op.get_attr(tf_kernel_str)[1:3])

    def convert_softmax(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Softmax.name

赵奇可 已提交
706 707 708 709 710 711 712 713 714 715 716 717 718 719
    def convert_resize_bicubic(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.ResizeBicubic.name
        del op.input[1:]

        size_arg = op.arg.add()
        size_arg.name = MaceKeyword.mace_resize_size_str
        size_value = tf_op.inputs[1].eval().astype(np.int32)
        size_arg.ints.extend(size_value)
        self._skip_tensor.add(tf_op.inputs[1].name)
        align_corners_arg = op.arg.add()
        align_corners_arg.name = MaceKeyword.mace_align_corners_str
        align_corners_arg.i = tf_op.get_attr(tf_align_corners)

720 721 722 723 724 725 726 727 728
    def convert_resize_bilinear(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.ResizeBilinear.name
        del op.input[1:]

        size_arg = op.arg.add()
        size_arg.name = MaceKeyword.mace_resize_size_str
        size_value = tf_op.inputs[1].eval().astype(np.int32)
        size_arg.ints.extend(size_value)
729
        self._skip_tensor.add(tf_op.inputs[1].name)
730 731 732 733
        align_corners_arg = op.arg.add()
        align_corners_arg.name = MaceKeyword.mace_align_corners_str
        align_corners_arg.i = tf_op.get_attr(tf_align_corners)

L
lichao18 已提交
734 735 736
    def convert_resize_nearest_neighbor(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.ResizeNearestNeighbor.name
737

L
lichao18 已提交
738 739 740 741 742
        align_corners_arg = op.arg.add()
        align_corners_arg.name = MaceKeyword.mace_align_corners_str
        align_corners_arg.i = tf_op.get_attr(tf_align_corners)

    def convert_space_batch(self, tf_op):
743 744 745 746 747 748 749 750 751
        op = self.convert_general_op(tf_op)
        del op.input[1:]

        size_arg = op.arg.add()
        size_arg.name = MaceKeyword.mace_space_batch_block_shape_str
        size_value = tf_op.inputs[1].eval().astype(np.int32)
        size_arg.ints.extend(size_value)

        crops_or_paddings_arg = op.arg.add()
752
        if op.type == TFOpType.BatchToSpaceND.name:
753 754 755 756 757 758 759 760 761
            op.type = MaceOp.BatchToSpaceND.name
            crops_or_paddings_arg.name = \
                MaceKeyword.mace_batch_to_space_crops_str
        else:
            op.type = MaceOp.SpaceToBatchND.name
            crops_or_paddings_arg.name = MaceKeyword.mace_paddings_str
        crops_or_paddings_value = tf_op.inputs[2].eval().astype(np.int32).flat
        crops_or_paddings_arg.ints.extend(crops_or_paddings_value)

762 763
        self._skip_tensor.add(tf_op.inputs[1].name)
        self._skip_tensor.add(tf_op.inputs[2].name)
764 765 766

    def convert_space_depth(self, tf_op):
        op = self.convert_general_op(tf_op)
767
        if op.type == TFOpType.SpaceToDepth.name:
768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784
            op.type = MaceOp.SpaceToDepth.name
        else:
            op.type = MaceOp.DepthToSpace.name

        size_arg = op.arg.add()
        size_arg.name = MaceKeyword.mace_space_depth_block_size_str
        size_arg.i = tf_op.get_attr(tf_block_size)

    def convert_pad(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Pad.name
        del op.input[1:]

        paddings_arg = op.arg.add()
        paddings_arg.name = MaceKeyword.mace_paddings_str
        paddings_value = tf_op.inputs[1].eval().astype(np.int32).flat
        paddings_arg.ints.extend(paddings_value)
785
        self._skip_tensor.add(tf_op.inputs[1].name)
786

787 788 789
        pad_type_arg = op.arg.add()
        pad_type_arg.name = MaceKeyword.mace_pad_type_str

叶剑武 已提交
790
        if tf_op.type == TFOpType.Pad or tf_op.type == TFOpType.PadV2:
791 792 793
            if len(tf_op.inputs) == 3:
                constant_value_arg = op.arg.add()
                constant_value_arg.name = MaceKeyword.mace_constant_value_str
叶剑武 已提交
794 795 796 797 798 799 800 801 802
                constant_value = tf_op.inputs[2].eval().flat[0]
                tf_dt = tf_op.inputs[2].dtype
                if tf_dt == tf.float32:
                    constant_value_arg.f = constant_value
                elif tf_dt == tf.int32:
                    constant_value_arg.i = constant_value
                else:
                    mace_check(False,
                               "Unsupported data type: %s" % tf_dt.name)
803 804 805 806 807 808
                self._skip_tensor.add(tf_op.inputs[2].name)

            pad_type_arg.i = PadType.CONSTANT.value

        elif tf_op.type == TFOpType.MirrorPad:
            pad_type_arg.i = self.pad_type[tf_op.get_attr('mode')].value
809 810 811 812 813 814 815 816 817

    def convert_concat(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Concat.name
        del op.input[-1]

        axis_arg = op.arg.add()
        axis_arg.name = MaceKeyword.mace_axis_str
        axis = tf_op.inputs[-1].eval().astype(np.int32)
L
liuqi 已提交
818
        axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis
819 820
        axis_arg.i = axis

821
        self._skip_tensor.add(tf_op.inputs[-1].name)
822

李寅 已提交
823 824 825 826
    def convert_matmul(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.MatMul.name

827 828 829 830 831 832
        try:
            adj_x = tf_op.get_attr('adj_x')
            transpose_a_arg = op.arg.add()
            transpose_a_arg.name = MaceKeyword.mace_transpose_a_str
            transpose_a_arg.i = int(adj_x)
        except ValueError:
李寅 已提交
833 834 835 836 837 838 839
            try:
                transpose_a = tf_op.get_attr('transpose_a')
                transpose_a_arg = op.arg.add()
                transpose_a_arg.name = MaceKeyword.mace_transpose_a_str
                transpose_a_arg.i = int(transpose_a)
            except ValueError:
                pass
840 841 842 843 844 845 846

        try:
            adj_y = tf_op.get_attr('adj_y')
            transpose_b_arg = op.arg.add()
            transpose_b_arg.name = MaceKeyword.mace_transpose_b_str
            transpose_b_arg.i = int(adj_y)
        except ValueError:
李寅 已提交
847 848 849 850 851 852 853
            try:
                transpose_b = tf_op.get_attr('transpose_b')
                transpose_b_arg = op.arg.add()
                transpose_b_arg.name = MaceKeyword.mace_transpose_b_str
                transpose_b_arg.i = int(transpose_b)
            except ValueError:
                pass
854

855 856 857 858 859
    def convert_shape(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Shape.name
        op.output_type.extend([mace_pb2.DT_INT32])

860 861 862 863
    def convert_reshape(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Reshape.name

864 865 866 867
    def convert_expand_dims(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.ExpandDims.name

L
liuqi 已提交
868
        axis_value = tf_op.inputs[1].eval().astype(np.int32)
869 870 871
        axis_arg = op.arg.add()
        axis_arg.name = MaceKeyword.mace_axis_str
        axis_arg.i = axis_value
L
liuqi 已提交
872
        del op.input[1]
873

874 875 876 877 878 879 880 881 882 883 884 885 886 887
    def convert_squeeze(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Squeeze.name

        axis_arg = op.arg.add()
        axis_arg.name = MaceKeyword.mace_axis_str
        try:
            axis_value = tf_op.get_attr('squeeze_dims')
        except ValueError:
            try:
                axis_value = tf_op.get_attr('axis')
            except ValueError:
                axis_value = []
        axis_arg.ints.extend(axis_value)
888

李寅 已提交
889
    def convert_transpose(self, tf_op):
890 891 892
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Transpose.name

李寅 已提交
893 894 895
        perm = tf_op.inputs[1].eval().astype(np.int32)
        ordered_perm = np.sort(perm)

896 897 898 899 900 901 902 903
        if np.array_equal(perm, ordered_perm):
            op.type = MaceOp.Identity.name
            del op.input[1:]
            self._skip_tensor.add(tf_op.inputs[1].name)
        else:
            dims_arg = op.arg.add()
            dims_arg.name = MaceKeyword.mace_dims_str
            dims_arg.ints.extend(perm)
李寅 已提交
904

905
    def convert_reduce(self, tf_op):
906 907 908
        op = self.convert_general_op(tf_op)
        del op.input[1:]

L
liutuo 已提交
909 910 911 912
        op.type = MaceOp.Reduce.name

        reduce_type_arg = op.arg.add()
        reduce_type_arg.name = MaceKeyword.mace_reduce_type_str
913
        reduce_type_arg.i = self.reduce_math_type[tf_op.type].value
L
liutuo 已提交
914

L
liutuo 已提交
915 916
        axis_arg = op.arg.add()
        axis_arg.name = MaceKeyword.mace_axis_str
L
liutuo 已提交
917 918 919 920 921 922 923 924 925 926
        if len(tf_op.inputs) > 1:
            reduce_dims = tf_op.inputs[1].eval()
        else:
            try:
                reduce_dims = tf_op.get_attr('axis')
            except ValueError:
                try:
                    reduce_dims = tf_op.get_attr('reduction_indices')
                except ValueError:
                    reduce_dims = []
927
        if isinstance(reduce_dims, (np.ndarray, list)):
L
liutuo 已提交
928 929 930
            axis_arg.ints.extend(reduce_dims)
        else:
            axis_arg.ints.append(reduce_dims)
L
liutuo 已提交
931 932
        keep_dims_arg = op.arg.add()
        keep_dims_arg.name = MaceKeyword.mace_keepdims_str
933
        try:
L
liutuo 已提交
934
            keep_dims = tf_op.get_attr('keepdims')
935
        except ValueError:
L
liutuo 已提交
936 937 938 939 940
            try:
                keep_dims = tf_op.get_attr('keep_dims')
            except ValueError:
                keep_dims = 0
        keep_dims_arg.i = keep_dims
941 942

        self._skip_tensor.add(tf_op.inputs[1].name)
943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982

    def convert_gather(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Gather.name

        if len(tf_op.inputs) >= 3:
            axis_arg = op.arg.add()
            axis_arg.name = MaceKeyword.mace_axis_str
            axis_arg.i = tf_op.inputs[2].eval()

    def convert_stridedslice(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.StridedSlice.name

        begin_mask_arg = op.arg.add()
        begin_mask_arg.name = MaceKeyword.mace_begin_mask_str
        begin_mask_arg.i = tf_op.get_attr(MaceKeyword.mace_begin_mask_str)

        end_mask_arg = op.arg.add()
        end_mask_arg.name = MaceKeyword.mace_end_mask_str
        end_mask_arg.i = tf_op.get_attr(MaceKeyword.mace_end_mask_str)

        ellipsis_mask_arg = op.arg.add()
        ellipsis_mask_arg.name = MaceKeyword.mace_ellipsis_mask_str
        ellipsis_mask_arg.i = tf_op.get_attr(
            MaceKeyword.mace_ellipsis_mask_str)

        new_axis_mask_arg = op.arg.add()
        new_axis_mask_arg.name = MaceKeyword.mace_new_axis_mask_str
        new_axis_mask_arg.i = tf_op.get_attr(
            MaceKeyword.mace_new_axis_mask_str)

        shrink_axis_mask_arg = op.arg.add()
        shrink_axis_mask_arg.name = MaceKeyword.mace_shrink_axis_mask_str
        shrink_axis_mask_arg.i = tf_op.get_attr(
            MaceKeyword.mace_shrink_axis_mask_str)

    def convert_slice(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.StridedSlice.name
李寅 已提交
983 984 985
        arg = op.arg.add()
        arg.name = 'slice'
        arg.i = 1
986

987 988 989 990
    def convert_reverse(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Reverse.name

991 992 993 994 995 996 997 998 999 1000
    def convert_stack(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Stack.name

        axis_arg = op.arg.add()
        axis_arg.name = MaceKeyword.mace_axis_str
        try:
            axis_arg.i = tf_op.get_attr(MaceKeyword.mace_axis_str)
        except ValueError:
            axis_arg.i = 0
李寅 已提交
1001

Y
yejianwu 已提交
1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012
    def convert_unstack(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Unstack.name

        axis_arg = op.arg.add()
        axis_arg.name = MaceKeyword.mace_axis_str
        try:
            axis_arg.i = tf_op.get_attr(MaceKeyword.mace_axis_str)
        except ValueError:
            axis_arg.i = 0

李寅 已提交
1013 1014 1015 1016 1017 1018 1019
    def convert_cast(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Cast.name

        try:
            dtype = tf_op.get_attr('DstT')
            if dtype == tf.int32:
李寅 已提交
1020
                op.output_type.extend([mace_pb2.DT_INT32])
李寅 已提交
1021
            elif dtype == tf.float32:
L
liyin 已提交
1022
                op.output_type.extend([mace_pb2.DT_FLOAT])
李寅 已提交
1023 1024 1025
            else:
                mace_check(False, "data type %s not supported" % dtype)
        except ValueError:
L
liyin 已提交
1026
            op.output_type.extend([mace_pb2.DT_FLOAT])
李寅 已提交
1027 1028 1029 1030 1031

    def convert_argmax(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.ArgMax.name
        op.output_type.extend([mace_pb2.DT_INT32])
L
liuqi 已提交
1032 1033 1034

    def convert_split(self, tf_op):
        op = self.convert_general_op(tf_op)
L
liutuo 已提交
1035 1036 1037 1038 1039 1040 1041
        num_or_size_splits = tf_op.get_attr('num_split')
        if num_or_size_splits == 1:
            op.type = MaceOp.Identity.name
        else:
            op.type = MaceOp.Split.name
            axis = tf_op.inputs[0].eval().astype(np.int32)
            axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis
L
liuqi 已提交
1042

L
liutuo 已提交
1043 1044 1045
            axis_arg = op.arg.add()
            axis_arg.name = MaceKeyword.mace_axis_str
            axis_arg.i = axis
Y
yejianwu 已提交
1046

L
liutuo 已提交
1047 1048 1049 1050
            num_split_arg = op.arg.add()
            num_split_arg.name = MaceKeyword.mace_num_split_str
            num_split_arg.i = num_or_size_splits
        del op.input[0]
L
liuqi 已提交
1051
        self._skip_tensor.add(tf_op.inputs[0].name)
李寅 已提交
1052 1053 1054 1055 1056 1057 1058

    def convert_fake_quantize(self, tf_op):
        op = self.convert_general_op(tf_op)
        min_arg = op.arg.add()
        min_arg.name = 'min'
        max_arg = op.arg.add()
        max_arg.name = 'max'
1059 1060 1061 1062 1063 1064
        if tf_op.type == TFOpType.FakeQuantWithMinMaxVars.name:
            min_arg.f = tf_op.inputs[1].eval()
            max_arg.f = tf_op.inputs[2].eval()
        elif tf_op.type == TFOpType.FakeQuantWithMinMaxArgs.name:
            min_arg.f = float(tf_op.get_attr('min'))
            max_arg.f = float(tf_op.get_attr('max'))
李寅 已提交
1065 1066 1067 1068 1069 1070 1071
        narrow_range_arg = op.arg.add()
        narrow_range_arg.name = 'narrow_range'
        narrow_range_arg.i = int(tf_op.get_attr('narrow_range'))
        num_bits_arg = op.arg.add()
        num_bits_arg.name = 'num_bits'
        num_bits_arg.i = int(tf_op.get_attr('num_bits'))

1072 1073 1074
        if tf_op.type == TFOpType.FakeQuantWithMinMaxVars.name:
            self._skip_tensor.add(tf_op.inputs[1].name)
            self._skip_tensor.add(tf_op.inputs[2].name)
1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094

    def convert_cumsum(self, tf_op):
        op = self.convert_general_op(tf_op)
        op.type = MaceOp.Cumsum.name

        axis = tf_op.inputs[1].eval().astype(np.int32)
        axis_arg = op.arg.add()
        axis_arg.name = MaceKeyword.mace_axis_str
        axis_arg.i = axis
        del op.input[1]

        exclusive = tf_op.get_attr('exclusive')
        exclusive_arg = op.arg.add()
        exclusive_arg.name = MaceKeyword.mace_exclusive_str
        exclusive_arg.i = int(exclusive)

        reverse = tf_op.get_attr('reverse')
        reverse_arg = op.arg.add()
        reverse_arg.name = MaceKeyword.mace_reverse_str
        reverse_arg.i = int(reverse)