提交 20e93630 编写于 作者: M Megvii Engine Team

feat(mge): rename batch_norm2d -> batch_norm

GitOrigin-RevId: 253e8564eab59528c3c08170958e7c0b3fe3b1c3
上级 aa626726
......@@ -39,7 +39,7 @@ __all__ = [
"adaptive_avg_pool2d",
"adaptive_max_pool2d",
"avg_pool2d",
"batch_norm2d",
"batch_norm",
"conv2d",
"conv_transpose2d",
"dot",
......@@ -605,7 +605,7 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor:
return cached / down
def batch_norm2d(
def batch_norm(
inp: Tensor,
running_mean: Tensor = None,
running_var: Tensor = None,
......@@ -639,6 +639,8 @@ def batch_norm2d(
Default: True
:return: output tensor.
"""
if inp.ndim != 4:
raise NotImplementedError("batch_norm for ndim != 4")
def full_value(value):
C = inp.shape[1]
......
......@@ -11,7 +11,7 @@ from typing import Optional
import numpy as np
from ..distributed.group import WORLD, Group
from ..functional.nn import batch_norm2d, sync_batch_norm
from ..functional.nn import batch_norm, sync_batch_norm
from ..tensor import Parameter, Tensor
from . import init
from .module import Module
......@@ -96,7 +96,7 @@ class _BatchNorm(Module):
else:
exponential_average_factor = 0.0 # useless
output = batch_norm2d(
output = batch_norm(
inp,
self.running_mean if self.track_running_stats else None,
self.running_var if self.track_running_stats else None,
......
......@@ -327,14 +327,14 @@ def test_module_api_hooks():
assert pre_hook_num == 4
assert post_hook_num == 4
mean1 = Parameter(np.zeros(shape), dtype=np.float32)
bn1 = F.batch_norm2d(
bn1 = F.batch_norm(
x + 3, mean1, Parameter(np.ones(shape), dtype=np.float32), training=True
)
np.testing.assert_allclose(
net.i.bn.running_mean.numpy(), mean1.numpy(),
)
mean2 = Parameter(np.zeros(shape), dtype=np.float32)
bn2 = F.batch_norm2d(
bn2 = F.batch_norm(
bn1 + 3, mean2, Parameter(np.ones(shape), dtype=np.float32), training=True
)
np.testing.assert_allclose(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册