提交 ca1d74ad 编写于 作者: M Megvii Engine Team

fix(mge/utils): fix types import error

GitOrigin-RevId: d2d4bd272563bf710000d2f6892b7cbd1985009d
上级 87ff58f7
......@@ -21,6 +21,7 @@ from ..distributed import WORLD, is_distributed
from ..jit.tracing import is_tracing
from ..random import uniform
from ..tensor import Tensor
from ..utils.tuple_function import _pair, _pair_nonzero
from .debug_param import get_conv_execution_strategy
from .distributed import all_reduce_sum
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu
......@@ -35,7 +36,6 @@ from .tensor import (
squeeze,
zeros,
)
from .types import _pair, _pair_nonzero
__all__ = [
"adaptive_avg_pool2d",
......
......@@ -11,8 +11,8 @@ from typing import Tuple, Union
from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin
from ..tensor import Tensor
from ..utils.tuple_function import _pair, _pair_nonzero
from .debug_param import get_conv_execution_strategy
from .types import _pair, _pair_nonzero
def conv_bias_activation(
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import functools
def get_ndtuple(value, *, n, allow_zero: bool = True):
r"""
Converts possibly 1D tuple to n-dim tuple.
:param value: value will be filled in generated tuple.
:param n: how many elements will the tuple have.
:param allow_zero: whether to allow zero tuple value.
:return: a tuple.
"""
if not isinstance(value, collections.Iterable):
value = int(value)
value = tuple([value for i in range(n)])
else:
assert len(value) == n, "tuple len is not equal to n: {}".format(value)
spatial_axis = map(int, value)
value = tuple(spatial_axis)
if allow_zero:
minv = 0
else:
minv = 1
assert min(value) >= minv, "invalid value: {}".format(value)
return value
_single = functools.partial(get_ndtuple, n=1, allow_zero=True)
_pair = functools.partial(get_ndtuple, n=2, allow_zero=True)
_pair_nonzero = functools.partial(get_ndtuple, n=2, allow_zero=False)
_triple = functools.partial(get_ndtuple, n=3, allow_zero=True)
_quadruple = functools.partial(get_ndtuple, n=4, allow_zero=True)
......@@ -11,8 +11,8 @@ from typing import Tuple, Union
import numpy as np
from ..functional import conv1d, conv2d, conv_transpose2d, local_conv2d, relu
from ..functional.types import _pair, _pair_nonzero
from ..tensor import Parameter
from ..utils.tuple_function import _pair, _pair_nonzero
from . import init
from .module import Module
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册