提交 3892aa0b 编写于 作者: M Megvii Engine Team

fix(imperative/amp): fix bn params for nhwc amp

GitOrigin-RevId: 57a3b9d4181ba8110ad736a80b54f43c3d239962
上级 6f0b5820
...@@ -51,14 +51,7 @@ class _Hashable: ...@@ -51,14 +51,7 @@ class _Hashable:
return self.value == o.value return self.value == o.value
def _matmul( def _matmul(inp1, inp2, transpose_a=False, transpose_b=False, compute_mode="default"):
inp1,
inp2,
transpose_a=False,
transpose_b=False,
compute_mode="default",
format="default",
):
dim1, dim2 = inp1.ndim, inp2.ndim dim1, dim2 = inp1.ndim, inp2.ndim
assert dim1 > 0 and dim2 > 0 assert dim1 > 0 and dim2 > 0
maxdim = dim1 if dim1 > dim2 else dim2 maxdim = dim1 if dim1 > dim2 else dim2
......
...@@ -1206,6 +1206,7 @@ def batch_norm( ...@@ -1206,6 +1206,7 @@ def batch_norm(
if x is None: if x is None:
x = Const(value, inp.dtype, inp.device) x = Const(value, inp.dtype, inp.device)
x.format = inp.format
shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) shape = astensor1d(pshape, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), x, shape) (result,) = apply(builtin.Broadcast(), x, shape)
return result return result
...@@ -1227,14 +1228,14 @@ def batch_norm( ...@@ -1227,14 +1228,14 @@ def batch_norm(
if not training: if not training:
op = builtin.BatchNorm( op = builtin.BatchNorm(
fwd_mode=BatchNorm.FwdMode.INFERENCE, param_dim="dim_1c11", epsilon=eps fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="dim_1c11"
) )
ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] ret = apply(op, inp, weight, bias, running_mean, running_var)[-1]
return ret return ret
else: else:
op = builtin.BatchNorm( op = builtin.BatchNorm(
avg_factor=1 - momentum, param_dim="dim_1c11", epsilon=eps avg_factor=1 - momentum, epsilon=eps, param_dim="dim_1c11"
) )
if has_mean or has_var: if has_mean or has_var:
running_mean = make_full_if_none(running_mean, 0) running_mean = make_full_if_none(running_mean, 0)
......
...@@ -272,6 +272,9 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: ...@@ -272,6 +272,9 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
x = Const(value, inp.dtype, inp.device) x = Const(value, inp.dtype, inp.device)
if inp.ndim == 0: if inp.ndim == 0:
return x return x
# set x's format to use FormatTransformation rule for Broadcast.
x.format = inp.format
return broadcast_to(x, inp.shape) return broadcast_to(x, inp.shape)
......
...@@ -91,13 +91,14 @@ class Optimizer(metaclass=ABCMeta): ...@@ -91,13 +91,14 @@ class Optimizer(metaclass=ABCMeta):
else: else:
param_group["params"] = list(param_group["params"]) param_group["params"] = list(param_group["params"])
for param in param_group["params"]: with _config._override(auto_format_convert=False):
if not isinstance(param, Parameter): for param in param_group["params"]:
raise TypeError( if not isinstance(param, Parameter):
"optimizer can only optimize Parameters, but one of the params is " raise TypeError(
+ str(type(param)) "optimizer can only optimize Parameters, but one of the params is "
) + str(type(param))
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) )
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format))
for name, default in self._defaults.items(): for name, default in self._defaults.items():
if default is required and name not in param_group: if default is required and name not in param_group:
......
...@@ -58,7 +58,6 @@ def run_around_tests(): ...@@ -58,7 +58,6 @@ def run_around_tests():
"benchmark_kernel": config.benchmark_kernel, "benchmark_kernel": config.benchmark_kernel,
"deterministic_kernel": config.deterministic_kernel, "deterministic_kernel": config.deterministic_kernel,
"compute_mode": config._compute_mode, "compute_mode": config._compute_mode,
"conv_format": config._conv_format,
"amp_enabled": amp.enabled, "amp_enabled": amp.enabled,
"convert_inputs": _get_convert_inputs(), "convert_inputs": _get_convert_inputs(),
"amp_dtype_autocast": _get_amp_dtype_autocast(), "amp_dtype_autocast": _get_amp_dtype_autocast(),
...@@ -82,7 +81,6 @@ def run_around_tests(): ...@@ -82,7 +81,6 @@ def run_around_tests():
"benchmark_kernel": config.benchmark_kernel, "benchmark_kernel": config.benchmark_kernel,
"deterministic_kernel": config.deterministic_kernel, "deterministic_kernel": config.deterministic_kernel,
"compute_mode": config._compute_mode, "compute_mode": config._compute_mode,
"conv_format": config._conv_format,
"amp_enabled": amp.enabled, "amp_enabled": amp.enabled,
"convert_inputs": _get_convert_inputs(), "convert_inputs": _get_convert_inputs(),
"amp_dtype_autocast": _get_amp_dtype_autocast(), "amp_dtype_autocast": _get_amp_dtype_autocast(),
......
...@@ -386,13 +386,6 @@ def test_backward_conv2d_dimshuffle(is_symbolic): ...@@ -386,13 +386,6 @@ def test_backward_conv2d_dimshuffle(is_symbolic):
return F.transpose(self.conv(inp), (0, 2, 3, 1)).reshape(1, 18, 2) return F.transpose(self.conv(inp), (0, 2, 3, 1)).reshape(1, 18, 2)
inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4))) inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4)))
# x = tensor(data.transpose(0, 2, 3, 1), format="nhwc")
# w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc")
# b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc")
# grads = [
# np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)),
# np.array([12, 12, 12]).reshape((1, 1, 1, 3)),
# ]
_compare_backward([inp], Net(), is_symbolic) _compare_backward([inp], Net(), is_symbolic)
...@@ -403,37 +396,10 @@ def test_backward_groupconv2d_bn(is_symbolic): ...@@ -403,37 +396,10 @@ def test_backward_groupconv2d_bn(is_symbolic):
super().__init__() super().__init__()
self.conv0 = M.Conv2d(32, 256, 3, groups=32, stride=2) self.conv0 = M.Conv2d(32, 256, 3, groups=32, stride=2)
self.conv1 = M.Conv2d(256, 2048, 3, groups=32, stride=2) self.conv1 = M.Conv2d(256, 2048, 3, groups=32, stride=2)
# self.bn = M.BatchNorm2d(2048) self.bn = M.BatchNorm2d(2048)
def forward(self, inp): def forward(self, inp):
# test manually convert to NHWC, usually used in detection head return self.bn(self.conv1(self.conv0(inp)))
return self.conv1(self.conv0(inp))
inp = mge.tensor(np.ones(shape=(32, 32, 56, 56)).astype("float32")) inp = mge.tensor(np.ones(shape=(32, 32, 56, 56)).astype("float32"))
_compare_backward([inp], Net(), is_symbolic) _compare_backward([inp], Net(), is_symbolic)
# def func(x, w, b, bn_w, bn_b):
# x = F.conv2d(x, w, b, groups=2)
# x = F.batch_norm(
# x,
# running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
# running_var=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
# weight=bn_w,
# bias=bn_b,
# training=True,
# inplace=True,
# )
# return x
# data = np.arange(0, 24).reshape((1, 2, 3, 4))
# x = tensor(data.transpose(0, 2, 3, 1), format="nhwc")
# w = tensor(np.ones((2, 1, 1, 1, 1)), format="nhwc")
# b = tensor(np.ones((1, 1, 1, 2)), format="nhwc")
# bn_w = tensor(np.ones((1, 1, 1, 2)), format="nhwc")
# bn_b = tensor(np.ones((1, 1, 1, 2)), format="nhwc")
# grads = [
# np.array([66, 210]).reshape((2, 1, 1, 1, 1)),
# np.array([12, 12]).reshape((1, 1, 1, 2)),
# np.array([12, 12]).reshape((1, 1, 1, 2)),
# np.array([12, 12]).reshape((1, 1, 1, 2)),
# ]
# _compare_backward(x, func, [w, b, bn_w, bn_b], grads, is_symbolic)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册