提交 afddefb6 编写于 作者: M Megvii Engine Team

feat(mge/imperative): add more trace test

GitOrigin-RevId: b02e420a8a4ef7290fa103aa45487f89ed83db0e
上级 a085b71c
...@@ -16,6 +16,8 @@ import pytest ...@@ -16,6 +16,8 @@ import pytest
import megengine as mge import megengine as mge
import megengine.functional as F import megengine.functional as F
from megengine import jit
from megengine.core._trace_option import set_tensor_shape
from megengine.functional.debug_param import set_conv_execution_strategy from megengine.functional.debug_param import set_conv_execution_strategy
from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module
from megengine.optimizer import SGD from megengine.optimizer import SGD
...@@ -129,7 +131,7 @@ def update_model(model_path): ...@@ -129,7 +131,7 @@ def update_model(model_path):
mge.save(checkpoint, model_path) mge.save(checkpoint, model_path)
def run_test( def run_train(
model_path, use_jit, use_symbolic, sublinear_memory_config=None, max_err=None, model_path, use_jit, use_symbolic, sublinear_memory_config=None, max_err=None,
): ):
...@@ -175,6 +177,37 @@ def run_test( ...@@ -175,6 +177,37 @@ def run_test(
assertTensorClose(param[1], param_ref[1], max_err=max_err) assertTensorClose(param[1], param_ref[1], max_err=max_err)
def run_eval(
model_path, use_symbolic, sublinear_memory_config=None, max_err=None,
):
"""
Load the model with test cases and run the training for one iter.
The loss and updated weights are compared with reference value to verify the correctness.
Dump a new file with updated result by calling update_model
if you think the test fails due to numerical rounding errors instead of bugs.
Please think twice before you do so.
"""
net = MnistNet(has_bn=True)
checkpoint = mge.load(model_path)
net.load_state_dict(checkpoint["net_init"])
data = Tensor(checkpoint["data"], dtype=np.float32)
def eval_fun(data, *, net=None):
pred = net(data)
return pred
refer_value = eval_fun(data, net=net)
eval_fun = jit.trace(eval_fun, symbolic=use_symbolic)
for _ in range(3):
new_value = eval_fun(data, net=net)
assertTensorClose(new_value.numpy(), refer_value.numpy(), max_err=max_err)
def test_correctness(): def test_correctness():
if mge.is_cuda_available(): if mge.is_cuda_available():
model_name = "mnist_model_with_test.mge" model_name = "mnist_model_with_test.mge"
...@@ -183,7 +216,7 @@ def test_correctness(): ...@@ -183,7 +216,7 @@ def test_correctness():
model_path = os.path.join(os.path.dirname(__file__), model_name) model_path = os.path.join(os.path.dirname(__file__), model_name)
set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE") set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE")
run_test(model_path, False, False, max_err=1e-5) run_train(model_path, False, False, max_err=1e-5)
# run_test(model_path, True, False) # run_test(model_path, True, False)
# run_test(model_path, True, True) # run_test(model_path, True, True)
...@@ -192,3 +225,6 @@ def test_correctness(): ...@@ -192,3 +225,6 @@ def test_correctness():
# run_test( # run_test(
# model_path, True, True, sublinear_memory_config=config, max_err=1e-5, # model_path, True, True, sublinear_memory_config=config, max_err=1e-5,
# ) # )
run_eval(model_path, False, max_err=1e-7)
# run_eval(model_path, True, max_err=1e-7) # XXX: fix me
...@@ -25,7 +25,7 @@ TEST(TestOprUtility, InputCallback) { ...@@ -25,7 +25,7 @@ TEST(TestOprUtility, InputCallback) {
dv.copy_from(*hv).sync(); dv.copy_from(*hv).sync();
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
auto callback = [dv]() {return dv;}; auto callback = [dv]() {return dv;};
auto outputs = opr::InputCallback::make(*graph, callback, dv.comp_node(), dv.dtype()); auto outputs = opr::InputCallback::make(*graph, callback, dv.comp_node(), dv.dtype(), {2, 3});
HostTensorND hout; HostTensorND hout;
ComputingGraph::OutputSpec outspec{make_callback_copy(outputs[0], hout)}; ComputingGraph::OutputSpec outspec{make_callback_copy(outputs[0], hout)};
...@@ -99,7 +99,7 @@ TEST(TestOprUtility, CallbackChain) { ...@@ -99,7 +99,7 @@ TEST(TestOprUtility, CallbackChain) {
dev_x.storage({}); dev_x.storage({});
return ret; return ret;
}; };
auto out = opr::InputCallback::make(*graph, callback, cn, dev_x.dtype()); auto out = opr::InputCallback::make(*graph, callback, cn, dev_x.dtype(), {2, 3});
x = out[0]; x = out[0];
dummy = out[1]; dummy = out[1];
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册