From f2111248654d0fe92f91418690e556100272d2d8 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 1 Feb 2021 10:24:02 +0800 Subject: [PATCH] feat(mge/trace): support dict return value processing in trace GitOrigin-RevId: 5b1c08848b41eaeac1e4066bce3119c90506be9f --- imperative/python/megengine/jit/tracing.py | 22 ++- imperative/python/test/unit/test_tracing.py | 199 ++++++++++---------- 2 files changed, 115 insertions(+), 106 deletions(-) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 8d5ae719..dca50c47 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -642,22 +642,24 @@ class trace: if self._capture_as_const: self._process_inputs(*args, **kwargs) outputs = self.__wrapped__(*args, **kwargs) - transform = False - # outputs can be None + if self._capture_as_const: + self._process_outputs(outputs) + + # outputs could be None if outputs is not None: - if not isinstance(outputs, collections.abc.Sequence): - transform = True - outputs = (outputs,) - for o in outputs: + list_outputs = outputs + if isinstance(outputs, collections.abc.Mapping): + _, list_outputs = zip(*sorted(outputs.items())) + elif not isinstance(outputs, collections.abc.Sequence): + list_outputs = (outputs,) + + for o in list_outputs: # if outputs are copied, then use the newest info in trace data structure if o._copied: self._active_tensors[o._mixin_handle] = TensorWeakRef(o) if self._untraced and self._symbolic: self._lazy_eval_tensors[o._mixin_handle] = TensorWeakRef(o) - if self._capture_as_const: - self._process_outputs(outputs) - if transform: - outputs = outputs[0] + return outputs def dump( diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index 9e637af6..4a8b1352 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -28,18 +28,32 @@ from megengine.module import Module from megengine.random import normal, uniform -def test_trace(): - for symbolic in [False, True]: - - @trace(symbolic=symbolic) - def f(x): +@pytest.mark.parametrize("trace_mode", [False, True]) +@pytest.mark.parametrize("return_mode", ["Value", "Tuple", "List", "Dict"]) +def test_trace(trace_mode, return_mode): + @trace(symbolic=trace_mode) + def f(x): + if return_mode == "Tuple": + return (-x,) + elif return_mode == "List": + return [-x] + elif return_mode == "Dict": + return {"neg": -x} + else: return -x - x = tensor([1]) - y = f(x).numpy() + def get_numpy(y): + if return_mode == "Tuple" or return_mode == "List": + return y[0].numpy() + elif return_mode == "Dict": + return y["neg"].numpy() + return y.numpy() - for i in range(3): - np.testing.assert_equal(f(x).numpy(), y) + x = tensor([1]) + y = get_numpy(f(x)) + + for i in range(3): + np.testing.assert_equal(get_numpy(f(x)), y) def test_output_copy_trace(): @@ -54,51 +68,46 @@ def test_output_copy_trace(): x = F.exp(x) return x - net = Simple() - - gm = GradManager().attach(net.parameters()) - opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9) - data = tensor(np.arange(4).reshape(2, 2), dtype="float32") + ys = {False: [], True: []} - @trace(symbolic=False) - def train_f1(d): - with gm: - loss = net(d) - gm.backward(loss) - opt.step().clear_grad() - return loss + for symbolic in [False, True]: + net = Simple() + gm = GradManager().attach(net.parameters()) + opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9) + data = tensor(np.arange(4).reshape(2, 2), dtype="float32") - @trace(symbolic=True) - def train_f2(d): - with gm: - loss = net(d) - gm.backward(loss) - opt.step().clear_grad() - return loss + @trace(symbolic=symbolic) + def train_func(d): + with gm: + loss = net(d) + gm.backward(loss) + opt.step().clear_grad() + return loss - for i in range(2): - y1 = train_f1(data).numpy() - y2 = train_f2(data).numpy() - np.testing.assert_equal(y1, y2) + for i in range(3): + y = train_func(data).numpy() + ys[symbolic].append(y) + for i in range(3): + np.testing.assert_equal(ys[False][i], ys[True][i]) -def test_exclude_from_trace(): - for symbolic in [False, True]: - @trace(symbolic=symbolic) - def f(x): - x = -x - with exclude_from_trace(): - if i % 2: - x = -x - x = -x - return x +@pytest.mark.parametrize("trace_mode", [False, True]) +def test_exclude_from_trace(trace_mode): + @trace(symbolic=trace_mode) + def f(x): + x = -x + with exclude_from_trace(): + if i % 2: + x = -x + x = -x + return x - x = tensor([1]) + x = tensor([1]) - for i in range(3): - y = f(x).numpy() - np.testing.assert_equal(f(x).numpy(), y) + for i in range(3): + y = f(x).numpy() + np.testing.assert_equal(f(x).numpy(), y) def test_print_in_trace(): @@ -191,21 +200,20 @@ def test_dump_volatile(): ) -def test_trace_profiler(): - for symbolic in [False, True]: - - @trace(symbolic=symbolic, profiling=True) - def f(x): - return -x +@pytest.mark.parametrize("trace_mode", [False, True]) +def test_trace_profiler(trace_mode): + @trace(symbolic=trace_mode, profiling=True) + def f(x): + return -x - x = tensor([1]) - y = f(x).numpy() + x = tensor([1]) + y = f(x).numpy() - f(x) - f(x) # XXX: has to run twice + f(x) + f(x) # XXX: has to run twice - out = f.get_profile() - assert out.get("profiler") + out = f.get_profile() + assert out.get("profiler") @pytest.mark.skip(reason="force opt_level=0 when building graph") @@ -306,20 +314,20 @@ def test_trace_cvt_bool(): np.testing.assert_equal(f(x).numpy(), False) -def test_trace_reshape(): - for symbolic in [False, True]: - x1 = tensor(np.random.randn(2, 10, 10)) - x2 = tensor(np.random.randn(4, 10, 10)) - x3 = tensor(np.random.randn(8, 10, 10)) +@pytest.mark.parametrize("trace_mode", [False, True]) +def test_trace_reshape(trace_mode): + x1 = tensor(np.random.randn(2, 10, 10)) + x2 = tensor(np.random.randn(4, 10, 10)) + x3 = tensor(np.random.randn(8, 10, 10)) - @trace(symbolic=symbolic, capture_as_const=True) - def f(x): - y = x.reshape(x.shape[0], 100) - return y + @trace(symbolic=trace_mode, capture_as_const=True) + def f(x): + y = x.reshape(x.shape[0], 100) + return y - f(x1) - f(x2) - f(x3) + f(x1) + f(x2) + f(x3) def test_trace_topk(): @@ -387,20 +395,20 @@ def test_raise_on_trace(): assert catch_count == 1 -def test_trace_broadcast(): - for symbolic in [False, True]: - x1 = tensor(np.random.randn(3, 1, 1)) - x2 = tensor(np.random.randn(1, 4, 1)) - x3 = tensor(np.random.randn(1, 1, 5)) +@pytest.mark.parametrize("trace_mode", [False, True]) +def test_trace_broadcast(trace_mode): + x1 = tensor(np.random.randn(3, 1, 1)) + x2 = tensor(np.random.randn(1, 4, 1)) + x3 = tensor(np.random.randn(1, 1, 5)) - @trace(symbolic=symbolic, capture_as_const=True) - def f(x): - y = F.broadcast_to(x, (3, 4, 5)) - return y + @trace(symbolic=trace_mode, capture_as_const=True) + def f(x): + y = F.broadcast_to(x, (3, 4, 5)) + return y - f(x1) - f(x2) - f(x3) + f(x1) + f(x2) + f(x3) def test_trace_nms(): @@ -466,21 +474,20 @@ def test_slice(): y + y -def test_random(): +@pytest.mark.parametrize("shape_mode", [False, True]) +def test_random(shape_mode): def run_test(op): - for symbolic_shape in [True, False]: - - @trace(symbolic=True, symbolic_shape=symbolic_shape) - def f(): - out = op(size=[10, 10]) - out_shape = out.shape - assert out_shape is not None - if not isinstance(out_shape, tuple): - assert out.shape.numpy() is not None - return out - - for _ in range(3): - f() + @trace(symbolic=True, symbolic_shape=shape_mode) + def f(): + out = op(size=[10, 10]) + out_shape = out.shape + assert out_shape is not None + if not isinstance(out_shape, tuple): + assert out.shape.numpy() is not None + return out + + for _ in range(3): + f() run_test(uniform) run_test(normal) -- GitLab