train.py 9.0 KB
Newer Older
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
#order: standard library, third party, local library 
16 17 18
import os
import time
import sys
19
import math
20
import argparse
21
import numpy as np
22 23 24 25 26 27 28 29 30
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid import framework
import reader
from utils import *
31 32
from mobilenet_v1 import *
from mobilenet_v2 import *
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47

args = parse_args()
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
    print_arguments(args)


def eval(net, test_data_loader, eop):
    total_loss = 0.0
    total_acc1 = 0.0
    total_acc5 = 0.0
    total_sample = 0
    t_last = 0
    for img, label in test_data_loader():
        t1 = time.time()
        label = to_variable(label.numpy().astype('int64').reshape(
48
            int(args.batch_size // paddle.fluid.core.get_cuda_device_count()),
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
            1))
        out = net(img)
        softmax_out = fluid.layers.softmax(out, use_cudnn=False)
        loss = fluid.layers.cross_entropy(input=softmax_out, label=label)
        avg_loss = fluid.layers.mean(x=loss)
        acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
        acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
        t2 = time.time()
        print( "test | epoch id: %d, avg_loss %0.5f acc_top1 %0.5f acc_top5 %0.5f %2.4f sec read_t:%2.4f" % \
                (eop, avg_loss.numpy(), acc_top1.numpy(), acc_top5.numpy(), t2 - t1 , t1 - t_last))
        sys.stdout.flush()
        total_loss += avg_loss.numpy()
        total_acc1 += acc_top1.numpy()
        total_acc5 += acc_top5.numpy()
        total_sample += 1
        t_last = time.time()
    print("final eval loss %0.3f acc1 %0.3f acc5 %0.3f" % \
          (total_loss / total_sample, \
           total_acc1 / total_sample, total_acc5 / total_sample))
    sys.stdout.flush()


def train_mobilenet():
72 73 74 75 76 77
    if not args.use_gpu:
        place = fluid.CPUPlace()
    elif not args.use_data_parallel:
        place = fluid.CUDAPlace(0)
    else:
        place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
78
    with fluid.dygraph.guard(place):
79
        # 1. init net and optimizer
80 81 82 83 84 85 86 87 88 89
        if args.ce:
            print("ce mode")
            seed = 33
            np.random.seed(seed)
            fluid.default_startup_program().random_seed = seed
            fluid.default_main_program().random_seed = seed
        if args.use_data_parallel:
            strategy = fluid.dygraph.parallel.prepare_context()

        if args.model == "MobileNetV1":
90 91
            net = MobileNetV1(class_dim=args.class_dim, scale=1.0)
            model_path_pre = 'mobilenet_v1'
92
        elif args.model == "MobileNetV2":
C
chajchaj 已提交
93
            net = MobileNetV2(class_dim=args.class_dim, scale=1.0)
94
            model_path_pre = 'mobilenet_v2'
95 96 97 98 99 100
        else:
            print(
                "wrong model name, please try model = MobileNetV1 or MobileNetV2"
            )
            exit()

C
chajchaj 已提交
101
        optimizer = create_optimizer(args=args, parameter_list=net.parameters())
102 103
        if args.use_data_parallel:
            net = fluid.dygraph.parallel.DataParallel(net, strategy)
104 105 106 107 108 109 110 111 112 113 114 115

        # 2. load checkpoint
        if args.checkpoint:
            assert os.path.exists(args.checkpoint + ".pdparams"), \
                "Given dir {}.pdparams not exist.".format(args.checkpoint)
            assert os.path.exists(args.checkpoint + ".pdopt"), \
                "Given dir {}.pdopt not exist.".format(args.checkpoint)
            para_dict, opti_dict = fluid.dygraph.load_dygraph(args.checkpoint)
            net.set_dict(para_dict)
            optimizer.set_dict(opti_dict)

        # 3. reader
116 117 118 119 120 121 122 123 124 125
        train_data_loader, train_data = utility.create_data_loader(
            is_train=True, args=args)
        test_data_loader, test_data = utility.create_data_loader(
            is_train=False, args=args)
        num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
        imagenet_reader = reader.ImageNetReader(0)
        train_reader = imagenet_reader.train(settings=args)
        test_reader = imagenet_reader.val(settings=args)
        train_data_loader.set_sample_list_generator(train_reader, place)
        test_data_loader.set_sample_list_generator(test_reader, place)
126 127 128

        # 4. train loop
        for eop in range(args.num_epochs):
129 130 131 132 133 134 135 136 137 138
            if num_trainers > 1:
                imagenet_reader.set_shuffle_seed(eop + (
                    args.random_seed if args.random_seed else 0))
            net.train()
            total_loss = 0.0
            total_acc1 = 0.0
            total_acc5 = 0.0
            total_sample = 0
            batch_id = 0
            t_last = 0
139
            # 4.1 for each batch, call net() , backward(), and minimize()
140 141 142
            for img, label in train_data_loader():
                t1 = time.time()
                label = to_variable(label.numpy().astype('int64').reshape(
143
                    int(args.batch_size //
144 145
                        paddle.fluid.core.get_cuda_device_count()), 1))
                t_start = time.time()
146 147

                # 4.1.1 call net()
148
                out = net(img)
149

150 151 152 153 154 155 156 157
                t_end = time.time()
                softmax_out = fluid.layers.softmax(out, use_cudnn=False)
                loss = fluid.layers.cross_entropy(
                    input=softmax_out, label=label)
                avg_loss = fluid.layers.mean(x=loss)
                acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
                acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
                t_start_back = time.time()
158 159

                # 4.1.2 call backward()
160 161 162 163 164 165
                if args.use_data_parallel:
                    avg_loss = net.scale_loss(avg_loss)
                    avg_loss.backward()
                    net.apply_collective_grads()
                else:
                    avg_loss.backward()
166

167
                t_end_back = time.time()
168 169

                # 4.1.3 call minimize()
170
                optimizer.minimize(avg_loss)
171

172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
                net.clear_gradients()
                t2 = time.time()
                train_batch_elapse = t2 - t1
                if batch_id % args.print_step == 0:
                    print( "epoch id: %d, batch step: %d,  avg_loss %0.5f acc_top1 %0.5f acc_top5 %0.5f %2.4f sec net_t:%2.4f back_t:%2.4f read_t:%2.4f" % \
                            (eop, batch_id, avg_loss.numpy(), acc_top1.numpy(), acc_top5.numpy(), train_batch_elapse,
                              t_end - t_start, t_end_back - t_start_back,  t1 - t_last))
                    sys.stdout.flush()
                total_loss += avg_loss.numpy()
                total_acc1 += acc_top1.numpy()
                total_acc5 += acc_top5.numpy()
                total_sample += 1
                batch_id += 1
                t_last = time.time()
            if args.ce:
                print("kpis\ttrain_acc1\t%0.3f" % (total_acc1 / total_sample))
                print("kpis\ttrain_acc5\t%0.3f" % (total_acc5 / total_sample))
                print("kpis\ttrain_loss\t%0.3f" % (total_loss / total_sample))
            print("epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f %2.4f sec" % \
                  (eop, batch_id, total_loss / total_sample, \
                   total_acc1 / total_sample, total_acc5 / total_sample, train_batch_elapse))
193 194
   
            # 4.2 save checkpoint
195 196 197 198
            save_parameters = (not args.use_data_parallel) or (
                args.use_data_parallel and
                fluid.dygraph.parallel.Env().local_rank == 0)
            if save_parameters:
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
                if not os.path.isdir(args.model_save_dir):
                    os.makedirs(args.model_save_dir)
                model_path = os.path.join(
                    args.model_save_dir, "_" + model_path_pre + "_epoch{}".format(eop))
                fluid.dygraph.save_dygraph(net.state_dict(), model_path)
                fluid.dygraph.save_dygraph(optimizer.state_dict(), model_path)

            # 4.3 validation
            net.eval()
            eval(net, test_data_loader, eop)

        # 5. save final results
        save_parameters = (not args.use_data_parallel) or (
            args.use_data_parallel and
            fluid.dygraph.parallel.Env().local_rank == 0)
        if save_parameters:
            model_path = os.path.join(
                args.model_save_dir, "_" + model_path_pre + "_final")
            fluid.dygraph.save_dygraph(net.state_dict(), model_path)
218 219 220 221


if __name__ == '__main__':
    train_mobilenet()