base_converter.py 17.8 KB
Newer Older
L
Liangliang He 已提交
1
# Copyright 2018 The MACE Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


16 17 18 19 20
from enum import Enum

from mace.proto import mace_pb2


21 22 23 24
class DeviceType(Enum):
    CPU = 0
    GPU = 2
    HEXAGON = 3
B
Bin Li 已提交
25
    HTA = 4
26
    APU = 5
27 28


29
class DataFormat(Enum):
30
    NONE = 0
31 32 33 34 35 36
    NHWC = 1
    NCHW = 2
    HWIO = 100
    OIHW = 101
    HWOI = 102
    OHWI = 103
37
    AUTO = 1000
38 39


L
liutuo 已提交
40 41
# SAME_LOWER: if the amount of paddings to be added is odd,
# it will add the extra data to the right or bottom
42 43 44 45
class PaddingMode(Enum):
    VALID = 0
    SAME = 1
    FULL = 2
L
liutuo 已提交
46 47
    SAME_LOWER = 3
    NA = 4
48 49 50 51 52 53 54


class PoolingType(Enum):
    AVG = 1
    MAX = 2


L
liutuo 已提交
55 56 57 58 59
class RoundMode(Enum):
    FLOOR = 0
    CEIL = 1


60 61 62 63 64 65 66
class ActivationType(Enum):
    NOOP = 0
    RELU = 1
    RELUX = 2
    PRELU = 3
    TANH = 4
    SIGMOID = 5
L
liutuo 已提交
67
    LEAKYRELU = 6
68 69 70 71 72 73 74 75 76 77 78 79 80


class EltwiseType(Enum):
    SUM = 0
    SUB = 1
    PROD = 2
    DIV = 3
    MIN = 4
    MAX = 5
    NEG = 6
    ABS = 7
    SQR_DIFF = 8
    POW = 9
李寅 已提交
81
    EQUAL = 10
W
w-adamski 已提交
82
    FLOOR_DIV = 11
83 84


L
liutuo 已提交
85 86 87 88 89 90 91
class ReduceType(Enum):
    MEAN = 0
    MIN = 1
    MAX = 2
    PROD = 3


92 93 94 95 96 97
class PadType(Enum):
    CONSTANT = 0
    REFLECT = 1
    SYMMETRIC = 2


L
liutuo 已提交
98 99 100
class FrameworkType(Enum):
    TENSORFLOW = 0
    CAFFE = 1
L
liutuo 已提交
101
    ONNX = 2
L
liutuo 已提交
102 103


104 105 106
MaceSupportedOps = [
    'Activation',
    'AddN',
李寅 已提交
107
    'ArgMax',
108 109 110
    'BatchNorm',
    'BatchToSpaceND',
    'BiasAdd',
李寅 已提交
111
    'Cast',
112 113 114
    'ChannelShuffle',
    'Concat',
    'Conv2D',
115
    'Crop',
116
    'Deconv2D',
L
liutuo 已提交
117
    'Delay',
118 119
    'DepthToSpace',
    'DepthwiseConv2d',
L
liutuo 已提交
120
    'DepthwiseDeconv2d',
121 122
    'Dequantize',
    'Eltwise',
123
    'ExpandDims',
L
liutuo 已提交
124
    'ExtractPooling',
Y
yejianwu 已提交
125
    'Fill',
126
    'FullyConnected',
127 128
    'Gather',
    'Identity',
L
liutuo 已提交
129
    'InferConv2dShape',
L
liutuo 已提交
130
    'KaldiBatchNorm',
131
    'LocalResponseNorm',
Y
yejianwu 已提交
132
    'LSTMCell',
L
liutuo 已提交
133 134
    'LstmNonlinear',
    'DynamicLSTM',
135
    'MatMul',
W
Wiktor Adamski 已提交
136
    'OneHot',
137
    'Pad',
L
liutuo 已提交
138
    'PadContext',
139
    'PNorm',
140
    'Pooling',
L
lichao18 已提交
141
    'PriorBox',
142 143
    'Proposal',
    'Quantize',
L
liutuo 已提交
144
    'Reduce',
145
    'Reshape',
赵奇可 已提交
146
    'ResizeBicubic',
147
    'ResizeBilinear',
L
lichao18 已提交
148
    'ResizeNearestNeighbor',
149
    'Reverse',
150
    'ScalarMath',
151 152
    'Slice',
    'Splice',
Y
yejianwu 已提交
153
    'Split',
154 155 156
    'Shape',
    'Squeeze',
    'Stack',
Y
yejianwu 已提交
157
    'Unstack',
158
    'StridedSlice',
159 160 161
    'Softmax',
    'SpaceToBatchND',
    'SpaceToDepth',
162
    'SqrDiffMean',
163 164
    'SumGroup',
    'TargetRMSNorm',
165
    'Transpose',
166
    'Cumsum',
167 168 169 170
]

MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str)

171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
MaceFixedDataFormatOps = [MaceOp.BatchNorm,
                          MaceOp.BatchToSpaceND,
                          MaceOp.Conv2D,
                          MaceOp.Deconv2D,
                          MaceOp.DepthToSpace,
                          MaceOp.DepthwiseConv2d,
                          MaceOp.DepthwiseDeconv2d,
                          MaceOp.FullyConnected,
                          MaceOp.Pooling,
                          MaceOp.ResizeBicubic,
                          MaceOp.ResizeBilinear,
                          MaceOp.ResizeNearestNeighbor,
                          MaceOp.SpaceToBatchND,
                          MaceOp.SpaceToDepth]

MaceTransposableDataFormatOps = [MaceOp.Activation,
                                 MaceOp.AddN,
                                 MaceOp.BiasAdd,
                                 MaceOp.ChannelShuffle,
                                 MaceOp.Concat,
                                 MaceOp.Crop,
                                 MaceOp.Eltwise,
                                 MaceOp.Pad,
                                 MaceOp.Reduce,
                                 MaceOp.Softmax,
                                 MaceOp.Split,
                                 MaceOp.SqrDiffMean]
198

199 200 201 202 203 204 205 206

class MaceKeyword(object):
    # node related str
    mace_input_node_name = 'mace_input_node'
    mace_output_node_name = 'mace_output_node'
    mace_buffer_type = 'buffer_type'
    # arg related str
    mace_padding_str = 'padding'
207
    mace_padding_type_str = 'padding'
208 209 210 211 212 213 214
    mace_padding_values_str = 'padding_values'
    mace_strides_str = 'strides'
    mace_dilations_str = 'dilations'
    mace_pooling_type_str = 'pooling_type'
    mace_global_pooling_str = 'global_pooling'
    mace_kernel_str = 'kernels'
    mace_data_format_str = 'data_format'
215
    mace_has_data_format_str = 'has_data_format'
216 217 218 219
    mace_filter_format_str = 'filter_format'
    mace_element_type_str = 'type'
    mace_activation_type_str = 'activation'
    mace_activation_max_limit_str = 'max_limit'
Y
yejianwu 已提交
220
    mace_activation_leakyrelu_coefficient_str = 'leakyrelu_coefficient'
221 222 223 224 225 226 227
    mace_resize_size_str = 'size'
    mace_batch_to_space_crops_str = 'crops'
    mace_paddings_str = 'paddings'
    mace_align_corners_str = 'align_corners'
    mace_space_batch_block_shape_str = 'block_shape'
    mace_space_depth_block_size_str = 'block_size'
    mace_constant_value_str = 'constant_value'
L
lichao18 已提交
228
    mace_dim_str = 'dim'
229 230
    mace_dims_str = 'dims'
    mace_axis_str = 'axis'
L
lichao18 已提交
231 232
    mace_end_axis_str = 'end_axis'
    mace_num_axes_str = 'num_axes'
Y
yejianwu 已提交
233
    mace_num_split_str = 'num_split'
L
liutuo 已提交
234
    mace_keepdims_str = 'keepdims'
235 236
    mace_shape_str = 'shape'
    mace_winograd_filter_transformed = 'is_filter_transformed'
237
    mace_device = 'device'
238
    mace_scalar_input_str = 'scalar_input'
L
liutuo 已提交
239
    mace_wino_block_size = 'wino_block_size'
L
liutuo 已提交
240
    mace_output_shape_str = 'output_shape'
241 242 243 244 245 246 247
    mace_begin_mask_str = 'begin_mask'
    mace_end_mask_str = 'end_mask'
    mace_ellipsis_mask_str = 'ellipsis_mask'
    mace_new_axis_mask_str = 'new_axis_mask'
    mace_shrink_axis_mask_str = 'shrink_axis_mask'
    mace_transpose_a_str = 'transpose_a'
    mace_transpose_b_str = 'transpose_b'
248
    mace_op_data_type_str = 'T'
249
    mace_offset_str = 'offset'
250
    mace_opencl_max_image_size = "opencl_max_image_size"
251 252
    mace_seperate_buffer_str = 'seperate_buffer'
    mace_scalar_input_index_str = 'scalar_input_index'
253
    mace_opencl_mem_type = "opencl_mem_type"
L
liutuo 已提交
254
    mace_framework_type_str = "framework_type"
L
liutuo 已提交
255
    mace_group_str = "group"
256 257
    mace_wino_arg_str = "wino_block_size"
    mace_quantize_flag_arg_str = "quantize_flag"
L
liutuo 已提交
258 259 260 261
    mace_epsilon_str = 'epsilon'
    mace_reduce_type_str = 'reduce_type'
    mace_argmin_str = 'argmin'
    mace_round_mode_str = 'round_mode'
L
lichao18 已提交
262 263 264 265 266 267 268 269
    mace_min_size_str = 'min_size'
    mace_max_size_str = 'max_size'
    mace_aspect_ratio_str = 'aspect_ratio'
    mace_flip_str = 'flip'
    mace_clip_str = 'clip'
    mace_variance_str = 'variance'
    mace_step_h_str = 'step_h'
    mace_step_w_str = 'step_w'
B
Bin Li 已提交
270
    mace_find_range_every_time = 'find_range_every_time'
271
    mace_non_zero = 'non_zero'
272
    mace_pad_type_str = 'pad_type'
273 274
    mace_exclusive_str = 'exclusive'
    mace_reverse_str = 'reverse'
275
    mace_const_data_num_arg_str = 'const_data_num'
276 277 278


class TransformerRule(Enum):
李寅 已提交
279 280 281 282 283 284 285
    REMOVE_IDENTITY_OP = 1
    TRANSFORM_GLOBAL_POOLING = 2
    FOLD_RESHAPE = 3
    TRANSFORM_MATMUL_TO_FC = 4
    FOLD_BATCHNORM = 5
    FOLD_CONV_AND_BN = 6
    FOLD_DEPTHWISE_CONV_AND_BN = 7
286
    ADD_WINOGRAD_ARG = 8
李寅 已提交
287 288
    TRANSFORM_ADD_TO_BIASADD = 9
    FOLD_BIASADD = 10
289 290 291 292 293 294
    FLATTEN_ATROUS_CONV = 11
    FOLD_ACTIVATION = 12
    TRANSPOSE_FILTERS = 13
    RESHAPE_FC_WEIGHT = 14
    TRANSPOSE_DATA_FORMAT = 15
    TRANSFORM_GLOBAL_CONV_TO_FC = 16
295
    ADD_BUFFER_TRANSFORM = 17
李寅 已提交
296
    ADD_DEVICE = 18
297
    SORT_BY_EXECUTION = 19
298
    ADD_IN_OUT_TENSOR_INFO = 20
李寅 已提交
299
    ADD_MACE_INPUT_AND_OUTPUT_NODES = 21
300
    UPDATE_FLOAT_OP_DATA_TYPE = 22
李寅 已提交
301 302 303
    QUANTIZE_NODES = 23
    ADD_QUANTIZE_TENSOR_RANGE = 24
    QUANTIZE_WEIGHTS = 25
304 305
    TRANSFORM_LSTMCELL_ZEROSTATE = 26
    TRANSFORM_BASIC_LSTMCELL = 27
李寅 已提交
306 307
    TRANSFORM_FAKE_QUANTIZE = 28
    CHECK_QUANTIZE_INFO = 29
B
Bin Li 已提交
308
    REARRANGE_BATCH_TO_SPACE = 30
309
    ADD_OPENCL_INFORMATIONS = 31
L
liutuo 已提交
310
    FOLD_DECONV_AND_BN = 32
311
    FOLD_SQRDIFF_MEAN = 33
312
    TRANSPOSE_MATMUL_WEIGHT = 34
李寅 已提交
313
    FOLD_EMBEDDING_LOOKUP = 35
L
lichao18 已提交
314
    TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN = 36
Y
yejianwu 已提交
315
    FOLD_FC_RESHAPE = 37
李寅 已提交
316
    TRANSFORM_CHANNEL_SHUFFLE = 38
317
    UPDATE_DATA_FORMAT = 39
318
    QUANTIZE_SPECIFIC_OPS_ONLY = 40
319
    FP16_MATMUL_WEIGHT = 41
Y
yulianfei 已提交
320
    FP16_GATHER_WEIGHT = 42
321 322 323 324 325 326 327 328 329 330 331 332 333 334


class ConverterInterface(object):
    """Base class for converting external models to mace models."""

    def run(self):
        raise NotImplementedError('run')


class NodeInfo(object):
    """A class for describing node information"""

    def __init__(self):
        self._name = None
L
liuqi 已提交
335
        self._data_type = mace_pb2.DT_FLOAT
336
        self._shape = []
337
        self._data_format = DataFormat.NHWC
李寅 已提交
338
        self._range = [-1.0, 1.0]
339 340 341 342 343

    @property
    def name(self):
        return self._name

L
liuqi 已提交
344 345 346 347
    @property
    def data_type(self):
        return self._data_type

348 349 350 351
    @property
    def shape(self):
        return self._shape

352 353 354 355
    @property
    def data_format(self):
        return self._data_format

李寅 已提交
356 357 358 359
    @property
    def range(self):
        return self._range

360 361 362 363
    @name.setter
    def name(self, name):
        self._name = name

L
liuqi 已提交
364 365 366 367
    @data_type.setter
    def data_type(self, data_type):
        self._data_type = data_type

368 369 370 371
    @shape.setter
    def shape(self, shape):
        self._shape = shape

372 373 374 375
    @data_format.setter
    def data_format(self, data_format):
        self._data_format = data_format

李寅 已提交
376 377 378 379
    @range.setter
    def range(self, range):
        self._range = range

380 381 382 383 384 385 386
    def __str__(self):
        return '%s %s' % (self._name, str(self._shape))


class ConverterOption(object):
    """A class for specifying options passed to converter tool"""

李寅 已提交
387
    def __init__(self):
388 389
        self._input_nodes = {}
        self._output_nodes = {}
B
Bin Li 已提交
390
        self._check_nodes = {}
391
        self._data_type = mace_pb2.DT_FLOAT
392
        self._device = DeviceType.CPU.value
393
        self._winograd = 0
李寅 已提交
394 395
        self._quantize = False
        self._quantize_range_file = ""
396
        self._change_concat_ranges = False
李寅 已提交
397
        self._transformer_option = None
398
        self._cl_mem_type = ""
399 400 401 402 403 404 405 406 407

    @property
    def input_nodes(self):
        return self._input_nodes

    @property
    def output_nodes(self):
        return self._output_nodes

B
Bin Li 已提交
408 409 410 411
    @property
    def check_nodes(self):
        return self._check_nodes

412 413 414 415 416 417 418 419 420
    @property
    def data_type(self):
        return self._data_type

    @property
    def device(self):
        return self._device

    @property
421 422
    def winograd(self):
        return self._winograd
423

李寅 已提交
424 425 426 427
    @property
    def quantize(self):
        return self._quantize

428 429 430 431
    @property
    def change_concat_ranges(self):
        return self._change_concat_ranges

李寅 已提交
432 433 434 435
    @property
    def quantize_range_file(self):
        return self._quantize_range_file

436 437 438 439
    @property
    def transformer_option(self):
        return self._transformer_option

440 441 442 443
    @property
    def cl_mem_type(self):
        return self._cl_mem_type

444 445
    @input_nodes.setter
    def input_nodes(self, input_nodes):
B
Bin Li 已提交
446
        for node in input_nodes.values():
447 448 449 450 451 452 453
            self._input_nodes[node.name] = node

    def add_input_node(self, input_node):
        self._input_nodes[input_node.name] = input_node

    @output_nodes.setter
    def output_nodes(self, output_nodes):
B
Bin Li 已提交
454
        for node in output_nodes.values():
455 456 457 458 459
            self.output_nodes[node.name] = node

    def add_output_node(self, output_node):
        self._output_nodes[output_node.name] = output_node

B
Bin Li 已提交
460 461
    @check_nodes.setter
    def check_nodes(self, check_nodes):
B
Bin Li 已提交
462
        for node in check_nodes.values():
B
Bin Li 已提交
463 464 465 466 467
            self.check_nodes[node.name] = node

    def add_check_node(self, check_node):
        self._check_nodes[check_node.name] = check_node

468 469 470 471 472 473 474 475
    @data_type.setter
    def data_type(self, data_type):
        self._data_type = data_type

    @device.setter
    def device(self, device):
        self._device = device

476 477 478
    @winograd.setter
    def winograd(self, winograd):
        self._winograd = winograd
479

李寅 已提交
480 481 482 483 484 485 486 487
    @quantize.setter
    def quantize(self, quantize):
        self._quantize = quantize

    @quantize_range_file.setter
    def quantize_range_file(self, quantize_range_file):
        self._quantize_range_file = quantize_range_file

488 489 490 491
    @change_concat_ranges.setter
    def change_concat_ranges(self, change_concat_ranges):
        self._change_concat_ranges = change_concat_ranges

李寅 已提交
492 493 494 495
    @transformer_option.setter
    def transformer_option(self, transformer_option):
        self._transformer_option = transformer_option

496 497 498 499
    @cl_mem_type.setter
    def cl_mem_type(self, cl_mem_type):
        self._cl_mem_type = cl_mem_type

500 501 502 503 504 505 506 507
    def disable_transpose_filters(self):
        if TransformerRule.TRANSPOSE_FILTERS in self._transformer_option:
            self._transformer_option.remove(TransformerRule.TRANSPOSE_FILTERS)

    def enable_transpose_filters(self):
        if TransformerRule.TRANSPOSE_FILTERS not in self._transformer_option:
            self._transformer_option.append(TransformerRule.TRANSPOSE_FILTERS)

李寅 已提交
508 509 510 511 512
    def build(self):
        if self._transformer_option:
            self._transformer_option = [TransformerRule[transformer]
                                        for transformer in self._transformer_option]  # noqa
        else:
513 514
            self._transformer_option = [
                # Model structure related transformation
李寅 已提交
515
                TransformerRule.TRANSFORM_FAKE_QUANTIZE,
516 517
                TransformerRule.REMOVE_IDENTITY_OP,
                TransformerRule.TRANSFORM_GLOBAL_POOLING,
Y
yejianwu 已提交
518 519
                TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE,
                TransformerRule.TRANSFORM_BASIC_LSTMCELL,
L
lichao18 已提交
520
                TransformerRule.TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN,
521 522
                TransformerRule.FOLD_RESHAPE,
                TransformerRule.TRANSFORM_MATMUL_TO_FC,
523 524 525
                # For StoB -> conv -> BtoS -> BN pattern
                # Insert flatten_atrous_conv before fold_xxx_and_bn
                TransformerRule.FLATTEN_ATROUS_CONV,
526 527
                TransformerRule.FOLD_BATCHNORM,
                TransformerRule.FOLD_CONV_AND_BN,
L
liutuo 已提交
528
                TransformerRule.FOLD_DECONV_AND_BN,
529
                TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN,
530
                TransformerRule.TRANSFORM_ADD_TO_BIASADD,
B
Bin Li 已提交
531
                TransformerRule.REARRANGE_BATCH_TO_SPACE,
532
                TransformerRule.FOLD_BIASADD,
533 534
                TransformerRule.FLATTEN_ATROUS_CONV,
                TransformerRule.FOLD_ACTIVATION,
535
                TransformerRule.FOLD_SQRDIFF_MEAN,
536 537
                TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC,
                TransformerRule.RESHAPE_FC_WEIGHT,
Y
yejianwu 已提交
538
                TransformerRule.FOLD_FC_RESHAPE,
李寅 已提交
539
                TransformerRule.TRANSFORM_CHANNEL_SHUFFLE,
540 541
                # Model data format related transformation
                TransformerRule.TRANSPOSE_FILTERS,
542 543
                # Mace model structure related transformation
                TransformerRule.ADD_IN_OUT_TENSOR_INFO,
544
                TransformerRule.TRANSPOSE_MATMUL_WEIGHT,
545 546
                # Add winograd argument
                TransformerRule.ADD_WINOGRAD_ARG,
547 548 549
                # Data type related transformation
                TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE,
                # Transform finalization
550
                TransformerRule.ADD_OPENCL_INFORMATIONS,
李寅 已提交
551
                # for quantization entropy calibration use
552
                TransformerRule.SORT_BY_EXECUTION,
553 554
                # update the data format of ops
                TransformerRule.UPDATE_DATA_FORMAT,
555
                TransformerRule.TRANSPOSE_DATA_FORMAT,
李寅 已提交
556 557
                # Need to be put after SORT_BY_EXECUTION
                TransformerRule.ADD_QUANTIZE_TENSOR_RANGE,
558 559
            ]
            if self._quantize:
李寅 已提交
560 561
                self._transformer_option = self._transformer_option + [
                    # need to be put after ADD_QUANTIZE_TENSOR_RANGE
李寅 已提交
562 563 564
                    TransformerRule.QUANTIZE_NODES,
                    TransformerRule.QUANTIZE_WEIGHTS,
                    TransformerRule.SORT_BY_EXECUTION,
李寅 已提交
565
                    TransformerRule.CHECK_QUANTIZE_INFO,
李寅 已提交
566 567
                ]

568 569 570 571 572 573 574 575 576

class ConverterUtil(object):
    @staticmethod
    def get_arg(op, arg_name):
        for arg in op.arg:
            if arg.name == arg_name:
                return arg
        return None

577 578 579 580 581 582 583 584 585 586
    @staticmethod
    def del_arg(op, arg_name):
        found_idx = -1
        for idx in range(len(op.arg)):
            if op.arg[idx].name == arg_name:
                found_idx = idx
                break
        if found_idx != -1:
            del op.arg[found_idx]

587 588 589 590 591 592
    @staticmethod
    def add_data_format_arg(op, data_format):
        data_format_arg = op.arg.add()
        data_format_arg.name = MaceKeyword.mace_data_format_str
        data_format_arg.i = data_format.value

李寅 已提交
593 594 595 596 597 598
    @staticmethod
    def add_data_type_arg(op, data_type):
        data_type_arg = op.arg.add()
        data_type_arg.name = MaceKeyword.mace_op_data_type_str
        data_type_arg.i = data_type

599 600 601 602 603 604 605 606 607
    @staticmethod
    def data_format(op):
        arg = ConverterUtil.get_arg(op, MaceKeyword.mace_data_format_str)
        if arg is None:
            return None
        elif arg.i == DataFormat.NHWC.value:
            return DataFormat.NHWC
        elif arg.i == DataFormat.NCHW.value:
            return DataFormat.NCHW
608 609
        elif arg.i == DataFormat.AUTO.value:
            return DataFormat.AUTO
610 611 612 613 614 615 616 617 618 619 620 621 622 623
        else:
            return None

    @staticmethod
    def set_filter_format(net, filter_format):
        arg = net.arg.add()
        arg.name = MaceKeyword.mace_filter_format_str
        arg.i = filter_format.value

    @staticmethod
    def filter_format(net):
        arg = ConverterUtil.get_arg(net, MaceKeyword.mace_filter_format_str)
        if arg is None:
            return None
624 625 626 627 628 629
        elif arg.i == DataFormat.HWIO.value:
            return DataFormat.HWIO
        elif arg.i == DataFormat.HWOI.value:
            return DataFormat.HWOI
        elif arg.i == DataFormat.OIHW.value:
            return DataFormat.OIHW
630 631
        else:
            return None