提交 f642b05e 编写于 作者: M Megvii Engine Team

test(mge): update traced_module unit test

GitOrigin-RevId: 3948d50d7901a85737a19795bc1866ceb08bd29d
上级 fb20cb36
...@@ -229,6 +229,7 @@ class GetAttr(Expr): ...@@ -229,6 +229,7 @@ class GetAttr(Expr):
name = None name = None
r"""name: the qualified name of the attribute to be retrieved.""" r"""name: the qualified name of the attribute to be retrieved."""
def __init__(self, module, name, type=None, orig_name=None): def __init__(self, module, name, type=None, orig_name=None):
super().__init__() super().__init__()
assert isinstance(module, ModuleNode) assert isinstance(module, ModuleNode)
...@@ -276,6 +277,7 @@ class CallMethod(Expr): ...@@ -276,6 +277,7 @@ class CallMethod(Expr):
method: the method name. method: the method name.
Default: "__call__" Default: "__call__"
""" """
def __init__(self, node, method="__call__"): def __init__(self, node, method="__call__"):
super().__init__() super().__init__()
if isinstance(node, type): if isinstance(node, type):
...@@ -351,6 +353,7 @@ class Apply(Expr): ...@@ -351,6 +353,7 @@ class Apply(Expr):
opdef: the applied :class:`OpDef`. opdef: the applied :class:`OpDef`.
""" """
opdef = None opdef = None
def __init__(self, opdef): def __init__(self, opdef):
super().__init__() super().__init__()
assert isinstance(opdef, OpDef) assert isinstance(opdef, OpDef)
...@@ -422,6 +425,7 @@ class CallFunction(Expr): ...@@ -422,6 +425,7 @@ class CallFunction(Expr):
Args: Args:
func: a built-in function. func: a built-in function.
""" """
def __init__(self, func): def __init__(self, func):
super().__init__() super().__init__()
assert isinstance(func, Callable) assert isinstance(func, Callable)
......
...@@ -5,12 +5,21 @@ ...@@ -5,12 +5,21 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pickle
import numpy as np import numpy as np
import megengine.functional as F import megengine.functional as F
import megengine.module as M import megengine.module as M
from megengine.module.identity import Identity
from megengine.traced_module import trace_module from megengine.traced_module import trace_module
from megengine.traced_module.expr import CallFunction, GetAttr from megengine.traced_module.expr import CallFunction, Expr, GetAttr
from megengine.traced_module.node import Node
class IdentityMod(M.Module):
def forward(self, x):
return x
class MyBlock(M.Module): class MyBlock(M.Module):
...@@ -18,11 +27,13 @@ class MyBlock(M.Module): ...@@ -18,11 +27,13 @@ class MyBlock(M.Module):
super(MyBlock, self).__init__() super(MyBlock, self).__init__()
self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False) self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False)
self.bn1 = M.BatchNorm2d(channels) self.bn1 = M.BatchNorm2d(channels)
self.nothing = IdentityMod()
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = self.bn1(x) x = self.bn1(x)
x = F.relu(x) + 1 x = F.relu(x) + 1
x = self.nothing(x)
return x return x
...@@ -31,10 +42,24 @@ class MyModule(M.Module): ...@@ -31,10 +42,24 @@ class MyModule(M.Module):
super(MyModule, self).__init__() super(MyModule, self).__init__()
self.block0 = MyBlock() self.block0 = MyBlock()
self.block1 = MyBlock() self.block1 = MyBlock()
self.nothing = IdentityMod()
def forward(self, x): def forward(self, x):
x = self.block0(x) x = self.block0(x)
x = self.block1(x) x = self.block1(x)
x = self.nothing(x)
return x
class NewModule(M.Module):
def __init__(self, traced_module):
super(NewModule, self).__init__()
self.module = traced_module
def forward(self, x):
x = x - 1
x = self.module(x)
x = x + 1
return x return x
...@@ -82,6 +107,12 @@ def test_delete(): ...@@ -82,6 +107,12 @@ def test_delete():
graph.compile() graph.compile()
np.testing.assert_allclose(expect - 1, F.relu(traced_module(x) - 1), atol=1e-6) np.testing.assert_allclose(expect - 1, F.relu(traced_module(x) - 1), atol=1e-6)
# clear graph
graph.replace_node({graph.outputs[0]: graph.inputs[1]})
graph.compile()
np.testing.assert_equal(len(list(graph._exprs)), 0)
np.testing.assert_equal(traced_module(x).numpy(), x.numpy())
def test_flatten(): def test_flatten():
traced_module, x, expect = _init_module() traced_module, x, expect = _init_module()
...@@ -89,6 +120,74 @@ def test_flatten(): ...@@ -89,6 +120,74 @@ def test_flatten():
traced_module.graph.compile() traced_module.graph.compile()
assert all(not isinstance(i, GetAttr) for i in traced_module.graph._exprs) assert all(not isinstance(i, GetAttr) for i in traced_module.graph._exprs)
assert len(traced_module.graph._exprs) == 12 assert len(traced_module.graph._exprs) == 12
np.testing.assert_equal(expect.numpy(), traced_module(x).numpy())
def test_id_and_name():
def _check_id(traced_module):
_total_ids = traced_module.graph._total_ids
node_ids = [n._id for n in traced_module.graph.nodes().as_list()]
assert len(set(node_ids)) == len(node_ids)
assert max(node_ids) + 1 == len(node_ids)
expr_ids = [n._id for n in traced_module.graph.exprs().as_list()]
assert len(set(expr_ids)) == len(expr_ids)
assert max(expr_ids) + 1 == _total_ids[1]
def _check_name(flatened_module):
node_names = [n._name for n in flatened_module.graph.nodes().as_list()]
assert len(set(node_names)) == len(node_names)
traced_module, x, expect = _init_module()
_check_id(traced_module)
flattened_module = traced_module.flatten()
_check_id(flattened_module)
_check_name(flattened_module)
# pickle check
obj = pickle.dumps(traced_module)
traced_module = pickle.loads(obj)
Node._set_next_id(159)
Expr._set_next_id(1024)
graph = traced_module.graph
for expr in graph.get_function_by_type(F.relu).as_list():
relu_out = expr.outputs[0]
cur_graph = expr.top_graph
with cur_graph.insert_exprs():
neg_out = F.neg(relu_out)
cur_graph.replace_node({relu_out: neg_out})
cur_graph.compile()
_check_id(traced_module)
flattened_module = traced_module.flatten()
_check_id(flattened_module)
_check_name(flattened_module)
# check trace TracedModule
obj = pickle.dumps(traced_module)
traced_module = pickle.loads(obj)
module = NewModule(traced_module)
traced_module = trace_module(module, x)
_check_id(traced_module)
flattened_module = traced_module.flatten()
_check_id(flattened_module)
_check_name(flattened_module)
def test_set_name():
traced_module, x, expect = _init_module()
graph = traced_module.graph
output_node = graph.outputs[0]
def rename(name):
output_node.name = name
np.testing.assert_raises(AssertionError, rename, "block1_out")
rename("output")
np.testing.assert_equal(str(graph.outputs[0]), "output")
def test_extra_block(): def test_extra_block():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册