提交 3892aa0b 编写于 作者: M Megvii Engine Team

fix(imperative/amp): fix bn params for nhwc amp

GitOrigin-RevId: 57a3b9d4181ba8110ad736a80b54f43c3d239962
上级 6f0b5820
......@@ -51,14 +51,7 @@ class _Hashable:
return self.value == o.value
def _matmul(
inp1,
inp2,
transpose_a=False,
transpose_b=False,
compute_mode="default",
format="default",
):
def _matmul(inp1, inp2, transpose_a=False, transpose_b=False, compute_mode="default"):
dim1, dim2 = inp1.ndim, inp2.ndim
assert dim1 > 0 and dim2 > 0
maxdim = dim1 if dim1 > dim2 else dim2
......
......@@ -1206,6 +1206,7 @@ def batch_norm(
if x is None:
x = Const(value, inp.dtype, inp.device)
x.format = inp.format
shape = astensor1d(pshape, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), x, shape)
return result
......@@ -1227,14 +1228,14 @@ def batch_norm(
if not training:
op = builtin.BatchNorm(
fwd_mode=BatchNorm.FwdMode.INFERENCE, param_dim="dim_1c11", epsilon=eps
fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="dim_1c11"
)
ret = apply(op, inp, weight, bias, running_mean, running_var)[-1]
return ret
else:
op = builtin.BatchNorm(
avg_factor=1 - momentum, param_dim="dim_1c11", epsilon=eps
avg_factor=1 - momentum, epsilon=eps, param_dim="dim_1c11"
)
if has_mean or has_var:
running_mean = make_full_if_none(running_mean, 0)
......
......@@ -272,6 +272,9 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
x = Const(value, inp.dtype, inp.device)
if inp.ndim == 0:
return x
# set x's format to use FormatTransformation rule for Broadcast.
x.format = inp.format
return broadcast_to(x, inp.shape)
......
......@@ -91,13 +91,14 @@ class Optimizer(metaclass=ABCMeta):
else:
param_group["params"] = list(param_group["params"])
for param in param_group["params"]:
if not isinstance(param, Parameter):
raise TypeError(
"optimizer can only optimize Parameters, but one of the params is "
+ str(type(param))
)
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format))
with _config._override(auto_format_convert=False):
for param in param_group["params"]:
if not isinstance(param, Parameter):
raise TypeError(
"optimizer can only optimize Parameters, but one of the params is "
+ str(type(param))
)
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format))
for name, default in self._defaults.items():
if default is required and name not in param_group:
......
......@@ -58,7 +58,6 @@ def run_around_tests():
"benchmark_kernel": config.benchmark_kernel,
"deterministic_kernel": config.deterministic_kernel,
"compute_mode": config._compute_mode,
"conv_format": config._conv_format,
"amp_enabled": amp.enabled,
"convert_inputs": _get_convert_inputs(),
"amp_dtype_autocast": _get_amp_dtype_autocast(),
......@@ -82,7 +81,6 @@ def run_around_tests():
"benchmark_kernel": config.benchmark_kernel,
"deterministic_kernel": config.deterministic_kernel,
"compute_mode": config._compute_mode,
"conv_format": config._conv_format,
"amp_enabled": amp.enabled,
"convert_inputs": _get_convert_inputs(),
"amp_dtype_autocast": _get_amp_dtype_autocast(),
......
......@@ -386,13 +386,6 @@ def test_backward_conv2d_dimshuffle(is_symbolic):
return F.transpose(self.conv(inp), (0, 2, 3, 1)).reshape(1, 18, 2)
inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4)))
# x = tensor(data.transpose(0, 2, 3, 1), format="nhwc")
# w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc")
# b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc")
# grads = [
# np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)),
# np.array([12, 12, 12]).reshape((1, 1, 1, 3)),
# ]
_compare_backward([inp], Net(), is_symbolic)
......@@ -403,37 +396,10 @@ def test_backward_groupconv2d_bn(is_symbolic):
super().__init__()
self.conv0 = M.Conv2d(32, 256, 3, groups=32, stride=2)
self.conv1 = M.Conv2d(256, 2048, 3, groups=32, stride=2)
# self.bn = M.BatchNorm2d(2048)
self.bn = M.BatchNorm2d(2048)
def forward(self, inp):
# test manually convert to NHWC, usually used in detection head
return self.conv1(self.conv0(inp))
return self.bn(self.conv1(self.conv0(inp)))
inp = mge.tensor(np.ones(shape=(32, 32, 56, 56)).astype("float32"))
_compare_backward([inp], Net(), is_symbolic)
# def func(x, w, b, bn_w, bn_b):
# x = F.conv2d(x, w, b, groups=2)
# x = F.batch_norm(
# x,
# running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
# running_var=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
# weight=bn_w,
# bias=bn_b,
# training=True,
# inplace=True,
# )
# return x
# data = np.arange(0, 24).reshape((1, 2, 3, 4))
# x = tensor(data.transpose(0, 2, 3, 1), format="nhwc")
# w = tensor(np.ones((2, 1, 1, 1, 1)), format="nhwc")
# b = tensor(np.ones((1, 1, 1, 2)), format="nhwc")
# bn_w = tensor(np.ones((1, 1, 1, 2)), format="nhwc")
# bn_b = tensor(np.ones((1, 1, 1, 2)), format="nhwc")
# grads = [
# np.array([66, 210]).reshape((2, 1, 1, 1, 1)),
# np.array([12, 12]).reshape((1, 1, 1, 2)),
# np.array([12, 12]).reshape((1, 1, 1, 2)),
# np.array([12, 12]).reshape((1, 1, 1, 2)),
# ]
# _compare_backward(x, func, [w, b, bn_w, bn_b], grads, is_symbolic)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册