mnist.py 2.9 KB
Newer Older
L
LielinJiang 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
L
LielinJiang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division
from __future__ import print_function

import argparse
L
LielinJiang 已提交
19
import paddle
L
LielinJiang 已提交
20 21 22

from paddle import fluid
from paddle.fluid.optimizer import Momentum
L
LielinJiang 已提交
23
from paddle.vision.datasets.mnist import MNIST
L
LielinJiang 已提交
24

L
LielinJiang 已提交
25 26
from paddle.vision.models import LeNet
from paddle.static import InputSpec as Input
L
LielinJiang 已提交
27 28 29


def main():
L
LielinJiang 已提交
30 31
    device = paddle.set_device(FLAGS.device)
    paddle.disable_static(device) if FLAGS.dynamic else None
L
LielinJiang 已提交
32

L
LielinJiang 已提交
33 34
    train_dataset = MNIST(mode='train')
    val_dataset = MNIST(mode='test')
L
LielinJiang 已提交
35

L
LielinJiang 已提交
36 37 38 39 40
    inputs = [Input(shape=[None, 1, 28, 28], dtype='float32', name='image')]
    labels = [Input(shape=[None, 1], dtype='int64', name='label')]

    net = LeNet()
    model = paddle.Model(net, inputs, labels)
L
LielinJiang 已提交
41 42 43 44 45 46

    optim = Momentum(
        learning_rate=FLAGS.lr, momentum=.9, parameter_list=model.parameters())

    model.prepare(
        optim,
L
LielinJiang 已提交
47 48
        paddle.nn.CrossEntropyLoss(),
        paddle.metric.Accuracy(topk=(1, 2)))
L
LielinJiang 已提交
49 50 51 52 53 54 55 56 57 58 59 60

    if FLAGS.resume is not None:
        model.load(FLAGS.resume)

    if FLAGS.eval_only:
        model.evaluate(val_dataset, batch_size=FLAGS.batch_size)
        return

    model.fit(train_dataset,
              val_dataset,
              epochs=FLAGS.epoch,
              batch_size=FLAGS.batch_size,
L
LielinJiang 已提交
61
              save_dir=FLAGS.output_dir)
L
LielinJiang 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81


if __name__ == '__main__':
    parser = argparse.ArgumentParser("CNN training on MNIST")
    parser.add_argument(
        "--device", type=str, default='gpu', help="device to use, gpu or cpu")
    parser.add_argument(
        "-d", "--dynamic", action='store_true', help="enable dygraph mode")
    parser.add_argument(
        "-e", "--epoch", default=10, type=int, help="number of epoch")
    parser.add_argument(
        '--lr',
        '--learning-rate',
        default=1e-3,
        type=float,
        metavar='LR',
        help='initial learning rate')
    parser.add_argument(
        "-b", "--batch_size", default=128, type=int, help="batch size")
    parser.add_argument(
L
LielinJiang 已提交
82 83 84 85
        "--output-dir",
        type=str,
        default='mnist_checkpoint',
        help="checkpoint save dir")
L
LielinJiang 已提交
86 87 88 89 90 91 92 93 94 95
    parser.add_argument(
        "-r",
        "--resume",
        default=None,
        type=str,
        help="checkpoint path to resume")
    parser.add_argument(
        "--eval-only", action='store_true', help="only evaluate the model")
    FLAGS = parser.parse_args()
    main()