提交 5655636c 编写于 作者: M Megvii Engine Team

chore(mge/metric): funtional.topk_acc (add nn) use functional.metric

GitOrigin-RevId: 77fd432cb24678c2351a7656ddd44339618f9cb6
上级 d5688c3d
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
from . import metric, utils, vision
from .elemwise import * from .elemwise import *
from .math import * from .math import *
from .nn import * from .nn import *
from .tensor import * from .tensor import *
from .utils import *
from . import distributed # isort:skip from . import utils, vision, distributed # isort:skip
# delete namespace # delete namespace
# pylint: disable=undefined-variable # pylint: disable=undefined-variable
......
...@@ -43,6 +43,7 @@ from .debug_param import get_execution_strategy ...@@ -43,6 +43,7 @@ from .debug_param import get_execution_strategy
from .distributed import all_reduce_sum from .distributed import all_reduce_sum
from .elemwise import _elwise, exp, log, log1p, maximum, minimum from .elemwise import _elwise, exp, log, log1p, maximum, minimum
from .math import max, sum from .math import max, sum
from .metric import topk_accuracy
from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros
__all__ = [ __all__ = [
...@@ -85,6 +86,7 @@ __all__ = [ ...@@ -85,6 +86,7 @@ __all__ = [
"softmax", "softmax",
"softplus", "softplus",
"sync_batch_norm", "sync_batch_norm",
"topk_accuracy",
"warp_affine", "warp_affine",
"warp_perspective", "warp_perspective",
"pixel_shuffle", "pixel_shuffle",
...@@ -1960,5 +1962,4 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: ...@@ -1960,5 +1962,4 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:
from .quantized import conv_bias_activation # isort:skip from .quantized import conv_bias_activation # isort:skip
from .loss import * # isort:skip from .loss import * # isort:skip
from .metric import * # isort:skip
from .vision import * # isort:skip from .vision import * # isort:skip
...@@ -7,8 +7,6 @@ from ..utils.deprecation import deprecated_func ...@@ -7,8 +7,6 @@ from ..utils.deprecation import deprecated_func
from .elemwise import abs, maximum, minimum from .elemwise import abs, maximum, minimum
from .tensor import ones, zeros from .tensor import ones, zeros
__all__ = ["topk_accuracy"]
def _assert_equal( def _assert_equal(
expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册