# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import os import sys import numpy as np import megengine as mge import megengine.functional as F from megengine import jit, tensor from megengine.functional.debug_param import set_conv_execution_strategy from megengine.module import BatchNorm2d, Conv2d, Linear, MaxPool2d, Module from megengine.optimizer import SGD from megengine.test import assertTensorClose class MnistNet(Module): def __init__(self, has_bn=False): super().__init__() self.conv0 = Conv2d(1, 20, kernel_size=5, bias=True) self.pool0 = MaxPool2d(2) self.conv1 = Conv2d(20, 20, kernel_size=5, bias=True) self.pool1 = MaxPool2d(2) self.fc0 = Linear(20 * 4 * 4, 500, bias=True) self.fc1 = Linear(500, 10, bias=True) self.bn0 = None self.bn1 = None if has_bn: self.bn0 = BatchNorm2d(20) self.bn1 = BatchNorm2d(20) def forward(self, x): x = self.conv0(x) if self.bn0: x = self.bn0(x) x = F.relu(x) x = self.pool0(x) x = self.conv1(x) if self.bn1: x = self.bn1(x) x = F.relu(x) x = self.pool1(x) x = F.flatten(x, 1) x = self.fc0(x) x = F.relu(x) x = self.fc1(x) return x def train(data, label, net, opt): pred = net(data) loss = F.cross_entropy_with_softmax(pred, label) opt.backward(loss) return loss def update_model(model_path): """ Update the dumped model with test cases for new reference values. The model with pre-trained weights is trained for one iter with the test data attached. The loss and updated net state dict is dumped. """ net = MnistNet(has_bn=True) checkpoint = mge.load(model_path) net.load_state_dict(checkpoint["net_init"]) lr = checkpoint["sgd_lr"] opt = SGD(net.parameters(), lr=lr) data = tensor(dtype=np.float32) label = tensor(dtype=np.int32) data.set_value(checkpoint["data"]) label.set_value(checkpoint["label"]) opt.zero_grad() loss = train(data, label, net=net, opt=opt) opt.step() checkpoint.update({"net_updated": net.state_dict(), "loss": loss.numpy()}) mge.save(checkpoint, model_path) def run_test(model_path, use_jit, use_symbolic): """ Load the model with test cases and run the training for one iter. The loss and updated weights are compared with reference value to verify the correctness. Dump a new file with updated result by calling update_model if you think the test fails due to numerical rounding errors instead of bugs. Please think twice before you do so. """ net = MnistNet(has_bn=True) checkpoint = mge.load(model_path) net.load_state_dict(checkpoint["net_init"]) lr = checkpoint["sgd_lr"] opt = SGD(net.parameters(), lr=lr) data = tensor(dtype=np.float32) label = tensor(dtype=np.int32) data.set_value(checkpoint["data"]) label.set_value(checkpoint["label"]) max_err = 1e-1 train_func = train if use_jit: train_func = jit.trace(train_func, symbolic=use_symbolic) opt.zero_grad() loss = train_func(data, label, net=net, opt=opt) opt.step() assertTensorClose(loss.numpy(), checkpoint["loss"], max_err=max_err) for param, param_ref in zip( net.state_dict().items(), checkpoint["net_updated"].items() ): assert param[0] == param_ref[0] assertTensorClose(param[1], param_ref[1], max_err=max_err) def test_correctness(): if mge.is_cuda_available(): model_name = "mnist_model_with_test.mge" else: model_name = "mnist_model_with_test_cpu.mge" model_path = os.path.join(os.path.dirname(__file__), model_name) set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE") run_test(model_path, False, False) run_test(model_path, True, False) run_test(model_path, True, True)