base_converter.py 19.3 KB
Newer Older
L
Liangliang He 已提交
1
# Copyright 2018 The MACE Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


16 17
from enum import Enum

L
liyin 已提交
18 19 20 21
from py_proto import mace_pb2

from utils.config_parser import DataFormat
from utils.config_parser import DeviceType
22 23


L
liutuo 已提交
24 25
# SAME_LOWER: if the amount of paddings to be added is odd,
# it will add the extra data to the right or bottom
26 27 28 29
class PaddingMode(Enum):
    VALID = 0
    SAME = 1
    FULL = 2
L
liutuo 已提交
30 31
    SAME_LOWER = 3
    NA = 4
32 33 34 35 36 37 38


class PoolingType(Enum):
    AVG = 1
    MAX = 2


L
liutuo 已提交
39 40 41 42 43
class RoundMode(Enum):
    FLOOR = 0
    CEIL = 1


44 45 46 47 48 49 50
class ActivationType(Enum):
    NOOP = 0
    RELU = 1
    RELUX = 2
    PRELU = 3
    TANH = 4
    SIGMOID = 5
L
liutuo 已提交
51
    LEAKYRELU = 6
52
    RELU6 = 7
53 54 55 56 57 58 59 60 61 62 63 64 65


class EltwiseType(Enum):
    SUM = 0
    SUB = 1
    PROD = 2
    DIV = 3
    MIN = 4
    MAX = 5
    NEG = 6
    ABS = 7
    SQR_DIFF = 8
    POW = 9
李寅 已提交
66
    EQUAL = 10
W
w-adamski 已提交
67
    FLOOR_DIV = 11
L
liutuo 已提交
68
    CLIP = 12
69
    SIGN = 13
70 71


L
liutuo 已提交
72 73 74 75 76
class ReduceType(Enum):
    MEAN = 0
    MIN = 1
    MAX = 2
    PROD = 3
L
liutuo 已提交
77
    SUM = 4
L
liutuo 已提交
78 79


80 81 82 83 84 85
class PadType(Enum):
    CONSTANT = 0
    REFLECT = 1
    SYMMETRIC = 2


L
liutuo 已提交
86 87 88
class FrameworkType(Enum):
    TENSORFLOW = 0
    CAFFE = 1
L
liutuo 已提交
89
    ONNX = 2
L
liutuo 已提交
90 91


92 93 94
MaceSupportedOps = [
    'Activation',
    'AddN',
李寅 已提交
95
    'ArgMax',
96 97 98
    'BatchNorm',
    'BatchToSpaceND',
    'BiasAdd',
李寅 已提交
99
    'Cast',
100 101 102
    'ChannelShuffle',
    'Concat',
    'Conv2D',
103
    'Crop',
104 105 106
    'Deconv2D',
    'DepthToSpace',
    'DepthwiseConv2d',
L
liutuo 已提交
107
    'DepthwiseDeconv2d',
108 109
    'Dequantize',
    'Eltwise',
110
    'ExpandDims',
L
liutuo 已提交
111
    'ExtractPooling',
Y
yejianwu 已提交
112
    'Fill',
113
    'FullyConnected',
114 115
    'Gather',
    'Identity',
L
liutuo 已提交
116
    'IfDefined',
L
liutuo 已提交
117
    'InferConv2dShape',
L
liutuo 已提交
118
    'KaldiBatchNorm',
119
    'LocalResponseNorm',
Y
yejianwu 已提交
120
    'LSTMCell',
L
liutuo 已提交
121 122
    'LstmNonlinear',
    'DynamicLSTM',
123
    'MatMul',
W
Wiktor Adamski 已提交
124
    'OneHot',
125
    'Pad',
L
liutuo 已提交
126
    'PadContext',
127
    'PNorm',
128
    'Pooling',
L
lichao18 已提交
129
    'PriorBox',
130 131
    'Proposal',
    'Quantize',
L
liutuo 已提交
132
    'Reduce',
L
liutuo 已提交
133
    'ReplaceIndex',
134
    'Reshape',
赵奇可 已提交
135
    'ResizeBicubic',
136
    'ResizeBilinear',
L
lichao18 已提交
137
    'ResizeNearestNeighbor',
138
    'Reverse',
139
    'ScalarMath',
140
    'Select',
141 142
    'Slice',
    'Splice',
Y
yejianwu 已提交
143
    'Split',
144 145 146
    'Shape',
    'Squeeze',
    'Stack',
Y
yejianwu 已提交
147
    'Unstack',
148
    'Unsqueeze',
149
    'StridedSlice',
150 151 152
    'Softmax',
    'SpaceToBatchND',
    'SpaceToDepth',
153
    'SqrDiffMean',
L
liutuo 已提交
154
    'Subsample',
155 156
    'SumGroup',
    'TargetRMSNorm',
157
    'Transpose',
158
    'Cumsum',
叶剑武 已提交
159
    'Tile',
160 161
    'LpNorm',
    'MVNorm',
162 163 164 165
]

MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str)

166 167 168 169 170 171 172 173 174 175 176 177 178
MaceFixedDataFormatOps = [MaceOp.BatchNorm,
                          MaceOp.BatchToSpaceND,
                          MaceOp.Conv2D,
                          MaceOp.Deconv2D,
                          MaceOp.DepthToSpace,
                          MaceOp.DepthwiseConv2d,
                          MaceOp.DepthwiseDeconv2d,
                          MaceOp.FullyConnected,
                          MaceOp.Pooling,
                          MaceOp.ResizeBicubic,
                          MaceOp.ResizeBilinear,
                          MaceOp.ResizeNearestNeighbor,
                          MaceOp.SpaceToBatchND,
179 180 181
                          MaceOp.SpaceToDepth,
                          MaceOp.LpNorm,
                          MaceOp.MVNorm]
182 183 184 185 186 187 188 189 190 191

MaceTransposableDataFormatOps = [MaceOp.Activation,
                                 MaceOp.AddN,
                                 MaceOp.BiasAdd,
                                 MaceOp.ChannelShuffle,
                                 MaceOp.Concat,
                                 MaceOp.Crop,
                                 MaceOp.Eltwise,
                                 MaceOp.Pad,
                                 MaceOp.Reduce,
192
                                 MaceOp.Reshape,
193 194
                                 MaceOp.Softmax,
                                 MaceOp.Split,
Y
yejianwu 已提交
195
                                 MaceOp.Squeeze,
叶剑武 已提交
196 197
                                 MaceOp.SqrDiffMean,
                                 MaceOp.Tile]
198

199 200 201 202 203 204 205 206

class MaceKeyword(object):
    # node related str
    mace_input_node_name = 'mace_input_node'
    mace_output_node_name = 'mace_output_node'
    mace_buffer_type = 'buffer_type'
    # arg related str
    mace_padding_str = 'padding'
207
    mace_padding_type_str = 'padding'
208 209 210 211 212 213 214
    mace_padding_values_str = 'padding_values'
    mace_strides_str = 'strides'
    mace_dilations_str = 'dilations'
    mace_pooling_type_str = 'pooling_type'
    mace_global_pooling_str = 'global_pooling'
    mace_kernel_str = 'kernels'
    mace_data_format_str = 'data_format'
215
    mace_has_data_format_str = 'has_data_format'
216 217 218 219
    mace_filter_format_str = 'filter_format'
    mace_element_type_str = 'type'
    mace_activation_type_str = 'activation'
    mace_activation_max_limit_str = 'max_limit'
Y
yejianwu 已提交
220
    mace_activation_leakyrelu_coefficient_str = 'leakyrelu_coefficient'
221 222 223 224
    mace_resize_size_str = 'size'
    mace_batch_to_space_crops_str = 'crops'
    mace_paddings_str = 'paddings'
    mace_align_corners_str = 'align_corners'
225 226
    mace_height_scale_str = 'height_scale'
    mace_width_scale_str = 'width_scale'
227 228 229
    mace_space_batch_block_shape_str = 'block_shape'
    mace_space_depth_block_size_str = 'block_size'
    mace_constant_value_str = 'constant_value'
L
lichao18 已提交
230
    mace_dim_str = 'dim'
231 232
    mace_dims_str = 'dims'
    mace_axis_str = 'axis'
L
lichao18 已提交
233 234
    mace_end_axis_str = 'end_axis'
    mace_num_axes_str = 'num_axes'
Y
yejianwu 已提交
235
    mace_num_split_str = 'num_split'
L
liutuo 已提交
236
    mace_keepdims_str = 'keepdims'
237 238
    mace_shape_str = 'shape'
    mace_winograd_filter_transformed = 'is_filter_transformed'
239
    mace_device = 'device'
240
    mace_scalar_input_str = 'scalar_input'
L
liutuo 已提交
241
    mace_wino_block_size = 'wino_block_size'
L
liutuo 已提交
242
    mace_output_shape_str = 'output_shape'
243 244 245 246 247 248 249
    mace_begin_mask_str = 'begin_mask'
    mace_end_mask_str = 'end_mask'
    mace_ellipsis_mask_str = 'ellipsis_mask'
    mace_new_axis_mask_str = 'new_axis_mask'
    mace_shrink_axis_mask_str = 'shrink_axis_mask'
    mace_transpose_a_str = 'transpose_a'
    mace_transpose_b_str = 'transpose_b'
250
    mace_op_data_type_str = 'T'
251
    mace_offset_str = 'offset'
252
    mace_opencl_max_image_size = "opencl_max_image_size"
253 254
    mace_seperate_buffer_str = 'seperate_buffer'
    mace_scalar_input_index_str = 'scalar_input_index'
255
    mace_opencl_mem_type = "opencl_mem_type"
L
liutuo 已提交
256
    mace_framework_type_str = "framework_type"
L
liutuo 已提交
257
    mace_group_str = "group"
258 259
    mace_wino_arg_str = "wino_block_size"
    mace_quantize_flag_arg_str = "quantize_flag"
L
liutuo 已提交
260 261 262
    mace_epsilon_str = 'epsilon'
    mace_reduce_type_str = 'reduce_type'
    mace_argmin_str = 'argmin'
263 264
    mace_out_val_str = 'out_val'
    mace_top_k_str = 'top_k'
L
liutuo 已提交
265
    mace_round_mode_str = 'round_mode'
L
lichao18 已提交
266 267 268 269 270 271 272 273
    mace_min_size_str = 'min_size'
    mace_max_size_str = 'max_size'
    mace_aspect_ratio_str = 'aspect_ratio'
    mace_flip_str = 'flip'
    mace_clip_str = 'clip'
    mace_variance_str = 'variance'
    mace_step_h_str = 'step_h'
    mace_step_w_str = 'step_w'
B
Bin Li 已提交
274
    mace_find_range_every_time = 'find_range_every_time'
275
    mace_non_zero = 'non_zero'
276
    mace_pad_type_str = 'pad_type'
277 278
    mace_exclusive_str = 'exclusive'
    mace_reverse_str = 'reverse'
279
    mace_const_data_num_arg_str = 'const_data_num'
L
liutuo 已提交
280
    mace_coeff_str = 'coeff'
L
liutuo 已提交
281 282
    mace_input_indexes_str = 'input_indexes'
    mace_output_indexes_str = 'output_indexes'
283 284 285
    mace_p_str = 'p'
    mace_nor_var_str = 'normalize_variance'
    mace_across_ch_str = 'across_channels'
286 287 288


class TransformerRule(Enum):
289
    REMOVE_USELESS_OP = 1
李寅 已提交
290 291 292 293 294 295
    TRANSFORM_GLOBAL_POOLING = 2
    FOLD_RESHAPE = 3
    TRANSFORM_MATMUL_TO_FC = 4
    FOLD_BATCHNORM = 5
    FOLD_CONV_AND_BN = 6
    FOLD_DEPTHWISE_CONV_AND_BN = 7
296
    ADD_WINOGRAD_ARG = 8
李寅 已提交
297 298
    TRANSFORM_ADD_TO_BIASADD = 9
    FOLD_BIASADD = 10
299 300 301 302 303 304
    FLATTEN_ATROUS_CONV = 11
    FOLD_ACTIVATION = 12
    TRANSPOSE_FILTERS = 13
    RESHAPE_FC_WEIGHT = 14
    TRANSPOSE_DATA_FORMAT = 15
    TRANSFORM_GLOBAL_CONV_TO_FC = 16
305
    ADD_BUFFER_TRANSFORM = 17
李寅 已提交
306
    ADD_DEVICE = 18
307
    SORT_BY_EXECUTION = 19
308
    ADD_IN_OUT_TENSOR_INFO = 20
李寅 已提交
309
    ADD_MACE_INPUT_AND_OUTPUT_NODES = 21
310
    UPDATE_FLOAT_OP_DATA_TYPE = 22
李寅 已提交
311 312 313
    QUANTIZE_NODES = 23
    ADD_QUANTIZE_TENSOR_RANGE = 24
    QUANTIZE_WEIGHTS = 25
314 315
    TRANSFORM_LSTMCELL_ZEROSTATE = 26
    TRANSFORM_BASIC_LSTMCELL = 27
李寅 已提交
316 317
    TRANSFORM_FAKE_QUANTIZE = 28
    CHECK_QUANTIZE_INFO = 29
B
Bin Li 已提交
318
    REARRANGE_BATCH_TO_SPACE = 30
319
    ADD_OPENCL_INFORMATIONS = 31
L
liutuo 已提交
320
    FOLD_DECONV_AND_BN = 32
321
    FOLD_SQRDIFF_MEAN = 33
322
    TRANSPOSE_MATMUL_WEIGHT = 34
李寅 已提交
323
    FOLD_EMBEDDING_LOOKUP = 35
L
luxuhui 已提交
324
    TRANSPOSE_RESHAPE_AND_FLATTEN = 36
Y
yejianwu 已提交
325
    FOLD_FC_RESHAPE = 37
李寅 已提交
326
    TRANSFORM_CHANNEL_SHUFFLE = 38
327
    UPDATE_DATA_FORMAT = 39
328
    QUANTIZE_SPECIFIC_OPS_ONLY = 40
329
    FP16_MATMUL_WEIGHT = 41
Y
yulianfei 已提交
330
    FP16_GATHER_WEIGHT = 42
B
Bin Li 已提交
331
    QUANTIZE_LARGE_WEIGHTS = 43
332
    TRANSPOSE_SHAPE_TENSOR_TO_PARAM = 44
333 334 335 336 337 338 339 340 341 342 343 344 345 346


class ConverterInterface(object):
    """Base class for converting external models to mace models."""

    def run(self):
        raise NotImplementedError('run')


class NodeInfo(object):
    """A class for describing node information"""

    def __init__(self):
        self._name = None
L
liuqi 已提交
347
        self._data_type = mace_pb2.DT_FLOAT
348
        self._shape = []
349
        self._data_format = DataFormat.NHWC
李寅 已提交
350
        self._range = [-1.0, 1.0]
351 352 353 354 355

    @property
    def name(self):
        return self._name

L
liuqi 已提交
356 357 358 359
    @property
    def data_type(self):
        return self._data_type

360 361 362 363
    @property
    def shape(self):
        return self._shape

364 365 366 367
    @property
    def data_format(self):
        return self._data_format

李寅 已提交
368 369 370 371
    @property
    def range(self):
        return self._range

372 373 374 375
    @name.setter
    def name(self, name):
        self._name = name

L
liuqi 已提交
376 377 378 379
    @data_type.setter
    def data_type(self, data_type):
        self._data_type = data_type

380 381 382 383
    @shape.setter
    def shape(self, shape):
        self._shape = shape

384 385 386 387
    @data_format.setter
    def data_format(self, data_format):
        self._data_format = data_format

李寅 已提交
388 389 390 391
    @range.setter
    def range(self, range):
        self._range = range

392 393 394 395 396 397 398
    def __str__(self):
        return '%s %s' % (self._name, str(self._shape))


class ConverterOption(object):
    """A class for specifying options passed to converter tool"""

李寅 已提交
399
    def __init__(self):
400 401
        self._input_nodes = {}
        self._output_nodes = {}
B
Bin Li 已提交
402
        self._check_nodes = {}
403
        self._data_type = mace_pb2.DT_FLOAT
404
        self._device = DeviceType.CPU.value
405
        self._winograd = 0
李寅 已提交
406
        self._quantize = False
B
Bin Li 已提交
407
        self._quantize_large_weights = False
李寅 已提交
408
        self._quantize_range_file = ""
409
        self._change_concat_ranges = False
李寅 已提交
410
        self._transformer_option = None
411
        self._cl_mem_type = "image"
B
Bin Li 已提交
412
        self._quantize_stat = False
413 414 415 416 417 418 419 420 421

    @property
    def input_nodes(self):
        return self._input_nodes

    @property
    def output_nodes(self):
        return self._output_nodes

B
Bin Li 已提交
422 423 424 425
    @property
    def check_nodes(self):
        return self._check_nodes

426 427 428 429 430 431 432 433 434
    @property
    def data_type(self):
        return self._data_type

    @property
    def device(self):
        return self._device

    @property
435 436
    def winograd(self):
        return self._winograd
437

李寅 已提交
438 439 440 441
    @property
    def quantize(self):
        return self._quantize

B
Bin Li 已提交
442 443 444 445
    @property
    def quantize_large_weights(self):
        return self._quantize_large_weights

446 447 448 449
    @property
    def change_concat_ranges(self):
        return self._change_concat_ranges

李寅 已提交
450 451 452 453
    @property
    def quantize_range_file(self):
        return self._quantize_range_file

454 455 456 457
    @property
    def transformer_option(self):
        return self._transformer_option

458 459 460 461
    @property
    def cl_mem_type(self):
        return self._cl_mem_type

B
Bin Li 已提交
462 463 464 465
    @property
    def quantize_stat(self):
        return self._quantize_stat

466 467
    @input_nodes.setter
    def input_nodes(self, input_nodes):
B
Bin Li 已提交
468
        for node in input_nodes.values():
469 470 471 472 473 474 475
            self._input_nodes[node.name] = node

    def add_input_node(self, input_node):
        self._input_nodes[input_node.name] = input_node

    @output_nodes.setter
    def output_nodes(self, output_nodes):
B
Bin Li 已提交
476
        for node in output_nodes.values():
477 478 479 480 481
            self.output_nodes[node.name] = node

    def add_output_node(self, output_node):
        self._output_nodes[output_node.name] = output_node

B
Bin Li 已提交
482 483
    @check_nodes.setter
    def check_nodes(self, check_nodes):
B
Bin Li 已提交
484
        for node in check_nodes.values():
B
Bin Li 已提交
485 486 487 488 489
            self.check_nodes[node.name] = node

    def add_check_node(self, check_node):
        self._check_nodes[check_node.name] = check_node

490 491 492 493 494 495 496 497
    @data_type.setter
    def data_type(self, data_type):
        self._data_type = data_type

    @device.setter
    def device(self, device):
        self._device = device

498 499 500
    @winograd.setter
    def winograd(self, winograd):
        self._winograd = winograd
501

李寅 已提交
502 503 504 505
    @quantize.setter
    def quantize(self, quantize):
        self._quantize = quantize

B
Bin Li 已提交
506 507 508 509
    @quantize_large_weights.setter
    def quantize_large_weights(self, quantize_large_weights):
        self._quantize_large_weights = quantize_large_weights

李寅 已提交
510 511 512 513
    @quantize_range_file.setter
    def quantize_range_file(self, quantize_range_file):
        self._quantize_range_file = quantize_range_file

514 515 516 517
    @change_concat_ranges.setter
    def change_concat_ranges(self, change_concat_ranges):
        self._change_concat_ranges = change_concat_ranges

李寅 已提交
518 519 520 521
    @transformer_option.setter
    def transformer_option(self, transformer_option):
        self._transformer_option = transformer_option

522 523 524 525
    @cl_mem_type.setter
    def cl_mem_type(self, cl_mem_type):
        self._cl_mem_type = cl_mem_type

B
Bin Li 已提交
526 527 528 529
    @quantize_stat.setter
    def quantize_stat(self, quantize_stat):
        self._quantize_stat = quantize_stat

530 531 532 533 534 535 536 537
    def disable_transpose_filters(self):
        if TransformerRule.TRANSPOSE_FILTERS in self._transformer_option:
            self._transformer_option.remove(TransformerRule.TRANSPOSE_FILTERS)

    def enable_transpose_filters(self):
        if TransformerRule.TRANSPOSE_FILTERS not in self._transformer_option:
            self._transformer_option.append(TransformerRule.TRANSPOSE_FILTERS)

李寅 已提交
538 539 540 541 542
    def build(self):
        if self._transformer_option:
            self._transformer_option = [TransformerRule[transformer]
                                        for transformer in self._transformer_option]  # noqa
        else:
543 544
            self._transformer_option = [
                # Model structure related transformation
545
                TransformerRule.REMOVE_USELESS_OP,
李寅 已提交
546
                TransformerRule.TRANSFORM_FAKE_QUANTIZE,
547
                TransformerRule.REMOVE_USELESS_OP,
548
                TransformerRule.TRANSFORM_GLOBAL_POOLING,
Y
yejianwu 已提交
549 550
                TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE,
                TransformerRule.TRANSFORM_BASIC_LSTMCELL,
L
luxuhui 已提交
551
                TransformerRule.TRANSPOSE_RESHAPE_AND_FLATTEN,
552
                TransformerRule.TRANSPOSE_SHAPE_TENSOR_TO_PARAM,
553 554
                TransformerRule.FOLD_RESHAPE,
                TransformerRule.TRANSFORM_MATMUL_TO_FC,
555 556 557
                # For StoB -> conv -> BtoS -> BN pattern
                # Insert flatten_atrous_conv before fold_xxx_and_bn
                TransformerRule.FLATTEN_ATROUS_CONV,
558 559
                TransformerRule.FOLD_BATCHNORM,
                TransformerRule.FOLD_CONV_AND_BN,
L
liutuo 已提交
560
                TransformerRule.FOLD_DECONV_AND_BN,
561
                TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN,
562
                TransformerRule.TRANSFORM_ADD_TO_BIASADD,
B
Bin Li 已提交
563
                TransformerRule.REARRANGE_BATCH_TO_SPACE,
564
                TransformerRule.FOLD_BIASADD,
565 566
                TransformerRule.FLATTEN_ATROUS_CONV,
                TransformerRule.FOLD_ACTIVATION,
567
                TransformerRule.FOLD_SQRDIFF_MEAN,
568 569
                TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC,
                TransformerRule.RESHAPE_FC_WEIGHT,
Y
yejianwu 已提交
570
                TransformerRule.FOLD_FC_RESHAPE,
李寅 已提交
571
                TransformerRule.TRANSFORM_CHANNEL_SHUFFLE,
572 573
                # Model data format related transformation
                TransformerRule.TRANSPOSE_FILTERS,
574 575
                # Mace model structure related transformation
                TransformerRule.ADD_IN_OUT_TENSOR_INFO,
576
                TransformerRule.TRANSPOSE_MATMUL_WEIGHT,
577 578
                # Add winograd argument
                TransformerRule.ADD_WINOGRAD_ARG,
579 580 581
                # Data type related transformation
                TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE,
                # Transform finalization
582
                TransformerRule.ADD_OPENCL_INFORMATIONS,
李寅 已提交
583
                # for quantization entropy calibration use
584
                TransformerRule.SORT_BY_EXECUTION,
585 586
                # update the data format of ops
                TransformerRule.UPDATE_DATA_FORMAT,
587
                TransformerRule.TRANSPOSE_DATA_FORMAT,
李寅 已提交
588 589
                # Need to be put after SORT_BY_EXECUTION
                TransformerRule.ADD_QUANTIZE_TENSOR_RANGE,
590
            ]
B
Bin Li 已提交
591 592 593 594
            if self.quantize_large_weights:
                self._transformer_option = self._transformer_option + [
                    TransformerRule.QUANTIZE_LARGE_WEIGHTS
                ]
595
            if self._quantize:
李寅 已提交
596 597
                self._transformer_option = self._transformer_option + [
                    # need to be put after ADD_QUANTIZE_TENSOR_RANGE
李寅 已提交
598 599 600
                    TransformerRule.QUANTIZE_NODES,
                    TransformerRule.QUANTIZE_WEIGHTS,
                    TransformerRule.SORT_BY_EXECUTION,
李寅 已提交
601
                    TransformerRule.CHECK_QUANTIZE_INFO,
李寅 已提交
602 603
                ]

604 605 606 607 608 609 610 611 612

class ConverterUtil(object):
    @staticmethod
    def get_arg(op, arg_name):
        for arg in op.arg:
            if arg.name == arg_name:
                return arg
        return None

613 614 615 616 617 618 619 620 621 622
    @staticmethod
    def del_arg(op, arg_name):
        found_idx = -1
        for idx in range(len(op.arg)):
            if op.arg[idx].name == arg_name:
                found_idx = idx
                break
        if found_idx != -1:
            del op.arg[found_idx]

623 624 625 626 627 628
    @staticmethod
    def add_data_format_arg(op, data_format):
        data_format_arg = op.arg.add()
        data_format_arg.name = MaceKeyword.mace_data_format_str
        data_format_arg.i = data_format.value

李寅 已提交
629 630 631 632 633 634
    @staticmethod
    def add_data_type_arg(op, data_type):
        data_type_arg = op.arg.add()
        data_type_arg.name = MaceKeyword.mace_op_data_type_str
        data_type_arg.i = data_type

635 636 637 638 639 640 641 642 643
    @staticmethod
    def data_format(op):
        arg = ConverterUtil.get_arg(op, MaceKeyword.mace_data_format_str)
        if arg is None:
            return None
        elif arg.i == DataFormat.NHWC.value:
            return DataFormat.NHWC
        elif arg.i == DataFormat.NCHW.value:
            return DataFormat.NCHW
644 645
        elif arg.i == DataFormat.AUTO.value:
            return DataFormat.AUTO
646 647 648 649 650 651 652 653 654 655 656 657 658 659
        else:
            return None

    @staticmethod
    def set_filter_format(net, filter_format):
        arg = net.arg.add()
        arg.name = MaceKeyword.mace_filter_format_str
        arg.i = filter_format.value

    @staticmethod
    def filter_format(net):
        arg = ConverterUtil.get_arg(net, MaceKeyword.mace_filter_format_str)
        if arg is None:
            return None
660 661 662 663 664 665
        elif arg.i == DataFormat.HWIO.value:
            return DataFormat.HWIO
        elif arg.i == DataFormat.HWOI.value:
            return DataFormat.HWOI
        elif arg.i == DataFormat.OIHW.value:
            return DataFormat.OIHW
666 667
        else:
            return None