diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 73842709177ef488c9cbd2bc2b89ac85e8ab1777..9c734a99473fdfbda0660a4df55c639b84f485ce 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -222,45 +222,40 @@ def _normalize_axis( raise +_opr_map = { + ("-", 1): builtin.Elemwise(mode="negate"), + ("fma3", 3): builtin.Elemwise(mode="FUSE_MUL_ADD3"), + ("fma4", 4): builtin.Elemwise(mode="FUSE_MUL_ADD4"), +} + +for name, mode in [ + ("+", "add"), + ("-", "sub"), + ("*", "mul"), + ("/", "true_div"), + ("//", "floor_div"), + ("**", "pow"), + ("max", "max"), + ("additive", "add"), +]: + _opr_map[(name, 2)] = builtin.Elemwise(mode=mode) + + def subgraph(name, dtype, device, nr_inputs, gopt_level=None): if device.physical_name.startswith("cpu"): gopt_level = None # disable jit and compile - binary_ops = { - "+": lambda: builtin.Elemwise(mode="add"), - "-": lambda: builtin.Elemwise(mode="sub"), - "*": lambda: builtin.Elemwise(mode="mul"), - "/": lambda: builtin.Elemwise(mode="true_div"), - "//": lambda: builtin.Elemwise(mode="floor_div"), - "**": lambda: builtin.Elemwise(mode="pow"), - "√": lambda: builtin.Elemwise(mode="expm1"), - "max": lambda: builtin.Elemwise(mode="max"), - "additive": lambda: builtin.Elemwise(mode="add"), - } - - unary_ops = { - "-": lambda: builtin.Elemwise(mode="negate"), - } - - ternary_ops = { - "fma3": lambda: builtin.Elemwise(mode="FUSE_MUL_ADD3"), - } - - quaternary_ops = {"fma4": lambda: builtin.Elemwise(mode="FUSE_MUL_ADD4")} + def as_op(op, nargs): + if isinstance(op, str): + assert (op, nargs) in _opr_map, "unknown operator" + op = _opr_map[(op, nargs)] + return op def decorator(func): builder = _SubgraphBuilder(name) def apply_expr(op, *args, nr_out=None): - if isinstance(op, str): - if len(args) == 2: - op = binary_ops[op]() - elif len(args) == 1: - op = unary_ops[op]() - elif len(args) == 3: - op = ternary_ops[op]() - elif len(args) == 4: - op = quaternary_ops[op]() + op = as_op(op, len(args)) results = builder.apply(op, args, 1 if nr_out is None else nr_out) if nr_out is None: assert len(results) == 1 @@ -282,3 +277,40 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): return lambda: builder.compile(gopt_level) return decorator + + +def interpret_subgraph(func, dtype, device): + def as_op(op, nargs): + if isinstance(op, str) and (op, nargs) in _opr_map: + op = _opr_map[(op, nargs)] + return op + + def decorated_func(*args): + def apply_expr(op, *args, nr_out=None): + op = as_op(op, len(args)) + results = apply(op, *args) + if nr_out is None: + assert len(results) == 1 + return results[0] + else: + assert len(results) == nr_out + return results + + def apply_const(value, dtype=dtype, device=device): + return Const(value, dtype=dtype, device=device)()[0] + + outputs, outputs_has_grad = func(args, apply_expr, apply_const) + return outputs + + return decorated_func + + +def subgraph_fn(name, dtype, device, nr_inputs, gopt_level=None, interpret=False): + def decorator(func): + if not interpret: + op = subgraph(name, dtype, device, nr_inputs, gopt_level=gopt_level)(func) + return lambda *args: apply(op(), *args) + else: + return interpret_subgraph(func, dtype, device) + + return decorator diff --git a/imperative/python/test/unit/core/test_subgraph.py b/imperative/python/test/unit/core/test_subgraph.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9405e53d6d0023af805b911c0b1384d716ab17 --- /dev/null +++ b/imperative/python/test/unit/core/test_subgraph.py @@ -0,0 +1,108 @@ +import functools + +import numpy as np +import pytest + +import megengine +from megengine.autodiff.grad_manager import GradManager +from megengine.core.ops.builtin import GetVarShape, Reduce, TypeCvt +from megengine.core.tensor.utils import subgraph_fn +from megengine.device import CompNode, get_default_device +from megengine.jit import trace + +_assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6) + + +@functools.lru_cache(maxsize=None) +def _get_batch_norm_fn(dtype, device, channels, ndim, interpret, gopt_level): + @subgraph_fn( + "BatchNormNd", + dtype=dtype, + device=device, + nr_inputs=4, + interpret=interpret, + gopt_level=gopt_level, + ) + def batch_norm_nd(inputs, f, c): + input, eps, weight, bias = inputs[0:4] + reduce_shape = c( + (1, channels) + (1,) * (ndim - 2), dtype="int32", device=device + ) + input_shape = f(GetVarShape(), input) + input_elems = f(Reduce(mode="product", axis=0), input_shape) + reduce_elems = f(Reduce(mode="product", axis=0), reduce_shape) + reduce_size = f("//", input_elems, reduce_elems) + reduce_size = f(TypeCvt(dtype=dtype), reduce_size) + channel_x1s = f(Reduce(mode="sum"), input, reduce_shape) + channel_x2s = f(Reduce(mode="sum_sqr"), input, reduce_shape) + channel_mean = f("/", channel_x1s, reduce_size) + channel_var = f( + "-", f("/", channel_x2s, reduce_size), f("*", channel_mean, channel_mean), + ) + invsqrt_channel_var = f("**", f("+", channel_var, eps), c(-0.5)) + inv_var_wt = f("*", invsqrt_channel_var, weight) + neg_channel_mean = f("-", channel_mean) + outvar = f( + "fma3", input, inv_var_wt, f("fma3", neg_channel_mean, inv_var_wt, bias), + ) + return (outvar,), (True,) + + return batch_norm_nd + + +@pytest.mark.parametrize("device", [get_default_device(), "cpux"]) +@pytest.mark.parametrize("batch_size", [1, 8]) +@pytest.mark.parametrize("channels", [3]) +@pytest.mark.parametrize( + "use_trace, symbolic", [(False, None), (True, False), (True, True)] +) +@pytest.mark.parametrize("gopt_level", [None, 1, 2]) +@pytest.mark.parametrize("dtype", ["float32"]) +def test_subgraph(device, batch_size, channels, use_trace, symbolic, gopt_level, dtype): + device = CompNode(device) + + def subgraph_batch_norm(inp, weight, bias, eps, diff): + inp = inp.detach() + with GradManager().attach(inp) as gm: + batch_norm_fn = _get_batch_norm_fn( + dtype, device, channels, ndim, interpret=False, gopt_level=gopt_level + ) + out, *_ = batch_norm_fn(inp, eps, weight, bias) + gm.backward(out * 1e3 + 1e3, diff) + return out, inp.grad + + def primitive_batch_norm(inp, weight, bias, eps, diff): + inp = inp.detach() + with GradManager().attach(inp) as gm: + batch_norm_fn = _get_batch_norm_fn( + dtype, device, channels, ndim, interpret=True, gopt_level=gopt_level + ) + (out,) = batch_norm_fn(inp, eps, weight, bias) + gm.backward(out * 1e3 + 1e3, diff) + return out, inp.grad + + if use_trace: + subgraph_batch_norm = trace(symbolic=symbolic)(subgraph_batch_norm) + primitive_batch_norm = trace(symbolic=symbolic)(primitive_batch_norm) + + def rand_tensor(shape, dtype=dtype, device=device): + return megengine.tensor(np.random.random(shape), dtype=dtype, device=device) + + # test shape change + for image_shape in [(223, 223), (10, 20)]: + ndim = len(image_shape) + 2 + input_shape = (batch_size, channels) + image_shape + param_shape = (1, channels) + (1,) * len(image_shape) + + inp = rand_tensor(input_shape) * 1e3 + 1e3 + weight = rand_tensor(param_shape) + bias = rand_tensor(param_shape) + eps = megengine.tensor(1e-5, dtype=dtype, device=device) + + diff = rand_tensor(input_shape) + + out1, grad1 = subgraph_batch_norm(inp, weight, bias, eps, diff) + out2, grad2 = primitive_batch_norm(inp, weight, bias, eps, diff) + + _assert_allclose(out1.numpy(), out2.numpy()) + _assert_allclose(grad1.numpy(), grad2.numpy()) diff --git a/imperative/python/test/unit/module/test_batchnorm.py b/imperative/python/test/unit/module/test_batchnorm.py index 12d61a4f3467de460b3958f67b222c3834ffb719..de659cc02626f2e565ceeb5cdc4008312d74ebae 100644 --- a/imperative/python/test/unit/module/test_batchnorm.py +++ b/imperative/python/test/unit/module/test_batchnorm.py @@ -15,6 +15,7 @@ import pytest import megengine as mge import megengine.distributed as dist from megengine import Tensor +from megengine.autodiff.grad_manager import GradManager from megengine.core._trace_option import use_symbolic_shape from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm @@ -337,3 +338,33 @@ def test_syncbn2d_no_stats(): yv_expect = (xv - mean) / sd _assert_allclose(yv.numpy(), yv_expect) + + +def test_syncbn2d_grad(): + nr_chan = 8 + data_shape = (3, nr_chan, 16, 16) + syncbn = SyncBatchNorm(8, track_running_stats=False) + bn = BatchNorm2d(8, track_running_stats=False) + for i in range(4): + if i == 2: + syncbn.training = False + bn.training = False + inp = Tensor(np.random.normal(loc=2.3, size=data_shape).astype(np.float32)) + diff = Tensor(np.random.normal(size=data_shape).astype(np.float32)) + + with GradManager().attach(inp) as gm: + oup = syncbn(inp) + gm.backward(oup, diff) + + grad = inp.grad + inp.grad = None + + with GradManager().attach(inp) as gm: + oup_expect = bn(inp) + gm.backward(oup_expect, diff) + + grad_expect = inp.grad + inp.grad = None + + _assert_allclose(oup.numpy(), oup_expect.numpy()) + _assert_allclose(grad.numpy(), grad_expect.numpy())