mnist_train.py 2.9 KB
Newer Older
Z
zheng-huanhuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import sys

import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy

from mindarmour.utils.logger import LogUtil

from lenet5_net import LeNet5

sys.path.append("..")
from data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance()
P
pkuliuliu 已提交
32
TAG = "Lenet5_train"
Z
zheng-huanhuan 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46


def mnist_train(epoch_size, batch_size, lr, momentum):
    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
                        enable_mem_reuse=False)

    lr = lr
    momentum = momentum
    epoch_size = epoch_size
    mnist_path = "./MNIST_unzip/"
    ds = generate_mnist_dataset(os.path.join(mnist_path, "train"),
                                batch_size=batch_size, repeat_size=1)

    network = LeNet5()
P
pkuliuliu 已提交
47 48
    net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True,
                                                reduction="mean")
Z
zheng-huanhuan 已提交
49
    net_opt = nn.Momentum(network.trainable_params(), lr, momentum)
P
pkuliuliu 已提交
50 51 52 53 54
    config_ck = CheckpointConfig(save_checkpoint_steps=1875,
                                 keep_checkpoint_max=10)
    ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
                                 directory="./trained_ckpt_file/",
                                 config=config_ck)
Z
zheng-huanhuan 已提交
55 56 57
    model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})

    LOGGER.info(TAG, "============== Starting Training ==============")
P
pkuliuliu 已提交
58 59
    model.train(epoch_size, ds, callbacks=[ckpoint_cb, LossMonitor()],
                dataset_sink_mode=False)
Z
zheng-huanhuan 已提交
60 61

    LOGGER.info(TAG, "============== Starting Testing ==============")
P
pkuliuliu 已提交
62 63
    ckpt_file_name = "trained_ckpt_file/checkpoint_lenet-10_1875.ckpt"
    param_dict = load_checkpoint(ckpt_file_name)
Z
zheng-huanhuan 已提交
64
    load_param_into_net(network, param_dict)
P
pkuliuliu 已提交
65 66 67
    ds_eval = generate_mnist_dataset(os.path.join(mnist_path, "test"),
                                     batch_size=batch_size)
    acc = model.eval(ds_eval, dataset_sink_mode=False)
Z
zheng-huanhuan 已提交
68 69 70 71
    LOGGER.info(TAG, "============== Accuracy: %s ==============", acc)


if __name__ == '__main__':
P
pkuliuliu 已提交
72
    mnist_train(10, 32, 0.01, 0.9)