diff --git a/imperative/python/test/integration/test_correctness.py b/imperative/python/test/integration/test_correctness.py index 73d3fbed247021c1a708a5645bd8814f09144bb2..7519c06a71053cac9c6f74efce4c7867a32ac54f 100644 --- a/imperative/python/test/integration/test_correctness.py +++ b/imperative/python/test/integration/test_correctness.py @@ -16,6 +16,8 @@ import pytest import megengine as mge import megengine.functional as F +from megengine import jit +from megengine.core._trace_option import set_tensor_shape from megengine.functional.debug_param import set_conv_execution_strategy from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module from megengine.optimizer import SGD @@ -129,7 +131,7 @@ def update_model(model_path): mge.save(checkpoint, model_path) -def run_test( +def run_train( model_path, use_jit, use_symbolic, sublinear_memory_config=None, max_err=None, ): @@ -175,6 +177,37 @@ def run_test( assertTensorClose(param[1], param_ref[1], max_err=max_err) +def run_eval( + model_path, use_symbolic, sublinear_memory_config=None, max_err=None, +): + + """ + Load the model with test cases and run the training for one iter. + The loss and updated weights are compared with reference value to verify the correctness. + + Dump a new file with updated result by calling update_model + if you think the test fails due to numerical rounding errors instead of bugs. + Please think twice before you do so. + + """ + net = MnistNet(has_bn=True) + checkpoint = mge.load(model_path) + net.load_state_dict(checkpoint["net_init"]) + + data = Tensor(checkpoint["data"], dtype=np.float32) + + def eval_fun(data, *, net=None): + pred = net(data) + return pred + + refer_value = eval_fun(data, net=net) + eval_fun = jit.trace(eval_fun, symbolic=use_symbolic) + + for _ in range(3): + new_value = eval_fun(data, net=net) + assertTensorClose(new_value.numpy(), refer_value.numpy(), max_err=max_err) + + def test_correctness(): if mge.is_cuda_available(): model_name = "mnist_model_with_test.mge" @@ -183,7 +216,7 @@ def test_correctness(): model_path = os.path.join(os.path.dirname(__file__), model_name) set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE") - run_test(model_path, False, False, max_err=1e-5) + run_train(model_path, False, False, max_err=1e-5) # run_test(model_path, True, False) # run_test(model_path, True, True) @@ -192,3 +225,6 @@ def test_correctness(): # run_test( # model_path, True, True, sublinear_memory_config=config, max_err=1e-5, # ) + + run_eval(model_path, False, max_err=1e-7) + # run_eval(model_path, True, max_err=1e-7) # XXX: fix me diff --git a/imperative/src/test/opr_utility.cpp b/imperative/src/test/opr_utility.cpp index 6454489be0b171683430c1b622edb812d7dc3b99..c808d2cd36067cf1a87538888453ffd6b18f66ba 100644 --- a/imperative/src/test/opr_utility.cpp +++ b/imperative/src/test/opr_utility.cpp @@ -25,7 +25,7 @@ TEST(TestOprUtility, InputCallback) { dv.copy_from(*hv).sync(); auto graph = ComputingGraph::make(); auto callback = [dv]() {return dv;}; - auto outputs = opr::InputCallback::make(*graph, callback, dv.comp_node(), dv.dtype()); + auto outputs = opr::InputCallback::make(*graph, callback, dv.comp_node(), dv.dtype(), {2, 3}); HostTensorND hout; ComputingGraph::OutputSpec outspec{make_callback_copy(outputs[0], hout)}; @@ -99,7 +99,7 @@ TEST(TestOprUtility, CallbackChain) { dev_x.storage({}); return ret; }; - auto out = opr::InputCallback::make(*graph, callback, cn, dev_x.dtype()); + auto out = opr::InputCallback::make(*graph, callback, cn, dev_x.dtype(), {2, 3}); x = out[0]; dummy = out[1]; }