提交 a7a3bf2d 编写于 作者: M Megvii Engine Team

test(subgraph): simple test for subgraph

GitOrigin-RevId: 3d6ecd5db73a662e0b25b254657401d54dd322e0
上级 d063d577
......@@ -222,45 +222,40 @@ def _normalize_axis(
raise
_opr_map = {
("-", 1): builtin.Elemwise(mode="negate"),
("fma3", 3): builtin.Elemwise(mode="FUSE_MUL_ADD3"),
("fma4", 4): builtin.Elemwise(mode="FUSE_MUL_ADD4"),
}
for name, mode in [
("+", "add"),
("-", "sub"),
("*", "mul"),
("/", "true_div"),
("//", "floor_div"),
("**", "pow"),
("max", "max"),
("additive", "add"),
]:
_opr_map[(name, 2)] = builtin.Elemwise(mode=mode)
def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
if device.physical_name.startswith("cpu"):
gopt_level = None # disable jit and compile
binary_ops = {
"+": lambda: builtin.Elemwise(mode="add"),
"-": lambda: builtin.Elemwise(mode="sub"),
"*": lambda: builtin.Elemwise(mode="mul"),
"/": lambda: builtin.Elemwise(mode="true_div"),
"//": lambda: builtin.Elemwise(mode="floor_div"),
"**": lambda: builtin.Elemwise(mode="pow"),
"√": lambda: builtin.Elemwise(mode="expm1"),
"max": lambda: builtin.Elemwise(mode="max"),
"additive": lambda: builtin.Elemwise(mode="add"),
}
unary_ops = {
"-": lambda: builtin.Elemwise(mode="negate"),
}
ternary_ops = {
"fma3": lambda: builtin.Elemwise(mode="FUSE_MUL_ADD3"),
}
quaternary_ops = {"fma4": lambda: builtin.Elemwise(mode="FUSE_MUL_ADD4")}
def as_op(op, nargs):
if isinstance(op, str):
assert (op, nargs) in _opr_map, "unknown operator"
op = _opr_map[(op, nargs)]
return op
def decorator(func):
builder = _SubgraphBuilder(name)
def apply_expr(op, *args, nr_out=None):
if isinstance(op, str):
if len(args) == 2:
op = binary_ops[op]()
elif len(args) == 1:
op = unary_ops[op]()
elif len(args) == 3:
op = ternary_ops[op]()
elif len(args) == 4:
op = quaternary_ops[op]()
op = as_op(op, len(args))
results = builder.apply(op, args, 1 if nr_out is None else nr_out)
if nr_out is None:
assert len(results) == 1
......@@ -282,3 +277,40 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
return lambda: builder.compile(gopt_level)
return decorator
def interpret_subgraph(func, dtype, device):
def as_op(op, nargs):
if isinstance(op, str) and (op, nargs) in _opr_map:
op = _opr_map[(op, nargs)]
return op
def decorated_func(*args):
def apply_expr(op, *args, nr_out=None):
op = as_op(op, len(args))
results = apply(op, *args)
if nr_out is None:
assert len(results) == 1
return results[0]
else:
assert len(results) == nr_out
return results
def apply_const(value, dtype=dtype, device=device):
return Const(value, dtype=dtype, device=device)()[0]
outputs, outputs_has_grad = func(args, apply_expr, apply_const)
return outputs
return decorated_func
def subgraph_fn(name, dtype, device, nr_inputs, gopt_level=None, interpret=False):
def decorator(func):
if not interpret:
op = subgraph(name, dtype, device, nr_inputs, gopt_level=gopt_level)(func)
return lambda *args: apply(op(), *args)
else:
return interpret_subgraph(func, dtype, device)
return decorator
import functools
import numpy as np
import pytest
import megengine
from megengine.autodiff.grad_manager import GradManager
from megengine.core.ops.builtin import GetVarShape, Reduce, TypeCvt
from megengine.core.tensor.utils import subgraph_fn
from megengine.device import CompNode, get_default_device
from megengine.jit import trace
_assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6)
@functools.lru_cache(maxsize=None)
def _get_batch_norm_fn(dtype, device, channels, ndim, interpret, gopt_level):
@subgraph_fn(
"BatchNormNd",
dtype=dtype,
device=device,
nr_inputs=4,
interpret=interpret,
gopt_level=gopt_level,
)
def batch_norm_nd(inputs, f, c):
input, eps, weight, bias = inputs[0:4]
reduce_shape = c(
(1, channels) + (1,) * (ndim - 2), dtype="int32", device=device
)
input_shape = f(GetVarShape(), input)
input_elems = f(Reduce(mode="product", axis=0), input_shape)
reduce_elems = f(Reduce(mode="product", axis=0), reduce_shape)
reduce_size = f("//", input_elems, reduce_elems)
reduce_size = f(TypeCvt(dtype=dtype), reduce_size)
channel_x1s = f(Reduce(mode="sum"), input, reduce_shape)
channel_x2s = f(Reduce(mode="sum_sqr"), input, reduce_shape)
channel_mean = f("/", channel_x1s, reduce_size)
channel_var = f(
"-", f("/", channel_x2s, reduce_size), f("*", channel_mean, channel_mean),
)
invsqrt_channel_var = f("**", f("+", channel_var, eps), c(-0.5))
inv_var_wt = f("*", invsqrt_channel_var, weight)
neg_channel_mean = f("-", channel_mean)
outvar = f(
"fma3", input, inv_var_wt, f("fma3", neg_channel_mean, inv_var_wt, bias),
)
return (outvar,), (True,)
return batch_norm_nd
@pytest.mark.parametrize("device", [get_default_device(), "cpux"])
@pytest.mark.parametrize("batch_size", [1, 8])
@pytest.mark.parametrize("channels", [3])
@pytest.mark.parametrize(
"use_trace, symbolic", [(False, None), (True, False), (True, True)]
)
@pytest.mark.parametrize("gopt_level", [None, 1, 2])
@pytest.mark.parametrize("dtype", ["float32"])
def test_subgraph(device, batch_size, channels, use_trace, symbolic, gopt_level, dtype):
device = CompNode(device)
def subgraph_batch_norm(inp, weight, bias, eps, diff):
inp = inp.detach()
with GradManager().attach(inp) as gm:
batch_norm_fn = _get_batch_norm_fn(
dtype, device, channels, ndim, interpret=False, gopt_level=gopt_level
)
out, *_ = batch_norm_fn(inp, eps, weight, bias)
gm.backward(out * 1e3 + 1e3, diff)
return out, inp.grad
def primitive_batch_norm(inp, weight, bias, eps, diff):
inp = inp.detach()
with GradManager().attach(inp) as gm:
batch_norm_fn = _get_batch_norm_fn(
dtype, device, channels, ndim, interpret=True, gopt_level=gopt_level
)
(out,) = batch_norm_fn(inp, eps, weight, bias)
gm.backward(out * 1e3 + 1e3, diff)
return out, inp.grad
if use_trace:
subgraph_batch_norm = trace(symbolic=symbolic)(subgraph_batch_norm)
primitive_batch_norm = trace(symbolic=symbolic)(primitive_batch_norm)
def rand_tensor(shape, dtype=dtype, device=device):
return megengine.tensor(np.random.random(shape), dtype=dtype, device=device)
# test shape change
for image_shape in [(223, 223), (10, 20)]:
ndim = len(image_shape) + 2
input_shape = (batch_size, channels) + image_shape
param_shape = (1, channels) + (1,) * len(image_shape)
inp = rand_tensor(input_shape) * 1e3 + 1e3
weight = rand_tensor(param_shape)
bias = rand_tensor(param_shape)
eps = megengine.tensor(1e-5, dtype=dtype, device=device)
diff = rand_tensor(input_shape)
out1, grad1 = subgraph_batch_norm(inp, weight, bias, eps, diff)
out2, grad2 = primitive_batch_norm(inp, weight, bias, eps, diff)
_assert_allclose(out1.numpy(), out2.numpy())
_assert_allclose(grad1.numpy(), grad2.numpy())
......@@ -15,6 +15,7 @@ import pytest
import megengine as mge
import megengine.distributed as dist
from megengine import Tensor
from megengine.autodiff.grad_manager import GradManager
from megengine.core._trace_option import use_symbolic_shape
from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm
......@@ -337,3 +338,33 @@ def test_syncbn2d_no_stats():
yv_expect = (xv - mean) / sd
_assert_allclose(yv.numpy(), yv_expect)
def test_syncbn2d_grad():
nr_chan = 8
data_shape = (3, nr_chan, 16, 16)
syncbn = SyncBatchNorm(8, track_running_stats=False)
bn = BatchNorm2d(8, track_running_stats=False)
for i in range(4):
if i == 2:
syncbn.training = False
bn.training = False
inp = Tensor(np.random.normal(loc=2.3, size=data_shape).astype(np.float32))
diff = Tensor(np.random.normal(size=data_shape).astype(np.float32))
with GradManager().attach(inp) as gm:
oup = syncbn(inp)
gm.backward(oup, diff)
grad = inp.grad
inp.grad = None
with GradManager().attach(inp) as gm:
oup_expect = bn(inp)
gm.backward(oup_expect, diff)
grad_expect = inp.grad
inp.grad = None
_assert_allclose(oup.numpy(), oup_expect.numpy())
_assert_allclose(grad.numpy(), grad_expect.numpy())
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册