base_converter.py 18.8 KB
Newer Older
L
Liangliang He 已提交
1
# Copyright 2018 The MACE Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


16 17
from enum import Enum

L
liyin 已提交
18 19 20 21
from py_proto import mace_pb2

from utils.config_parser import DataFormat
from utils.config_parser import DeviceType
22 23


L
liutuo 已提交
24 25
# SAME_LOWER: if the amount of paddings to be added is odd,
# it will add the extra data to the right or bottom
26 27 28 29
class PaddingMode(Enum):
    VALID = 0
    SAME = 1
    FULL = 2
L
liutuo 已提交
30 31
    SAME_LOWER = 3
    NA = 4
32 33 34 35 36 37 38


class PoolingType(Enum):
    AVG = 1
    MAX = 2


L
liutuo 已提交
39 40 41 42 43
class RoundMode(Enum):
    FLOOR = 0
    CEIL = 1


44 45 46 47 48 49 50
class ActivationType(Enum):
    NOOP = 0
    RELU = 1
    RELUX = 2
    PRELU = 3
    TANH = 4
    SIGMOID = 5
L
liutuo 已提交
51
    LEAKYRELU = 6
52 53 54 55 56 57 58 59 60 61 62 63 64


class EltwiseType(Enum):
    SUM = 0
    SUB = 1
    PROD = 2
    DIV = 3
    MIN = 4
    MAX = 5
    NEG = 6
    ABS = 7
    SQR_DIFF = 8
    POW = 9
李寅 已提交
65
    EQUAL = 10
W
w-adamski 已提交
66
    FLOOR_DIV = 11
L
liutuo 已提交
67
    CLIP = 12
68
    SIGN = 13
69 70


L
liutuo 已提交
71 72 73 74 75
class ReduceType(Enum):
    MEAN = 0
    MIN = 1
    MAX = 2
    PROD = 3
L
liutuo 已提交
76
    SUM = 4
L
liutuo 已提交
77 78


79 80 81 82 83 84
class PadType(Enum):
    CONSTANT = 0
    REFLECT = 1
    SYMMETRIC = 2


L
liutuo 已提交
85 86 87
class FrameworkType(Enum):
    TENSORFLOW = 0
    CAFFE = 1
L
liutuo 已提交
88
    ONNX = 2
L
liutuo 已提交
89 90


91 92 93
MaceSupportedOps = [
    'Activation',
    'AddN',
李寅 已提交
94
    'ArgMax',
95 96 97
    'BatchNorm',
    'BatchToSpaceND',
    'BiasAdd',
李寅 已提交
98
    'Cast',
99 100 101
    'ChannelShuffle',
    'Concat',
    'Conv2D',
102
    'Crop',
103 104 105
    'Deconv2D',
    'DepthToSpace',
    'DepthwiseConv2d',
L
liutuo 已提交
106
    'DepthwiseDeconv2d',
107 108
    'Dequantize',
    'Eltwise',
109
    'ExpandDims',
L
liutuo 已提交
110
    'ExtractPooling',
Y
yejianwu 已提交
111
    'Fill',
112
    'FullyConnected',
113 114
    'Gather',
    'Identity',
L
liutuo 已提交
115
    'IfDefined',
L
liutuo 已提交
116
    'InferConv2dShape',
L
liutuo 已提交
117
    'KaldiBatchNorm',
118
    'LocalResponseNorm',
Y
yejianwu 已提交
119
    'LSTMCell',
L
liutuo 已提交
120 121
    'LstmNonlinear',
    'DynamicLSTM',
122
    'MatMul',
W
Wiktor Adamski 已提交
123
    'OneHot',
124
    'Pad',
L
liutuo 已提交
125
    'PadContext',
126
    'PNorm',
127
    'Pooling',
L
lichao18 已提交
128
    'PriorBox',
129 130
    'Proposal',
    'Quantize',
L
liutuo 已提交
131
    'Reduce',
L
liutuo 已提交
132
    'ReplaceIndex',
133
    'Reshape',
赵奇可 已提交
134
    'ResizeBicubic',
135
    'ResizeBilinear',
L
lichao18 已提交
136
    'ResizeNearestNeighbor',
137
    'Reverse',
138
    'ScalarMath',
139 140
    'Slice',
    'Splice',
Y
yejianwu 已提交
141
    'Split',
142 143 144
    'Shape',
    'Squeeze',
    'Stack',
Y
yejianwu 已提交
145
    'Unstack',
146
    'Unsqueeze',
147
    'StridedSlice',
148 149 150
    'Softmax',
    'SpaceToBatchND',
    'SpaceToDepth',
151
    'SqrDiffMean',
L
liutuo 已提交
152
    'Subsample',
153 154
    'SumGroup',
    'TargetRMSNorm',
155
    'Transpose',
156
    'Cumsum',
叶剑武 已提交
157
    'Tile',
158 159
    'LpNorm',
    'MVNorm',
160 161 162 163
]

MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str)

164 165 166 167 168 169 170 171 172 173 174 175 176
MaceFixedDataFormatOps = [MaceOp.BatchNorm,
                          MaceOp.BatchToSpaceND,
                          MaceOp.Conv2D,
                          MaceOp.Deconv2D,
                          MaceOp.DepthToSpace,
                          MaceOp.DepthwiseConv2d,
                          MaceOp.DepthwiseDeconv2d,
                          MaceOp.FullyConnected,
                          MaceOp.Pooling,
                          MaceOp.ResizeBicubic,
                          MaceOp.ResizeBilinear,
                          MaceOp.ResizeNearestNeighbor,
                          MaceOp.SpaceToBatchND,
177 178 179
                          MaceOp.SpaceToDepth,
                          MaceOp.LpNorm,
                          MaceOp.MVNorm]
180 181 182 183 184 185 186 187 188 189

MaceTransposableDataFormatOps = [MaceOp.Activation,
                                 MaceOp.AddN,
                                 MaceOp.BiasAdd,
                                 MaceOp.ChannelShuffle,
                                 MaceOp.Concat,
                                 MaceOp.Crop,
                                 MaceOp.Eltwise,
                                 MaceOp.Pad,
                                 MaceOp.Reduce,
190
                                 MaceOp.Reshape,
191 192
                                 MaceOp.Softmax,
                                 MaceOp.Split,
Y
yejianwu 已提交
193
                                 MaceOp.Squeeze,
叶剑武 已提交
194 195
                                 MaceOp.SqrDiffMean,
                                 MaceOp.Tile]
196

197 198 199 200 201 202 203 204

class MaceKeyword(object):
    # node related str
    mace_input_node_name = 'mace_input_node'
    mace_output_node_name = 'mace_output_node'
    mace_buffer_type = 'buffer_type'
    # arg related str
    mace_padding_str = 'padding'
205
    mace_padding_type_str = 'padding'
206 207 208 209 210 211 212
    mace_padding_values_str = 'padding_values'
    mace_strides_str = 'strides'
    mace_dilations_str = 'dilations'
    mace_pooling_type_str = 'pooling_type'
    mace_global_pooling_str = 'global_pooling'
    mace_kernel_str = 'kernels'
    mace_data_format_str = 'data_format'
213
    mace_has_data_format_str = 'has_data_format'
214 215 216 217
    mace_filter_format_str = 'filter_format'
    mace_element_type_str = 'type'
    mace_activation_type_str = 'activation'
    mace_activation_max_limit_str = 'max_limit'
Y
yejianwu 已提交
218
    mace_activation_leakyrelu_coefficient_str = 'leakyrelu_coefficient'
219 220 221 222 223 224 225
    mace_resize_size_str = 'size'
    mace_batch_to_space_crops_str = 'crops'
    mace_paddings_str = 'paddings'
    mace_align_corners_str = 'align_corners'
    mace_space_batch_block_shape_str = 'block_shape'
    mace_space_depth_block_size_str = 'block_size'
    mace_constant_value_str = 'constant_value'
L
lichao18 已提交
226
    mace_dim_str = 'dim'
227 228
    mace_dims_str = 'dims'
    mace_axis_str = 'axis'
L
lichao18 已提交
229 230
    mace_end_axis_str = 'end_axis'
    mace_num_axes_str = 'num_axes'
Y
yejianwu 已提交
231
    mace_num_split_str = 'num_split'
L
liutuo 已提交
232
    mace_keepdims_str = 'keepdims'
233 234
    mace_shape_str = 'shape'
    mace_winograd_filter_transformed = 'is_filter_transformed'
235
    mace_device = 'device'
236
    mace_scalar_input_str = 'scalar_input'
L
liutuo 已提交
237
    mace_wino_block_size = 'wino_block_size'
L
liutuo 已提交
238
    mace_output_shape_str = 'output_shape'
239 240 241 242 243 244 245
    mace_begin_mask_str = 'begin_mask'
    mace_end_mask_str = 'end_mask'
    mace_ellipsis_mask_str = 'ellipsis_mask'
    mace_new_axis_mask_str = 'new_axis_mask'
    mace_shrink_axis_mask_str = 'shrink_axis_mask'
    mace_transpose_a_str = 'transpose_a'
    mace_transpose_b_str = 'transpose_b'
246
    mace_op_data_type_str = 'T'
247
    mace_offset_str = 'offset'
248
    mace_opencl_max_image_size = "opencl_max_image_size"
249 250
    mace_seperate_buffer_str = 'seperate_buffer'
    mace_scalar_input_index_str = 'scalar_input_index'
251
    mace_opencl_mem_type = "opencl_mem_type"
L
liutuo 已提交
252
    mace_framework_type_str = "framework_type"
L
liutuo 已提交
253
    mace_group_str = "group"
254 255
    mace_wino_arg_str = "wino_block_size"
    mace_quantize_flag_arg_str = "quantize_flag"
L
liutuo 已提交
256 257 258 259
    mace_epsilon_str = 'epsilon'
    mace_reduce_type_str = 'reduce_type'
    mace_argmin_str = 'argmin'
    mace_round_mode_str = 'round_mode'
L
lichao18 已提交
260 261 262 263 264 265 266 267
    mace_min_size_str = 'min_size'
    mace_max_size_str = 'max_size'
    mace_aspect_ratio_str = 'aspect_ratio'
    mace_flip_str = 'flip'
    mace_clip_str = 'clip'
    mace_variance_str = 'variance'
    mace_step_h_str = 'step_h'
    mace_step_w_str = 'step_w'
B
Bin Li 已提交
268
    mace_find_range_every_time = 'find_range_every_time'
269
    mace_non_zero = 'non_zero'
270
    mace_pad_type_str = 'pad_type'
271 272
    mace_exclusive_str = 'exclusive'
    mace_reverse_str = 'reverse'
273
    mace_const_data_num_arg_str = 'const_data_num'
L
liutuo 已提交
274
    mace_coeff_str = 'coeff'
L
liutuo 已提交
275 276
    mace_input_indexes_str = 'input_indexes'
    mace_output_indexes_str = 'output_indexes'
277 278 279
    mace_p_str = 'p'
    mace_nor_var_str = 'normalize_variance'
    mace_across_ch_str = 'across_channels'
280 281 282


class TransformerRule(Enum):
李寅 已提交
283 284 285 286 287 288 289
    REMOVE_IDENTITY_OP = 1
    TRANSFORM_GLOBAL_POOLING = 2
    FOLD_RESHAPE = 3
    TRANSFORM_MATMUL_TO_FC = 4
    FOLD_BATCHNORM = 5
    FOLD_CONV_AND_BN = 6
    FOLD_DEPTHWISE_CONV_AND_BN = 7
290
    ADD_WINOGRAD_ARG = 8
李寅 已提交
291 292
    TRANSFORM_ADD_TO_BIASADD = 9
    FOLD_BIASADD = 10
293 294 295 296 297 298
    FLATTEN_ATROUS_CONV = 11
    FOLD_ACTIVATION = 12
    TRANSPOSE_FILTERS = 13
    RESHAPE_FC_WEIGHT = 14
    TRANSPOSE_DATA_FORMAT = 15
    TRANSFORM_GLOBAL_CONV_TO_FC = 16
299
    ADD_BUFFER_TRANSFORM = 17
李寅 已提交
300
    ADD_DEVICE = 18
301
    SORT_BY_EXECUTION = 19
302
    ADD_IN_OUT_TENSOR_INFO = 20
李寅 已提交
303
    ADD_MACE_INPUT_AND_OUTPUT_NODES = 21
304
    UPDATE_FLOAT_OP_DATA_TYPE = 22
李寅 已提交
305 306 307
    QUANTIZE_NODES = 23
    ADD_QUANTIZE_TENSOR_RANGE = 24
    QUANTIZE_WEIGHTS = 25
308 309
    TRANSFORM_LSTMCELL_ZEROSTATE = 26
    TRANSFORM_BASIC_LSTMCELL = 27
李寅 已提交
310 311
    TRANSFORM_FAKE_QUANTIZE = 28
    CHECK_QUANTIZE_INFO = 29
B
Bin Li 已提交
312
    REARRANGE_BATCH_TO_SPACE = 30
313
    ADD_OPENCL_INFORMATIONS = 31
L
liutuo 已提交
314
    FOLD_DECONV_AND_BN = 32
315
    FOLD_SQRDIFF_MEAN = 33
316
    TRANSPOSE_MATMUL_WEIGHT = 34
李寅 已提交
317
    FOLD_EMBEDDING_LOOKUP = 35
L
luxuhui 已提交
318
    TRANSPOSE_RESHAPE_AND_FLATTEN = 36
Y
yejianwu 已提交
319
    FOLD_FC_RESHAPE = 37
李寅 已提交
320
    TRANSFORM_CHANNEL_SHUFFLE = 38
321
    UPDATE_DATA_FORMAT = 39
322
    QUANTIZE_SPECIFIC_OPS_ONLY = 40
323
    FP16_MATMUL_WEIGHT = 41
Y
yulianfei 已提交
324
    FP16_GATHER_WEIGHT = 42
B
Bin Li 已提交
325
    QUANTIZE_LARGE_WEIGHTS = 43
326 327 328 329 330 331 332 333 334 335 336 337 338 339


class ConverterInterface(object):
    """Base class for converting external models to mace models."""

    def run(self):
        raise NotImplementedError('run')


class NodeInfo(object):
    """A class for describing node information"""

    def __init__(self):
        self._name = None
L
liuqi 已提交
340
        self._data_type = mace_pb2.DT_FLOAT
341
        self._shape = []
342
        self._data_format = DataFormat.NHWC
李寅 已提交
343
        self._range = [-1.0, 1.0]
344 345 346 347 348

    @property
    def name(self):
        return self._name

L
liuqi 已提交
349 350 351 352
    @property
    def data_type(self):
        return self._data_type

353 354 355 356
    @property
    def shape(self):
        return self._shape

357 358 359 360
    @property
    def data_format(self):
        return self._data_format

李寅 已提交
361 362 363 364
    @property
    def range(self):
        return self._range

365 366 367 368
    @name.setter
    def name(self, name):
        self._name = name

L
liuqi 已提交
369 370 371 372
    @data_type.setter
    def data_type(self, data_type):
        self._data_type = data_type

373 374 375 376
    @shape.setter
    def shape(self, shape):
        self._shape = shape

377 378 379 380
    @data_format.setter
    def data_format(self, data_format):
        self._data_format = data_format

李寅 已提交
381 382 383 384
    @range.setter
    def range(self, range):
        self._range = range

385 386 387 388 389 390 391
    def __str__(self):
        return '%s %s' % (self._name, str(self._shape))


class ConverterOption(object):
    """A class for specifying options passed to converter tool"""

李寅 已提交
392
    def __init__(self):
393 394
        self._input_nodes = {}
        self._output_nodes = {}
B
Bin Li 已提交
395
        self._check_nodes = {}
396
        self._data_type = mace_pb2.DT_FLOAT
397
        self._device = DeviceType.CPU.value
398
        self._winograd = 0
李寅 已提交
399
        self._quantize = False
B
Bin Li 已提交
400
        self._quantize_large_weights = False
李寅 已提交
401
        self._quantize_range_file = ""
402
        self._change_concat_ranges = False
李寅 已提交
403
        self._transformer_option = None
404
        self._cl_mem_type = "image"
405 406 407 408 409 410 411 412 413

    @property
    def input_nodes(self):
        return self._input_nodes

    @property
    def output_nodes(self):
        return self._output_nodes

B
Bin Li 已提交
414 415 416 417
    @property
    def check_nodes(self):
        return self._check_nodes

418 419 420 421 422 423 424 425 426
    @property
    def data_type(self):
        return self._data_type

    @property
    def device(self):
        return self._device

    @property
427 428
    def winograd(self):
        return self._winograd
429

李寅 已提交
430 431 432 433
    @property
    def quantize(self):
        return self._quantize

B
Bin Li 已提交
434 435 436 437
    @property
    def quantize_large_weights(self):
        return self._quantize_large_weights

438 439 440 441
    @property
    def change_concat_ranges(self):
        return self._change_concat_ranges

李寅 已提交
442 443 444 445
    @property
    def quantize_range_file(self):
        return self._quantize_range_file

446 447 448 449
    @property
    def transformer_option(self):
        return self._transformer_option

450 451 452 453
    @property
    def cl_mem_type(self):
        return self._cl_mem_type

454 455
    @input_nodes.setter
    def input_nodes(self, input_nodes):
B
Bin Li 已提交
456
        for node in input_nodes.values():
457 458 459 460 461 462 463
            self._input_nodes[node.name] = node

    def add_input_node(self, input_node):
        self._input_nodes[input_node.name] = input_node

    @output_nodes.setter
    def output_nodes(self, output_nodes):
B
Bin Li 已提交
464
        for node in output_nodes.values():
465 466 467 468 469
            self.output_nodes[node.name] = node

    def add_output_node(self, output_node):
        self._output_nodes[output_node.name] = output_node

B
Bin Li 已提交
470 471
    @check_nodes.setter
    def check_nodes(self, check_nodes):
B
Bin Li 已提交
472
        for node in check_nodes.values():
B
Bin Li 已提交
473 474 475 476 477
            self.check_nodes[node.name] = node

    def add_check_node(self, check_node):
        self._check_nodes[check_node.name] = check_node

478 479 480 481 482 483 484 485
    @data_type.setter
    def data_type(self, data_type):
        self._data_type = data_type

    @device.setter
    def device(self, device):
        self._device = device

486 487 488
    @winograd.setter
    def winograd(self, winograd):
        self._winograd = winograd
489

李寅 已提交
490 491 492 493
    @quantize.setter
    def quantize(self, quantize):
        self._quantize = quantize

B
Bin Li 已提交
494 495 496 497
    @quantize_large_weights.setter
    def quantize_large_weights(self, quantize_large_weights):
        self._quantize_large_weights = quantize_large_weights

李寅 已提交
498 499 500 501
    @quantize_range_file.setter
    def quantize_range_file(self, quantize_range_file):
        self._quantize_range_file = quantize_range_file

502 503 504 505
    @change_concat_ranges.setter
    def change_concat_ranges(self, change_concat_ranges):
        self._change_concat_ranges = change_concat_ranges

李寅 已提交
506 507 508 509
    @transformer_option.setter
    def transformer_option(self, transformer_option):
        self._transformer_option = transformer_option

510 511 512 513
    @cl_mem_type.setter
    def cl_mem_type(self, cl_mem_type):
        self._cl_mem_type = cl_mem_type

514 515 516 517 518 519 520 521
    def disable_transpose_filters(self):
        if TransformerRule.TRANSPOSE_FILTERS in self._transformer_option:
            self._transformer_option.remove(TransformerRule.TRANSPOSE_FILTERS)

    def enable_transpose_filters(self):
        if TransformerRule.TRANSPOSE_FILTERS not in self._transformer_option:
            self._transformer_option.append(TransformerRule.TRANSPOSE_FILTERS)

李寅 已提交
522 523 524 525 526
    def build(self):
        if self._transformer_option:
            self._transformer_option = [TransformerRule[transformer]
                                        for transformer in self._transformer_option]  # noqa
        else:
527 528
            self._transformer_option = [
                # Model structure related transformation
529
                TransformerRule.REMOVE_IDENTITY_OP,
李寅 已提交
530
                TransformerRule.TRANSFORM_FAKE_QUANTIZE,
531 532
                TransformerRule.REMOVE_IDENTITY_OP,
                TransformerRule.TRANSFORM_GLOBAL_POOLING,
Y
yejianwu 已提交
533 534
                TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE,
                TransformerRule.TRANSFORM_BASIC_LSTMCELL,
L
luxuhui 已提交
535
                TransformerRule.TRANSPOSE_RESHAPE_AND_FLATTEN,
536 537
                TransformerRule.FOLD_RESHAPE,
                TransformerRule.TRANSFORM_MATMUL_TO_FC,
538 539 540
                # For StoB -> conv -> BtoS -> BN pattern
                # Insert flatten_atrous_conv before fold_xxx_and_bn
                TransformerRule.FLATTEN_ATROUS_CONV,
541 542
                TransformerRule.FOLD_BATCHNORM,
                TransformerRule.FOLD_CONV_AND_BN,
L
liutuo 已提交
543
                TransformerRule.FOLD_DECONV_AND_BN,
544
                TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN,
545
                TransformerRule.TRANSFORM_ADD_TO_BIASADD,
B
Bin Li 已提交
546
                TransformerRule.REARRANGE_BATCH_TO_SPACE,
547
                TransformerRule.FOLD_BIASADD,
548 549
                TransformerRule.FLATTEN_ATROUS_CONV,
                TransformerRule.FOLD_ACTIVATION,
550
                TransformerRule.FOLD_SQRDIFF_MEAN,
551 552
                TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC,
                TransformerRule.RESHAPE_FC_WEIGHT,
Y
yejianwu 已提交
553
                TransformerRule.FOLD_FC_RESHAPE,
李寅 已提交
554
                TransformerRule.TRANSFORM_CHANNEL_SHUFFLE,
555 556
                # Model data format related transformation
                TransformerRule.TRANSPOSE_FILTERS,
557 558
                # Mace model structure related transformation
                TransformerRule.ADD_IN_OUT_TENSOR_INFO,
559
                TransformerRule.TRANSPOSE_MATMUL_WEIGHT,
560 561
                # Add winograd argument
                TransformerRule.ADD_WINOGRAD_ARG,
562 563 564
                # Data type related transformation
                TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE,
                # Transform finalization
565
                TransformerRule.ADD_OPENCL_INFORMATIONS,
李寅 已提交
566
                # for quantization entropy calibration use
567
                TransformerRule.SORT_BY_EXECUTION,
568 569
                # update the data format of ops
                TransformerRule.UPDATE_DATA_FORMAT,
570
                TransformerRule.TRANSPOSE_DATA_FORMAT,
李寅 已提交
571 572
                # Need to be put after SORT_BY_EXECUTION
                TransformerRule.ADD_QUANTIZE_TENSOR_RANGE,
573
            ]
B
Bin Li 已提交
574 575 576 577
            if self.quantize_large_weights:
                self._transformer_option = self._transformer_option + [
                    TransformerRule.QUANTIZE_LARGE_WEIGHTS
                ]
578
            if self._quantize:
李寅 已提交
579 580
                self._transformer_option = self._transformer_option + [
                    # need to be put after ADD_QUANTIZE_TENSOR_RANGE
李寅 已提交
581 582 583
                    TransformerRule.QUANTIZE_NODES,
                    TransformerRule.QUANTIZE_WEIGHTS,
                    TransformerRule.SORT_BY_EXECUTION,
李寅 已提交
584
                    TransformerRule.CHECK_QUANTIZE_INFO,
李寅 已提交
585 586
                ]

587 588 589 590 591 592 593 594 595

class ConverterUtil(object):
    @staticmethod
    def get_arg(op, arg_name):
        for arg in op.arg:
            if arg.name == arg_name:
                return arg
        return None

596 597 598 599 600 601 602 603 604 605
    @staticmethod
    def del_arg(op, arg_name):
        found_idx = -1
        for idx in range(len(op.arg)):
            if op.arg[idx].name == arg_name:
                found_idx = idx
                break
        if found_idx != -1:
            del op.arg[found_idx]

606 607 608 609 610 611
    @staticmethod
    def add_data_format_arg(op, data_format):
        data_format_arg = op.arg.add()
        data_format_arg.name = MaceKeyword.mace_data_format_str
        data_format_arg.i = data_format.value

李寅 已提交
612 613 614 615 616 617
    @staticmethod
    def add_data_type_arg(op, data_type):
        data_type_arg = op.arg.add()
        data_type_arg.name = MaceKeyword.mace_op_data_type_str
        data_type_arg.i = data_type

618 619 620 621 622 623 624 625 626
    @staticmethod
    def data_format(op):
        arg = ConverterUtil.get_arg(op, MaceKeyword.mace_data_format_str)
        if arg is None:
            return None
        elif arg.i == DataFormat.NHWC.value:
            return DataFormat.NHWC
        elif arg.i == DataFormat.NCHW.value:
            return DataFormat.NCHW
627 628
        elif arg.i == DataFormat.AUTO.value:
            return DataFormat.AUTO
629 630 631 632 633 634 635 636 637 638 639 640 641 642
        else:
            return None

    @staticmethod
    def set_filter_format(net, filter_format):
        arg = net.arg.add()
        arg.name = MaceKeyword.mace_filter_format_str
        arg.i = filter_format.value

    @staticmethod
    def filter_format(net):
        arg = ConverterUtil.get_arg(net, MaceKeyword.mace_filter_format_str)
        if arg is None:
            return None
643 644 645 646 647 648
        elif arg.i == DataFormat.HWIO.value:
            return DataFormat.HWIO
        elif arg.i == DataFormat.HWOI.value:
            return DataFormat.HWOI
        elif arg.i == DataFormat.OIHW.value:
            return DataFormat.OIHW
649 650
        else:
            return None