train.py 8.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

M
minqiyang 已提交
15
from __future__ import print_function
16 17
import argparse
import ast
M
minqiyang 已提交
18
import numpy as np
19 20
from PIL import Image
import os
M
minqiyang 已提交
21 22 23
import paddle


24 25 26 27 28 29
def parse_args():
    parser = argparse.ArgumentParser("Training for Mnist.")
    parser.add_argument(
        "--use_data_parallel",
        type=ast.literal_eval,
        default=False,
C
chengduo 已提交
30 31
        help="The flag indicating whether to use data parallel mode to train the model."
    )
D
Divano 已提交
32 33
    parser.add_argument("-e", "--epoch", default=5, type=int, help="set epoch")
    parser.add_argument("--ce", action="store_true", help="run ce")
34 35 36 37 38 39
    parser.add_argument(
        '--use_gpu',
        type=ast.literal_eval,
        default=True,
        help='default use gpu.')

40 41 42 43
    args = parser.parse_args()
    return args


Z
zhang wenhui 已提交
44
class SimpleImgConvPool(paddle.nn.Layer):
M
minqiyang 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
    def __init__(self,
                 num_channels,
                 num_filters,
                 filter_size,
                 pool_size,
                 pool_stride,
                 pool_padding=0,
                 pool_type='max',
                 global_pooling=False,
                 conv_stride=1,
                 conv_padding=0,
                 conv_dilation=1,
                 conv_groups=1,
                 act=None,
                 use_cudnn=False,
                 param_attr=None,
                 bias_attr=None):
S
songyouwei 已提交
62
        super(SimpleImgConvPool, self).__init__()
M
minqiyang 已提交
63

Z
zhang wenhui 已提交
64 65 66 67
        self._conv2d = paddle.nn.Conv2d(
            in_channels=num_channels,
            out_channels=num_filters,
            kernel_size=filter_size,
M
minqiyang 已提交
68 69 70 71
            stride=conv_stride,
            padding=conv_padding,
            dilation=conv_dilation,
            groups=conv_groups,
Z
zhang wenhui 已提交
72 73 74
            weight_attr=None,
            bias_attr=None)
        self._act = act
M
minqiyang 已提交
75

Z
zhang wenhui 已提交
76
        self._pool2d = paddle.fluid.dygraph.nn.Pool2D(
M
minqiyang 已提交
77 78 79 80 81 82 83 84 85
            pool_size=pool_size,
            pool_type=pool_type,
            pool_stride=pool_stride,
            pool_padding=pool_padding,
            global_pooling=global_pooling,
            use_cudnn=use_cudnn)

    def forward(self, inputs):
        x = self._conv2d(inputs)
Z
zhang wenhui 已提交
86
        x = getattr(paddle.nn.functional, self._act)(x) if self._act else x
M
minqiyang 已提交
87 88 89 90
        x = self._pool2d(x)
        return x


Z
zhang wenhui 已提交
91
class MNIST(paddle.nn.Layer):
S
songyouwei 已提交
92 93
    def __init__(self):
        super(MNIST, self).__init__()
M
minqiyang 已提交
94 95

        self._simple_img_conv_pool_1 = SimpleImgConvPool(
S
songyouwei 已提交
96
            1, 20, 5, 2, 2, act="relu")
M
minqiyang 已提交
97 98

        self._simple_img_conv_pool_2 = SimpleImgConvPool(
S
songyouwei 已提交
99
            20, 50, 5, 2, 2, act="relu")
M
minqiyang 已提交
100

S
songyouwei 已提交
101
        self.pool_2_shape = 50 * 4 * 4
M
minqiyang 已提交
102
        SIZE = 10
S
songyouwei 已提交
103
        scale = (2.0 / (self.pool_2_shape**2 * SIZE))**0.5
Z
zhang wenhui 已提交
104 105 106 107 108 109
        self._fc = paddle.nn.Linear(
            in_features=self.pool_2_shape,
            out_features=10,
            weight_attr=paddle.ParamAttr(
                initializer=paddle.nn.initializer.Normal(
                    loc=0.0, scale=scale)))
M
minqiyang 已提交
110

111
    def forward(self, inputs, label=None):
M
minqiyang 已提交
112 113
        x = self._simple_img_conv_pool_1(inputs)
        x = self._simple_img_conv_pool_2(x)
Z
zhang wenhui 已提交
114
        x = paddle.fluid.layers.reshape(x, shape=[-1, self.pool_2_shape])
M
minqiyang 已提交
115
        x = self._fc(x)
Z
zhang wenhui 已提交
116
        x = paddle.nn.functional.softmax(x)
117
        if label is not None:
Z
zhang wenhui 已提交
118
            acc = paddle.metric.accuracy(input=x, label=label)
119 120 121 122 123
            return x, acc
        else:
            return x


124 125 126 127 128 129 130 131 132 133
def reader_decorator(reader):
    def __reader__():
        for item in reader():
            img = np.array(item[0]).astype('float32').reshape(1, 28, 28)
            label = np.array(item[1]).astype('int64').reshape(1)
            yield img, label

    return __reader__


134
def test_mnist(reader, model, batch_size):
135 136 137
    acc_set = []
    avg_loss_set = []
    for batch_id, data in enumerate(reader()):
138
        img, label = data
139 140
        label.stop_gradient = True
        prediction, acc = model(img, label)
Z
zhang wenhui 已提交
141 142
        loss = paddle.fluid.layers.cross_entropy(input=prediction, label=label)
        avg_loss = paddle.mean(loss)
143 144 145 146 147 148 149 150
        acc_set.append(float(acc.numpy()))
        avg_loss_set.append(float(avg_loss.numpy()))

        # get test acc and loss
    acc_val_mean = np.array(acc_set).mean()
    avg_loss_val_mean = np.array(avg_loss_set).mean()

    return avg_loss_val_mean, acc_val_mean
M
minqiyang 已提交
151 152


153
def inference_mnist():
154
    if not args.use_gpu:
Z
zhang wenhui 已提交
155
        place = paddle.CPUPlace()
156
    elif not args.use_data_parallel:
Z
zhang wenhui 已提交
157
        place = paddle.CUDAPlace(0)
158
    else:
Z
zhang wenhui 已提交
159
        place = paddle.CUDAPlace(paddle.fluid.dygraph.parallel.Env().dev_id)
160

Z
zhang wenhui 已提交
161 162
    paddle.disable_static(place)
    mnist_infer = MNIST()
163
        # load checkpoint
Z
zhang wenhui 已提交
164 165 166
    model_dict, _ = paddle.fluid.load_dygraph("save_temp")
    mnist_infer.set_dict(model_dict)
    print("checkpoint loaded")
167 168

        # start evaluate mode
Z
zhang wenhui 已提交
169
    mnist_infer.eval()
170

Z
zhang wenhui 已提交
171 172 173 174 175 176
    def load_image(file):
        im = Image.open(file).convert('L')
        im = im.resize((28, 28), Image.ANTIALIAS)
        im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)
        im = im / 255.0 * 2.0 - 1.0
        return im
177

Z
zhang wenhui 已提交
178 179
    cur_dir = os.path.dirname(os.path.realpath(__file__))
    tensor_img = load_image(cur_dir + '/image/infer_3.png')
180

Z
zhang wenhui 已提交
181 182 183 184
    results = mnist_infer(paddle.to_tensor(data=tensor_img, dtype=None, place=None, stop_gradient=True))
    lab = np.argsort(results.numpy())
    print("Inference result of image/infer_3.png is: %d" % lab[0][-1])
    paddle.enable_static()
185 186 187


def train_mnist(args):
D
Divano 已提交
188
    epoch_num = args.epoch
189
    BATCH_SIZE = 64
M
minqiyang 已提交
190

191
    if not args.use_gpu:
Z
zhang wenhui 已提交
192
        place = paddle.CPUPlace()
193
    elif not args.use_data_parallel:
Z
zhang wenhui 已提交
194
        place = paddle.CUDAPlace(0)
195
    else:
Z
zhang wenhui 已提交
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
        place = paddle.CUDAPlace(paddle.fluid.dygraph.parallel.Env().dev_id)

    paddle.disable_static(place)
    if args.ce:
        print("ce mode")
        seed = 33
        np.random.seed(seed)
        paddle.static.default_startup_program().random_seed = seed
        paddle.static.default_main_program().random_seed = seed

    if args.use_data_parallel:
        strategy = paddle.fluid.dygraph.parallel.prepare_context()
    mnist = MNIST()
    adam = paddle.optimizer.Adam(
        learning_rate=0.001, parameters=mnist.parameters())
    if args.use_data_parallel:
        mnist = paddle.fluid.dygraph.parallel.DataParallel(mnist, strategy)

    train_reader = paddle.batch(
        reader_decorator(paddle.dataset.mnist.train()),
        batch_size=BATCH_SIZE,
        drop_last=True)
    if args.use_data_parallel:
        train_reader = paddle.fluid.contrib.reader.distributed_batch_reader(
            train_reader)

    test_reader = paddle.batch(
        reader_decorator(paddle.dataset.mnist.test()),
        batch_size=BATCH_SIZE,
        drop_last=True)

    train_loader = paddle.io.DataLoader.from_generator(capacity=10)
    train_loader.set_sample_list_generator(train_reader, places=place)

    test_loader = paddle.io.DataLoader.from_generator(capacity=10)
    test_loader.set_sample_list_generator(test_reader, places=place)

    for epoch in range(epoch_num):
        for batch_id, data in enumerate(train_loader()):
            img, label = data
            label.stop_gradient = True

            cost, acc = mnist(img, label)

            loss = paddle.fluid.layers.cross_entropy(cost, label)
            avg_loss = paddle.mean(loss)

            if args.use_data_parallel:
                avg_loss = mnist.scale_loss(avg_loss)
                avg_loss.backward()
                mnist.apply_collective_grads()
            else:
                avg_loss.backward()

            adam.minimize(avg_loss)
251
                # save checkpoint
Z
zhang wenhui 已提交
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
            mnist.clear_gradients()
            if batch_id % 100 == 0:
                print("Loss at epoch {} step {}: {:}".format(
                    epoch, batch_id, avg_loss.numpy()))

        mnist.eval()
        test_cost, test_acc = test_mnist(test_loader, mnist, BATCH_SIZE)
        mnist.train()
        if args.ce:
            print("kpis\ttest_acc\t%s" % test_acc)
            print("kpis\ttest_cost\t%s" % test_cost)
        print("Loss at epoch {} , Test avg_loss is: {}, acc is: {}".format(
            epoch, test_cost, test_acc))

    save_parameters = (not args.use_data_parallel) or (
        args.use_data_parallel and
        paddle.fluid.dygraph.parallel.Env().local_rank == 0)
    if save_parameters:
        paddle.fluid.save_dygraph(mnist.state_dict(), "save_temp")

        print("checkpoint saved")

        inference_mnist()
    paddle.enable_static()
M
minqiyang 已提交
276 277 278


if __name__ == '__main__':
279 280
    args = parse_args()
    train_mnist(args)