提交 c9830d20 编写于 作者: M Megvii Engine Team

fix(mge/test): fix tolerance

GitOrigin-RevId: 58c029b394edc17e1d6b1abb261cd76d58c6ab4f
上级 20450817
......@@ -63,7 +63,10 @@ def train(data, label, net, opt):
def update_model(model_path):
"""
Update the dumped model with test cases for new reference values
Update the dumped model with test cases for new reference values.
The model with pre-trained weights is trained for one iter with the test data attached.
The loss and updated net state dict is dumped.
"""
net = MnistNet(has_bn=True)
checkpoint = mge.load(model_path)
......@@ -89,9 +92,6 @@ def run_test(model_path, use_jit, use_symbolic):
"""
Load the model with test cases and run the training for one iter.
The loss and updated weights are compared with reference value to verify the correctness.
The model with pre-trained weights is trained for one iter and the net state dict is dumped.
The test cases is appended to the model file. The reference result is obtained
by running the train for one iter.
Dump a new file with updated result by calling update_model
if you think the test fails due to numerical rounding errors instead of bugs.
......@@ -109,7 +109,7 @@ def run_test(model_path, use_jit, use_symbolic):
data.set_value(checkpoint["data"])
label.set_value(checkpoint["label"])
max_err = 0.0
max_err = 1e-1
train_func = train
if use_jit:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册