train.py 9.2 KB
Newer Older
R
ruri 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

R
root 已提交
15 16 17
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
R
ruri 已提交
18

19 20 21 22
import os
import numpy as np
import time
import sys
23

24

25 26 27 28 29 30 31 32 33 34 35 36 37
def set_paddle_flags(flags):
    for key, value in flags.items():
        if os.environ.get(key, None) is None:
            os.environ[key] = str(value)


# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
# not take any effect. 
set_paddle_flags({
    'FLAGS_eager_delete_tensor_gb': 0,  # enable gc 
    'FLAGS_fraction_of_gpu_memory_to_use': 0.98
})
R
ruri 已提交
38

39
import paddle
40
import paddle.fluid as fluid
41
from paddle.fluid import profiler
R
ruri 已提交
42 43
import reader
from utils import *
44
import models
R
ruri 已提交
45 46
from build_model import create_model

47

R
ruri 已提交
48
def build_program(is_train, main_prog, startup_prog, args):
R
ruri 已提交
49 50 51 52 53 54 55 56 57
    """build program, and add grad op in program accroding to different mode

    Args:
        is_train: mode: train or test
        main_prog: main program
        startup_prog: strartup program
        args: arguments

    Returns : 
58 59
        train mode: [Loss, global_lr, data_loader]
        test mode: [Loss, data_loader]
R
ruri 已提交
60
    """
61 62 63 64
    if args.model.startswith('EfficientNet'):
        is_test = False if is_train else True
        override_params = {"drop_connect_rate": args.drop_connect_rate}
        padding_type = args.padding_type
65
        use_se = args.use_se
66 67 68 69
        model = models.__dict__[args.model](is_test=is_test,
                                            override_params=override_params,
                                            padding_type=padding_type,
                                            use_se=use_se)
70 71
    else:
        model = models.__dict__[args.model]()
R
ruri 已提交
72
    with fluid.program_guard(main_prog, startup_prog):
R
ruri 已提交
73 74 75
        if args.random_seed:
            main_prog.random_seed = args.random_seed
            startup_prog.random_seed = args.random_seed
R
ruri 已提交
76
        with fluid.unique_name.guard():
77
            data_loader, loss_out = create_model(model, args, is_train)
R
ruri 已提交
78
            # add backward op in program
R
ruri 已提交
79
            if is_train:
R
ruri 已提交
80 81 82 83
                optimizer = create_optimizer(args)
                avg_cost = loss_out[0]
                optimizer.minimize(avg_cost)
                #XXX: fetch learning rate now, better implement is required here. 
R
root 已提交
84
                global_lr = optimizer._global_learning_rate()
R
ruri 已提交
85 86
                global_lr.persistable = True
                loss_out.append(global_lr)
87
                if args.use_ema:
88 89 90 91
                    global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter(
                    )
                    ema = ExponentialMovingAverage(
                        args.ema_decay, thres_steps=global_steps)
92 93
                    ema.update()
                    loss_out.append(ema)
94
            loss_out.append(data_loader)
R
ruri 已提交
95
    return loss_out
R
ruri 已提交
96

97 98 99

def validate(args, test_data_loader, exe, test_prog, test_fetch_list, pass_id,
             train_batch_metrics_record):
100 101 102
    test_batch_time_record = []
    test_batch_metrics_record = []
    test_batch_id = 0
103
    test_data_loader.start()
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
    try:
        while True:
            t1 = time.time()
            test_batch_metrics = exe.run(program=test_prog,
                                         fetch_list=test_fetch_list)
            t2 = time.time()
            test_batch_elapse = t2 - t1
            test_batch_time_record.append(test_batch_elapse)

            test_batch_metrics_avg = np.mean(
                np.array(test_batch_metrics), axis=1)
            test_batch_metrics_record.append(test_batch_metrics_avg)

            print_info(pass_id, test_batch_id, args.print_step,
                       test_batch_metrics_avg, test_batch_elapse, "batch")
            sys.stdout.flush()
            test_batch_id += 1

    except fluid.core.EOFException:
123
        test_data_loader.reset()
124 125 126 127 128 129 130 131 132 133 134
    #train_epoch_time_avg = np.mean(np.array(train_batch_time_record))
    train_epoch_metrics_avg = np.mean(
        np.array(train_batch_metrics_record), axis=0)

    test_epoch_time_avg = np.mean(np.array(test_batch_time_record))
    test_epoch_metrics_avg = np.mean(
        np.array(test_batch_metrics_record), axis=0)

    print_info(pass_id, 0, 0,
               list(train_epoch_metrics_avg) + list(test_epoch_metrics_avg),
               test_epoch_time_avg, "epoch")
R
ruri 已提交
135

136

R
ruri 已提交
137
def train(args):
R
ruri 已提交
138 139 140 141 142
    """Train model
    
    Args:
        args: all arguments.    
    """
R
ruri 已提交
143 144 145
    startup_prog = fluid.Program()
    train_prog = fluid.Program()
    test_prog = fluid.Program()
R
ruri 已提交
146 147 148 149 150 151

    train_out = build_program(
        is_train=True,
        main_prog=train_prog,
        startup_prog=startup_prog,
        args=args)
152
    train_data_loader = train_out[-1]
153 154 155 156 157
    if args.use_ema:
        train_fetch_vars = train_out[:-2]
        ema = train_out[-2]
    else:
        train_fetch_vars = train_out[:-1]
158 159

    train_fetch_list = [var.name for var in train_fetch_vars]
R
ruri 已提交
160 161 162 163 164 165

    test_out = build_program(
        is_train=False,
        main_prog=test_prog,
        startup_prog=startup_prog,
        args=args)
166
    test_data_loader = test_out[-1]
R
ruri 已提交
167
    test_fetch_vars = test_out[:-1]
168 169

    test_fetch_list = [var.name for var in test_fetch_vars]
R
ruri 已提交
170 171

    #Create test_prog and set layers' is_test params to True
R
ruri 已提交
172
    test_prog = test_prog.clone(for_test=True)
173

174 175
    gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
    place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace()
176
    exe = fluid.Executor(place)
R
ruri 已提交
177
    exe.run(startup_prog)
178

R
ruri 已提交
179 180
    #init model by checkpoint or pretrianed model.
    init_model(exe, args, train_prog)
181 182 183 184
    num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
    imagenet_reader = reader.ImageNetReader(0 if num_trainers > 1 else None)
    train_reader = imagenet_reader.train(settings=args)
    test_reader = imagenet_reader.val(settings=args)
R
ruri 已提交
185

186 187
    train_data_loader.set_sample_list_generator(train_reader, place)
    test_data_loader.set_sample_list_generator(test_reader, place)
R
ruri 已提交
188 189

    compiled_train_prog = best_strategy_compiled(args, train_prog,
190 191
                                                 train_fetch_vars[0], exe)
    trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
192
    total_batch_num = 0  #this is for benchmark
R
ruri 已提交
193
    for pass_id in range(args.num_epochs):
194
        if num_trainers > 1:
195 196
            imagenet_reader.set_shuffle_seed(pass_id + (
                args.random_seed if args.random_seed else 0))
R
ruri 已提交
197 198 199
        train_batch_id = 0
        train_batch_time_record = []
        train_batch_metrics_record = []
R
ruri 已提交
200

201
        train_data_loader.start()
R
ruri 已提交
202 203
        try:
            while True:
204 205
                if args.max_iter and total_batch_num == args.max_iter:
                    return
R
ruri 已提交
206
                t1 = time.time()
R
ruri 已提交
207 208
                train_batch_metrics = exe.run(compiled_train_prog,
                                              fetch_list=train_fetch_list)
R
ruri 已提交
209
                t2 = time.time()
R
ruri 已提交
210 211 212 213 214
                train_batch_elapse = t2 - t1
                train_batch_time_record.append(train_batch_elapse)
                train_batch_metrics_avg = np.mean(
                    np.array(train_batch_metrics), axis=1)
                train_batch_metrics_record.append(train_batch_metrics_avg)
215 216
                if trainer_id == 0:
                    print_info(pass_id, train_batch_id, args.print_step,
217 218
                               train_batch_metrics_avg, train_batch_elapse,
                               "batch")
219
                    sys.stdout.flush()
R
ruri 已提交
220
                train_batch_id += 1
221 222 223 224 225 226 227 228
                total_batch_num = total_batch_num + 1 #this is for benchmark

                ##profiler tools
                if args.is_profiler and pass_id == 0 and train_batch_id == 100: 
                    profiler.start_profiler("All")
                elif args.is_profiler and pass_id == 0 and train_batch_id == 150:
                    profiler.stop_profiler("total", args.profiler_path)
                    return
R
ruri 已提交
229

R
ruri 已提交
230
        except fluid.core.EOFException:
231
            train_data_loader.reset()
232

233
        if trainer_id == 0 and args.validate:
234 235 236
            if args.use_ema:
                print('ExponentialMovingAverage validate start...')
                with ema.apply(exe):
237 238 239
                    validate(args, test_data_loader, exe, test_prog,
                             test_fetch_list, pass_id,
                             train_batch_metrics_record)
240
                print('ExponentialMovingAverage validate over!')
R
ruri 已提交
241

242 243
            validate(args, test_data_loader, exe, test_prog, test_fetch_list,
                     pass_id, train_batch_metrics_record)
244 245 246
            #For now, save model per epoch.
            if pass_id % args.save_step == 0:
                save_model(args, exe, train_prog, pass_id)
247

248

249
def main():
R
ruri 已提交
250
    args = parse_args()
251 252
    if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
        print_arguments(args)
R
ruri 已提交
253
    check_args(args)
254
    train(args)
255

256 257 258

if __name__ == '__main__':
    main()