base_converter.py 18.4 KB
Newer Older
L
Liangliang He 已提交
1
# Copyright 2018 The MACE Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


16 17
from enum import Enum

L
liyin 已提交
18 19 20 21
from py_proto import mace_pb2

from utils.config_parser import DataFormat
from utils.config_parser import DeviceType
22 23


L
liutuo 已提交
24 25
# SAME_LOWER: if the amount of paddings to be added is odd,
# it will add the extra data to the right or bottom
26 27 28 29
class PaddingMode(Enum):
    VALID = 0
    SAME = 1
    FULL = 2
L
liutuo 已提交
30 31
    SAME_LOWER = 3
    NA = 4
32 33 34 35 36 37 38


class PoolingType(Enum):
    AVG = 1
    MAX = 2


L
liutuo 已提交
39 40 41 42 43
class RoundMode(Enum):
    FLOOR = 0
    CEIL = 1


44 45 46 47 48 49 50
class ActivationType(Enum):
    NOOP = 0
    RELU = 1
    RELUX = 2
    PRELU = 3
    TANH = 4
    SIGMOID = 5
L
liutuo 已提交
51
    LEAKYRELU = 6
52 53 54 55 56 57 58 59 60 61 62 63 64


class EltwiseType(Enum):
    SUM = 0
    SUB = 1
    PROD = 2
    DIV = 3
    MIN = 4
    MAX = 5
    NEG = 6
    ABS = 7
    SQR_DIFF = 8
    POW = 9
李寅 已提交
65
    EQUAL = 10
W
w-adamski 已提交
66
    FLOOR_DIV = 11
L
liutuo 已提交
67
    CLIP = 12
68 69


L
liutuo 已提交
70 71 72 73 74
class ReduceType(Enum):
    MEAN = 0
    MIN = 1
    MAX = 2
    PROD = 3
L
liutuo 已提交
75
    SUM = 4
L
liutuo 已提交
76 77


78 79 80 81 82 83
class PadType(Enum):
    CONSTANT = 0
    REFLECT = 1
    SYMMETRIC = 2


L
liutuo 已提交
84 85 86
class FrameworkType(Enum):
    TENSORFLOW = 0
    CAFFE = 1
L
liutuo 已提交
87
    ONNX = 2
L
liutuo 已提交
88 89


90 91 92
MaceSupportedOps = [
    'Activation',
    'AddN',
李寅 已提交
93
    'ArgMax',
94 95 96
    'BatchNorm',
    'BatchToSpaceND',
    'BiasAdd',
李寅 已提交
97
    'Cast',
98 99 100
    'ChannelShuffle',
    'Concat',
    'Conv2D',
101
    'Crop',
102
    'Deconv2D',
L
liutuo 已提交
103
    'Delay',
104 105
    'DepthToSpace',
    'DepthwiseConv2d',
L
liutuo 已提交
106
    'DepthwiseDeconv2d',
107 108
    'Dequantize',
    'Eltwise',
109
    'ExpandDims',
L
liutuo 已提交
110
    'ExtractPooling',
Y
yejianwu 已提交
111
    'Fill',
112
    'FullyConnected',
113 114
    'Gather',
    'Identity',
L
liutuo 已提交
115
    'InferConv2dShape',
L
liutuo 已提交
116
    'KaldiBatchNorm',
117
    'LocalResponseNorm',
Y
yejianwu 已提交
118
    'LSTMCell',
L
liutuo 已提交
119 120
    'LstmNonlinear',
    'DynamicLSTM',
121
    'MatMul',
W
Wiktor Adamski 已提交
122
    'OneHot',
123
    'Pad',
L
liutuo 已提交
124
    'PadContext',
125
    'PNorm',
126
    'Pooling',
L
lichao18 已提交
127
    'PriorBox',
128 129
    'Proposal',
    'Quantize',
L
liutuo 已提交
130
    'Reduce',
131
    'Reshape',
赵奇可 已提交
132
    'ResizeBicubic',
133
    'ResizeBilinear',
L
lichao18 已提交
134
    'ResizeNearestNeighbor',
135
    'Reverse',
136
    'ScalarMath',
137 138
    'Slice',
    'Splice',
Y
yejianwu 已提交
139
    'Split',
140 141 142
    'Shape',
    'Squeeze',
    'Stack',
Y
yejianwu 已提交
143
    'Unstack',
144
    'Unsqueeze',
145
    'StridedSlice',
146 147 148
    'Softmax',
    'SpaceToBatchND',
    'SpaceToDepth',
149
    'SqrDiffMean',
150 151
    'SumGroup',
    'TargetRMSNorm',
152
    'Transpose',
153
    'Cumsum',
叶剑武 已提交
154
    'Tile',
155 156 157 158
]

MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str)

159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
MaceFixedDataFormatOps = [MaceOp.BatchNorm,
                          MaceOp.BatchToSpaceND,
                          MaceOp.Conv2D,
                          MaceOp.Deconv2D,
                          MaceOp.DepthToSpace,
                          MaceOp.DepthwiseConv2d,
                          MaceOp.DepthwiseDeconv2d,
                          MaceOp.FullyConnected,
                          MaceOp.Pooling,
                          MaceOp.ResizeBicubic,
                          MaceOp.ResizeBilinear,
                          MaceOp.ResizeNearestNeighbor,
                          MaceOp.SpaceToBatchND,
                          MaceOp.SpaceToDepth]

MaceTransposableDataFormatOps = [MaceOp.Activation,
                                 MaceOp.AddN,
                                 MaceOp.BiasAdd,
                                 MaceOp.ChannelShuffle,
                                 MaceOp.Concat,
                                 MaceOp.Crop,
                                 MaceOp.Eltwise,
                                 MaceOp.Pad,
                                 MaceOp.Reduce,
L
luxuhui 已提交
183
                                 MaceOp.Reshape,
184 185
                                 MaceOp.Softmax,
                                 MaceOp.Split,
Y
yejianwu 已提交
186
                                 MaceOp.Squeeze,
叶剑武 已提交
187 188
                                 MaceOp.SqrDiffMean,
                                 MaceOp.Tile]
189

190 191 192 193 194 195 196 197

class MaceKeyword(object):
    # node related str
    mace_input_node_name = 'mace_input_node'
    mace_output_node_name = 'mace_output_node'
    mace_buffer_type = 'buffer_type'
    # arg related str
    mace_padding_str = 'padding'
198
    mace_padding_type_str = 'padding'
199 200 201 202 203 204 205
    mace_padding_values_str = 'padding_values'
    mace_strides_str = 'strides'
    mace_dilations_str = 'dilations'
    mace_pooling_type_str = 'pooling_type'
    mace_global_pooling_str = 'global_pooling'
    mace_kernel_str = 'kernels'
    mace_data_format_str = 'data_format'
206
    mace_has_data_format_str = 'has_data_format'
207 208 209 210
    mace_filter_format_str = 'filter_format'
    mace_element_type_str = 'type'
    mace_activation_type_str = 'activation'
    mace_activation_max_limit_str = 'max_limit'
Y
yejianwu 已提交
211
    mace_activation_leakyrelu_coefficient_str = 'leakyrelu_coefficient'
212 213 214 215 216 217 218
    mace_resize_size_str = 'size'
    mace_batch_to_space_crops_str = 'crops'
    mace_paddings_str = 'paddings'
    mace_align_corners_str = 'align_corners'
    mace_space_batch_block_shape_str = 'block_shape'
    mace_space_depth_block_size_str = 'block_size'
    mace_constant_value_str = 'constant_value'
L
lichao18 已提交
219
    mace_dim_str = 'dim'
220 221
    mace_dims_str = 'dims'
    mace_axis_str = 'axis'
L
lichao18 已提交
222 223
    mace_end_axis_str = 'end_axis'
    mace_num_axes_str = 'num_axes'
Y
yejianwu 已提交
224
    mace_num_split_str = 'num_split'
L
liutuo 已提交
225
    mace_keepdims_str = 'keepdims'
226 227
    mace_shape_str = 'shape'
    mace_winograd_filter_transformed = 'is_filter_transformed'
228
    mace_device = 'device'
229
    mace_scalar_input_str = 'scalar_input'
L
liutuo 已提交
230
    mace_wino_block_size = 'wino_block_size'
L
liutuo 已提交
231
    mace_output_shape_str = 'output_shape'
232 233 234 235 236 237 238
    mace_begin_mask_str = 'begin_mask'
    mace_end_mask_str = 'end_mask'
    mace_ellipsis_mask_str = 'ellipsis_mask'
    mace_new_axis_mask_str = 'new_axis_mask'
    mace_shrink_axis_mask_str = 'shrink_axis_mask'
    mace_transpose_a_str = 'transpose_a'
    mace_transpose_b_str = 'transpose_b'
239
    mace_op_data_type_str = 'T'
240
    mace_offset_str = 'offset'
241
    mace_opencl_max_image_size = "opencl_max_image_size"
242 243
    mace_seperate_buffer_str = 'seperate_buffer'
    mace_scalar_input_index_str = 'scalar_input_index'
244
    mace_opencl_mem_type = "opencl_mem_type"
L
liutuo 已提交
245
    mace_framework_type_str = "framework_type"
L
liutuo 已提交
246
    mace_group_str = "group"
247 248
    mace_wino_arg_str = "wino_block_size"
    mace_quantize_flag_arg_str = "quantize_flag"
L
liutuo 已提交
249 250 251 252
    mace_epsilon_str = 'epsilon'
    mace_reduce_type_str = 'reduce_type'
    mace_argmin_str = 'argmin'
    mace_round_mode_str = 'round_mode'
L
lichao18 已提交
253 254 255 256 257 258 259 260
    mace_min_size_str = 'min_size'
    mace_max_size_str = 'max_size'
    mace_aspect_ratio_str = 'aspect_ratio'
    mace_flip_str = 'flip'
    mace_clip_str = 'clip'
    mace_variance_str = 'variance'
    mace_step_h_str = 'step_h'
    mace_step_w_str = 'step_w'
B
Bin Li 已提交
261
    mace_find_range_every_time = 'find_range_every_time'
262
    mace_non_zero = 'non_zero'
263
    mace_pad_type_str = 'pad_type'
264 265
    mace_exclusive_str = 'exclusive'
    mace_reverse_str = 'reverse'
266
    mace_const_data_num_arg_str = 'const_data_num'
L
liutuo 已提交
267
    mace_coeff_str = 'coeff'
268 269 270


class TransformerRule(Enum):
李寅 已提交
271 272 273 274 275 276 277
    REMOVE_IDENTITY_OP = 1
    TRANSFORM_GLOBAL_POOLING = 2
    FOLD_RESHAPE = 3
    TRANSFORM_MATMUL_TO_FC = 4
    FOLD_BATCHNORM = 5
    FOLD_CONV_AND_BN = 6
    FOLD_DEPTHWISE_CONV_AND_BN = 7
278
    ADD_WINOGRAD_ARG = 8
李寅 已提交
279 280
    TRANSFORM_ADD_TO_BIASADD = 9
    FOLD_BIASADD = 10
281 282 283 284 285 286
    FLATTEN_ATROUS_CONV = 11
    FOLD_ACTIVATION = 12
    TRANSPOSE_FILTERS = 13
    RESHAPE_FC_WEIGHT = 14
    TRANSPOSE_DATA_FORMAT = 15
    TRANSFORM_GLOBAL_CONV_TO_FC = 16
287
    ADD_BUFFER_TRANSFORM = 17
李寅 已提交
288
    ADD_DEVICE = 18
289
    SORT_BY_EXECUTION = 19
290
    ADD_IN_OUT_TENSOR_INFO = 20
李寅 已提交
291
    ADD_MACE_INPUT_AND_OUTPUT_NODES = 21
292
    UPDATE_FLOAT_OP_DATA_TYPE = 22
李寅 已提交
293 294 295
    QUANTIZE_NODES = 23
    ADD_QUANTIZE_TENSOR_RANGE = 24
    QUANTIZE_WEIGHTS = 25
296 297
    TRANSFORM_LSTMCELL_ZEROSTATE = 26
    TRANSFORM_BASIC_LSTMCELL = 27
李寅 已提交
298 299
    TRANSFORM_FAKE_QUANTIZE = 28
    CHECK_QUANTIZE_INFO = 29
B
Bin Li 已提交
300
    REARRANGE_BATCH_TO_SPACE = 30
301
    ADD_OPENCL_INFORMATIONS = 31
L
liutuo 已提交
302
    FOLD_DECONV_AND_BN = 32
303
    FOLD_SQRDIFF_MEAN = 33
304
    TRANSPOSE_MATMUL_WEIGHT = 34
李寅 已提交
305
    FOLD_EMBEDDING_LOOKUP = 35
L
luxuhui 已提交
306
    TRANSPOSE_RESHAPE_AND_FLATTEN = 36
Y
yejianwu 已提交
307
    FOLD_FC_RESHAPE = 37
李寅 已提交
308
    TRANSFORM_CHANNEL_SHUFFLE = 38
309
    UPDATE_DATA_FORMAT = 39
310
    QUANTIZE_SPECIFIC_OPS_ONLY = 40
311
    FP16_MATMUL_WEIGHT = 41
Y
yulianfei 已提交
312
    FP16_GATHER_WEIGHT = 42
B
Bin Li 已提交
313
    QUANTIZE_LARGE_WEIGHTS = 43
314 315 316 317 318 319 320 321 322 323 324 325 326 327


class ConverterInterface(object):
    """Base class for converting external models to mace models."""

    def run(self):
        raise NotImplementedError('run')


class NodeInfo(object):
    """A class for describing node information"""

    def __init__(self):
        self._name = None
L
liuqi 已提交
328
        self._data_type = mace_pb2.DT_FLOAT
329
        self._shape = []
330
        self._data_format = DataFormat.NHWC
李寅 已提交
331
        self._range = [-1.0, 1.0]
332 333 334 335 336

    @property
    def name(self):
        return self._name

L
liuqi 已提交
337 338 339 340
    @property
    def data_type(self):
        return self._data_type

341 342 343 344
    @property
    def shape(self):
        return self._shape

345 346 347 348
    @property
    def data_format(self):
        return self._data_format

李寅 已提交
349 350 351 352
    @property
    def range(self):
        return self._range

353 354 355 356
    @name.setter
    def name(self, name):
        self._name = name

L
liuqi 已提交
357 358 359 360
    @data_type.setter
    def data_type(self, data_type):
        self._data_type = data_type

361 362 363 364
    @shape.setter
    def shape(self, shape):
        self._shape = shape

365 366 367 368
    @data_format.setter
    def data_format(self, data_format):
        self._data_format = data_format

李寅 已提交
369 370 371 372
    @range.setter
    def range(self, range):
        self._range = range

373 374 375 376 377 378 379
    def __str__(self):
        return '%s %s' % (self._name, str(self._shape))


class ConverterOption(object):
    """A class for specifying options passed to converter tool"""

李寅 已提交
380
    def __init__(self):
381 382
        self._input_nodes = {}
        self._output_nodes = {}
B
Bin Li 已提交
383
        self._check_nodes = {}
384
        self._data_type = mace_pb2.DT_FLOAT
385
        self._device = DeviceType.CPU.value
386
        self._winograd = 0
李寅 已提交
387
        self._quantize = False
B
Bin Li 已提交
388
        self._quantize_large_weights = False
李寅 已提交
389
        self._quantize_range_file = ""
390
        self._change_concat_ranges = False
李寅 已提交
391
        self._transformer_option = None
392
        self._cl_mem_type = "image"
393 394 395 396 397 398 399 400 401

    @property
    def input_nodes(self):
        return self._input_nodes

    @property
    def output_nodes(self):
        return self._output_nodes

B
Bin Li 已提交
402 403 404 405
    @property
    def check_nodes(self):
        return self._check_nodes

406 407 408 409 410 411 412 413 414
    @property
    def data_type(self):
        return self._data_type

    @property
    def device(self):
        return self._device

    @property
415 416
    def winograd(self):
        return self._winograd
417

李寅 已提交
418 419 420 421
    @property
    def quantize(self):
        return self._quantize

B
Bin Li 已提交
422 423 424 425
    @property
    def quantize_large_weights(self):
        return self._quantize_large_weights

426 427 428 429
    @property
    def change_concat_ranges(self):
        return self._change_concat_ranges

李寅 已提交
430 431 432 433
    @property
    def quantize_range_file(self):
        return self._quantize_range_file

434 435 436 437
    @property
    def transformer_option(self):
        return self._transformer_option

438 439 440 441
    @property
    def cl_mem_type(self):
        return self._cl_mem_type

442 443
    @input_nodes.setter
    def input_nodes(self, input_nodes):
B
Bin Li 已提交
444
        for node in input_nodes.values():
445 446 447 448 449 450 451
            self._input_nodes[node.name] = node

    def add_input_node(self, input_node):
        self._input_nodes[input_node.name] = input_node

    @output_nodes.setter
    def output_nodes(self, output_nodes):
B
Bin Li 已提交
452
        for node in output_nodes.values():
453 454 455 456 457
            self.output_nodes[node.name] = node

    def add_output_node(self, output_node):
        self._output_nodes[output_node.name] = output_node

B
Bin Li 已提交
458 459
    @check_nodes.setter
    def check_nodes(self, check_nodes):
B
Bin Li 已提交
460
        for node in check_nodes.values():
B
Bin Li 已提交
461 462 463 464 465
            self.check_nodes[node.name] = node

    def add_check_node(self, check_node):
        self._check_nodes[check_node.name] = check_node

466 467 468 469 470 471 472 473
    @data_type.setter
    def data_type(self, data_type):
        self._data_type = data_type

    @device.setter
    def device(self, device):
        self._device = device

474 475 476
    @winograd.setter
    def winograd(self, winograd):
        self._winograd = winograd
477

李寅 已提交
478 479 480 481
    @quantize.setter
    def quantize(self, quantize):
        self._quantize = quantize

B
Bin Li 已提交
482 483 484 485
    @quantize_large_weights.setter
    def quantize_large_weights(self, quantize_large_weights):
        self._quantize_large_weights = quantize_large_weights

李寅 已提交
486 487 488 489
    @quantize_range_file.setter
    def quantize_range_file(self, quantize_range_file):
        self._quantize_range_file = quantize_range_file

490 491 492 493
    @change_concat_ranges.setter
    def change_concat_ranges(self, change_concat_ranges):
        self._change_concat_ranges = change_concat_ranges

李寅 已提交
494 495 496 497
    @transformer_option.setter
    def transformer_option(self, transformer_option):
        self._transformer_option = transformer_option

498 499 500 501
    @cl_mem_type.setter
    def cl_mem_type(self, cl_mem_type):
        self._cl_mem_type = cl_mem_type

502 503 504 505 506 507 508 509
    def disable_transpose_filters(self):
        if TransformerRule.TRANSPOSE_FILTERS in self._transformer_option:
            self._transformer_option.remove(TransformerRule.TRANSPOSE_FILTERS)

    def enable_transpose_filters(self):
        if TransformerRule.TRANSPOSE_FILTERS not in self._transformer_option:
            self._transformer_option.append(TransformerRule.TRANSPOSE_FILTERS)

李寅 已提交
510 511 512 513 514
    def build(self):
        if self._transformer_option:
            self._transformer_option = [TransformerRule[transformer]
                                        for transformer in self._transformer_option]  # noqa
        else:
515 516
            self._transformer_option = [
                # Model structure related transformation
517
                TransformerRule.REMOVE_IDENTITY_OP,
李寅 已提交
518
                TransformerRule.TRANSFORM_FAKE_QUANTIZE,
519 520
                TransformerRule.REMOVE_IDENTITY_OP,
                TransformerRule.TRANSFORM_GLOBAL_POOLING,
Y
yejianwu 已提交
521 522
                TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE,
                TransformerRule.TRANSFORM_BASIC_LSTMCELL,
L
luxuhui 已提交
523
                TransformerRule.TRANSPOSE_RESHAPE_AND_FLATTEN,
524 525
                TransformerRule.FOLD_RESHAPE,
                TransformerRule.TRANSFORM_MATMUL_TO_FC,
526 527 528
                # For StoB -> conv -> BtoS -> BN pattern
                # Insert flatten_atrous_conv before fold_xxx_and_bn
                TransformerRule.FLATTEN_ATROUS_CONV,
529 530
                TransformerRule.FOLD_BATCHNORM,
                TransformerRule.FOLD_CONV_AND_BN,
L
liutuo 已提交
531
                TransformerRule.FOLD_DECONV_AND_BN,
532
                TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN,
533
                TransformerRule.TRANSFORM_ADD_TO_BIASADD,
B
Bin Li 已提交
534
                TransformerRule.REARRANGE_BATCH_TO_SPACE,
535
                TransformerRule.FOLD_BIASADD,
536 537
                TransformerRule.FLATTEN_ATROUS_CONV,
                TransformerRule.FOLD_ACTIVATION,
538
                TransformerRule.FOLD_SQRDIFF_MEAN,
539 540
                TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC,
                TransformerRule.RESHAPE_FC_WEIGHT,
Y
yejianwu 已提交
541
                TransformerRule.FOLD_FC_RESHAPE,
李寅 已提交
542
                TransformerRule.TRANSFORM_CHANNEL_SHUFFLE,
543 544
                # Model data format related transformation
                TransformerRule.TRANSPOSE_FILTERS,
545 546
                # Mace model structure related transformation
                TransformerRule.ADD_IN_OUT_TENSOR_INFO,
547
                TransformerRule.TRANSPOSE_MATMUL_WEIGHT,
548 549
                # Add winograd argument
                TransformerRule.ADD_WINOGRAD_ARG,
550 551 552
                # Data type related transformation
                TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE,
                # Transform finalization
553
                TransformerRule.ADD_OPENCL_INFORMATIONS,
李寅 已提交
554
                # for quantization entropy calibration use
555
                TransformerRule.SORT_BY_EXECUTION,
556 557
                # update the data format of ops
                TransformerRule.UPDATE_DATA_FORMAT,
558
                TransformerRule.TRANSPOSE_DATA_FORMAT,
李寅 已提交
559 560
                # Need to be put after SORT_BY_EXECUTION
                TransformerRule.ADD_QUANTIZE_TENSOR_RANGE,
561
            ]
B
Bin Li 已提交
562 563 564 565
            if self.quantize_large_weights:
                self._transformer_option = self._transformer_option + [
                    TransformerRule.QUANTIZE_LARGE_WEIGHTS
                ]
566
            if self._quantize:
李寅 已提交
567 568
                self._transformer_option = self._transformer_option + [
                    # need to be put after ADD_QUANTIZE_TENSOR_RANGE
李寅 已提交
569 570 571
                    TransformerRule.QUANTIZE_NODES,
                    TransformerRule.QUANTIZE_WEIGHTS,
                    TransformerRule.SORT_BY_EXECUTION,
李寅 已提交
572
                    TransformerRule.CHECK_QUANTIZE_INFO,
李寅 已提交
573 574
                ]

575 576 577 578 579 580 581 582 583

class ConverterUtil(object):
    @staticmethod
    def get_arg(op, arg_name):
        for arg in op.arg:
            if arg.name == arg_name:
                return arg
        return None

584 585 586 587 588 589 590 591 592 593
    @staticmethod
    def del_arg(op, arg_name):
        found_idx = -1
        for idx in range(len(op.arg)):
            if op.arg[idx].name == arg_name:
                found_idx = idx
                break
        if found_idx != -1:
            del op.arg[found_idx]

594 595 596 597 598 599
    @staticmethod
    def add_data_format_arg(op, data_format):
        data_format_arg = op.arg.add()
        data_format_arg.name = MaceKeyword.mace_data_format_str
        data_format_arg.i = data_format.value

李寅 已提交
600 601 602 603 604 605
    @staticmethod
    def add_data_type_arg(op, data_type):
        data_type_arg = op.arg.add()
        data_type_arg.name = MaceKeyword.mace_op_data_type_str
        data_type_arg.i = data_type

606 607 608 609 610 611 612 613 614
    @staticmethod
    def data_format(op):
        arg = ConverterUtil.get_arg(op, MaceKeyword.mace_data_format_str)
        if arg is None:
            return None
        elif arg.i == DataFormat.NHWC.value:
            return DataFormat.NHWC
        elif arg.i == DataFormat.NCHW.value:
            return DataFormat.NCHW
615 616
        elif arg.i == DataFormat.AUTO.value:
            return DataFormat.AUTO
617 618 619 620 621 622 623 624 625 626 627 628 629 630
        else:
            return None

    @staticmethod
    def set_filter_format(net, filter_format):
        arg = net.arg.add()
        arg.name = MaceKeyword.mace_filter_format_str
        arg.i = filter_format.value

    @staticmethod
    def filter_format(net):
        arg = ConverterUtil.get_arg(net, MaceKeyword.mace_filter_format_str)
        if arg is None:
            return None
631 632 633 634 635 636
        elif arg.i == DataFormat.HWIO.value:
            return DataFormat.HWIO
        elif arg.i == DataFormat.HWOI.value:
            return DataFormat.HWOI
        elif arg.i == DataFormat.OIHW.value:
            return DataFormat.OIHW
637 638
        else:
            return None