train.py 7.0 KB
Newer Older
R
ruri 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

R
root 已提交
15 16 17
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
R
ruri 已提交
18

19 20 21 22
import os
import numpy as np
import time
import sys
R
root 已提交
23 24
import functools
import math
25

26

27 28 29 30 31 32 33 34 35 36 37 38 39
def set_paddle_flags(flags):
    for key, value in flags.items():
        if os.environ.get(key, None) is None:
            os.environ[key] = str(value)


# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
# not take any effect. 
set_paddle_flags({
    'FLAGS_eager_delete_tensor_gb': 0,  # enable gc 
    'FLAGS_fraction_of_gpu_memory_to_use': 0.98
})
R
ruri 已提交
40

R
ruri 已提交
41 42 43
import argparse
import functools
import subprocess
R
ruri 已提交
44

45
import paddle
46
import paddle.fluid as fluid
R
ruri 已提交
47 48
import reader
from utils import *
49
import models
R
ruri 已提交
50 51
from build_model import create_model

R
ruri 已提交
52 53

def build_program(is_train, main_prog, startup_prog, args):
R
ruri 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66
    """build program, and add grad op in program accroding to different mode

    Args:
        is_train: mode: train or test
        main_prog: main program
        startup_prog: strartup program
        args: arguments

    Returns : 
        train mode: [Loss, global_lr, py_reader]
        test mode: [Loss, py_reader]
    """
    model = models.__dict__[args.model]()
R
ruri 已提交
67
    with fluid.program_guard(main_prog, startup_prog):
R
ruri 已提交
68 69 70
        if args.random_seed:
            main_prog.random_seed = args.random_seed
            startup_prog.random_seed = args.random_seed
R
ruri 已提交
71
        with fluid.unique_name.guard():
R
ruri 已提交
72 73
            py_reader, loss_out = create_model(model, args, is_train)
            # add backward op in program
R
ruri 已提交
74
            if is_train:
R
ruri 已提交
75 76 77 78
                optimizer = create_optimizer(args)
                avg_cost = loss_out[0]
                optimizer.minimize(avg_cost)
                #XXX: fetch learning rate now, better implement is required here. 
R
root 已提交
79
                global_lr = optimizer._global_learning_rate()
R
ruri 已提交
80 81 82 83
                global_lr.persistable = True
                loss_out.append(global_lr)
            loss_out.append(py_reader)
    return loss_out
R
ruri 已提交
84 85 86


def train(args):
R
ruri 已提交
87 88 89 90 91
    """Train model
    
    Args:
        args: all arguments.    
    """
R
ruri 已提交
92 93 94
    startup_prog = fluid.Program()
    train_prog = fluid.Program()
    test_prog = fluid.Program()
R
ruri 已提交
95 96 97 98 99 100 101 102

    train_out = build_program(
        is_train=True,
        main_prog=train_prog,
        startup_prog=startup_prog,
        args=args)
    train_py_reader = train_out[-1]
    train_fetch_vars = train_out[:-1]
103 104

    train_fetch_list = [var.name for var in train_fetch_vars]
R
ruri 已提交
105 106 107 108 109 110 111 112

    test_out = build_program(
        is_train=False,
        main_prog=test_prog,
        startup_prog=startup_prog,
        args=args)
    test_py_reader = test_out[-1]
    test_fetch_vars = test_out[:-1]
113 114

    test_fetch_list = [var.name for var in test_fetch_vars]
R
ruri 已提交
115 116

    #Create test_prog and set layers' is_test params to True
R
ruri 已提交
117
    test_prog = test_prog.clone(for_test=True)
118

119 120
    gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
    place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace()
121
    exe = fluid.Executor(place)
R
ruri 已提交
122
    exe.run(startup_prog)
123

R
ruri 已提交
124 125 126 127 128 129 130 131
    #init model by checkpoint or pretrianed model.
    init_model(exe, args, train_prog)

    train_reader = reader.train(settings=args)
    train_reader = paddle.batch(
        train_reader,
        batch_size=int(args.batch_size / fluid.core.get_cuda_device_count()),
        drop_last=True)
132

R
ruri 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
    test_reader = reader.val(settings=args)
    test_reader = paddle.batch(
        test_reader, batch_size=args.test_batch_size, drop_last=True)

    train_py_reader.decorate_sample_list_generator(train_reader, place)
    test_py_reader.decorate_sample_list_generator(test_reader, place)

    compiled_train_prog = best_strategy_compiled(args, train_prog,
                                                 train_fetch_vars[0])

    for pass_id in range(args.num_epochs):

        train_batch_id = 0
        test_batch_id = 0
        train_batch_time_record = []
        test_batch_time_record = []
        train_batch_metrics_record = []
        test_batch_metrics_record = []
R
ruri 已提交
151 152

        train_py_reader.start()
R
ruri 已提交
153

R
ruri 已提交
154 155 156
        try:
            while True:
                t1 = time.time()
R
ruri 已提交
157 158
                train_batch_metrics = exe.run(compiled_train_prog,
                                              fetch_list=train_fetch_list)
R
ruri 已提交
159
                t2 = time.time()
R
ruri 已提交
160 161 162 163 164 165 166 167 168 169 170
                train_batch_elapse = t2 - t1
                train_batch_time_record.append(train_batch_elapse)
                train_batch_metrics_avg = np.mean(
                    np.array(train_batch_metrics), axis=1)
                train_batch_metrics_record.append(train_batch_metrics_avg)

                print_info(pass_id, train_batch_id, args.print_step,
                           train_batch_metrics_avg, train_batch_elapse, "batch")
                sys.stdout.flush()
                train_batch_id += 1

R
ruri 已提交
171 172
        except fluid.core.EOFException:
            train_py_reader.reset()
173

R
ruri 已提交
174 175 176 177
        test_py_reader.start()
        try:
            while True:
                t1 = time.time()
R
ruri 已提交
178 179
                test_batch_metrics = exe.run(program=test_prog,
                                             fetch_list=test_fetch_list)
R
ruri 已提交
180
                t2 = time.time()
R
ruri 已提交
181 182 183 184 185 186 187 188 189 190
                test_batch_elapse = t2 - t1
                test_batch_time_record.append(test_batch_elapse)

                test_batch_metrics_avg = np.mean(
                    np.array(test_batch_metrics), axis=1)
                test_batch_metrics_record.append(test_batch_metrics_avg)

                print_info(pass_id, test_batch_id, args.print_step,
                           test_batch_metrics_avg, test_batch_elapse, "batch")
                sys.stdout.flush()
R
ruri 已提交
191
                test_batch_id += 1
R
ruri 已提交
192

R
ruri 已提交
193 194
        except fluid.core.EOFException:
            test_py_reader.reset()
R
ruri 已提交
195 196 197 198 199 200 201
        train_epoch_time_avg = np.mean(np.array(train_batch_time_record))
        train_epoch_metrics_avg = np.mean(
            np.array(train_batch_metrics_record), axis=0)

        test_epoch_time_avg = np.mean(np.array(test_batch_time_record))
        test_epoch_metrics_avg = np.mean(
            np.array(test_batch_metrics_record), axis=0)
R
ruri 已提交
202

R
ruri 已提交
203 204 205
        print_info(pass_id, 0, 0,
                   list(train_epoch_metrics_avg) + list(test_epoch_metrics_avg),
                   0, "epoch")
206 207 208
        #For now, save model per epoch.
        if pass_id % args.save_step == 0:
            save_model(args, exe, train_prog, pass_id)
209

210

211
def main():
R
ruri 已提交
212
    args = parse_args()
213
    print_arguments(args)
R
ruri 已提交
214
    check_args(args)
215
    train(args)
216

217 218 219

if __name__ == '__main__':
    main()