train.py 9.1 KB
Newer Older
R
ruri 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

R
root 已提交
15 16 17
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
R
ruri 已提交
18

19 20 21
import os
import time
import sys
22

R
ruri 已提交
23
import numpy as np
24
import paddle
25
import paddle.fluid as fluid
26
from paddle.fluid import profiler
R
ruri 已提交
27 28
import reader
from utils import *
29
import models
R
ruri 已提交
30 31
from build_model import create_model

32

R
ruri 已提交
33
def build_program(is_train, main_prog, startup_prog, args):
R
ruri 已提交
34 35
    """build program, and add grad op in program accroding to different mode

R
ruri 已提交
36 37
    Parameters:
        is_train: indicate train mode or test mode
R
ruri 已提交
38 39 40 41 42
        main_prog: main program
        startup_prog: strartup program
        args: arguments

    Returns : 
43 44
        train mode: [Loss, global_lr, data_loader]
        test mode: [Loss, data_loader]
R
ruri 已提交
45
    """
46 47 48
    if args.model.startswith('EfficientNet'):
        override_params = {"drop_connect_rate": args.drop_connect_rate}
        padding_type = args.padding_type
49
        use_se = args.use_se
R
ruri 已提交
50
        model = models.__dict__[args.model](is_test=not is_train,
51 52 53
                                            override_params=override_params,
                                            padding_type=padding_type,
                                            use_se=use_se)
54 55
    else:
        model = models.__dict__[args.model]()
R
ruri 已提交
56
    with fluid.program_guard(main_prog, startup_prog):
R
ruri 已提交
57 58 59
        if args.random_seed:
            main_prog.random_seed = args.random_seed
            startup_prog.random_seed = args.random_seed
R
ruri 已提交
60
        with fluid.unique_name.guard():
61
            data_loader, loss_out = create_model(model, args, is_train)
R
ruri 已提交
62
            # add backward op in program
R
ruri 已提交
63
            if is_train:
R
ruri 已提交
64 65 66 67
                optimizer = create_optimizer(args)
                avg_cost = loss_out[0]
                optimizer.minimize(avg_cost)
                #XXX: fetch learning rate now, better implement is required here. 
R
root 已提交
68
                global_lr = optimizer._global_learning_rate()
R
ruri 已提交
69 70
                global_lr.persistable = True
                loss_out.append(global_lr)
71
                if args.use_ema:
72 73 74 75
                    global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter(
                    )
                    ema = ExponentialMovingAverage(
                        args.ema_decay, thres_steps=global_steps)
76 77
                    ema.update()
                    loss_out.append(ema)
78
            loss_out.append(data_loader)
R
ruri 已提交
79
    return loss_out
R
ruri 已提交
80

R
ruri 已提交
81

82
def validate(args, test_iter, exe, test_prog, test_fetch_list, pass_id,
83
             train_batch_metrics_record):
84 85 86
    test_batch_time_record = []
    test_batch_metrics_record = []
    test_batch_id = 0
87 88 89 90 91 92 93 94 95
    for batch in test_iter:
        t1 = time.time()
        test_batch_metrics = exe.run(program=test_prog,
                                     feed=batch,
                                     fetch_list=test_fetch_list)
        t2 = time.time()
        test_batch_elapse = t2 - t1
        test_batch_time_record.append(test_batch_elapse)

R
ruri 已提交
96
        test_batch_metrics_avg = np.mean(np.array(test_batch_metrics), axis=1)
97 98 99 100 101 102
        test_batch_metrics_record.append(test_batch_metrics_avg)

        print_info(pass_id, test_batch_id, args.print_step,
                   test_batch_metrics_avg, test_batch_elapse, "batch")
        sys.stdout.flush()
        test_batch_id += 1
103 104 105 106 107 108 109 110 111 112 113 114

    #train_epoch_time_avg = np.mean(np.array(train_batch_time_record))
    train_epoch_metrics_avg = np.mean(
        np.array(train_batch_metrics_record), axis=0)

    test_epoch_time_avg = np.mean(np.array(test_batch_time_record))
    test_epoch_metrics_avg = np.mean(
        np.array(test_batch_metrics_record), axis=0)

    print_info(pass_id, 0, 0,
               list(train_epoch_metrics_avg) + list(test_epoch_metrics_avg),
               test_epoch_time_avg, "epoch")
R
ruri 已提交
115

116

R
ruri 已提交
117
def train(args):
R
ruri 已提交
118 119 120 121 122
    """Train model
    
    Args:
        args: all arguments.    
    """
R
ruri 已提交
123 124 125
    startup_prog = fluid.Program()
    train_prog = fluid.Program()
    test_prog = fluid.Program()
R
ruri 已提交
126 127 128 129 130 131

    train_out = build_program(
        is_train=True,
        main_prog=train_prog,
        startup_prog=startup_prog,
        args=args)
132
    train_data_loader = train_out[-1]
133 134 135 136 137
    if args.use_ema:
        train_fetch_vars = train_out[:-2]
        ema = train_out[-2]
    else:
        train_fetch_vars = train_out[:-1]
138 139

    train_fetch_list = [var.name for var in train_fetch_vars]
R
ruri 已提交
140 141 142 143 144 145

    test_out = build_program(
        is_train=False,
        main_prog=test_prog,
        startup_prog=startup_prog,
        args=args)
146
    test_data_loader = test_out[-1]
R
ruri 已提交
147
    test_fetch_vars = test_out[:-1]
148 149

    test_fetch_list = [var.name for var in test_fetch_vars]
R
ruri 已提交
150 151

    #Create test_prog and set layers' is_test params to True
R
ruri 已提交
152
    test_prog = test_prog.clone(for_test=True)
153

154 155
    gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
    place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace()
156
    exe = fluid.Executor(place)
R
ruri 已提交
157
    exe.run(startup_prog)
158

159 160
    trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))

R
ruri 已提交
161 162
    #init model by checkpoint or pretrianed model.
    init_model(exe, args, train_prog)
163
    num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
164 165 166 167 168 169 170 171 172 173 174 175 176 177
    if args.use_dali:
        import dali
        train_iter = dali.train(settings=args)
        if trainer_id == 0:
            test_iter = dali.val(settings=args)
    else:
        imagenet_reader = reader.ImageNetReader(0 if num_trainers > 1 else None)
        train_reader = imagenet_reader.train(settings=args)
        test_reader = imagenet_reader.val(settings=args)
        places = place
        if num_trainers <= 1 and args.use_gpu:
            places = fluid.framework.cuda_places()
        train_data_loader.set_sample_list_generator(train_reader, places)
        test_data_loader.set_sample_list_generator(test_reader, place)
R
ruri 已提交
178 179

    compiled_train_prog = best_strategy_compiled(args, train_prog,
180
                                                 train_fetch_vars[0], exe)
181 182
    #NOTE: this for benchmark
    total_batch_num = 0
R
ruri 已提交
183
    for pass_id in range(args.num_epochs):
184
        if num_trainers > 1 and not args.use_dali:
R
ruri 已提交
185 186
            imagenet_reader.set_shuffle_seed(pass_id + (
                args.random_seed if args.random_seed else 0))
R
ruri 已提交
187 188 189
        train_batch_id = 0
        train_batch_time_record = []
        train_batch_metrics_record = []
R
ruri 已提交
190

191 192 193 194 195 196
        if not args.use_dali:
            train_iter = train_data_loader()
            test_iter = test_data_loader()

        t1 = time.time()
        for batch in train_iter:
197 198 199
            #NOTE: this is for benchmark
            if args.max_iter and total_batch_num == args.max_iter:
                return
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
            train_batch_metrics = exe.run(compiled_train_prog,
                                          feed=batch,
                                          fetch_list=train_fetch_list)
            t2 = time.time()
            train_batch_elapse = t2 - t1
            train_batch_time_record.append(train_batch_elapse)
            train_batch_metrics_avg = np.mean(
                np.array(train_batch_metrics), axis=1)
            train_batch_metrics_record.append(train_batch_metrics_avg)
            if trainer_id == 0:
                print_info(pass_id, train_batch_id, args.print_step,
                           train_batch_metrics_avg, train_batch_elapse, "batch")
                sys.stdout.flush()
            train_batch_id += 1
            t1 = time.time()
215 216 217 218 219 220 221
            #NOTE: this for benchmark profiler
            total_batch_num = total_batch_num + 1
            if args.is_profiler and pass_id == 0 and train_batch_id == args.print_step:
                profiler.start_profiler("All")
            elif args.is_profiler and pass_id == 0 and train_batch_id == args.print_step + 5:
                profiler.stop_profiler("total", args.profiler_path)
                return
222 223 224

        if args.use_dali:
            train_iter.reset()
225

226
        if trainer_id == 0 and args.validate:
227 228 229
            if args.use_ema:
                print('ExponentialMovingAverage validate start...')
                with ema.apply(exe):
R
ruri 已提交
230 231
                    validate(args, test_iter, exe, test_prog, test_fetch_list,
                             pass_id, train_batch_metrics_record)
232
                print('ExponentialMovingAverage validate over!')
R
ruri 已提交
233

R
ruri 已提交
234 235
            validate(args, test_iter, exe, test_prog, test_fetch_list, pass_id,
                     train_batch_metrics_record)
236 237 238
            #For now, save model per epoch.
            if pass_id % args.save_step == 0:
                save_model(args, exe, train_prog, pass_id)
239

240 241
            if args.use_dali:
                test_iter.reset()
242

R
ruri 已提交
243

244
def main():
R
ruri 已提交
245
    args = parse_args()
246 247
    if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
        print_arguments(args)
R
ruri 已提交
248
    check_args(args)
249
    train(args)
250

251 252 253

if __name__ == '__main__':
    main()