提交 f2111248 编写于 作者: M Megvii Engine Team

feat(mge/trace): support dict return value processing in trace

GitOrigin-RevId: 5b1c08848b41eaeac1e4066bce3119c90506be9f
上级 cbff4d7c
...@@ -642,22 +642,24 @@ class trace: ...@@ -642,22 +642,24 @@ class trace:
if self._capture_as_const: if self._capture_as_const:
self._process_inputs(*args, **kwargs) self._process_inputs(*args, **kwargs)
outputs = self.__wrapped__(*args, **kwargs) outputs = self.__wrapped__(*args, **kwargs)
transform = False if self._capture_as_const:
# outputs can be None self._process_outputs(outputs)
# outputs could be None
if outputs is not None: if outputs is not None:
if not isinstance(outputs, collections.abc.Sequence): list_outputs = outputs
transform = True if isinstance(outputs, collections.abc.Mapping):
outputs = (outputs,) _, list_outputs = zip(*sorted(outputs.items()))
for o in outputs: elif not isinstance(outputs, collections.abc.Sequence):
list_outputs = (outputs,)
for o in list_outputs:
# if outputs are copied, then use the newest info in trace data structure # if outputs are copied, then use the newest info in trace data structure
if o._copied: if o._copied:
self._active_tensors[o._mixin_handle] = TensorWeakRef(o) self._active_tensors[o._mixin_handle] = TensorWeakRef(o)
if self._untraced and self._symbolic: if self._untraced and self._symbolic:
self._lazy_eval_tensors[o._mixin_handle] = TensorWeakRef(o) self._lazy_eval_tensors[o._mixin_handle] = TensorWeakRef(o)
if self._capture_as_const:
self._process_outputs(outputs)
if transform:
outputs = outputs[0]
return outputs return outputs
def dump( def dump(
......
...@@ -28,18 +28,32 @@ from megengine.module import Module ...@@ -28,18 +28,32 @@ from megengine.module import Module
from megengine.random import normal, uniform from megengine.random import normal, uniform
def test_trace(): @pytest.mark.parametrize("trace_mode", [False, True])
for symbolic in [False, True]: @pytest.mark.parametrize("return_mode", ["Value", "Tuple", "List", "Dict"])
def test_trace(trace_mode, return_mode):
@trace(symbolic=symbolic) @trace(symbolic=trace_mode)
def f(x): def f(x):
if return_mode == "Tuple":
return (-x,)
elif return_mode == "List":
return [-x]
elif return_mode == "Dict":
return {"neg": -x}
else:
return -x return -x
x = tensor([1]) def get_numpy(y):
y = f(x).numpy() if return_mode == "Tuple" or return_mode == "List":
return y[0].numpy()
elif return_mode == "Dict":
return y["neg"].numpy()
return y.numpy()
for i in range(3): x = tensor([1])
np.testing.assert_equal(f(x).numpy(), y) y = get_numpy(f(x))
for i in range(3):
np.testing.assert_equal(get_numpy(f(x)), y)
def test_output_copy_trace(): def test_output_copy_trace():
...@@ -54,51 +68,46 @@ def test_output_copy_trace(): ...@@ -54,51 +68,46 @@ def test_output_copy_trace():
x = F.exp(x) x = F.exp(x)
return x return x
net = Simple() ys = {False: [], True: []}
gm = GradManager().attach(net.parameters())
opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9)
data = tensor(np.arange(4).reshape(2, 2), dtype="float32")
@trace(symbolic=False) for symbolic in [False, True]:
def train_f1(d): net = Simple()
with gm: gm = GradManager().attach(net.parameters())
loss = net(d) opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9)
gm.backward(loss) data = tensor(np.arange(4).reshape(2, 2), dtype="float32")
opt.step().clear_grad()
return loss
@trace(symbolic=True) @trace(symbolic=symbolic)
def train_f2(d): def train_func(d):
with gm: with gm:
loss = net(d) loss = net(d)
gm.backward(loss) gm.backward(loss)
opt.step().clear_grad() opt.step().clear_grad()
return loss return loss
for i in range(2): for i in range(3):
y1 = train_f1(data).numpy() y = train_func(data).numpy()
y2 = train_f2(data).numpy() ys[symbolic].append(y)
np.testing.assert_equal(y1, y2)
for i in range(3):
np.testing.assert_equal(ys[False][i], ys[True][i])
def test_exclude_from_trace():
for symbolic in [False, True]:
@trace(symbolic=symbolic) @pytest.mark.parametrize("trace_mode", [False, True])
def f(x): def test_exclude_from_trace(trace_mode):
x = -x @trace(symbolic=trace_mode)
with exclude_from_trace(): def f(x):
if i % 2: x = -x
x = -x with exclude_from_trace():
x = -x if i % 2:
return x x = -x
x = -x
return x
x = tensor([1]) x = tensor([1])
for i in range(3): for i in range(3):
y = f(x).numpy() y = f(x).numpy()
np.testing.assert_equal(f(x).numpy(), y) np.testing.assert_equal(f(x).numpy(), y)
def test_print_in_trace(): def test_print_in_trace():
...@@ -191,21 +200,20 @@ def test_dump_volatile(): ...@@ -191,21 +200,20 @@ def test_dump_volatile():
) )
def test_trace_profiler(): @pytest.mark.parametrize("trace_mode", [False, True])
for symbolic in [False, True]: def test_trace_profiler(trace_mode):
@trace(symbolic=trace_mode, profiling=True)
@trace(symbolic=symbolic, profiling=True) def f(x):
def f(x): return -x
return -x
x = tensor([1]) x = tensor([1])
y = f(x).numpy() y = f(x).numpy()
f(x) f(x)
f(x) # XXX: has to run twice f(x) # XXX: has to run twice
out = f.get_profile() out = f.get_profile()
assert out.get("profiler") assert out.get("profiler")
@pytest.mark.skip(reason="force opt_level=0 when building graph") @pytest.mark.skip(reason="force opt_level=0 when building graph")
...@@ -306,20 +314,20 @@ def test_trace_cvt_bool(): ...@@ -306,20 +314,20 @@ def test_trace_cvt_bool():
np.testing.assert_equal(f(x).numpy(), False) np.testing.assert_equal(f(x).numpy(), False)
def test_trace_reshape(): @pytest.mark.parametrize("trace_mode", [False, True])
for symbolic in [False, True]: def test_trace_reshape(trace_mode):
x1 = tensor(np.random.randn(2, 10, 10)) x1 = tensor(np.random.randn(2, 10, 10))
x2 = tensor(np.random.randn(4, 10, 10)) x2 = tensor(np.random.randn(4, 10, 10))
x3 = tensor(np.random.randn(8, 10, 10)) x3 = tensor(np.random.randn(8, 10, 10))
@trace(symbolic=symbolic, capture_as_const=True) @trace(symbolic=trace_mode, capture_as_const=True)
def f(x): def f(x):
y = x.reshape(x.shape[0], 100) y = x.reshape(x.shape[0], 100)
return y return y
f(x1) f(x1)
f(x2) f(x2)
f(x3) f(x3)
def test_trace_topk(): def test_trace_topk():
...@@ -387,20 +395,20 @@ def test_raise_on_trace(): ...@@ -387,20 +395,20 @@ def test_raise_on_trace():
assert catch_count == 1 assert catch_count == 1
def test_trace_broadcast(): @pytest.mark.parametrize("trace_mode", [False, True])
for symbolic in [False, True]: def test_trace_broadcast(trace_mode):
x1 = tensor(np.random.randn(3, 1, 1)) x1 = tensor(np.random.randn(3, 1, 1))
x2 = tensor(np.random.randn(1, 4, 1)) x2 = tensor(np.random.randn(1, 4, 1))
x3 = tensor(np.random.randn(1, 1, 5)) x3 = tensor(np.random.randn(1, 1, 5))
@trace(symbolic=symbolic, capture_as_const=True) @trace(symbolic=trace_mode, capture_as_const=True)
def f(x): def f(x):
y = F.broadcast_to(x, (3, 4, 5)) y = F.broadcast_to(x, (3, 4, 5))
return y return y
f(x1) f(x1)
f(x2) f(x2)
f(x3) f(x3)
def test_trace_nms(): def test_trace_nms():
...@@ -466,21 +474,20 @@ def test_slice(): ...@@ -466,21 +474,20 @@ def test_slice():
y + y y + y
def test_random(): @pytest.mark.parametrize("shape_mode", [False, True])
def test_random(shape_mode):
def run_test(op): def run_test(op):
for symbolic_shape in [True, False]: @trace(symbolic=True, symbolic_shape=shape_mode)
def f():
@trace(symbolic=True, symbolic_shape=symbolic_shape) out = op(size=[10, 10])
def f(): out_shape = out.shape
out = op(size=[10, 10]) assert out_shape is not None
out_shape = out.shape if not isinstance(out_shape, tuple):
assert out_shape is not None assert out.shape.numpy() is not None
if not isinstance(out_shape, tuple): return out
assert out.shape.numpy() is not None
return out for _ in range(3):
f()
for _ in range(3):
f()
run_test(uniform) run_test(uniform)
run_test(normal) run_test(normal)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册