提交 8796586b 编写于 作者: M Megvii Engine Team

refactor(functional): import all from metric in nn

GitOrigin-RevId: 41ab78d78d712f65cb75b6eb686e1de00a23f318
上级 fb15b301
...@@ -15,6 +15,10 @@ from .elemwise import abs, maximum, minimum ...@@ -15,6 +15,10 @@ from .elemwise import abs, maximum, minimum
from .math import topk as _topk from .math import topk as _topk
from .tensor import broadcast_to, transpose from .tensor import broadcast_to, transpose
__all__ = [
"topk_accuracy",
]
def topk_accuracy( def topk_accuracy(
logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1 logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1
......
...@@ -1660,3 +1660,4 @@ warp_perspective = deprecated_func( ...@@ -1660,3 +1660,4 @@ warp_perspective = deprecated_func(
) )
from .quantized import conv_bias_activation # isort:skip from .quantized import conv_bias_activation # isort:skip
from .loss import * # isort:skip from .loss import * # isort:skip
from .metric import * # isort:skip
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册