From 442b4f6c26b1995bd387352058873d1709e42d5f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 23 Jun 2021 10:56:33 +0800 Subject: [PATCH] test(traced_module): add some testcases for traced module GitOrigin-RevId: 0d6bb20b2b5110b5ecd280ec055bb14aed74ebfc --- .../experimental/traced_module/expr.py | 3 +- .../experimental/traced_module/pytree.py | 31 ++++++++-- .../python/test/integration/test_converge.py | 13 ++-- .../test_converge_with_gradient_clip.py | 12 ++-- .../test/integration/test_trace_dump.py | 1 + .../python/test/unit/module/test_module.py | 48 ++++++++++++--- .../test/unit/traced_module/test_jit_trace.py | 59 +++++++++++++++++++ 7 files changed, 147 insertions(+), 20 deletions(-) create mode 100644 imperative/python/test/unit/traced_module/test_jit_trace.py diff --git a/imperative/python/megengine/experimental/traced_module/expr.py b/imperative/python/megengine/experimental/traced_module/expr.py index 63318e2d..c4dce926 100644 --- a/imperative/python/megengine/experimental/traced_module/expr.py +++ b/imperative/python/megengine/experimental/traced_module/expr.py @@ -201,7 +201,8 @@ class Apply(Expr): NodeMixin.wrap_safe(i, Constant.make(i)) apply_node = cls.make(opdef) for i in inputs: - apply_node.add_input(NodeMixin.get(i)) + assert isinstance(i, RawTensor) + apply_node.inputs.append(NodeMixin.get(i)) unset_module_tracing() outputs = apply(opdef, *inputs) diff --git a/imperative/python/megengine/experimental/traced_module/pytree.py b/imperative/python/megengine/experimental/traced_module/pytree.py index 73b2c05d..d3cb9fed 100644 --- a/imperative/python/megengine/experimental/traced_module/pytree.py +++ b/imperative/python/megengine/experimental/traced_module/pytree.py @@ -1,3 +1,13 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +import collections from typing import Callable, NamedTuple SUPPORTED_TYPE = {} @@ -9,11 +19,22 @@ def register_supported_type(type, flatten, unflatten): SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) +def _dict_flatten(inp): + aux_data = [] + results = [] + for key, value in sorted(inp.items()): + results.append(value) + aux_data.append(key) + return results, aux_data + + +def _dict_unflatten(inps, aux_data): + return dict(zip(aux_data, inps)) + + register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x)) -register_supported_type( - dict, lambda x: (list(x.values()), list(x.keys())), lambda x, y: dict(zip(y, x)) -) +register_supported_type(dict, _dict_flatten, _dict_unflatten) register_supported_type( slice, lambda x: ([x.start, x.stop, x.step], None), @@ -68,6 +89,8 @@ class TreeDef: class LeafDef(TreeDef): def __init__(self, type): + if not isinstance(type, collections.abc.Sequence): + type = (type,) super().__init__(type, None, []) self.num_leaves = 1 @@ -77,4 +100,4 @@ class LeafDef(TreeDef): return leaves[0] def __repr__(self): - return "Leaf({})".format(self.type.__name__) + return "Leaf({})".format(", ".join(t.__name__ for t in self.type)) diff --git a/imperative/python/test/integration/test_converge.py b/imperative/python/test/integration/test_converge.py index 08c0cb6d..ab32aaa9 100644 --- a/imperative/python/test/integration/test_converge.py +++ b/imperative/python/test/integration/test_converge.py @@ -14,6 +14,7 @@ import megengine as mge import megengine.autodiff as ad import megengine.functional as F from megengine import Tensor +from megengine.experimental.traced_module import trace_module from megengine.module import Linear, Module from megengine.optimizer import SGD @@ -71,8 +72,13 @@ class XORNet(Module): return x -def test_training_converge(): +@pytest.mark.parametrize("test_traced_module", [True, False]) +def test_training_converge(test_traced_module): net = XORNet() + if test_training_converge: + inp = Tensor(np.random.random((14, 2))) + net = trace_module(net, inp) + opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) gm = ad.GradManager().attach(net.parameters()) @@ -105,9 +111,8 @@ def test_training_converge(): xx = xx.reshape((ngrid * ngrid, 1)) yy = yy.reshape((ngrid * ngrid, 1)) data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32)) - - pred = infer(data).numpy() - precision = calculate_precision(data.numpy(), pred) + pred = infer(data) + precision = calculate_precision(data.numpy(), pred.numpy()) assert precision == 1.0, "Test precision must be high enough, get {}".format( precision ) diff --git a/imperative/python/test/integration/test_converge_with_gradient_clip.py b/imperative/python/test/integration/test_converge_with_gradient_clip.py index fd6c642b..6ec9fafe 100644 --- a/imperative/python/test/integration/test_converge_with_gradient_clip.py +++ b/imperative/python/test/integration/test_converge_with_gradient_clip.py @@ -15,6 +15,7 @@ import megengine.autodiff as ad import megengine.functional as F import megengine.optimizer as optim from megengine import Tensor +from megengine.experimental.traced_module import trace_module from megengine.jit import trace from megengine.module import Linear, Module from megengine.optimizer import SGD @@ -73,8 +74,12 @@ class XORNet(Module): return x -def test_training_converge(): +@pytest.mark.parametrize("test_traced_module", [True, False]) +def test_training_converge(test_traced_module): net = XORNet() + if test_traced_module: + inp = Tensor(np.random.random((14, 2))) + net = trace_module(net, inp) opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) gm = ad.GradManager().attach(net.parameters()) @@ -110,9 +115,8 @@ def test_training_converge(): xx = xx.reshape((ngrid * ngrid, 1)) yy = yy.reshape((ngrid * ngrid, 1)) data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32)) - - pred = infer(data).numpy() - precision = calculate_precision(data.numpy(), pred) + pred = infer(data) + precision = calculate_precision(data.numpy(), pred.numpy()) print("precision=", precision) assert precision == 1.0, "Test precision must be high enough, get {}".format( precision diff --git a/imperative/python/test/integration/test_trace_dump.py b/imperative/python/test/integration/test_trace_dump.py index c719ee94..e256354a 100644 --- a/imperative/python/test/integration/test_trace_dump.py +++ b/imperative/python/test/integration/test_trace_dump.py @@ -19,6 +19,7 @@ import megengine.module as M import megengine.optimizer as optim from megengine import tensor from megengine.autodiff import GradManager +from megengine.experimental.traced_module import trace_module from megengine.jit import trace diff --git a/imperative/python/test/unit/module/test_module.py b/imperative/python/test/unit/module/test_module.py index 235a54d1..083b2a52 100644 --- a/imperative/python/test/unit/module/test_module.py +++ b/imperative/python/test/unit/module/test_module.py @@ -15,6 +15,7 @@ import pytest import megengine as mge import megengine.functional as F from megengine import Parameter, Tensor, tensor +from megengine.experimental.traced_module import TracedModule, trace_module from megengine.module import ( BatchNorm1d, BatchNorm2d, @@ -67,8 +68,18 @@ class MyModule(Module): return x -def test_module_api(): +@pytest.mark.parametrize("test_traced_module", [True, False]) +def test_module_api(test_traced_module): m = MyModule() + if test_traced_module: + buff = m.buff + param = m.param + m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16)))) + assert "buff" not in m.__dict__ + assert "param" not in m.__dict__ + m.buff = buff + m.param = param + assert list(m.children()) == [m.bn, m.i] assert list(m.named_children()) == [("bn", m.bn), ("i", m.i)] assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] @@ -141,8 +152,11 @@ def test_module_api(): assert m.bn.training == False and m.i.bn.training == False -def test_module_api_reuse_submodule(): +@pytest.mark.parametrize("test_traced_module", [True, False]) +def test_module_api_reuse_submodule(test_traced_module): m = MyModule() + if test_traced_module: + m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16)))) m.h = m.i # pylint: disable=attribute-defined-outside-init assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] assert list(m.named_modules()) == [ @@ -153,15 +167,21 @@ def test_module_api_reuse_submodule(): ] -def test_module_api_iterable_stability(): +@pytest.mark.parametrize("test_traced_module", [True, False]) +def test_module_api_iterable_stability(test_traced_module): m = MyModule() + if test_traced_module: + m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16)))) l = list(m.modules()) for _ in range(100): assert list(m.modules()) == l -def test_module_api_hooks(): +@pytest.mark.parametrize("test_traced_module", [True, False]) +def test_module_api_hooks(test_traced_module): net = MyModule() + if test_traced_module: + net = trace_module(net, Tensor(np.zeros((1, 4, 1, 1)))) pre_hook_num = 0 post_hook_num = 0 hooks = [] @@ -383,11 +403,16 @@ class Simple(Module): self.conv1.weight = self.conv0.weight def forward(self, inputs): - pass + x = self.conv0(inputs) + y = self.conv1(inputs) + return x + y -def test_shared_param(): +@pytest.mark.parametrize("test_traced_module", [True, False]) +def test_shared_param(test_traced_module): net = Simple() + if test_traced_module: + net = trace_module(net, tensor(np.random.random((1, 1, 8, 8)))) assert net.conv0.weight is net.conv1.weight data = tensor(np.random.random((1, 1, 8, 8)).astype(np.float32)) np.testing.assert_allclose(net.conv0(data).numpy(), net.conv1(data).numpy()) @@ -449,15 +474,21 @@ def test_shared_param_1d(): np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy()) -def test_pickle_module(): +@pytest.mark.parametrize("test_traced_module", [True, False]) +def test_pickle_module(test_traced_module): data_shape = (2, 28) data = tensor(np.random.random(data_shape)) mlp = MLP() + pred_gt = mlp(data) + if test_traced_module: + mlp = trace_module(mlp, data) # pickle before forward with BytesIO() as fout: mge.save(mlp, fout) fout.seek(0) mlp1 = mge.load(fout) + if test_traced_module: + assert type(mlp1) == TracedModule pred0 = mlp1(data) pred1 = mlp(data) @@ -467,8 +498,11 @@ def test_pickle_module(): mge.save(mlp, fout) fout.seek(0) mlp1 = mge.load(fout) + if test_traced_module: + assert type(mlp1) == TracedModule pred2 = mlp1(data) + np.testing.assert_allclose(pred_gt.numpy(), pred1.numpy(), atol=5e-6) np.testing.assert_allclose(pred0.numpy(), pred1.numpy(), atol=5e-6) np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6) diff --git a/imperative/python/test/unit/traced_module/test_jit_trace.py b/imperative/python/test/unit/traced_module/test_jit_trace.py new file mode 100644 index 00000000..0cd3bcae --- /dev/null +++ b/imperative/python/test/unit/traced_module/test_jit_trace.py @@ -0,0 +1,59 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import io + +import numpy as np + +import megengine.functional as F +import megengine.module as M +import megengine.utils.comp_graph_tools as cgtools +from megengine.experimental.traced_module import trace_module +from megengine.jit import trace +from megengine.module import Module + + +class MyBlock(Module): + def __init__(self, in_channels, channels): + super(MyBlock, self).__init__() + self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False) + self.bn1 = M.BatchNorm2d(channels) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + 1 + return x + + +class MyModule(Module): + def __init__(self): + super(MyModule, self).__init__() + self.block0 = MyBlock(8, 4) + self.block1 = MyBlock(4, 2) + + def forward(self, x): + x = self.block0(x) + x = self.block1(x) + return x + + +def test_jit_trace(): + module = MyModule() + module.eval() + x = F.ones((1, 8, 14, 14)) + expect = module(x) + traced_module = trace_module(module, x) + func = trace(traced_module, capture_as_const=True) + np.testing.assert_array_equal(func(x), expect) + model = io.BytesIO() + func.dump(model) + model.seek(0) + infer_cg = cgtools.GraphInference(model) + np.testing.assert_allclose( + list(infer_cg.run(x.numpy()).values())[0], expect, atol=1e-6 + ) -- GitLab