base_converter.py 20.4 KB
Newer Older
L
Liangliang He 已提交
1
# Copyright 2018 The MACE Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


16 17
from enum import Enum

L
liyin 已提交
18 19 20 21
from py_proto import mace_pb2

from utils.config_parser import DataFormat
from utils.config_parser import DeviceType
22 23


L
liutuo 已提交
24 25
# SAME_LOWER: if the amount of paddings to be added is odd,
# it will add the extra data to the right or bottom
26 27 28 29
class PaddingMode(Enum):
    VALID = 0
    SAME = 1
    FULL = 2
L
liutuo 已提交
30 31
    SAME_LOWER = 3
    NA = 4
32 33 34 35 36 37 38


class PoolingType(Enum):
    AVG = 1
    MAX = 2


L
liutuo 已提交
39 40 41 42 43
class RoundMode(Enum):
    FLOOR = 0
    CEIL = 1


44 45 46 47 48 49 50
class ActivationType(Enum):
    NOOP = 0
    RELU = 1
    RELUX = 2
    PRELU = 3
    TANH = 4
    SIGMOID = 5
L
liutuo 已提交
51
    LEAKYRELU = 6
52 53 54 55 56 57 58 59 60 61 62 63 64


class EltwiseType(Enum):
    SUM = 0
    SUB = 1
    PROD = 2
    DIV = 3
    MIN = 4
    MAX = 5
    NEG = 6
    ABS = 7
    SQR_DIFF = 8
    POW = 9
李寅 已提交
65
    EQUAL = 10
W
w-adamski 已提交
66
    FLOOR_DIV = 11
L
liutuo 已提交
67
    CLIP = 12
68
    SIGN = 13
69 70


L
liutuo 已提交
71 72 73 74 75
class ReduceType(Enum):
    MEAN = 0
    MIN = 1
    MAX = 2
    PROD = 3
L
liutuo 已提交
76
    SUM = 4
L
liutuo 已提交
77 78


79 80 81 82 83 84
class PadType(Enum):
    CONSTANT = 0
    REFLECT = 1
    SYMMETRIC = 2


L
liutuo 已提交
85 86 87
class FrameworkType(Enum):
    TENSORFLOW = 0
    CAFFE = 1
L
liutuo 已提交
88
    ONNX = 2
89
    MEGENGINE = 3
L
liutuo 已提交
90 91


92 93 94
MaceSupportedOps = [
    'Activation',
    'AddN',
李寅 已提交
95
    'ArgMax',
96 97 98
    'BatchNorm',
    'BatchToSpaceND',
    'BiasAdd',
李寅 已提交
99
    'Cast',
100 101 102
    'ChannelShuffle',
    'Concat',
    'Conv2D',
103
    'Crop',
L
luxuhui 已提交
104
    'Cumsum',
105 106 107
    'Deconv2D',
    'DepthToSpace',
    'DepthwiseConv2d',
L
liutuo 已提交
108
    'DepthwiseDeconv2d',
109 110
    'Dequantize',
    'Eltwise',
111
    'ExpandDims',
112
    'ExtractImagePatches',
L
liutuo 已提交
113
    'ExtractPooling',
Y
yejianwu 已提交
114
    'Fill',
115
    'FullyConnected',
116
    'Gather',
L
luxuhui 已提交
117
    'GroupNorm',
118
    'Identity',
L
liutuo 已提交
119
    'IfDefined',
L
liutuo 已提交
120
    'InferConv2dShape',
L
liutuo 已提交
121
    'KaldiBatchNorm',
122
    'LocalResponseNorm',
L
luxuhui 已提交
123
    'LpNorm',
Y
yejianwu 已提交
124
    'LSTMCell',
L
liutuo 已提交
125 126
    'LstmNonlinear',
    'DynamicLSTM',
127
    'MatMul',
L
luxuhui 已提交
128
    'MVNorm',
W
Wiktor Adamski 已提交
129
    'OneHot',
130
    'Pad',
L
liutuo 已提交
131
    'PadContext',
132
    'PNorm',
133
    'Pooling',
L
lichao18 已提交
134
    'PriorBox',
135 136
    'Proposal',
    'Quantize',
L
liutuo 已提交
137
    'Reduce',
L
liutuo 已提交
138
    'ReplaceIndex',
139
    'Reshape',
赵奇可 已提交
140
    'ResizeBicubic',
141
    'ResizeBilinear',
L
lichao18 已提交
142
    'ResizeNearestNeighbor',
143
    'Reverse',
144
    'ScalarMath',
145
    'Select',
146 147
    'Slice',
    'Splice',
Y
yejianwu 已提交
148
    'Split',
149 150 151
    'Shape',
    'Squeeze',
    'Stack',
Y
yejianwu 已提交
152
    'Unstack',
153
    'Unsqueeze',
154
    'StridedSlice',
155 156 157
    'Softmax',
    'SpaceToBatchND',
    'SpaceToDepth',
158
    'SqrDiffMean',
L
liutuo 已提交
159
    'Subsample',
160 161
    'SumGroup',
    'TargetRMSNorm',
叶剑武 已提交
162
    'Tile',
L
luxuhui 已提交
163
    'Transpose',
164 165 166 167
]

MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str)

168 169 170 171 172 173 174 175 176
MaceFixedDataFormatOps = [MaceOp.BatchNorm,
                          MaceOp.BatchToSpaceND,
                          MaceOp.Conv2D,
                          MaceOp.Deconv2D,
                          MaceOp.DepthToSpace,
                          MaceOp.DepthwiseConv2d,
                          MaceOp.DepthwiseDeconv2d,
                          MaceOp.FullyConnected,
                          MaceOp.Pooling,
177
                          MaceOp.ExtractImagePatches,
178 179 180 181
                          MaceOp.ResizeBicubic,
                          MaceOp.ResizeBilinear,
                          MaceOp.ResizeNearestNeighbor,
                          MaceOp.SpaceToBatchND,
182 183
                          MaceOp.SpaceToDepth,
                          MaceOp.LpNorm,
L
luxuhui 已提交
184 185
                          MaceOp.MVNorm,
                          MaceOp.GroupNorm]
186 187 188 189 190 191 192 193 194 195

MaceTransposableDataFormatOps = [MaceOp.Activation,
                                 MaceOp.AddN,
                                 MaceOp.BiasAdd,
                                 MaceOp.ChannelShuffle,
                                 MaceOp.Concat,
                                 MaceOp.Crop,
                                 MaceOp.Eltwise,
                                 MaceOp.Pad,
                                 MaceOp.Reduce,
196
                                 MaceOp.Reshape,
197 198
                                 MaceOp.Softmax,
                                 MaceOp.Split,
Y
yejianwu 已提交
199
                                 MaceOp.Squeeze,
叶剑武 已提交
200 201
                                 MaceOp.SqrDiffMean,
                                 MaceOp.Tile]
202

203 204 205 206 207 208 209 210

class MaceKeyword(object):
    # node related str
    mace_input_node_name = 'mace_input_node'
    mace_output_node_name = 'mace_output_node'
    mace_buffer_type = 'buffer_type'
    # arg related str
    mace_padding_str = 'padding'
211
    mace_padding_type_str = 'padding'
212 213 214 215 216 217 218
    mace_padding_values_str = 'padding_values'
    mace_strides_str = 'strides'
    mace_dilations_str = 'dilations'
    mace_pooling_type_str = 'pooling_type'
    mace_global_pooling_str = 'global_pooling'
    mace_kernel_str = 'kernels'
    mace_data_format_str = 'data_format'
219
    mace_has_data_format_str = 'has_data_format'
220 221 222 223
    mace_filter_format_str = 'filter_format'
    mace_element_type_str = 'type'
    mace_activation_type_str = 'activation'
    mace_activation_max_limit_str = 'max_limit'
Y
yejianwu 已提交
224
    mace_activation_leakyrelu_coefficient_str = 'leakyrelu_coefficient'
225 226 227 228
    mace_resize_size_str = 'size'
    mace_batch_to_space_crops_str = 'crops'
    mace_paddings_str = 'paddings'
    mace_align_corners_str = 'align_corners'
229 230
    mace_height_scale_str = 'height_scale'
    mace_width_scale_str = 'width_scale'
231
    mace_half_pixel_centers_str = 'half_pixel_centers'
232 233 234
    mace_space_batch_block_shape_str = 'block_shape'
    mace_space_depth_block_size_str = 'block_size'
    mace_constant_value_str = 'constant_value'
L
lichao18 已提交
235
    mace_dim_str = 'dim'
236 237
    mace_dims_str = 'dims'
    mace_axis_str = 'axis'
L
lichao18 已提交
238 239
    mace_end_axis_str = 'end_axis'
    mace_num_axes_str = 'num_axes'
Y
yejianwu 已提交
240
    mace_num_split_str = 'num_split'
L
luxuhui 已提交
241
    mace_size_splits_str = 'size_splits'
L
liutuo 已提交
242
    mace_keepdims_str = 'keepdims'
243 244
    mace_shape_str = 'shape'
    mace_winograd_filter_transformed = 'is_filter_transformed'
245
    mace_device = 'device'
246
    mace_scalar_input_str = 'scalar_input'
L
liutuo 已提交
247
    mace_wino_block_size = 'wino_block_size'
L
liutuo 已提交
248
    mace_output_shape_str = 'output_shape'
249 250 251 252 253 254 255
    mace_begin_mask_str = 'begin_mask'
    mace_end_mask_str = 'end_mask'
    mace_ellipsis_mask_str = 'ellipsis_mask'
    mace_new_axis_mask_str = 'new_axis_mask'
    mace_shrink_axis_mask_str = 'shrink_axis_mask'
    mace_transpose_a_str = 'transpose_a'
    mace_transpose_b_str = 'transpose_b'
256
    mace_op_data_type_str = 'T'
257
    mace_offset_str = 'offset'
258
    mace_opencl_max_image_size = "opencl_max_image_size"
259 260
    mace_seperate_buffer_str = 'seperate_buffer'
    mace_scalar_input_index_str = 'scalar_input_index'
261
    mace_opencl_mem_type = "opencl_mem_type"
L
liutuo 已提交
262
    mace_framework_type_str = "framework_type"
L
liutuo 已提交
263
    mace_group_str = "group"
L
luxuhui 已提交
264
    mace_group_num_str = "group_num"
265 266
    mace_wino_arg_str = "wino_block_size"
    mace_quantize_flag_arg_str = "quantize_flag"
L
liutuo 已提交
267 268 269
    mace_epsilon_str = 'epsilon'
    mace_reduce_type_str = 'reduce_type'
    mace_argmin_str = 'argmin'
270 271
    mace_out_val_str = 'out_val'
    mace_top_k_str = 'top_k'
L
liutuo 已提交
272
    mace_round_mode_str = 'round_mode'
L
lichao18 已提交
273 274 275 276 277 278 279 280
    mace_min_size_str = 'min_size'
    mace_max_size_str = 'max_size'
    mace_aspect_ratio_str = 'aspect_ratio'
    mace_flip_str = 'flip'
    mace_clip_str = 'clip'
    mace_variance_str = 'variance'
    mace_step_h_str = 'step_h'
    mace_step_w_str = 'step_w'
B
Bin Li 已提交
281
    mace_find_range_every_time = 'find_range_every_time'
282
    mace_non_zero = 'non_zero'
283
    mace_pad_type_str = 'pad_type'
284 285
    mace_exclusive_str = 'exclusive'
    mace_reverse_str = 'reverse'
286
    mace_const_data_num_arg_str = 'const_data_num'
L
liutuo 已提交
287
    mace_coeff_str = 'coeff'
L
liutuo 已提交
288 289
    mace_input_indexes_str = 'input_indexes'
    mace_output_indexes_str = 'output_indexes'
290 291 292
    mace_p_str = 'p'
    mace_nor_var_str = 'normalize_variance'
    mace_across_ch_str = 'across_channels'
293 294
    mace_apu_16bit_per_tensor = 'mace_apu_16bit_per_tensor'
    mace_apu_data_type_arg_str = 'apu_data_type'
295 296 297


class TransformerRule(Enum):
298
    REMOVE_USELESS_OP = 1
李寅 已提交
299 300 301 302 303 304
    TRANSFORM_GLOBAL_POOLING = 2
    FOLD_RESHAPE = 3
    TRANSFORM_MATMUL_TO_FC = 4
    FOLD_BATCHNORM = 5
    FOLD_CONV_AND_BN = 6
    FOLD_DEPTHWISE_CONV_AND_BN = 7
305
    ADD_WINOGRAD_ARG = 8
李寅 已提交
306 307
    TRANSFORM_ADD_TO_BIASADD = 9
    FOLD_BIASADD = 10
308 309 310 311 312 313
    FLATTEN_ATROUS_CONV = 11
    FOLD_ACTIVATION = 12
    TRANSPOSE_FILTERS = 13
    RESHAPE_FC_WEIGHT = 14
    TRANSPOSE_DATA_FORMAT = 15
    TRANSFORM_GLOBAL_CONV_TO_FC = 16
314
    ADD_BUFFER_TRANSFORM = 17
李寅 已提交
315
    ADD_DEVICE = 18
316
    SORT_BY_EXECUTION = 19
317
    ADD_IN_OUT_TENSOR_INFO = 20
李寅 已提交
318
    ADD_MACE_INPUT_AND_OUTPUT_NODES = 21
319
    UPDATE_FLOAT_OP_DATA_TYPE = 22
李寅 已提交
320 321 322
    QUANTIZE_NODES = 23
    ADD_QUANTIZE_TENSOR_RANGE = 24
    QUANTIZE_WEIGHTS = 25
323 324
    TRANSFORM_LSTMCELL_ZEROSTATE = 26
    TRANSFORM_BASIC_LSTMCELL = 27
李寅 已提交
325 326
    TRANSFORM_FAKE_QUANTIZE = 28
    CHECK_QUANTIZE_INFO = 29
B
Bin Li 已提交
327
    REARRANGE_BATCH_TO_SPACE = 30
328
    ADD_OPENCL_INFORMATIONS = 31
L
liutuo 已提交
329
    FOLD_DECONV_AND_BN = 32
330
    FOLD_SQRDIFF_MEAN = 33
331
    TRANSPOSE_MATMUL_WEIGHT = 34
李寅 已提交
332
    FOLD_EMBEDDING_LOOKUP = 35
L
luxuhui 已提交
333
    TRANSPOSE_RESHAPE_AND_FLATTEN = 36
Y
yejianwu 已提交
334
    FOLD_FC_RESHAPE = 37
李寅 已提交
335
    TRANSFORM_CHANNEL_SHUFFLE = 38
336
    UPDATE_DATA_FORMAT = 39
337
    QUANTIZE_SPECIFIC_OPS_ONLY = 40
338
    FP16_MATMUL_WEIGHT = 41
Y
yulianfei 已提交
339
    FP16_GATHER_WEIGHT = 42
B
Bin Li 已提交
340
    QUANTIZE_LARGE_WEIGHTS = 43
341
    TRANSPOSE_SHAPE_TENSOR_TO_PARAM = 44
342
    TRANSFORM_SINGLE_BN_TO_DEPTHWISE_CONV = 45
343 344
    TRANSFORM_MUL_MAX_TO_PRELU = 46
    TRANSFORM_EXPAND_DIMS_TO_RESHAPE = 47
345 346 347 348 349 350 351 352 353 354 355 356 357 358


class ConverterInterface(object):
    """Base class for converting external models to mace models."""

    def run(self):
        raise NotImplementedError('run')


class NodeInfo(object):
    """A class for describing node information"""

    def __init__(self):
        self._name = None
L
liuqi 已提交
359
        self._data_type = mace_pb2.DT_FLOAT
360
        self._shape = []
361
        self._data_format = DataFormat.NHWC
李寅 已提交
362
        self._range = [-1.0, 1.0]
363 364 365 366 367

    @property
    def name(self):
        return self._name

L
liuqi 已提交
368 369 370 371
    @property
    def data_type(self):
        return self._data_type

372 373 374 375
    @property
    def shape(self):
        return self._shape

376 377 378 379
    @property
    def data_format(self):
        return self._data_format

李寅 已提交
380 381 382 383
    @property
    def range(self):
        return self._range

384 385 386 387
    @name.setter
    def name(self, name):
        self._name = name

L
liuqi 已提交
388 389 390 391
    @data_type.setter
    def data_type(self, data_type):
        self._data_type = data_type

392 393 394 395
    @shape.setter
    def shape(self, shape):
        self._shape = shape

396 397 398 399
    @data_format.setter
    def data_format(self, data_format):
        self._data_format = data_format

李寅 已提交
400 401 402 403
    @range.setter
    def range(self, range):
        self._range = range

404 405 406 407 408 409 410
    def __str__(self):
        return '%s %s' % (self._name, str(self._shape))


class ConverterOption(object):
    """A class for specifying options passed to converter tool"""

李寅 已提交
411
    def __init__(self):
412 413
        self._input_nodes = {}
        self._output_nodes = {}
B
Bin Li 已提交
414
        self._check_nodes = {}
415
        self._data_type = mace_pb2.DT_FLOAT
416
        self._device = DeviceType.CPU.value
417
        self._winograd = 0
李寅 已提交
418
        self._quantize = False
419
        self._quantize_schema = ""
B
Bin Li 已提交
420
        self._quantize_large_weights = False
李寅 已提交
421
        self._quantize_range_file = ""
422
        self._change_concat_ranges = False
李寅 已提交
423
        self._transformer_option = None
424
        self._cl_mem_type = "image"
B
Bin Li 已提交
425
        self._quantize_stat = False
426 427 428 429 430 431 432 433 434

    @property
    def input_nodes(self):
        return self._input_nodes

    @property
    def output_nodes(self):
        return self._output_nodes

B
Bin Li 已提交
435 436 437 438
    @property
    def check_nodes(self):
        return self._check_nodes

439 440 441 442 443 444 445 446 447
    @property
    def data_type(self):
        return self._data_type

    @property
    def device(self):
        return self._device

    @property
448 449
    def winograd(self):
        return self._winograd
450

李寅 已提交
451 452 453 454
    @property
    def quantize(self):
        return self._quantize

455 456 457 458
    @property
    def quantize_schema(self):
        return self._quantize_schema

B
Bin Li 已提交
459 460 461 462
    @property
    def quantize_large_weights(self):
        return self._quantize_large_weights

463 464 465 466
    @property
    def change_concat_ranges(self):
        return self._change_concat_ranges

李寅 已提交
467 468 469 470
    @property
    def quantize_range_file(self):
        return self._quantize_range_file

471 472 473 474
    @property
    def transformer_option(self):
        return self._transformer_option

475 476 477 478
    @property
    def cl_mem_type(self):
        return self._cl_mem_type

B
Bin Li 已提交
479 480 481 482
    @property
    def quantize_stat(self):
        return self._quantize_stat

483 484
    @input_nodes.setter
    def input_nodes(self, input_nodes):
B
Bin Li 已提交
485
        for node in input_nodes.values():
486 487 488 489 490 491 492
            self._input_nodes[node.name] = node

    def add_input_node(self, input_node):
        self._input_nodes[input_node.name] = input_node

    @output_nodes.setter
    def output_nodes(self, output_nodes):
B
Bin Li 已提交
493
        for node in output_nodes.values():
494 495 496 497 498
            self.output_nodes[node.name] = node

    def add_output_node(self, output_node):
        self._output_nodes[output_node.name] = output_node

B
Bin Li 已提交
499 500
    @check_nodes.setter
    def check_nodes(self, check_nodes):
B
Bin Li 已提交
501
        for node in check_nodes.values():
B
Bin Li 已提交
502 503 504 505 506
            self.check_nodes[node.name] = node

    def add_check_node(self, check_node):
        self._check_nodes[check_node.name] = check_node

507 508 509 510 511 512 513 514
    @data_type.setter
    def data_type(self, data_type):
        self._data_type = data_type

    @device.setter
    def device(self, device):
        self._device = device

515 516 517
    @winograd.setter
    def winograd(self, winograd):
        self._winograd = winograd
518

李寅 已提交
519 520 521 522
    @quantize.setter
    def quantize(self, quantize):
        self._quantize = quantize

523 524 525 526
    @quantize_schema.setter
    def quantize_schema(self, quantize_schema):
        self._quantize_schema = quantize_schema

B
Bin Li 已提交
527 528 529 530
    @quantize_large_weights.setter
    def quantize_large_weights(self, quantize_large_weights):
        self._quantize_large_weights = quantize_large_weights

李寅 已提交
531 532 533 534
    @quantize_range_file.setter
    def quantize_range_file(self, quantize_range_file):
        self._quantize_range_file = quantize_range_file

535 536 537 538
    @change_concat_ranges.setter
    def change_concat_ranges(self, change_concat_ranges):
        self._change_concat_ranges = change_concat_ranges

李寅 已提交
539 540 541 542
    @transformer_option.setter
    def transformer_option(self, transformer_option):
        self._transformer_option = transformer_option

543 544 545 546
    @cl_mem_type.setter
    def cl_mem_type(self, cl_mem_type):
        self._cl_mem_type = cl_mem_type

B
Bin Li 已提交
547 548 549 550
    @quantize_stat.setter
    def quantize_stat(self, quantize_stat):
        self._quantize_stat = quantize_stat

551 552 553 554 555 556 557 558
    def disable_transpose_filters(self):
        if TransformerRule.TRANSPOSE_FILTERS in self._transformer_option:
            self._transformer_option.remove(TransformerRule.TRANSPOSE_FILTERS)

    def enable_transpose_filters(self):
        if TransformerRule.TRANSPOSE_FILTERS not in self._transformer_option:
            self._transformer_option.append(TransformerRule.TRANSPOSE_FILTERS)

李寅 已提交
559 560 561 562 563
    def build(self):
        if self._transformer_option:
            self._transformer_option = [TransformerRule[transformer]
                                        for transformer in self._transformer_option]  # noqa
        else:
564 565
            self._transformer_option = [
                # Model structure related transformation
566
                TransformerRule.REMOVE_USELESS_OP,
李寅 已提交
567
                TransformerRule.TRANSFORM_FAKE_QUANTIZE,
L
luxuhui 已提交
568
                TransformerRule.REMOVE_USELESS_OP,
569
                TransformerRule.TRANSFORM_GLOBAL_POOLING,
Y
yejianwu 已提交
570 571
                TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE,
                TransformerRule.TRANSFORM_BASIC_LSTMCELL,
L
luxuhui 已提交
572
                TransformerRule.TRANSPOSE_RESHAPE_AND_FLATTEN,
573
                TransformerRule.TRANSPOSE_SHAPE_TENSOR_TO_PARAM,
574 575
                TransformerRule.FOLD_RESHAPE,
                TransformerRule.TRANSFORM_MATMUL_TO_FC,
576 577 578
                # For StoB -> conv -> BtoS -> BN pattern
                # Insert flatten_atrous_conv before fold_xxx_and_bn
                TransformerRule.FLATTEN_ATROUS_CONV,
579 580
                TransformerRule.FOLD_BATCHNORM,
                TransformerRule.FOLD_CONV_AND_BN,
L
liutuo 已提交
581
                TransformerRule.FOLD_DECONV_AND_BN,
582
                TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN,
583
                TransformerRule.TRANSFORM_ADD_TO_BIASADD,
B
Bin Li 已提交
584
                TransformerRule.REARRANGE_BATCH_TO_SPACE,
585
                TransformerRule.FOLD_BIASADD,
586 587
                TransformerRule.FLATTEN_ATROUS_CONV,
                TransformerRule.FOLD_ACTIVATION,
588
                TransformerRule.FOLD_SQRDIFF_MEAN,
589 590
                TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC,
                TransformerRule.RESHAPE_FC_WEIGHT,
Y
yejianwu 已提交
591
                TransformerRule.FOLD_FC_RESHAPE,
李寅 已提交
592
                TransformerRule.TRANSFORM_CHANNEL_SHUFFLE,
593 594
                # Model data format related transformation
                TransformerRule.TRANSPOSE_FILTERS,
595 596
                # Mace model structure related transformation
                TransformerRule.ADD_IN_OUT_TENSOR_INFO,
597
                TransformerRule.TRANSPOSE_MATMUL_WEIGHT,
598 599
                # Add winograd argument
                TransformerRule.ADD_WINOGRAD_ARG,
600 601 602
                # Data type related transformation
                TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE,
                # Transform finalization
603
                TransformerRule.ADD_OPENCL_INFORMATIONS,
李寅 已提交
604
                # for quantization entropy calibration use
605
                TransformerRule.SORT_BY_EXECUTION,
606 607
                # update the data format of ops
                TransformerRule.UPDATE_DATA_FORMAT,
608
                TransformerRule.TRANSPOSE_DATA_FORMAT,
李寅 已提交
609 610
                # Need to be put after SORT_BY_EXECUTION
                TransformerRule.ADD_QUANTIZE_TENSOR_RANGE,
611
            ]
612 613 614
            if self._device == DeviceType.APU.value:
                self._transformer_option = self._transformer_option + [
                    TransformerRule.TRANSFORM_SINGLE_BN_TO_DEPTHWISE_CONV,
615 616
                    TransformerRule.TRANSFORM_MUL_MAX_TO_PRELU,
                    TransformerRule.TRANSFORM_EXPAND_DIMS_TO_RESHAPE,
617
                ]
B
Bin Li 已提交
618 619 620 621
            if self.quantize_large_weights:
                self._transformer_option = self._transformer_option + [
                    TransformerRule.QUANTIZE_LARGE_WEIGHTS
                ]
622
            if self._quantize:
李寅 已提交
623 624
                self._transformer_option = self._transformer_option + [
                    # need to be put after ADD_QUANTIZE_TENSOR_RANGE
李寅 已提交
625 626 627
                    TransformerRule.QUANTIZE_NODES,
                    TransformerRule.QUANTIZE_WEIGHTS,
                    TransformerRule.SORT_BY_EXECUTION,
李寅 已提交
628
                    TransformerRule.CHECK_QUANTIZE_INFO,
李寅 已提交
629 630
                ]

631 632 633 634 635 636 637 638 639

class ConverterUtil(object):
    @staticmethod
    def get_arg(op, arg_name):
        for arg in op.arg:
            if arg.name == arg_name:
                return arg
        return None

640 641 642 643 644 645 646 647 648 649
    @staticmethod
    def del_arg(op, arg_name):
        found_idx = -1
        for idx in range(len(op.arg)):
            if op.arg[idx].name == arg_name:
                found_idx = idx
                break
        if found_idx != -1:
            del op.arg[found_idx]

650 651 652 653 654 655
    @staticmethod
    def add_data_format_arg(op, data_format):
        data_format_arg = op.arg.add()
        data_format_arg.name = MaceKeyword.mace_data_format_str
        data_format_arg.i = data_format.value

李寅 已提交
656 657 658 659 660 661
    @staticmethod
    def add_data_type_arg(op, data_type):
        data_type_arg = op.arg.add()
        data_type_arg.name = MaceKeyword.mace_op_data_type_str
        data_type_arg.i = data_type

662 663 664 665 666 667 668 669 670
    @staticmethod
    def data_format(op):
        arg = ConverterUtil.get_arg(op, MaceKeyword.mace_data_format_str)
        if arg is None:
            return None
        elif arg.i == DataFormat.NHWC.value:
            return DataFormat.NHWC
        elif arg.i == DataFormat.NCHW.value:
            return DataFormat.NCHW
671 672
        elif arg.i == DataFormat.AUTO.value:
            return DataFormat.AUTO
673 674 675 676 677 678 679 680 681 682 683 684 685 686
        else:
            return None

    @staticmethod
    def set_filter_format(net, filter_format):
        arg = net.arg.add()
        arg.name = MaceKeyword.mace_filter_format_str
        arg.i = filter_format.value

    @staticmethod
    def filter_format(net):
        arg = ConverterUtil.get_arg(net, MaceKeyword.mace_filter_format_str)
        if arg is None:
            return None
687 688 689 690 691 692
        elif arg.i == DataFormat.HWIO.value:
            return DataFormat.HWIO
        elif arg.i == DataFormat.HWOI.value:
            return DataFormat.HWOI
        elif arg.i == DataFormat.OIHW.value:
            return DataFormat.OIHW
693 694
        else:
            return None