提交 fd83e025 编写于 作者: M Megvii Engine Team

fix(mge/test): fix reproducibility of test_correctness.py

GitOrigin-RevId: 70bd43bbabd022ce9538d92291ddb1c93a596e0a
上级 b8d27934
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os import os
import re
import subprocess
import sys import sys
import numpy as np import numpy as np
...@@ -15,18 +17,48 @@ import megengine as mge ...@@ -15,18 +17,48 @@ import megengine as mge
import megengine.functional as F import megengine.functional as F
from megengine import jit, tensor from megengine import jit, tensor
from megengine.functional.debug_param import set_conv_execution_strategy from megengine.functional.debug_param import set_conv_execution_strategy
from megengine.module import BatchNorm2d, Conv2d, Linear, MaxPool2d, Module from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module
from megengine.optimizer import SGD from megengine.optimizer import SGD
from megengine.test import assertTensorClose from megengine.test import assertTensorClose
def get_gpu_name():
try:
gpu_info = subprocess.check_output(
["nvidia-smi", "--query-gpu=gpu_name", "--format=csv,noheader"]
)
gpu_info = gpu_info.decode("ascii").split("\n")[0]
except:
gpu_info = "None"
return gpu_info
def get_cpu_name():
cpu_info = "None"
try:
cpu_info = subprocess.check_output(["cat", "/proc/cpuinfo"]).decode("ascii")
for line in cpu_info.split("\n"):
if "model name" in line:
return re.sub(".*model name.*:", "", line, 1).strip()
except:
pass
return cpu_info
def get_xpu_name():
if mge.is_cuda_available():
return get_gpu_name()
else:
return get_cpu_name()
class MnistNet(Module): class MnistNet(Module):
def __init__(self, has_bn=False): def __init__(self, has_bn=False):
super().__init__() super().__init__()
self.conv0 = Conv2d(1, 20, kernel_size=5, bias=True) self.conv0 = Conv2d(1, 20, kernel_size=5, bias=True)
self.pool0 = MaxPool2d(2) self.pool0 = AvgPool2d(2)
self.conv1 = Conv2d(20, 20, kernel_size=5, bias=True) self.conv1 = Conv2d(20, 20, kernel_size=5, bias=True)
self.pool1 = MaxPool2d(2) self.pool1 = AvgPool2d(2)
self.fc0 = Linear(20 * 4 * 4, 500, bias=True) self.fc0 = Linear(20 * 4 * 4, 500, bias=True)
self.fc1 = Linear(500, 10, bias=True) self.fc1 = Linear(500, 10, bias=True)
self.bn0 = None self.bn0 = None
...@@ -67,6 +99,13 @@ def update_model(model_path): ...@@ -67,6 +99,13 @@ def update_model(model_path):
The model with pre-trained weights is trained for one iter with the test data attached. The model with pre-trained weights is trained for one iter with the test data attached.
The loss and updated net state dict is dumped. The loss and updated net state dict is dumped.
.. code-block:: python
from test_correctness import update_model
update_model('mnist_model_with_test.mge') # for gpu
update_model('mnist_model_with_test_cpu.mge') # for cpu
""" """
net = MnistNet(has_bn=True) net = MnistNet(has_bn=True)
checkpoint = mge.load(model_path) checkpoint = mge.load(model_path)
...@@ -83,7 +122,11 @@ def update_model(model_path): ...@@ -83,7 +122,11 @@ def update_model(model_path):
loss = train(data, label, net=net, opt=opt) loss = train(data, label, net=net, opt=opt)
opt.step() opt.step()
checkpoint.update({"net_updated": net.state_dict(), "loss": loss.numpy()}) xpu_name = get_xpu_name()
checkpoint.update(
{"net_updated": net.state_dict(), "loss": loss.numpy(), "xpu": xpu_name}
)
mge.save(checkpoint, model_path) mge.save(checkpoint, model_path)
...@@ -109,7 +152,7 @@ def run_test(model_path, use_jit, use_symbolic): ...@@ -109,7 +152,7 @@ def run_test(model_path, use_jit, use_symbolic):
data.set_value(checkpoint["data"]) data.set_value(checkpoint["data"])
label.set_value(checkpoint["label"]) label.set_value(checkpoint["label"])
max_err = 1e-1 max_err = 1e-5
train_func = train train_func = train
if use_jit: if use_jit:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册