""" Layer/Module Helpers Hacked together by / Copyright 2020 Ross Wightman """ from itertools import repeat import collections.abc # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable): return x return tuple(repeat(x, n)) return parse to_1tuple = _ntuple(1) to_2tuple = _ntuple(2) to_3tuple = _ntuple(3) to_4tuple = _ntuple(4) to_ntuple = _ntuple def make_divisible(v, divisor=8, min_value=None, round_limit=.9): min_value = min_value or divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_v < round_limit * v: new_v += divisor return new_v