train.py 12.0 KB
Newer Older
R
ruri 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

R
root 已提交
15 16 17
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
R
ruri 已提交
18

19 20 21
import os
import time
import sys
22
import logging
23

R
ruri 已提交
24
import numpy as np
25
import paddle
26
import paddle.fluid as fluid
27
from paddle.fluid import profiler
R
ruri 已提交
28 29
import reader
from utils import *
30
import models
R
ruri 已提交
31 32
from build_model import create_model

33 34 35
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

36

37
class TimeAverager(object):
W
wanghuancoder 已提交
38 39 40 41
    def __init__(self):
        self.reset()

    def reset(self):
42 43
        self._cnt = 0
        self._total_time = 0
W
wanghuancoder 已提交
44 45

    def record(self, usetime):
46 47
        self._cnt += 1
        self._total_time += usetime
W
wanghuancoder 已提交
48 49

    def get_average(self):
50
        if self._cnt == 0:
W
wanghuancoder 已提交
51
            return 0
52
        return self._total_time / self._cnt
W
wanghuancoder 已提交
53 54


R
ruri 已提交
55
def build_program(is_train, main_prog, startup_prog, args):
56
    """build program, and add backward op in program accroding to different mode
R
ruri 已提交
57

R
ruri 已提交
58 59
    Parameters:
        is_train: indicate train mode or test mode
R
ruri 已提交
60 61 62 63
        main_prog: main program
        startup_prog: strartup program
        args: arguments

64
    Returns :
65 66
        train mode: [Loss, global_lr, data_loader]
        test mode: [Loss, data_loader]
R
ruri 已提交
67
    """
68 69 70
    if args.model.startswith('EfficientNet'):
        override_params = {"drop_connect_rate": args.drop_connect_rate}
        padding_type = args.padding_type
71
        use_se = args.use_se
R
ruri 已提交
72
        model = models.__dict__[args.model](is_test=not is_train,
73 74 75
                                            override_params=override_params,
                                            padding_type=padding_type,
                                            use_se=use_se)
76 77
    else:
        model = models.__dict__[args.model]()
78
    optimizer = None
R
ruri 已提交
79
    with fluid.program_guard(main_prog, startup_prog):
R
ruri 已提交
80
        if args.random_seed or args.enable_ce:
R
ruri 已提交
81 82
            main_prog.random_seed = args.random_seed
            startup_prog.random_seed = args.random_seed
R
ruri 已提交
83
        with fluid.unique_name.guard():
84
            data_loader, loss_out = create_model(model, args, is_train)
R
ruri 已提交
85
            # add backward op in program
R
ruri 已提交
86
            if is_train:
R
ruri 已提交
87 88
                optimizer = create_optimizer(args)
                avg_cost = loss_out[0]
89
                #XXX: fetch learning rate now, better implement is required here.
R
root 已提交
90
                global_lr = optimizer._global_learning_rate()
R
ruri 已提交
91 92
                global_lr.persistable = True
                loss_out.append(global_lr)
R
ruri 已提交
93

94
                if args.use_amp:
95
                    optimizer = paddle.static.amp.decorate(
R
ruri 已提交
96 97
                        optimizer,
                        init_loss_scaling=args.scale_loss,
98 99 100
                        use_dynamic_loss_scaling=args.use_dynamic_loss_scaling,
                        use_pure_fp16=args.use_pure_fp16,
                        use_fp16_guard=True)
R
ruri 已提交
101 102

                optimizer.minimize(avg_cost)
103
                if args.use_ema:
104 105 106 107
                    global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter(
                    )
                    ema = ExponentialMovingAverage(
                        args.ema_decay, thres_steps=global_steps)
108 109
                    ema.update()
                    loss_out.append(ema)
110
            loss_out.append(data_loader)
111
    return loss_out, optimizer
R
ruri 已提交
112

R
ruri 已提交
113

R
ruri 已提交
114 115 116 117 118 119 120
def validate(args,
             test_iter,
             exe,
             test_prog,
             test_fetch_list,
             pass_id,
             train_batch_metrics_record,
121 122
             train_batch_time_record=None,
             train_prog=None):
123 124 125
    test_batch_time_record = []
    test_batch_metrics_record = []
    test_batch_id = 0
126

R
ruri 已提交
127 128 129 130 131 132 133 134 135 136
    if int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) > 1:
        compiled_program = test_prog
    else:
        compiled_program = best_strategy_compiled(
            args,
            test_prog,
            test_fetch_list[0],
            exe,
            mode="val",
            share_prog=train_prog)
137 138
    for batch in test_iter:
        t1 = time.time()
139
        test_batch_metrics = exe.run(program=compiled_program,
140 141 142 143 144 145
                                     feed=batch,
                                     fetch_list=test_fetch_list)
        t2 = time.time()
        test_batch_elapse = t2 - t1
        test_batch_time_record.append(test_batch_elapse)

R
ruri 已提交
146
        test_batch_metrics_avg = np.mean(np.array(test_batch_metrics), axis=1)
147 148
        test_batch_metrics_record.append(test_batch_metrics_avg)

R
ruri 已提交
149
        print_info("batch", test_batch_metrics_avg, test_batch_elapse, pass_id,
150
                   test_batch_id, args.print_step, args.class_dim)
151 152
        sys.stdout.flush()
        test_batch_id += 1
153 154 155 156 157 158 159 160

    train_epoch_metrics_avg = np.mean(
        np.array(train_batch_metrics_record), axis=0)

    test_epoch_time_avg = np.mean(np.array(test_batch_time_record))
    test_epoch_metrics_avg = np.mean(
        np.array(test_batch_metrics_record), axis=0)

R
ruri 已提交
161 162 163 164
    print_info(
        "epoch",
        list(train_epoch_metrics_avg) + list(test_epoch_metrics_avg),
        test_epoch_time_avg,
165 166
        pass_id=pass_id,
        class_dim=args.class_dim)
R
ruri 已提交
167 168 169 170 171 172 173
    if args.enable_ce:
        device_num = fluid.core.get_cuda_device_count() if args.use_gpu else 1
        print_info(
            "ce",
            list(train_epoch_metrics_avg) + list(test_epoch_metrics_avg),
            train_batch_time_record,
            device_num=device_num)
R
ruri 已提交
174

175

R
ruri 已提交
176
def train(args):
R
ruri 已提交
177
    """Train model
178

R
ruri 已提交
179
    Args:
180
        args: all arguments.
R
ruri 已提交
181
    """
R
ruri 已提交
182 183
    startup_prog = fluid.Program()
    train_prog = fluid.Program()
184
    train_out, optimizer = build_program(
R
ruri 已提交
185 186 187 188
        is_train=True,
        main_prog=train_prog,
        startup_prog=startup_prog,
        args=args)
189
    train_data_loader = train_out[-1]
190 191 192 193 194
    if args.use_ema:
        train_fetch_vars = train_out[:-2]
        ema = train_out[-2]
    else:
        train_fetch_vars = train_out[:-1]
195 196

    train_fetch_list = [var.name for var in train_fetch_vars]
R
ruri 已提交
197

198 199
    if args.validate:
        test_prog = fluid.Program()
200
        test_out, _ = build_program(
201 202 203 204 205 206
            is_train=False,
            main_prog=test_prog,
            startup_prog=startup_prog,
            args=args)
        test_data_loader = test_out[-1]
        test_fetch_vars = test_out[:-1]
207

208
        test_fetch_list = [var.name for var in test_fetch_vars]
R
ruri 已提交
209

210 211
        #Create test_prog and set layers' is_test params to True
        test_prog = test_prog.clone(for_test=True)
212

213 214
    gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
    place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace()
215
    exe = fluid.Executor(place)
R
ruri 已提交
216
    exe.run(startup_prog)
217

218 219
    trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))

R
ruri 已提交
220 221
    #init model by checkpoint or pretrianed model.
    init_model(exe, args, train_prog)
222 223 224 225 226 227

    if args.use_amp:
        optimizer.amp_init(place,
                scope=paddle.static.global_scope(),
                test_program=test_prog if args.validate else None)

228
    num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
229 230 231 232 233 234 235 236
    if args.use_dali:
        import dali
        train_iter = dali.train(settings=args)
        if trainer_id == 0:
            test_iter = dali.val(settings=args)
    else:
        imagenet_reader = reader.ImageNetReader(0 if num_trainers > 1 else None)
        train_reader = imagenet_reader.train(settings=args)
R
ruri 已提交
237 238 239 240 241 242 243 244 245 246 247
        if args.use_gpu:
            if num_trainers <= 1:
                places = fluid.framework.cuda_places()
            else:
                places = place
        else:
            if num_trainers <= 1:
                places = fluid.framework.cpu_places()
            else:
                places = place

248
        train_data_loader.set_sample_list_generator(train_reader, places)
249 250 251 252

        if args.validate:
            test_reader = imagenet_reader.val(settings=args)
            test_data_loader.set_sample_list_generator(test_reader, places)
R
ruri 已提交
253 254

    compiled_train_prog = best_strategy_compiled(args, train_prog,
255
                                                 train_fetch_vars[0], exe)
W
wanghuancoder 已提交
256

257 258
    #NOTE: this for benchmark
    total_batch_num = 0
259 260
    batch_cost_averager = TimeAverager()
    reader_cost_averager = TimeAverager()
R
ruri 已提交
261
    for pass_id in range(args.num_epochs):
262
        if num_trainers > 1 and not args.use_dali:
R
ruri 已提交
263 264
            imagenet_reader.set_shuffle_seed(pass_id + (
                args.random_seed if args.random_seed else 0))
265

R
ruri 已提交
266 267 268
        train_batch_id = 0
        train_batch_time_record = []
        train_batch_metrics_record = []
R
ruri 已提交
269

270 271
        if not args.use_dali:
            train_iter = train_data_loader()
272 273
            if args.validate:
                test_iter = test_data_loader()
274

275
        batch_start = time.time()
276
        for batch in train_iter:
277 278 279
            #NOTE: this is for benchmark
            if args.max_iter and total_batch_num == args.max_iter:
                return
280

281 282
            reader_cost_averager.record(time.time() - batch_start)

283 284
            train_batch_metrics = exe.run(compiled_train_prog,
                                          feed=batch,
R
ruri 已提交
285 286 287 288 289
                                          fetch_list=train_fetch_list)

            train_batch_metrics_avg = np.mean(
                np.array(train_batch_metrics), axis=1)
            train_batch_metrics_record.append(train_batch_metrics_avg)
290 291 292 293 294 295

            # Record the time for ce and benchmark
            train_batch_elapse = time.time() - batch_start
            train_batch_time_record.append(train_batch_elapse)
            batch_cost_averager.record(train_batch_elapse)

296
            if trainer_id == 0:
297
                ips = float(args.batch_size) / batch_cost_averager.get_average()
298 299 300
                print_info(
                    "batch",
                    train_batch_metrics_avg,
301
                    batch_cost_averager.get_average(),
302 303 304
                    pass_id,
                    train_batch_id,
                    args.print_step,
305 306
                    reader_cost=reader_cost_averager.get_average(),
                    ips=ips)
307
                sys.stdout.flush()
W
wanghuancoder 已提交
308
                if train_batch_id % args.print_step == 0:
309 310 311
                    batch_cost_averager.reset()
                    reader_cost_averager.reset()

312
            train_batch_id += 1
313
            total_batch_num = total_batch_num + 1
314 315 316
            batch_start = time.time()

            #NOTE: this for benchmark profiler
317 318 319 320 321
            if args.is_profiler and pass_id == 0 and train_batch_id == args.print_step:
                profiler.start_profiler("All")
            elif args.is_profiler and pass_id == 0 and train_batch_id == args.print_step + 5:
                profiler.stop_profiler("total", args.profiler_path)
                return
322 323 324

        if args.use_dali:
            train_iter.reset()
325

326
        if trainer_id == 0 and args.validate:
327
            if args.use_ema:
328
                logger.info('ExponentialMovingAverage validate start...')
329
                with ema.apply(exe):
R
ruri 已提交
330
                    validate(args, test_iter, exe, test_prog, test_fetch_list,
331 332
                             pass_id, train_batch_metrics_record,
                             compiled_train_prog)
333
                logger.info('ExponentialMovingAverage validate over!')
R
ruri 已提交
334

R
ruri 已提交
335
            validate(args, test_iter, exe, test_prog, test_fetch_list, pass_id,
336 337
                     train_batch_metrics_record, train_batch_time_record,
                     compiled_train_prog)
338

339 340
            if args.use_dali:
                test_iter.reset()
341

342
        if trainer_id == 0 and pass_id % args.save_step == 0:
343 344
            save_model(args, exe, train_prog, pass_id)

R
ruri 已提交
345

346
def main():
R
ruri 已提交
347
    args = parse_args()
348 349
    if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
        print_arguments(args)
R
ruri 已提交
350
    check_args(args)
351
    train(args)
352

353 354

if __name__ == '__main__':
L
Leo Chen 已提交
355 356
    import paddle
    paddle.enable_static()
357
    main()