train.py 10.6 KB
Newer Older
R
ruri 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

R
root 已提交
15 16 17
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
R
ruri 已提交
18

19 20 21
import os
import time
import sys
22
import logging
23

R
ruri 已提交
24
import numpy as np
25
import paddle
26
import paddle.fluid as fluid
27
from paddle.fluid import profiler
R
ruri 已提交
28 29
import reader
from utils import *
30
import models
R
ruri 已提交
31 32
from build_model import create_model

33 34 35
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

36

R
ruri 已提交
37
def build_program(is_train, main_prog, startup_prog, args):
38
    """build program, and add backward op in program accroding to different mode
R
ruri 已提交
39

R
ruri 已提交
40 41
    Parameters:
        is_train: indicate train mode or test mode
R
ruri 已提交
42 43 44 45 46
        main_prog: main program
        startup_prog: strartup program
        args: arguments

    Returns : 
47 48
        train mode: [Loss, global_lr, data_loader]
        test mode: [Loss, data_loader]
R
ruri 已提交
49
    """
50 51 52
    if args.model.startswith('EfficientNet'):
        override_params = {"drop_connect_rate": args.drop_connect_rate}
        padding_type = args.padding_type
53
        use_se = args.use_se
R
ruri 已提交
54
        model = models.__dict__[args.model](is_test=not is_train,
55 56 57
                                            override_params=override_params,
                                            padding_type=padding_type,
                                            use_se=use_se)
58 59
    else:
        model = models.__dict__[args.model]()
R
ruri 已提交
60
    with fluid.program_guard(main_prog, startup_prog):
R
ruri 已提交
61
        if args.random_seed or args.enable_ce:
R
ruri 已提交
62 63
            main_prog.random_seed = args.random_seed
            startup_prog.random_seed = args.random_seed
R
ruri 已提交
64
        with fluid.unique_name.guard():
65
            data_loader, loss_out = create_model(model, args, is_train)
R
ruri 已提交
66
            # add backward op in program
R
ruri 已提交
67
            if is_train:
R
ruri 已提交
68 69 70
                optimizer = create_optimizer(args)
                avg_cost = loss_out[0]
                #XXX: fetch learning rate now, better implement is required here. 
R
root 已提交
71
                global_lr = optimizer._global_learning_rate()
R
ruri 已提交
72 73
                global_lr.persistable = True
                loss_out.append(global_lr)
R
ruri 已提交
74 75 76 77 78 79 80 81

                if args.use_fp16:
                    optimizer = fluid.contrib.mixed_precision.decorate(
                        optimizer,
                        init_loss_scaling=args.scale_loss,
                        use_dynamic_loss_scaling=args.use_dynamic_loss_scaling)

                optimizer.minimize(avg_cost)
82
                if args.use_ema:
83 84 85 86
                    global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter(
                    )
                    ema = ExponentialMovingAverage(
                        args.ema_decay, thres_steps=global_steps)
87 88
                    ema.update()
                    loss_out.append(ema)
89
            loss_out.append(data_loader)
R
ruri 已提交
90
    return loss_out
R
ruri 已提交
91

R
ruri 已提交
92

R
ruri 已提交
93 94 95 96 97 98 99
def validate(args,
             test_iter,
             exe,
             test_prog,
             test_fetch_list,
             pass_id,
             train_batch_metrics_record,
100 101
             train_batch_time_record=None,
             train_prog=None):
102 103 104
    test_batch_time_record = []
    test_batch_metrics_record = []
    test_batch_id = 0
R
ruri 已提交
105 106 107 108 109 110 111 112 113 114
    if int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) > 1:
        compiled_program = test_prog
    else:
        compiled_program = best_strategy_compiled(
            args,
            test_prog,
            test_fetch_list[0],
            exe,
            mode="val",
            share_prog=train_prog)
115 116
    for batch in test_iter:
        t1 = time.time()
117
        test_batch_metrics = exe.run(program=compiled_program,
118 119 120 121 122 123
                                     feed=batch,
                                     fetch_list=test_fetch_list)
        t2 = time.time()
        test_batch_elapse = t2 - t1
        test_batch_time_record.append(test_batch_elapse)

R
ruri 已提交
124
        test_batch_metrics_avg = np.mean(np.array(test_batch_metrics), axis=1)
125 126
        test_batch_metrics_record.append(test_batch_metrics_avg)

R
ruri 已提交
127
        print_info("batch", test_batch_metrics_avg, test_batch_elapse, pass_id,
128
                   test_batch_id, args.print_step, args.class_dim)
129 130
        sys.stdout.flush()
        test_batch_id += 1
131 132 133 134 135 136 137 138

    train_epoch_metrics_avg = np.mean(
        np.array(train_batch_metrics_record), axis=0)

    test_epoch_time_avg = np.mean(np.array(test_batch_time_record))
    test_epoch_metrics_avg = np.mean(
        np.array(test_batch_metrics_record), axis=0)

R
ruri 已提交
139 140 141 142
    print_info(
        "epoch",
        list(train_epoch_metrics_avg) + list(test_epoch_metrics_avg),
        test_epoch_time_avg,
143 144
        pass_id=pass_id,
        class_dim=args.class_dim)
R
ruri 已提交
145 146 147 148 149 150 151
    if args.enable_ce:
        device_num = fluid.core.get_cuda_device_count() if args.use_gpu else 1
        print_info(
            "ce",
            list(train_epoch_metrics_avg) + list(test_epoch_metrics_avg),
            train_batch_time_record,
            device_num=device_num)
R
ruri 已提交
152

153

R
ruri 已提交
154
def train(args):
R
ruri 已提交
155 156 157 158 159
    """Train model
    
    Args:
        args: all arguments.    
    """
R
ruri 已提交
160 161
    startup_prog = fluid.Program()
    train_prog = fluid.Program()
R
ruri 已提交
162 163 164 165 166
    train_out = build_program(
        is_train=True,
        main_prog=train_prog,
        startup_prog=startup_prog,
        args=args)
167
    train_data_loader = train_out[-1]
168 169 170 171 172
    if args.use_ema:
        train_fetch_vars = train_out[:-2]
        ema = train_out[-2]
    else:
        train_fetch_vars = train_out[:-1]
173 174

    train_fetch_list = [var.name for var in train_fetch_vars]
R
ruri 已提交
175

176 177 178 179 180 181 182 183 184
    if args.validate:
        test_prog = fluid.Program()
        test_out = build_program(
            is_train=False,
            main_prog=test_prog,
            startup_prog=startup_prog,
            args=args)
        test_data_loader = test_out[-1]
        test_fetch_vars = test_out[:-1]
185

186
        test_fetch_list = [var.name for var in test_fetch_vars]
R
ruri 已提交
187

188 189
        #Create test_prog and set layers' is_test params to True
        test_prog = test_prog.clone(for_test=True)
190

191 192
    gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
    place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace()
193
    exe = fluid.Executor(place)
R
ruri 已提交
194
    exe.run(startup_prog)
195

196 197
    trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))

R
ruri 已提交
198 199
    #init model by checkpoint or pretrianed model.
    init_model(exe, args, train_prog)
200
    num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
201 202 203 204 205 206 207 208
    if args.use_dali:
        import dali
        train_iter = dali.train(settings=args)
        if trainer_id == 0:
            test_iter = dali.val(settings=args)
    else:
        imagenet_reader = reader.ImageNetReader(0 if num_trainers > 1 else None)
        train_reader = imagenet_reader.train(settings=args)
R
ruri 已提交
209 210 211 212 213 214 215 216 217 218 219
        if args.use_gpu:
            if num_trainers <= 1:
                places = fluid.framework.cuda_places()
            else:
                places = place
        else:
            if num_trainers <= 1:
                places = fluid.framework.cpu_places()
            else:
                places = place

220
        train_data_loader.set_sample_list_generator(train_reader, places)
221 222 223 224

        if args.validate:
            test_reader = imagenet_reader.val(settings=args)
            test_data_loader.set_sample_list_generator(test_reader, places)
R
ruri 已提交
225 226

    compiled_train_prog = best_strategy_compiled(args, train_prog,
227
                                                 train_fetch_vars[0], exe)
228 229
    #NOTE: this for benchmark
    total_batch_num = 0
R
ruri 已提交
230
    for pass_id in range(args.num_epochs):
231
        if num_trainers > 1 and not args.use_dali:
R
ruri 已提交
232 233
            imagenet_reader.set_shuffle_seed(pass_id + (
                args.random_seed if args.random_seed else 0))
R
ruri 已提交
234 235 236
        train_batch_id = 0
        train_batch_time_record = []
        train_batch_metrics_record = []
R
ruri 已提交
237

238 239
        if not args.use_dali:
            train_iter = train_data_loader()
240 241
            if args.validate:
                test_iter = test_data_loader()
242 243 244

        t1 = time.time()
        for batch in train_iter:
245 246 247
            #NOTE: this is for benchmark
            if args.max_iter and total_batch_num == args.max_iter:
                return
248 249
            train_batch_metrics = exe.run(compiled_train_prog,
                                          feed=batch,
R
ruri 已提交
250
                                          fetch_list=train_fetch_list)
251 252 253
            t2 = time.time()
            train_batch_elapse = t2 - t1
            train_batch_time_record.append(train_batch_elapse)
R
ruri 已提交
254 255 256 257

            train_batch_metrics_avg = np.mean(
                np.array(train_batch_metrics), axis=1)
            train_batch_metrics_record.append(train_batch_metrics_avg)
258
            if trainer_id == 0:
R
ruri 已提交
259 260
                print_info("batch", train_batch_metrics_avg, train_batch_elapse,
                           pass_id, train_batch_id, args.print_step)
261 262 263
                sys.stdout.flush()
            train_batch_id += 1
            t1 = time.time()
264 265 266 267 268 269 270
            #NOTE: this for benchmark profiler
            total_batch_num = total_batch_num + 1
            if args.is_profiler and pass_id == 0 and train_batch_id == args.print_step:
                profiler.start_profiler("All")
            elif args.is_profiler and pass_id == 0 and train_batch_id == args.print_step + 5:
                profiler.stop_profiler("total", args.profiler_path)
                return
271 272 273

        if args.use_dali:
            train_iter.reset()
274

275
        if trainer_id == 0 and args.validate:
276
            if args.use_ema:
277
                logger.info('ExponentialMovingAverage validate start...')
278
                with ema.apply(exe):
R
ruri 已提交
279
                    validate(args, test_iter, exe, test_prog, test_fetch_list,
280 281
                             pass_id, train_batch_metrics_record,
                             compiled_train_prog)
282
                logger.info('ExponentialMovingAverage validate over!')
R
ruri 已提交
283

R
ruri 已提交
284
            validate(args, test_iter, exe, test_prog, test_fetch_list, pass_id,
285 286
                     train_batch_metrics_record, train_batch_time_record,
                     compiled_train_prog)
287

288 289
            if args.use_dali:
                test_iter.reset()
290

291 292 293
        if pass_id % args.save_step == 0:
            save_model(args, exe, train_prog, pass_id)

R
ruri 已提交
294

295
def main():
R
ruri 已提交
296
    args = parse_args()
297 298
    if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
        print_arguments(args)
R
ruri 已提交
299
    check_args(args)
300
    train(args)
301

302 303 304

if __name__ == '__main__':
    main()