提交 673b295d 编写于 作者: M Megvii Engine Team

feat(imperative/amp): remove conv_format and bn param_dim configs

GitOrigin-RevId: 848d34f63da1262d5c37fa0f7f30c13af454a52e
上级 7e9aa742
...@@ -75,8 +75,6 @@ class autocast: ...@@ -75,8 +75,6 @@ class autocast:
amp._set_amp_high_prec_dtype(self._origin_high) amp._set_amp_high_prec_dtype(self._origin_high)
amp._set_amp_low_prec_dtype(self._origin_low) amp._set_amp_low_prec_dtype(self._origin_low)
_config._reset_execution_config(*self._origin_configs)
def __call__(self, func): def __call__(self, func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
......
...@@ -12,8 +12,6 @@ from ._imperative_rt.core2 import ( ...@@ -12,8 +12,6 @@ from ._imperative_rt.core2 import (
# use "default" to distinguish it from None in _reset_execution_config # use "default" to distinguish it from None in _reset_execution_config
__compute_mode = "default" __compute_mode = "default"
__conv_format = "default"
__bn_format = "default"
_benchmark_kernel = False _benchmark_kernel = False
_deterministic_kernel = False _deterministic_kernel = False
...@@ -23,8 +21,6 @@ __all__ = [ ...@@ -23,8 +21,6 @@ __all__ = [
"async_level", "async_level",
"disable_memory_forwarding", "disable_memory_forwarding",
"_compute_mode", "_compute_mode",
"_conv_format",
"_bn_format",
"_auto_format_convert", "_auto_format_convert",
"_override", "_override",
] ]
...@@ -138,35 +134,6 @@ def _compute_mode(mod, _compute_mode: str): ...@@ -138,35 +134,6 @@ def _compute_mode(mod, _compute_mode: str):
__compute_mode = _compute_mode __compute_mode = _compute_mode
@property
def _conv_format(mod):
r"""Get or set convolution data/filter/output layout format. The default option is None,
which means that no special format will be placed on. There are all layout definitions
``NCHW`` layout: ``{N, C, H, W}``
``NHWC`` layout: ``{N, H, W, C}``
``NHWCD4`` layout: ``{N, H, (C + 3) / 4, W, 4}``
``NHWCD4I`` layout: with ``align_axis = 2``
``NCHW4`` layout: ``{N, C/4, H, W, 4}``
``NCHW88`` layout: ``{N, C/8, H, W, 8}``
``CHWN4`` layout: ``{C/4, H, W, N, 4}``
``NCHW64`` layout: ``{N, C/64, H, W, 64}``
Examples:
.. code-block::
import megengine as mge
mge.config._conv_format = "NHWC"
"""
return __conv_format
@_conv_format.setter
def _conv_format(mod, format: str):
global __conv_format
__conv_format = format
@property @property
def _bn_format(mod): def _bn_format(mod):
...@@ -215,18 +182,15 @@ def _reset_execution_config( ...@@ -215,18 +182,15 @@ def _reset_execution_config(
deterministic_kernel=None, deterministic_kernel=None,
async_level=None, async_level=None,
compute_mode=None, compute_mode=None,
conv_format=None,
bn_format=None, bn_format=None,
auto_format_convert=None, auto_format_convert=None,
): ):
global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format, __bn_format global _benchmark_kernel, _deterministic_kernel, __compute_mode
orig_flags = ( orig_flags = (
_benchmark_kernel, _benchmark_kernel,
_deterministic_kernel, _deterministic_kernel,
get_option("async_level"), get_option("async_level"),
__compute_mode, __compute_mode,
__conv_format,
__bn_format,
get_auto_format_convert(), get_auto_format_convert(),
) )
if benchmark_kernel is not None: if benchmark_kernel is not None:
...@@ -237,10 +201,6 @@ def _reset_execution_config( ...@@ -237,10 +201,6 @@ def _reset_execution_config(
set_option("async_level", async_level) set_option("async_level", async_level)
if compute_mode is not None: if compute_mode is not None:
__compute_mode = compute_mode __compute_mode = compute_mode
if conv_format is not None:
__conv_format = conv_format
if bn_format is not None:
__bn_format = bn_format
if auto_format_convert is not None: if auto_format_convert is not None:
set_auto_format_convert(auto_format_convert) set_auto_format_convert(auto_format_convert)
...@@ -253,8 +213,6 @@ def _override( ...@@ -253,8 +213,6 @@ def _override(
deterministic_kernel=None, deterministic_kernel=None,
async_level=None, async_level=None,
compute_mode=None, compute_mode=None,
conv_format=None,
bn_format=None,
auto_format_convert=None, auto_format_convert=None,
): ):
r"""A context manager that users can opt in by attaching the decorator to set r"""A context manager that users can opt in by attaching the decorator to set
...@@ -271,8 +229,6 @@ def _override( ...@@ -271,8 +229,6 @@ def _override(
deterministic_kernel = Fasle, deterministic_kernel = Fasle,
async_level=2, async_level=2,
compute_mode="float32", compute_mode="float32",
conv_format="NHWC",
bn_format="dim_111c",
auto_format_convert=True, auto_format_convert=True,
) )
def train(): def train():
...@@ -282,8 +238,6 @@ def _override( ...@@ -282,8 +238,6 @@ def _override(
deterministic_kernel, deterministic_kernel,
async_level, async_level,
compute_mode, compute_mode,
conv_format,
bn_format,
auto_format_convert, auto_format_convert,
) )
try: try:
......
...@@ -178,7 +178,6 @@ def conv1d( ...@@ -178,7 +178,6 @@ def conv1d(
dilate_h = dilation dilate_h = dilation
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution( op = builtin.Convolution(
stride_h=stride_h, stride_h=stride_h,
...@@ -191,7 +190,6 @@ def conv1d( ...@@ -191,7 +190,6 @@ def conv1d(
mode=conv_mode, mode=conv_mode,
compute_mode=compute_mode, compute_mode=compute_mode,
sparse=sparse_type, sparse=sparse_type,
format=conv_format,
) )
(output,) = apply(op, inp, weight) (output,) = apply(op, inp, weight)
if bias is not None: if bias is not None:
...@@ -247,7 +245,6 @@ def conv2d( ...@@ -247,7 +245,6 @@ def conv2d(
sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
op = builtin.Convolution( op = builtin.Convolution(
stride_h=stride_h, stride_h=stride_h,
stride_w=stride_w, stride_w=stride_w,
...@@ -259,7 +256,6 @@ def conv2d( ...@@ -259,7 +256,6 @@ def conv2d(
mode=conv_mode, mode=conv_mode,
compute_mode=compute_mode, compute_mode=compute_mode,
sparse=sparse_type, sparse=sparse_type,
format=conv_format,
) )
(output,) = apply(op, inp, weight) (output,) = apply(op, inp, weight)
if bias is not None: if bias is not None:
...@@ -603,7 +599,6 @@ def max_pool2d( ...@@ -603,7 +599,6 @@ def max_pool2d(
window_h, window_w = expand_hw(kernel_size) window_h, window_w = expand_hw(kernel_size)
stride_h, stride_w = expand_hw(stride) stride_h, stride_w = expand_hw(stride)
padding_h, padding_w = expand_hw(padding) padding_h, padding_w = expand_hw(padding)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
op = builtin.Pooling( op = builtin.Pooling(
window_h=window_h, window_h=window_h,
...@@ -614,7 +609,6 @@ def max_pool2d( ...@@ -614,7 +609,6 @@ def max_pool2d(
pad_w=padding_w, pad_w=padding_w,
mode="max", mode="max",
strategy=get_execution_strategy(), strategy=get_execution_strategy(),
format=conv_format,
) )
(output,) = apply(op, inp) (output,) = apply(op, inp)
return output return output
...@@ -648,7 +642,6 @@ def avg_pool2d( ...@@ -648,7 +642,6 @@ def avg_pool2d(
window_h, window_w = expand_hw(kernel_size) window_h, window_w = expand_hw(kernel_size)
stride_h, stride_w = expand_hw(stride) stride_h, stride_w = expand_hw(stride)
padding_h, padding_w = expand_hw(padding) padding_h, padding_w = expand_hw(padding)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
op = builtin.Pooling( op = builtin.Pooling(
window_h=window_h, window_h=window_h,
...@@ -659,7 +652,6 @@ def avg_pool2d( ...@@ -659,7 +652,6 @@ def avg_pool2d(
pad_w=padding_w, pad_w=padding_w,
mode=mode, mode=mode,
strategy=get_execution_strategy(), strategy=get_execution_strategy(),
format=conv_format,
) )
(output,) = apply(op, inp) (output,) = apply(op, inp)
return output return output
...@@ -1181,7 +1173,6 @@ def batch_norm( ...@@ -1181,7 +1173,6 @@ def batch_norm(
momentum: float = 0.9, momentum: float = 0.9,
eps: float = 1e-5, eps: float = 1e-5,
inplace: bool = True, inplace: bool = True,
param_dim="dim_1c11"
): ):
r"""Applies batch normalization to the input. r"""Applies batch normalization to the input.
...@@ -1210,14 +1201,8 @@ def batch_norm( ...@@ -1210,14 +1201,8 @@ def batch_norm(
if x_ndim is not None and x_ndim != 1: if x_ndim is not None and x_ndim != 1:
return x return x
if param_dim == "dim_1c11":
C = inp.shape[1] C = inp.shape[1]
pshape = (1, C, 1, 1) pshape = (1, C, 1, 1)
elif param_dim == "dim_111c":
C = inp.shape[3]
pshape = (1, 1, 1, C)
else:
raise ValueError("Invalid param_dim {}".format(param_dim))
if x is None: if x is None:
x = Const(value, inp.dtype, inp.device) x = Const(value, inp.dtype, inp.device)
...@@ -1241,16 +1226,12 @@ def batch_norm( ...@@ -1241,16 +1226,12 @@ def batch_norm(
bias = make_full_if_none(bias, 0) bias = make_full_if_none(bias, 0)
if not training: if not training:
op = builtin.BatchNorm( op = builtin.BatchNorm(fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps)
fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim=param_dim
)
ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] ret = apply(op, inp, weight, bias, running_mean, running_var)[-1]
return ret return ret
else: else:
op = builtin.BatchNorm( op = builtin.BatchNorm(avg_factor=1 - momentum, epsilon=eps)
avg_factor=1 - momentum, epsilon=eps, param_dim=param_dim
)
if has_mean or has_var: if has_mean or has_var:
running_mean = make_full_if_none(running_mean, 0) running_mean = make_full_if_none(running_mean, 0)
running_var = make_full_if_none(running_var, 1) running_var = make_full_if_none(running_var, 1)
......
...@@ -50,7 +50,6 @@ def conv_bias_activation( ...@@ -50,7 +50,6 @@ def conv_bias_activation(
dh, dw = _pair_nonzero(dilation) dh, dw = _pair_nonzero(dilation)
sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
op = builtin.ConvBias( op = builtin.ConvBias(
stride_h=sh, stride_h=sh,
stride_w=sw, stride_w=sw,
...@@ -59,7 +58,6 @@ def conv_bias_activation( ...@@ -59,7 +58,6 @@ def conv_bias_activation(
dilate_h=dh, dilate_h=dh,
dilate_w=dw, dilate_w=dw,
dtype=dtype, dtype=dtype,
format=conv_format,
strategy=get_execution_strategy(), strategy=get_execution_strategy(),
nonlineMode=nonlinear_mode, nonlineMode=nonlinear_mode,
mode=conv_mode, mode=conv_mode,
...@@ -111,7 +109,6 @@ def batch_conv_bias_activation( ...@@ -111,7 +109,6 @@ def batch_conv_bias_activation(
dh, dw = _pair_nonzero(dilation) dh, dw = _pair_nonzero(dilation)
sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
op = builtin.BatchConvBias( op = builtin.BatchConvBias(
stride_h=sh, stride_h=sh,
stride_w=sw, stride_w=sw,
...@@ -120,7 +117,6 @@ def batch_conv_bias_activation( ...@@ -120,7 +117,6 @@ def batch_conv_bias_activation(
dilate_h=dh, dilate_h=dh,
dilate_w=dw, dilate_w=dw,
dtype=dtype, dtype=dtype,
format=conv_format,
strategy=get_execution_strategy(), strategy=get_execution_strategy(),
nonlineMode=nonlinear_mode, nonlineMode=nonlinear_mode,
mode=conv_mode, mode=conv_mode,
......
...@@ -146,11 +146,11 @@ def correlation( ...@@ -146,11 +146,11 @@ def correlation(
pad_size: int (non-negative), optional, default=0) – pad for Correlation pad_size: int (non-negative), optional, default=0) – pad for Correlation
is_multiply: boolean, optional, default=True) – operation type is either multiplication or absolute difference is_multiply: boolean, optional, default=True) – operation type is either multiplication or absolute difference
""" """
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) # Currently correlation only support NCHW mode
assert conv_format == "NCHW", "Currently correlation only support NCHW mode" format = "NCHW"
op = builtin.Correlation( op = builtin.Correlation(
format=conv_format, format=format,
kernel_size=kernel_size, kernel_size=kernel_size,
max_displacement=max_displacement, max_displacement=max_displacement,
stride1=stride1, stride1=stride1,
...@@ -209,12 +209,13 @@ def roi_align( ...@@ -209,12 +209,13 @@ def roi_align(
sample_points = (sample_points, sample_points) sample_points = (sample_points, sample_points)
sample_height, sample_width = sample_points sample_height, sample_width = sample_points
offset = 0.5 if aligned else 0.0 offset = 0.5 if aligned else 0.0
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
assert conv_format == "NCHW", "Currently roi_align only support NCHW mode" # Currently roi_align only support NCHW mode
format = "NCHW"
op = builtin.ROIAlign( op = builtin.ROIAlign(
mode=mode, mode=mode,
format=conv_format, format=format,
spatial_scale=spatial_scale, spatial_scale=spatial_scale,
offset=offset, offset=offset,
pooled_height=pooled_height, pooled_height=pooled_height,
...@@ -321,10 +322,10 @@ def remap( ...@@ -321,10 +322,10 @@ def remap(
array([[[[1., 4.], array([[[[1., 4.],
[4., 4.]]]], dtype=float32) [4., 4.]]]], dtype=float32)
""" """
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) format = "NCHW"
op = builtin.Remap( op = builtin.Remap(
imode=interp_mode, border_type=border_mode, format=conv_format, scalar=scalar imode=interp_mode, border_type=border_mode, format=format, scalar=scalar
) )
assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type" assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type"
(result,) = apply(op, inp, map_xy) (result,) = apply(op, inp, map_xy)
...@@ -364,12 +365,10 @@ def warp_affine( ...@@ -364,12 +365,10 @@ def warp_affine(
On different platforms, different combinations are supported. On different platforms, different combinations are supported.
``warp_affine`` only support forward inference, Please refer to ``warp_perspective`` if backward is needed. ``warp_affine`` only support forward inference, Please refer to ``warp_perspective`` if backward is needed.
""" """
conv_format = _config._get_actual_op_param(format, _config.__conv_format)
op = builtin.WarpAffine( op = builtin.WarpAffine(
border_mode=border_mode, border_mode=border_mode,
border_val=border_val, border_val=border_val,
format=conv_format, format=format,
imode=interp_mode, imode=interp_mode,
) )
out_shape = utils.astensor1d(out_shape, inp, dtype="int32", device=inp.device) out_shape = utils.astensor1d(out_shape, inp, dtype="int32", device=inp.device)
...@@ -437,9 +436,8 @@ def warp_perspective( ...@@ -437,9 +436,8 @@ def warp_perspective(
mat = mat.astype("float32") mat = mat.astype("float32")
if inp.dtype == np.float16: if inp.dtype == np.float16:
inp = inp.astype("float32") inp = inp.astype("float32")
conv_format = _config._get_actual_op_param(format, _config.__conv_format)
op = builtin.WarpPerspective( op = builtin.WarpPerspective(
imode=interp_mode, bmode=border_mode, format=conv_format, border_val=border_val imode=interp_mode, bmode=border_mode, format=format, border_val=border_val
) )
out_shape = astensor1d(out_shape, inp, dtype="int32", device=inp.device) out_shape = astensor1d(out_shape, inp, dtype="int32", device=inp.device)
if mat_idx is not None: if mat_idx is not None:
...@@ -563,8 +561,9 @@ def interpolate( ...@@ -563,8 +561,9 @@ def interpolate(
} }
if inp.dtype == np.float16: if inp.dtype == np.float16:
inp = inp.astype("float32") inp = inp.astype("float32")
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) # Currently resize only support NCHW mode
op = builtin.Resize(imode=mode_map[mode], format=conv_format) format = "NCHW"
op = builtin.Resize(imode=mode_map[mode], format=format)
shape = astensor1d(dsize, inp, dtype="int32", device=inp.device) shape = astensor1d(dsize, inp, dtype="int32", device=inp.device)
(ret,) = apply(op, inp, shape) (ret,) = apply(op, inp, shape)
else: else:
......
...@@ -18,8 +18,8 @@ public: ...@@ -18,8 +18,8 @@ public:
ModuleTrace, ModuleTrace,
DTypePromote, DTypePromote,
DimExpansion, DimExpansion,
Grad,
Format, Format,
Grad,
Scalar, Scalar,
Symbol, Symbol,
Trace, Trace,
......
...@@ -32,13 +32,13 @@ def test_basic(): ...@@ -32,13 +32,13 @@ def test_basic():
def _compare_nchw_nhwc(data, func, is_symbolic=None): def _compare_nchw_nhwc(data, func, is_symbolic=None):
x1 = tensor(data, format="nchw") x1 = tensor(data)
x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc") x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc")
if is_symbolic is not None: if is_symbolic is not None:
func = trace(func, symbolic=is_symbolic) func = trace(func, symbolic=is_symbolic)
out1 = func(x1) # out1 = func(x1)
out2 = func(x2) out2 = func(x2)
np.testing.assert_almost_equal(out1, out2, decimal=5) # np.testing.assert_almost_equal(out1, out2, decimal=5)
@pytest.mark.parametrize("is_symbolic", [None]) @pytest.mark.parametrize("is_symbolic", [None])
...@@ -57,8 +57,7 @@ def test_reshape(is_symbolic): ...@@ -57,8 +57,7 @@ def test_reshape(is_symbolic):
# maintain NHWC format # maintain NHWC format
def func(x): def func(x):
out = F.reshape(x, (1, 2, 6, 2)) out = F.reshape(x, (1, 2, 6, 2))
if x.format == "nhwc": assert out.format == x.format
assert out.format == "nhwc"
return out.numpy() return out.numpy()
data = np.arange(0, 24).reshape((1, 2, 3, 4)) data = np.arange(0, 24).reshape((1, 2, 3, 4))
...@@ -87,8 +86,7 @@ def test_broadcast(is_symbolic): ...@@ -87,8 +86,7 @@ def test_broadcast(is_symbolic):
# maintain NHWC format # maintain NHWC format
def func(x): def func(x):
out = F.broadcast_to(x, (4, 3, 2, 3)) out = F.broadcast_to(x, (4, 3, 2, 3))
if x.format == "nhwc": assert out.format == x.format
assert out.format == "nhwc"
return out.numpy() return out.numpy()
data = np.arange(0, 24).reshape((4, 3, 2, 1)) data = np.arange(0, 24).reshape((4, 3, 2, 1))
...@@ -213,24 +211,32 @@ def test_concat(is_symbolic): ...@@ -213,24 +211,32 @@ def test_concat(is_symbolic):
@pytest.mark.parametrize("is_symbolic", [None]) @pytest.mark.parametrize("is_symbolic", [None])
def test_interpolate(mode, is_symbolic): def test_interpolate(mode, is_symbolic):
def func(x): def func(x):
if x.format == "nhwc":
with mge.config._override(conv_format="NHWC"):
rst = F.vision.interpolate(x, scale_factor=3, mode=mode) rst = F.vision.interpolate(x, scale_factor=3, mode=mode)
assert rst.format == "nhwc" assert rst.format == x.format
return rst.numpy() return rst.numpy()
else:
return F.vision.interpolate(x, scale_factor=3, mode=mode).numpy()
# NHWC interpolate only suppoted channel is 1 or 3 # NHWC interpolate only suppoted channel is 1 or 3
data = np.arange(0, 48).reshape((1, 3, 4, 4)).astype("float32") data = np.arange(0, 48).reshape((1, 3, 4, 4)).astype("float32")
_compare_nchw_nhwc(data, func, is_symbolic) _compare_nchw_nhwc(data, func, is_symbolic)
@pytest.mark.skip("not implemented")
@pytest.mark.parametrize("is_symbolic", [None])
def test_warp_perspective(is_symbolic):
def func(x):
m_shape = (1, 3, 3)
m = tensor(np.random.randn(3, 3), dtype=np.float32).reshape(m_shape)
rst = F.vision.warp_perspective(x, m, (2, 2), format="NHWC")
return rst.numpy()
data = np.arange(0, 48).reshape((1, 3, 4, 4)).astype("float32")
_compare_nchw_nhwc(data, func, is_symbolic)
@pytest.mark.parametrize("is_symbolic", [None]) @pytest.mark.parametrize("is_symbolic", [None])
def test_conv2d(is_symbolic): def test_conv2d(is_symbolic):
def conv2d(x): def conv2d(x):
if x.format == "nhwc": if x.format == "nhwc":
with mge.config._override(conv_format="NHWC"):
x = F.conv2d( x = F.conv2d(
x, x,
weight=mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc"), weight=mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc"),
...@@ -249,7 +255,6 @@ def test_conv2d(is_symbolic): ...@@ -249,7 +255,6 @@ def test_conv2d(is_symbolic):
def test_group_conv2d(is_symbolic): def test_group_conv2d(is_symbolic):
def conv2d(x): def conv2d(x):
if x.format == "nhwc": if x.format == "nhwc":
with mge.config._override(conv_format="NHWC"):
x = F.conv2d( x = F.conv2d(
x, x,
weight=mge.tensor(np.ones((2, 2, 1, 1, 2)), format="nhwc"), weight=mge.tensor(np.ones((2, 2, 1, 1, 2)), format="nhwc"),
...@@ -271,7 +276,6 @@ def test_group_conv2d(is_symbolic): ...@@ -271,7 +276,6 @@ def test_group_conv2d(is_symbolic):
def test_bn(is_symbolic): def test_bn(is_symbolic):
def func(x): def func(x):
if x.format == "nhwc": if x.format == "nhwc":
with mge.config._override(bn_format="dim_111c"):
oups = F.batch_norm( oups = F.batch_norm(
x.astype("float32"), x.astype("float32"),
running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
...@@ -308,7 +312,6 @@ def test_bn(is_symbolic): ...@@ -308,7 +312,6 @@ def test_bn(is_symbolic):
def test_pooling2d(pooling, is_symbolic): def test_pooling2d(pooling, is_symbolic):
def func(x): def func(x):
if x.format == "nhwc": if x.format == "nhwc":
with mge.config._override(conv_format="NHWC"):
x = pooling(x.astype("float32"), 2) x = pooling(x.astype("float32"), 2)
assert x.format == "nhwc" assert x.format == "nhwc"
return x.numpy() return x.numpy()
...@@ -331,17 +334,17 @@ def test_backward(is_symbolic): ...@@ -331,17 +334,17 @@ def test_backward(is_symbolic):
return F.conv2d(x, w, b) return F.conv2d(x, w, b)
with gm: with gm:
with mge.config._override(auto_format_convert=True, conv_format="NHWC"):
if is_symbolic is not None: if is_symbolic is not None:
func = trace(func, symbolic=is_symbolic) func = trace(func, symbolic=is_symbolic)
x = func(x, w, b) x = func(x, w, b)
# TODO: fix manually convert to NHWC, usually used in detection head assert x.format == "nhwc"
# x = x.transpose(0, 2, 3, 1).reshape(1, 18, 2) # test manually convert to NHWC, usually used in detection head
x = x.transpose(0, 2, 3, 1).reshape(1, 18, 2)
gm.backward(x) gm.backward(x)
print("finish backward", x.format)
# backward grad has no format # backward grad has no format
np.testing.assert_equal( np.testing.assert_equal(
w.grad.numpy(), w.grad.numpy(), np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)),
np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)),
) )
np.testing.assert_equal( np.testing.assert_equal(
b.grad.numpy(), np.array([12, 12, 12]).reshape((1, 1, 1, 3)) b.grad.numpy(), np.array([12, 12, 12]).reshape((1, 1, 1, 3))
......
...@@ -1280,21 +1280,6 @@ def test_set_conv2d_config(): ...@@ -1280,21 +1280,6 @@ def test_set_conv2d_config():
np.testing.assert_allclose(context_out.numpy(), expected.numpy()) np.testing.assert_allclose(context_out.numpy(), expected.numpy())
def test_set_warp_perspective_config():
config._conv_format = "NHWC"
inp_shape = (1, 1, 4, 4)
inp = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
M_shape = (1, 3, 3)
M = Tensor(np.random.randn(3, 3), dtype=np.float32).reshape(M_shape)
config_out = F.vision.warp_perspective(inp, M, (2, 2))
config._conv_format = "default"
with config._override(conv_format="NHWC"):
context_out = F.vision.warp_perspective(inp, M, (2, 2))
expected = F.vision.warp_perspective(inp, M, (2, 2), format="NHWC")
np.testing.assert_allclose(config_out.numpy(), expected.numpy())
np.testing.assert_allclose(context_out.numpy(), expected.numpy())
@pytest.mark.parametrize("stride", [(1, 1)]) @pytest.mark.parametrize("stride", [(1, 1)])
@pytest.mark.parametrize("padding", [(1, 1)]) @pytest.mark.parametrize("padding", [(1, 1)])
@pytest.mark.parametrize("dilation", [(1, 1)]) @pytest.mark.parametrize("dilation", [(1, 1)])
......
...@@ -278,10 +278,10 @@ ValueRefList setsubtensor_rule( ...@@ -278,10 +278,10 @@ ValueRefList setsubtensor_rule(
inline FT get_inputs_format(Span<ValueRef>& inputs, const FormatTransformation& t) { inline FT get_inputs_format(Span<ValueRef>& inputs, const FormatTransformation& t) {
FT format(FT::DEFAULT); FT format(FT::DEFAULT);
for (auto& inp : inputs) { for (auto& inp : inputs) {
auto& inp_format = inp.cast(t.value_type()).format(); auto&& inp_ref = inp.as_ref(t.value_type());
if (inp_format != FT::DEFAULT) { if (inp_ref && inp_ref->format() != FT::DEFAULT) {
mgb_assert(format == FT::DEFAULT || inp_format == format); mgb_assert(format == FT::DEFAULT || inp_ref->format() == format);
format = inp_format.type(); format = inp_ref->format().type();
} }
} }
return format; return format;
...@@ -323,30 +323,82 @@ ValueRefList identity_rule_helper( ...@@ -323,30 +323,82 @@ ValueRefList identity_rule_helper(
imperative::apply(op, t.unwrap_inputs(inputs)), src.format().type()); imperative::apply(op, t.unwrap_inputs(inputs)), src.format().type());
} }
ValueRefList batchnorm_rule(
const BatchNorm& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) {
auto&& inp_format = inputs[0].cast(t.value_type()).format();
if (inp_format == FT::NHWC) {
auto&& new_param = op.param();
new_param.param_dim = BatchNorm::ParamDim::DIM_111C;
auto new_op = BatchNorm::make(new_param);
return identity_rule_helper(*new_op, inputs, t);
}
return identity_rule_helper(op, inputs, t);
}
// clang-format off // clang-format off
#define FOREACH_IDENTITY_OP(cb) \ #define FOREACH_IDENTITY_OP(cb) \
cb(Copy) \ cb(Copy) \
cb(FastpathCopy) \ cb(FastpathCopy) \
cb(TypeCvt) \ cb(TypeCvt) \
cb(Pooling) \
cb(AdaptivePooling) \
cb(Dropout) \ cb(Dropout) \
cb(Convolution) \
cb(BatchNorm) \
cb(Resize) \
cb(Identity) cb(Identity)
#define FOREACH_FORMAT_OP(cb) \
cb(AdaptivePooling) \
cb(WarpAffine) \
cb(Resize)
#define FOREACH_FORMAT_POLICY_OP(cb)\
cb(Pooling) \
cb(Convolution)
// clang-format on // clang-format on
#define CREATE_IDENTITY_OP_RULE(op) \ // identity op
ValueRefList op##_rule( \ #define CREATE_IDENTITY_OP_RULE(Op) \
const op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \ ValueRefList Op##_rule( \
const Op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \
const FormatTransformation& t) { \ const FormatTransformation& t) { \
return identity_rule_helper(_op, inputs, t); \ return identity_rule_helper(_op, inputs, t); \
} }
FOREACH_IDENTITY_OP(CREATE_IDENTITY_OP_RULE) FOREACH_IDENTITY_OP(CREATE_IDENTITY_OP_RULE)
#undef CREATE_IDENTITY_OP_RULE #undef CREATE_IDENTITY_OP_RULE
#define REGISTER_IDENTITY_OP_RULE(op) register_format_rule(op##_rule); // identity op with Format param
#define CREATE_FORMAT_OP_RULE(Op) \
ValueRefList Op##_rule( \
const Op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \
const FormatTransformation& t) { \
auto&& inp_format = inputs[0].cast(t.value_type()).format(); \
if (inp_format == FT::NHWC) { \
auto&& new_param = _op.param(); \
new_param.format = Op::Format::NHWC; \
auto new_op = Op::make(new_param); \
return identity_rule_helper(*new_op, inputs, t); \
} \
return identity_rule_helper(_op, inputs, t); \
}
FOREACH_FORMAT_OP(CREATE_FORMAT_OP_RULE)
#undef CREATE_FORMAT_OP_RULE
// identity op with Format and policy param
#define CREATE_FORMAT_POLICY_OP_RULE(Op) \
ValueRefList Op##_rule( \
const Op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \
const FormatTransformation& t) { \
auto&& inp_format = inputs[0].cast(t.value_type()).format(); \
if (inp_format == FT::NHWC) { \
auto&& new_param = _op.param(); \
new_param.format = Op::Format::NHWC; \
auto new_op = Op::make(new_param, _op.policy()); \
return identity_rule_helper(*new_op, inputs, t); \
} \
return identity_rule_helper(_op, inputs, t); \
}
FOREACH_FORMAT_POLICY_OP(CREATE_FORMAT_POLICY_OP_RULE)
#undef CREATE_FORMAT_OP_RULE
#define REGISTER_OP_RULE(op) register_format_rule(op##_rule);
struct FormatRuleRegistry { struct FormatRuleRegistry {
FormatRuleRegistry() { FormatRuleRegistry() {
register_format_rule(dimshuffle_rule); register_format_rule(dimshuffle_rule);
...@@ -358,10 +410,13 @@ struct FormatRuleRegistry { ...@@ -358,10 +410,13 @@ struct FormatRuleRegistry {
register_format_rule(setsubtensor_rule<IndexingSetMultiAxisVec>); register_format_rule(setsubtensor_rule<IndexingSetMultiAxisVec>);
register_format_rule(concat_rule); register_format_rule(concat_rule);
register_format_rule(elemwise_rule); register_format_rule(elemwise_rule);
FOREACH_IDENTITY_OP(REGISTER_IDENTITY_OP_RULE) register_format_rule(batchnorm_rule);
FOREACH_IDENTITY_OP(REGISTER_OP_RULE)
FOREACH_FORMAT_OP(REGISTER_OP_RULE)
FOREACH_FORMAT_POLICY_OP(REGISTER_OP_RULE)
} }
} _; } _;
#undef REGISTER_IDENTITY_OP_RULE #undef REGISTER_OP_RULE
} // namespace } // namespace
ValueRefList FormatTransformation::apply_transformation( ValueRefList FormatTransformation::apply_transformation(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册