base_converter.py 19.4 KB
Newer Older
L
Liangliang He 已提交
1
# Copyright 2018 The MACE Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


16 17
from enum import Enum

L
liyin 已提交
18 19 20 21
from py_proto import mace_pb2

from utils.config_parser import DataFormat
from utils.config_parser import DeviceType
22 23


L
liutuo 已提交
24 25
# SAME_LOWER: if the amount of paddings to be added is odd,
# it will add the extra data to the right or bottom
26 27 28 29
class PaddingMode(Enum):
    VALID = 0
    SAME = 1
    FULL = 2
L
liutuo 已提交
30 31
    SAME_LOWER = 3
    NA = 4
32 33 34 35 36 37 38


class PoolingType(Enum):
    AVG = 1
    MAX = 2


L
liutuo 已提交
39 40 41 42 43
class RoundMode(Enum):
    FLOOR = 0
    CEIL = 1


44 45 46 47 48 49 50
class ActivationType(Enum):
    NOOP = 0
    RELU = 1
    RELUX = 2
    PRELU = 3
    TANH = 4
    SIGMOID = 5
L
liutuo 已提交
51
    LEAKYRELU = 6
52 53 54 55 56 57 58 59 60 61 62 63 64


class EltwiseType(Enum):
    SUM = 0
    SUB = 1
    PROD = 2
    DIV = 3
    MIN = 4
    MAX = 5
    NEG = 6
    ABS = 7
    SQR_DIFF = 8
    POW = 9
李寅 已提交
65
    EQUAL = 10
W
w-adamski 已提交
66
    FLOOR_DIV = 11
L
liutuo 已提交
67
    CLIP = 12
68
    SIGN = 13
69 70


L
liutuo 已提交
71 72 73 74 75
class ReduceType(Enum):
    MEAN = 0
    MIN = 1
    MAX = 2
    PROD = 3
L
liutuo 已提交
76
    SUM = 4
L
liutuo 已提交
77 78


79 80 81 82 83 84
class PadType(Enum):
    CONSTANT = 0
    REFLECT = 1
    SYMMETRIC = 2


L
liutuo 已提交
85 86 87
class FrameworkType(Enum):
    TENSORFLOW = 0
    CAFFE = 1
L
liutuo 已提交
88
    ONNX = 2
89
    MEGENGINE = 3
L
liutuo 已提交
90 91


92 93 94
MaceSupportedOps = [
    'Activation',
    'AddN',
李寅 已提交
95
    'ArgMax',
96 97 98
    'BatchNorm',
    'BatchToSpaceND',
    'BiasAdd',
李寅 已提交
99
    'Cast',
100 101 102
    'ChannelShuffle',
    'Concat',
    'Conv2D',
103
    'Crop',
L
luxuhui 已提交
104
    'Cumsum',
105 106 107
    'Deconv2D',
    'DepthToSpace',
    'DepthwiseConv2d',
L
liutuo 已提交
108
    'DepthwiseDeconv2d',
109 110
    'Dequantize',
    'Eltwise',
111
    'ExpandDims',
L
liutuo 已提交
112
    'ExtractPooling',
Y
yejianwu 已提交
113
    'Fill',
114
    'FullyConnected',
115
    'Gather',
L
luxuhui 已提交
116
    'GroupNorm',
117
    'Identity',
L
liutuo 已提交
118
    'IfDefined',
L
liutuo 已提交
119
    'InferConv2dShape',
L
liutuo 已提交
120
    'KaldiBatchNorm',
121
    'LocalResponseNorm',
L
luxuhui 已提交
122
    'LpNorm',
Y
yejianwu 已提交
123
    'LSTMCell',
L
liutuo 已提交
124 125
    'LstmNonlinear',
    'DynamicLSTM',
126
    'MatMul',
L
luxuhui 已提交
127
    'MVNorm',
W
Wiktor Adamski 已提交
128
    'OneHot',
129
    'Pad',
L
liutuo 已提交
130
    'PadContext',
131
    'PNorm',
132
    'Pooling',
L
lichao18 已提交
133
    'PriorBox',
134 135
    'Proposal',
    'Quantize',
L
liutuo 已提交
136
    'Reduce',
L
liutuo 已提交
137
    'ReplaceIndex',
138
    'Reshape',
赵奇可 已提交
139
    'ResizeBicubic',
140
    'ResizeBilinear',
L
lichao18 已提交
141
    'ResizeNearestNeighbor',
142
    'Reverse',
143
    'ScalarMath',
144
    'Select',
145 146
    'Slice',
    'Splice',
Y
yejianwu 已提交
147
    'Split',
148 149 150
    'Shape',
    'Squeeze',
    'Stack',
Y
yejianwu 已提交
151
    'Unstack',
152
    'Unsqueeze',
153
    'StridedSlice',
154 155 156
    'Softmax',
    'SpaceToBatchND',
    'SpaceToDepth',
157
    'SqrDiffMean',
L
liutuo 已提交
158
    'Subsample',
159 160
    'SumGroup',
    'TargetRMSNorm',
叶剑武 已提交
161
    'Tile',
L
luxuhui 已提交
162
    'Transpose',
163 164 165 166
]

MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str)

167 168 169 170 171 172 173 174 175 176 177 178 179
MaceFixedDataFormatOps = [MaceOp.BatchNorm,
                          MaceOp.BatchToSpaceND,
                          MaceOp.Conv2D,
                          MaceOp.Deconv2D,
                          MaceOp.DepthToSpace,
                          MaceOp.DepthwiseConv2d,
                          MaceOp.DepthwiseDeconv2d,
                          MaceOp.FullyConnected,
                          MaceOp.Pooling,
                          MaceOp.ResizeBicubic,
                          MaceOp.ResizeBilinear,
                          MaceOp.ResizeNearestNeighbor,
                          MaceOp.SpaceToBatchND,
180 181
                          MaceOp.SpaceToDepth,
                          MaceOp.LpNorm,
L
luxuhui 已提交
182 183
                          MaceOp.MVNorm,
                          MaceOp.GroupNorm]
184 185 186 187 188 189 190 191 192 193

MaceTransposableDataFormatOps = [MaceOp.Activation,
                                 MaceOp.AddN,
                                 MaceOp.BiasAdd,
                                 MaceOp.ChannelShuffle,
                                 MaceOp.Concat,
                                 MaceOp.Crop,
                                 MaceOp.Eltwise,
                                 MaceOp.Pad,
                                 MaceOp.Reduce,
194
                                 MaceOp.Reshape,
195 196
                                 MaceOp.Softmax,
                                 MaceOp.Split,
Y
yejianwu 已提交
197
                                 MaceOp.Squeeze,
叶剑武 已提交
198 199
                                 MaceOp.SqrDiffMean,
                                 MaceOp.Tile]
200

201 202 203 204 205 206 207 208

class MaceKeyword(object):
    # node related str
    mace_input_node_name = 'mace_input_node'
    mace_output_node_name = 'mace_output_node'
    mace_buffer_type = 'buffer_type'
    # arg related str
    mace_padding_str = 'padding'
209
    mace_padding_type_str = 'padding'
210 211 212 213 214 215 216
    mace_padding_values_str = 'padding_values'
    mace_strides_str = 'strides'
    mace_dilations_str = 'dilations'
    mace_pooling_type_str = 'pooling_type'
    mace_global_pooling_str = 'global_pooling'
    mace_kernel_str = 'kernels'
    mace_data_format_str = 'data_format'
217
    mace_has_data_format_str = 'has_data_format'
218 219 220 221
    mace_filter_format_str = 'filter_format'
    mace_element_type_str = 'type'
    mace_activation_type_str = 'activation'
    mace_activation_max_limit_str = 'max_limit'
Y
yejianwu 已提交
222
    mace_activation_leakyrelu_coefficient_str = 'leakyrelu_coefficient'
223 224 225 226
    mace_resize_size_str = 'size'
    mace_batch_to_space_crops_str = 'crops'
    mace_paddings_str = 'paddings'
    mace_align_corners_str = 'align_corners'
227 228
    mace_height_scale_str = 'height_scale'
    mace_width_scale_str = 'width_scale'
229
    mace_half_pixel_centers_str = 'half_pixel_centers'
230 231 232
    mace_space_batch_block_shape_str = 'block_shape'
    mace_space_depth_block_size_str = 'block_size'
    mace_constant_value_str = 'constant_value'
L
lichao18 已提交
233
    mace_dim_str = 'dim'
234 235
    mace_dims_str = 'dims'
    mace_axis_str = 'axis'
L
lichao18 已提交
236 237
    mace_end_axis_str = 'end_axis'
    mace_num_axes_str = 'num_axes'
Y
yejianwu 已提交
238
    mace_num_split_str = 'num_split'
L
liutuo 已提交
239
    mace_keepdims_str = 'keepdims'
240 241
    mace_shape_str = 'shape'
    mace_winograd_filter_transformed = 'is_filter_transformed'
242
    mace_device = 'device'
243
    mace_scalar_input_str = 'scalar_input'
L
liutuo 已提交
244
    mace_wino_block_size = 'wino_block_size'
L
liutuo 已提交
245
    mace_output_shape_str = 'output_shape'
246 247 248 249 250 251 252
    mace_begin_mask_str = 'begin_mask'
    mace_end_mask_str = 'end_mask'
    mace_ellipsis_mask_str = 'ellipsis_mask'
    mace_new_axis_mask_str = 'new_axis_mask'
    mace_shrink_axis_mask_str = 'shrink_axis_mask'
    mace_transpose_a_str = 'transpose_a'
    mace_transpose_b_str = 'transpose_b'
253
    mace_op_data_type_str = 'T'
254
    mace_offset_str = 'offset'
255
    mace_opencl_max_image_size = "opencl_max_image_size"
256 257
    mace_seperate_buffer_str = 'seperate_buffer'
    mace_scalar_input_index_str = 'scalar_input_index'
258
    mace_opencl_mem_type = "opencl_mem_type"
L
liutuo 已提交
259
    mace_framework_type_str = "framework_type"
L
liutuo 已提交
260
    mace_group_str = "group"
L
luxuhui 已提交
261
    mace_group_num_str = "group_num"
262 263
    mace_wino_arg_str = "wino_block_size"
    mace_quantize_flag_arg_str = "quantize_flag"
L
liutuo 已提交
264 265 266
    mace_epsilon_str = 'epsilon'
    mace_reduce_type_str = 'reduce_type'
    mace_argmin_str = 'argmin'
267 268
    mace_out_val_str = 'out_val'
    mace_top_k_str = 'top_k'
L
liutuo 已提交
269
    mace_round_mode_str = 'round_mode'
L
lichao18 已提交
270 271 272 273 274 275 276 277
    mace_min_size_str = 'min_size'
    mace_max_size_str = 'max_size'
    mace_aspect_ratio_str = 'aspect_ratio'
    mace_flip_str = 'flip'
    mace_clip_str = 'clip'
    mace_variance_str = 'variance'
    mace_step_h_str = 'step_h'
    mace_step_w_str = 'step_w'
B
Bin Li 已提交
278
    mace_find_range_every_time = 'find_range_every_time'
279
    mace_non_zero = 'non_zero'
280
    mace_pad_type_str = 'pad_type'
281 282
    mace_exclusive_str = 'exclusive'
    mace_reverse_str = 'reverse'
283
    mace_const_data_num_arg_str = 'const_data_num'
L
liutuo 已提交
284
    mace_coeff_str = 'coeff'
L
liutuo 已提交
285 286
    mace_input_indexes_str = 'input_indexes'
    mace_output_indexes_str = 'output_indexes'
287 288 289
    mace_p_str = 'p'
    mace_nor_var_str = 'normalize_variance'
    mace_across_ch_str = 'across_channels'
290 291 292


class TransformerRule(Enum):
293
    REMOVE_USELESS_OP = 1
李寅 已提交
294 295 296 297 298 299
    TRANSFORM_GLOBAL_POOLING = 2
    FOLD_RESHAPE = 3
    TRANSFORM_MATMUL_TO_FC = 4
    FOLD_BATCHNORM = 5
    FOLD_CONV_AND_BN = 6
    FOLD_DEPTHWISE_CONV_AND_BN = 7
300
    ADD_WINOGRAD_ARG = 8
李寅 已提交
301 302
    TRANSFORM_ADD_TO_BIASADD = 9
    FOLD_BIASADD = 10
303 304 305 306 307 308
    FLATTEN_ATROUS_CONV = 11
    FOLD_ACTIVATION = 12
    TRANSPOSE_FILTERS = 13
    RESHAPE_FC_WEIGHT = 14
    TRANSPOSE_DATA_FORMAT = 15
    TRANSFORM_GLOBAL_CONV_TO_FC = 16
309
    ADD_BUFFER_TRANSFORM = 17
李寅 已提交
310
    ADD_DEVICE = 18
311
    SORT_BY_EXECUTION = 19
312
    ADD_IN_OUT_TENSOR_INFO = 20
李寅 已提交
313
    ADD_MACE_INPUT_AND_OUTPUT_NODES = 21
314
    UPDATE_FLOAT_OP_DATA_TYPE = 22
李寅 已提交
315 316 317
    QUANTIZE_NODES = 23
    ADD_QUANTIZE_TENSOR_RANGE = 24
    QUANTIZE_WEIGHTS = 25
318 319
    TRANSFORM_LSTMCELL_ZEROSTATE = 26
    TRANSFORM_BASIC_LSTMCELL = 27
李寅 已提交
320 321
    TRANSFORM_FAKE_QUANTIZE = 28
    CHECK_QUANTIZE_INFO = 29
B
Bin Li 已提交
322
    REARRANGE_BATCH_TO_SPACE = 30
323
    ADD_OPENCL_INFORMATIONS = 31
L
liutuo 已提交
324
    FOLD_DECONV_AND_BN = 32
325
    FOLD_SQRDIFF_MEAN = 33
326
    TRANSPOSE_MATMUL_WEIGHT = 34
李寅 已提交
327
    FOLD_EMBEDDING_LOOKUP = 35
L
luxuhui 已提交
328
    TRANSPOSE_RESHAPE_AND_FLATTEN = 36
Y
yejianwu 已提交
329
    FOLD_FC_RESHAPE = 37
李寅 已提交
330
    TRANSFORM_CHANNEL_SHUFFLE = 38
331
    UPDATE_DATA_FORMAT = 39
332
    QUANTIZE_SPECIFIC_OPS_ONLY = 40
333
    FP16_MATMUL_WEIGHT = 41
Y
yulianfei 已提交
334
    FP16_GATHER_WEIGHT = 42
B
Bin Li 已提交
335
    QUANTIZE_LARGE_WEIGHTS = 43
336
    TRANSPOSE_SHAPE_TENSOR_TO_PARAM = 44
337 338 339 340 341 342 343 344 345 346 347 348 349 350


class ConverterInterface(object):
    """Base class for converting external models to mace models."""

    def run(self):
        raise NotImplementedError('run')


class NodeInfo(object):
    """A class for describing node information"""

    def __init__(self):
        self._name = None
L
liuqi 已提交
351
        self._data_type = mace_pb2.DT_FLOAT
352
        self._shape = []
353
        self._data_format = DataFormat.NHWC
李寅 已提交
354
        self._range = [-1.0, 1.0]
355 356 357 358 359

    @property
    def name(self):
        return self._name

L
liuqi 已提交
360 361 362 363
    @property
    def data_type(self):
        return self._data_type

364 365 366 367
    @property
    def shape(self):
        return self._shape

368 369 370 371
    @property
    def data_format(self):
        return self._data_format

李寅 已提交
372 373 374 375
    @property
    def range(self):
        return self._range

376 377 378 379
    @name.setter
    def name(self, name):
        self._name = name

L
liuqi 已提交
380 381 382 383
    @data_type.setter
    def data_type(self, data_type):
        self._data_type = data_type

384 385 386 387
    @shape.setter
    def shape(self, shape):
        self._shape = shape

388 389 390 391
    @data_format.setter
    def data_format(self, data_format):
        self._data_format = data_format

李寅 已提交
392 393 394 395
    @range.setter
    def range(self, range):
        self._range = range

396 397 398 399 400 401 402
    def __str__(self):
        return '%s %s' % (self._name, str(self._shape))


class ConverterOption(object):
    """A class for specifying options passed to converter tool"""

李寅 已提交
403
    def __init__(self):
404 405
        self._input_nodes = {}
        self._output_nodes = {}
B
Bin Li 已提交
406
        self._check_nodes = {}
407
        self._data_type = mace_pb2.DT_FLOAT
408
        self._device = DeviceType.CPU.value
409
        self._winograd = 0
李寅 已提交
410
        self._quantize = False
B
Bin Li 已提交
411
        self._quantize_large_weights = False
李寅 已提交
412
        self._quantize_range_file = ""
413
        self._change_concat_ranges = False
李寅 已提交
414
        self._transformer_option = None
415
        self._cl_mem_type = "image"
B
Bin Li 已提交
416
        self._quantize_stat = False
417 418 419 420 421 422 423 424 425

    @property
    def input_nodes(self):
        return self._input_nodes

    @property
    def output_nodes(self):
        return self._output_nodes

B
Bin Li 已提交
426 427 428 429
    @property
    def check_nodes(self):
        return self._check_nodes

430 431 432 433 434 435 436 437 438
    @property
    def data_type(self):
        return self._data_type

    @property
    def device(self):
        return self._device

    @property
439 440
    def winograd(self):
        return self._winograd
441

李寅 已提交
442 443 444 445
    @property
    def quantize(self):
        return self._quantize

B
Bin Li 已提交
446 447 448 449
    @property
    def quantize_large_weights(self):
        return self._quantize_large_weights

450 451 452 453
    @property
    def change_concat_ranges(self):
        return self._change_concat_ranges

李寅 已提交
454 455 456 457
    @property
    def quantize_range_file(self):
        return self._quantize_range_file

458 459 460 461
    @property
    def transformer_option(self):
        return self._transformer_option

462 463 464 465
    @property
    def cl_mem_type(self):
        return self._cl_mem_type

B
Bin Li 已提交
466 467 468 469
    @property
    def quantize_stat(self):
        return self._quantize_stat

470 471
    @input_nodes.setter
    def input_nodes(self, input_nodes):
B
Bin Li 已提交
472
        for node in input_nodes.values():
473 474 475 476 477 478 479
            self._input_nodes[node.name] = node

    def add_input_node(self, input_node):
        self._input_nodes[input_node.name] = input_node

    @output_nodes.setter
    def output_nodes(self, output_nodes):
B
Bin Li 已提交
480
        for node in output_nodes.values():
481 482 483 484 485
            self.output_nodes[node.name] = node

    def add_output_node(self, output_node):
        self._output_nodes[output_node.name] = output_node

B
Bin Li 已提交
486 487
    @check_nodes.setter
    def check_nodes(self, check_nodes):
B
Bin Li 已提交
488
        for node in check_nodes.values():
B
Bin Li 已提交
489 490 491 492 493
            self.check_nodes[node.name] = node

    def add_check_node(self, check_node):
        self._check_nodes[check_node.name] = check_node

494 495 496 497 498 499 500 501
    @data_type.setter
    def data_type(self, data_type):
        self._data_type = data_type

    @device.setter
    def device(self, device):
        self._device = device

502 503 504
    @winograd.setter
    def winograd(self, winograd):
        self._winograd = winograd
505

李寅 已提交
506 507 508 509
    @quantize.setter
    def quantize(self, quantize):
        self._quantize = quantize

B
Bin Li 已提交
510 511 512 513
    @quantize_large_weights.setter
    def quantize_large_weights(self, quantize_large_weights):
        self._quantize_large_weights = quantize_large_weights

李寅 已提交
514 515 516 517
    @quantize_range_file.setter
    def quantize_range_file(self, quantize_range_file):
        self._quantize_range_file = quantize_range_file

518 519 520 521
    @change_concat_ranges.setter
    def change_concat_ranges(self, change_concat_ranges):
        self._change_concat_ranges = change_concat_ranges

李寅 已提交
522 523 524 525
    @transformer_option.setter
    def transformer_option(self, transformer_option):
        self._transformer_option = transformer_option

526 527 528 529
    @cl_mem_type.setter
    def cl_mem_type(self, cl_mem_type):
        self._cl_mem_type = cl_mem_type

B
Bin Li 已提交
530 531 532 533
    @quantize_stat.setter
    def quantize_stat(self, quantize_stat):
        self._quantize_stat = quantize_stat

534 535 536 537 538 539 540 541
    def disable_transpose_filters(self):
        if TransformerRule.TRANSPOSE_FILTERS in self._transformer_option:
            self._transformer_option.remove(TransformerRule.TRANSPOSE_FILTERS)

    def enable_transpose_filters(self):
        if TransformerRule.TRANSPOSE_FILTERS not in self._transformer_option:
            self._transformer_option.append(TransformerRule.TRANSPOSE_FILTERS)

李寅 已提交
542 543 544 545 546
    def build(self):
        if self._transformer_option:
            self._transformer_option = [TransformerRule[transformer]
                                        for transformer in self._transformer_option]  # noqa
        else:
547 548
            self._transformer_option = [
                # Model structure related transformation
549
                TransformerRule.REMOVE_USELESS_OP,
李寅 已提交
550
                TransformerRule.TRANSFORM_FAKE_QUANTIZE,
551
                TransformerRule.TRANSFORM_GLOBAL_POOLING,
Y
yejianwu 已提交
552 553
                TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE,
                TransformerRule.TRANSFORM_BASIC_LSTMCELL,
L
luxuhui 已提交
554
                TransformerRule.TRANSPOSE_RESHAPE_AND_FLATTEN,
555
                TransformerRule.TRANSPOSE_SHAPE_TENSOR_TO_PARAM,
556 557
                TransformerRule.FOLD_RESHAPE,
                TransformerRule.TRANSFORM_MATMUL_TO_FC,
558 559 560
                # For StoB -> conv -> BtoS -> BN pattern
                # Insert flatten_atrous_conv before fold_xxx_and_bn
                TransformerRule.FLATTEN_ATROUS_CONV,
561 562
                TransformerRule.FOLD_BATCHNORM,
                TransformerRule.FOLD_CONV_AND_BN,
L
liutuo 已提交
563
                TransformerRule.FOLD_DECONV_AND_BN,
564
                TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN,
565
                TransformerRule.TRANSFORM_ADD_TO_BIASADD,
B
Bin Li 已提交
566
                TransformerRule.REARRANGE_BATCH_TO_SPACE,
567
                TransformerRule.FOLD_BIASADD,
568 569
                TransformerRule.FLATTEN_ATROUS_CONV,
                TransformerRule.FOLD_ACTIVATION,
570
                TransformerRule.FOLD_SQRDIFF_MEAN,
571 572
                TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC,
                TransformerRule.RESHAPE_FC_WEIGHT,
Y
yejianwu 已提交
573
                TransformerRule.FOLD_FC_RESHAPE,
李寅 已提交
574
                TransformerRule.TRANSFORM_CHANNEL_SHUFFLE,
575 576
                # Model data format related transformation
                TransformerRule.TRANSPOSE_FILTERS,
577 578
                # Mace model structure related transformation
                TransformerRule.ADD_IN_OUT_TENSOR_INFO,
579
                TransformerRule.TRANSPOSE_MATMUL_WEIGHT,
580 581
                # Add winograd argument
                TransformerRule.ADD_WINOGRAD_ARG,
582 583 584
                # Data type related transformation
                TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE,
                # Transform finalization
585
                TransformerRule.ADD_OPENCL_INFORMATIONS,
李寅 已提交
586
                # for quantization entropy calibration use
587
                TransformerRule.SORT_BY_EXECUTION,
588 589
                # update the data format of ops
                TransformerRule.UPDATE_DATA_FORMAT,
590
                TransformerRule.TRANSPOSE_DATA_FORMAT,
李寅 已提交
591 592
                # Need to be put after SORT_BY_EXECUTION
                TransformerRule.ADD_QUANTIZE_TENSOR_RANGE,
593
            ]
B
Bin Li 已提交
594 595 596 597
            if self.quantize_large_weights:
                self._transformer_option = self._transformer_option + [
                    TransformerRule.QUANTIZE_LARGE_WEIGHTS
                ]
598
            if self._quantize:
李寅 已提交
599 600
                self._transformer_option = self._transformer_option + [
                    # need to be put after ADD_QUANTIZE_TENSOR_RANGE
李寅 已提交
601 602 603
                    TransformerRule.QUANTIZE_NODES,
                    TransformerRule.QUANTIZE_WEIGHTS,
                    TransformerRule.SORT_BY_EXECUTION,
李寅 已提交
604
                    TransformerRule.CHECK_QUANTIZE_INFO,
李寅 已提交
605 606
                ]

607 608 609 610 611 612 613 614 615

class ConverterUtil(object):
    @staticmethod
    def get_arg(op, arg_name):
        for arg in op.arg:
            if arg.name == arg_name:
                return arg
        return None

616 617 618 619 620 621 622 623 624 625
    @staticmethod
    def del_arg(op, arg_name):
        found_idx = -1
        for idx in range(len(op.arg)):
            if op.arg[idx].name == arg_name:
                found_idx = idx
                break
        if found_idx != -1:
            del op.arg[found_idx]

626 627 628 629 630 631
    @staticmethod
    def add_data_format_arg(op, data_format):
        data_format_arg = op.arg.add()
        data_format_arg.name = MaceKeyword.mace_data_format_str
        data_format_arg.i = data_format.value

李寅 已提交
632 633 634 635 636 637
    @staticmethod
    def add_data_type_arg(op, data_type):
        data_type_arg = op.arg.add()
        data_type_arg.name = MaceKeyword.mace_op_data_type_str
        data_type_arg.i = data_type

638 639 640 641 642 643 644 645 646
    @staticmethod
    def data_format(op):
        arg = ConverterUtil.get_arg(op, MaceKeyword.mace_data_format_str)
        if arg is None:
            return None
        elif arg.i == DataFormat.NHWC.value:
            return DataFormat.NHWC
        elif arg.i == DataFormat.NCHW.value:
            return DataFormat.NCHW
647 648
        elif arg.i == DataFormat.AUTO.value:
            return DataFormat.AUTO
649 650 651 652 653 654 655 656 657 658 659 660 661 662
        else:
            return None

    @staticmethod
    def set_filter_format(net, filter_format):
        arg = net.arg.add()
        arg.name = MaceKeyword.mace_filter_format_str
        arg.i = filter_format.value

    @staticmethod
    def filter_format(net):
        arg = ConverterUtil.get_arg(net, MaceKeyword.mace_filter_format_str)
        if arg is None:
            return None
663 664 665 666 667 668
        elif arg.i == DataFormat.HWIO.value:
            return DataFormat.HWIO
        elif arg.i == DataFormat.HWOI.value:
            return DataFormat.HWOI
        elif arg.i == DataFormat.OIHW.value:
            return DataFormat.OIHW
669 670
        else:
            return None