mnist_train.py 2.7 KB
Newer Older
Z
zheng-huanhuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys

import mindspore.nn as nn
Z
zheng-huanhuan 已提交
18 19 20
from mindspore import context
from mindspore.nn.metrics import Accuracy
from mindspore.train import Model
Z
zheng-huanhuan 已提交
21 22 23 24
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net

from lenet5_net import LeNet5
Z
zheng-huanhuan 已提交
25
from mindarmour.utils.logger import LogUtil
Z
zheng-huanhuan 已提交
26 27 28 29

sys.path.append("..")
from data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance()
P
pkuliuliu 已提交
30
TAG = "Lenet5_train"
Z
zheng-huanhuan 已提交
31 32 33 34 35 36 37 38


def mnist_train(epoch_size, batch_size, lr, momentum):
    mnist_path = "./MNIST_unzip/"
    ds = generate_mnist_dataset(os.path.join(mnist_path, "train"),
                                batch_size=batch_size, repeat_size=1)

    network = LeNet5()
P
pkuliuliu 已提交
39 40
    net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True,
                                                reduction="mean")
Z
zheng-huanhuan 已提交
41
    net_opt = nn.Momentum(network.trainable_params(), lr, momentum)
P
pkuliuliu 已提交
42 43 44 45 46
    config_ck = CheckpointConfig(save_checkpoint_steps=1875,
                                 keep_checkpoint_max=10)
    ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
                                 directory="./trained_ckpt_file/",
                                 config=config_ck)
Z
zheng-huanhuan 已提交
47 48 49
    model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})

    LOGGER.info(TAG, "============== Starting Training ==============")
P
pkuliuliu 已提交
50 51
    model.train(epoch_size, ds, callbacks=[ckpoint_cb, LossMonitor()],
                dataset_sink_mode=False)
Z
zheng-huanhuan 已提交
52 53

    LOGGER.info(TAG, "============== Starting Testing ==============")
P
pkuliuliu 已提交
54 55
    ckpt_file_name = "trained_ckpt_file/checkpoint_lenet-10_1875.ckpt"
    param_dict = load_checkpoint(ckpt_file_name)
Z
zheng-huanhuan 已提交
56
    load_param_into_net(network, param_dict)
P
pkuliuliu 已提交
57 58 59
    ds_eval = generate_mnist_dataset(os.path.join(mnist_path, "test"),
                                     batch_size=batch_size)
    acc = model.eval(ds_eval, dataset_sink_mode=False)
Z
zheng-huanhuan 已提交
60 61 62 63
    LOGGER.info(TAG, "============== Accuracy: %s ==============", acc)


if __name__ == '__main__':
64 65
    context.set_context(mode=context.GRAPH_MODE, device_target="CPU",
                        enable_mem_reuse=False)
P
pkuliuliu 已提交
66
    mnist_train(10, 32, 0.01, 0.9)