提交 f4b16932 编写于 作者: M Megvii Engine Team

test(mgb/imperative): add adaptive pooling pytest

GitOrigin-RevId: c4dfed1f8047b3c6c6cca428e9790b5a2cff0b4a
上级 37e56f4b
......@@ -19,9 +19,17 @@ import megengine.autodiff as ad
import megengine.functional as F
from megengine import jit
from megengine.core._trace_option import set_tensor_shape
from megengine.core.tensor.utils import make_shape_tuple
from megengine.functional.debug_param import set_conv_execution_strategy
from megengine.jit import SublinearMemoryConfig
from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module
from megengine.module import (
AdaptiveAvgPool2d,
AvgPool2d,
BatchNorm2d,
Conv2d,
Linear,
Module,
)
from megengine.optimizer import SGD
from megengine.tensor import Tensor
......@@ -57,10 +65,13 @@ def get_xpu_name():
class MnistNet(Module):
def __init__(self, has_bn=False):
def __init__(self, has_bn=False, use_adaptive_pooling=False):
super().__init__()
self.conv0 = Conv2d(1, 20, kernel_size=5, bias=True)
self.pool0 = AvgPool2d(2)
if use_adaptive_pooling:
self.pool0 = AdaptiveAvgPool2d(12)
else:
self.pool0 = AvgPool2d(2)
self.conv1 = Conv2d(20, 20, kernel_size=5, bias=True)
self.pool1 = AvgPool2d(2)
self.fc0 = Linear(20 * 4 * 4, 500, bias=True)
......@@ -134,7 +145,12 @@ def update_model(model_path):
def run_train(
model_path, use_jit, use_symbolic, sublinear_memory_config=None, max_err=None,
model_path,
use_jit,
use_symbolic,
sublinear_memory_config=None,
max_err=None,
use_adaptive_pooling=False,
):
"""
......@@ -146,7 +162,7 @@ def run_train(
Please think twice before you do so.
"""
net = MnistNet(has_bn=True)
net = MnistNet(has_bn=True, use_adaptive_pooling=use_adaptive_pooling)
checkpoint = mge.load(model_path)
net.load_state_dict(checkpoint["net_init"])
lr = checkpoint["sgd_lr"]
......@@ -181,7 +197,11 @@ def run_train(
def run_eval(
model_path, use_symbolic, sublinear_memory_config=None, max_err=None,
model_path,
use_symbolic,
sublinear_memory_config=None,
max_err=None,
use_adaptive_pooling=False,
):
"""
......@@ -193,7 +213,7 @@ def run_eval(
Please think twice before you do so.
"""
net = MnistNet(has_bn=True)
net = MnistNet(has_bn=True, use_adaptive_pooling=use_adaptive_pooling)
checkpoint = mge.load(model_path)
net.load_state_dict(checkpoint["net_init"])
......@@ -231,3 +251,30 @@ def test_correctness():
run_eval(model_path, False, max_err=1e-7)
run_eval(model_path, True, max_err=1e-7)
def test_correctness_use_adaptive_pooling():
if mge.is_cuda_available():
model_name = "mnist_model_with_test.mge"
else:
model_name = "mnist_model_with_test_cpu.mge"
model_path = os.path.join(os.path.dirname(__file__), model_name)
set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE")
run_train(model_path, False, False, max_err=1e-5, use_adaptive_pooling=True)
run_train(model_path, True, False, max_err=1e-5, use_adaptive_pooling=True)
run_train(model_path, True, True, max_err=1e-5, use_adaptive_pooling=True)
# sublinear
config = SublinearMemoryConfig(genetic_nr_iter=10)
run_train(
model_path,
True,
True,
sublinear_memory_config=config,
max_err=1e-5,
use_adaptive_pooling=True,
)
run_eval(model_path, False, max_err=1e-7, use_adaptive_pooling=True)
run_eval(model_path, True, max_err=1e-7, use_adaptive_pooling=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册