main.py 3.0 KB
Newer Older
D
dyonghan 已提交
1 2 3 4 5 6 7 8 9 10
# LeNet5 mnist

import os
# os.environ['DEVICE_ID'] = '0'

import mindspore as ms
import mindspore.context as context
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as CV

D
dyonghan 已提交
11
from mindspore import nn
D
dyonghan 已提交
12
from mindspore.train import Model
D
dyonghan 已提交
13
from mindspore.train.callback import LossMonitor
D
dyonghan 已提交
14 15 16 17 18 19 20 21 22 23

context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')

DATA_DIR_TRAIN = "MNIST/train" # 训练集信息
DATA_DIR_TEST = "MNIST/test" # 测试集信息


def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32),
                   rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64):
    ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST)
D
dyonghan 已提交
24
    ds = ds.map(input_columns="image", operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()])
D
dyonghan 已提交
25
    ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32))
D
dyonghan 已提交
26
    ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True).repeat(num_epoch)
D
dyonghan 已提交
27 28 29 30
    
    return ds


D
dyonghan 已提交
31
class LeNet5(nn.Cell):
D
dyonghan 已提交
32
    def __init__(self):
D
dyonghan 已提交
33
        super(LeNet5, self).__init__()
D
dyonghan 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid')
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Dense(400, 120)
        self.fc2 = nn.Dense(120, 84)
        self.fc3 = nn.Dense(84, 10)
    
    def construct(self, input_x):
        output = self.conv1(input_x)
        output = self.relu(output)
        output = self.pool(output)
        output = self.conv2(output)
        output = self.relu(output)
        output = self.pool(output)
        output = self.flatten(output)
        output = self.fc1(output)
        output = self.fc2(output)
        output = self.fc3(output)
        
        return output


def test_train(lr=0.01, momentum=0.9, num_epoch=3, ckpt_name="a_lenet"):
    ds_train = create_dataset(num_epoch=num_epoch)
    ds_eval = create_dataset(training=False)
    
D
dyonghan 已提交
62
    net = LeNet5()
D
dyonghan 已提交
63 64 65
    loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
    opt = nn.Momentum(net.trainable_params(), lr, momentum)
    
D
dyonghan 已提交
66
    loss_cb = LossMonitor(per_print_times=1)
D
dyonghan 已提交
67 68
    
    model = Model(net, loss, opt, metrics={'acc', 'loss'})
D
dyonghan 已提交
69
    model.train(num_epoch, ds_train, callbacks=[loss_cb])
D
dyonghan 已提交
70 71 72
    metrics = model.eval(ds_eval)
    print('Metrics:', metrics)

D
dyonghan 已提交
73

D
dyonghan 已提交
74 75 76 77 78 79 80 81 82 83 84 85
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_url', required=True, default=None, help='Location of data.')
    parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')
    parser.add_argument('--num_epochs', type=int, default=1, help='Number of training epochs.')
    args, unknown = parser.parse_known_args()

    import moxing as mox
    mox.file.copy_parallel(src_url=args.data_url, dst_url='MNIST/')

    test_train()