train.py 8.6 KB
Newer Older
D
dengkaipeng 已提交
1
#  Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import os
import sys
import time
import argparse
19
import ast
20 21 22 23
import logging
import numpy as np
import paddle.fluid as fluid

24
from utils.train_utils import train_with_dataloader
25
import models
26 27
from utils.config_utils import *
from reader import get_reader
28
from metrics import get_metrics
29
from utils.utility import check_cuda
30
from utils.utility import check_version
31

D
dengkaipeng 已提交
32
logging.root.handlers = []
33 34 35 36 37 38 39 40
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser("Paddle Video train script")
    parser.add_argument(
S
SunGaofeng 已提交
41
        '--model_name',
42 43 44 45 46 47 48 49 50
        type=str,
        default='AttentionCluster',
        help='name of model to train.')
    parser.add_argument(
        '--config',
        type=str,
        default='configs/attention_cluster.txt',
        help='path to config file of model')
    parser.add_argument(
S
SunGaofeng 已提交
51
        '--batch_size',
52 53 54 55
        type=int,
        default=None,
        help='training batch size. None to use config file setting.')
    parser.add_argument(
S
SunGaofeng 已提交
56
        '--learning_rate',
57 58 59 60 61 62 63 64 65
        type=float,
        default=None,
        help='learning rate use for training. None to use config file setting.')
    parser.add_argument(
        '--pretrain',
        type=str,
        default=None,
        help='path to pretrain weights. None to use default weights path in  ~/.paddle/weights.'
    )
D
dengkaipeng 已提交
66
    parser.add_argument(
67
        '--resume',
D
dengkaipeng 已提交
68 69
        type=str,
        default=None,
D
dengkaipeng 已提交
70
        help='path to resume training based on previous checkpoints. '
S
SunGaofeng 已提交
71
        'None for not resuming any checkpoints.')
72
    parser.add_argument(
73 74 75 76
        '--use_gpu',
        type=ast.literal_eval,
        default=True,
        help='default use gpu.')
77
    parser.add_argument(
S
SunGaofeng 已提交
78
        '--no_memory_optimize',
79 80 81 82
        action='store_true',
        default=False,
        help='whether to use memory optimize in train')
    parser.add_argument(
83
        '--epoch',
84
        type=int,
85
        default=None,
86 87
        help='epoch number, 0 for read from config file')
    parser.add_argument(
S
SunGaofeng 已提交
88
        '--valid_interval',
89 90 91 92
        type=int,
        default=1,
        help='validation epoch interval, 0 for no validation.')
    parser.add_argument(
S
SunGaofeng 已提交
93
        '--save_dir',
94
        type=str,
95
        default=os.path.join('data', 'checkpoints'),
96 97
        help='directory name to save train snapshoot')
    parser.add_argument(
S
SunGaofeng 已提交
98
        '--log_interval',
99 100 101
        type=int,
        default=10,
        help='mini-batch interval to log.')
X
xiegegege 已提交
102
    parser.add_argument(
103
        '--fix_random_seed',
104
        type=ast.literal_eval,
X
xiegegege 已提交
105 106
        default=False,
        help='If set True, enable continuous evaluation job.')
107 108 109 110
    args = parser.parse_args()
    return args


D
dengkaipeng 已提交
111 112 113 114 115
def train(args):
    # parse config
    config = parse_config(args.config)
    train_config = merge_configs(config, 'train', vars(args))
    valid_config = merge_configs(config, 'valid', vars(args))
D
dengkaipeng 已提交
116
    print_configs(train_config, 'Train')
117 118
    train_model = models.get_model(args.model_name, train_config, mode='train')
    valid_model = models.get_model(args.model_name, valid_config, mode='valid')
D
dengkaipeng 已提交
119 120

    # build model
D
dengkaipeng 已提交
121
    startup = fluid.Program()
122
    train_prog = fluid.Program()
123
    if args.fix_random_seed:
X
xiegegege 已提交
124 125
        startup.random_seed = 1000
        train_prog.random_seed = 1000
D
dengkaipeng 已提交
126
    with fluid.program_guard(train_prog, startup):
127
        with fluid.unique_name.guard():
128
            train_model.build_input(use_dataloader=True)
129 130 131
            train_model.build_model()
            # for the input, has the form [data1, data2,..., label], so train_feeds[-1] is label
            train_feeds = train_model.feeds()
132 133 134 135
            train_fetch_list = train_model.fetches()
            train_loss = train_fetch_list[0]
            for item in train_fetch_list:
                item.persistable = True
136 137
            optimizer = train_model.optimizer()
            optimizer.minimize(train_loss)
138
            train_dataloader = train_model.dataloader()
139 140

    valid_prog = fluid.Program()
D
dengkaipeng 已提交
141
    with fluid.program_guard(valid_prog, startup):
142
        with fluid.unique_name.guard():
143
            valid_model.build_input(use_dataloader=True)
144 145
            valid_model.build_model()
            valid_feeds = valid_model.feeds()
146
            valid_fetch_list = valid_model.fetches()
147
            valid_dataloader = valid_model.dataloader()
148 149
            for item in valid_fetch_list:
                item.persistable = True
150 151 152

    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
D
dengkaipeng 已提交
153
    exe.run(startup)
154

155
    if args.resume:
D
dengkaipeng 已提交
156
        # if resume weights is given, load resume weights directly
157 158
        assert os.path.exists(args.resume), \
                "Given resume weight dir {} not exist.".format(args.resume)
S
SunGaofeng 已提交
159

160 161
        fluid.io.load_persistables(
            exe, '', main_program=train_prog, filename=args.resume)
D
dengkaipeng 已提交
162 163 164 165 166 167 168 169
    else:
        # if not in resume mode, load pretrain weights
        if args.pretrain:
            assert os.path.exists(args.pretrain), \
                    "Given pretrain weight dir {} not exist.".format(args.pretrain)
        pretrain = args.pretrain or train_model.get_pretrain_weights()
        if pretrain:
            train_model.load_pretrain_params(exe, pretrain, train_prog, place)
170

171 172
    build_strategy = fluid.BuildStrategy()
    build_strategy.enable_inplace = True
173 174
    if args.model_name in ['CTCN']:
        build_strategy.enable_sequential_execution = True
175

176 177 178 179 180 181
    compiled_train_prog = fluid.compiler.CompiledProgram(
        train_prog).with_data_parallel(
            loss_name=train_loss.name, build_strategy=build_strategy)
    compiled_valid_prog = fluid.compiler.CompiledProgram(
        valid_prog).with_data_parallel(
            share_vars_from=compiled_train_prog, build_strategy=build_strategy)
182

D
dengkaipeng 已提交
183
    # get reader
184
    bs_denominator = 1
185 186 187 188 189 190 191 192 193
    if args.use_gpu:
        # check number of GPUs
        gpus = os.getenv("CUDA_VISIBLE_DEVICES", "")
        if gpus == "":
            pass
        else:
            gpus = gpus.split(",")
            num_gpus = len(gpus)
            assert num_gpus == train_config.TRAIN.num_gpus, \
194 195
                   "num_gpus({}) set by CUDA_VISIBLE_DEVICES " \
                   "shoud be the same as that " \
196 197
                   "set in {}({})".format(
                   num_gpus, args.config, train_config.TRAIN.num_gpus)
198
        bs_denominator = train_config.TRAIN.num_gpus
199

200 201 202 203
    train_config.TRAIN.batch_size = int(train_config.TRAIN.batch_size /
                                        bs_denominator)
    valid_config.VALID.batch_size = int(valid_config.VALID.batch_size /
                                        bs_denominator)
204 205
    train_reader = get_reader(args.model_name.upper(), 'train', train_config)
    valid_reader = get_reader(args.model_name.upper(), 'valid', valid_config)
D
dengkaipeng 已提交
206 207

    # get metrics 
208 209
    train_metrics = get_metrics(args.model_name.upper(), 'train', train_config)
    valid_metrics = get_metrics(args.model_name.upper(), 'valid', valid_config)
D
dengkaipeng 已提交
210

211
    epochs = args.epoch or train_model.epoch_num()
212

213
    exe_places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
214
    train_dataloader.set_sample_list_generator(
215
        train_reader, places=exe_places)
216
    valid_dataloader.set_sample_list_generator(
217 218
        valid_reader, places=exe_places)

219
    train_with_dataloader(
220 221 222
        exe,
        train_prog,
        compiled_train_prog,  #train_exe,
223
        train_dataloader,
224 225 226 227 228 229 230 231 232
        train_fetch_list,
        train_metrics,
        epochs=epochs,
        log_interval=args.log_interval,
        valid_interval=args.valid_interval,
        save_dir=args.save_dir,
        save_model_name=args.model_name,
        fix_random_seed=args.fix_random_seed,
        compiled_test_prog=compiled_valid_prog,  #test_exe=valid_exe,
233
        test_dataloader=valid_dataloader,
234 235
        test_fetch_list=valid_fetch_list,
        test_metrics=valid_metrics)
236 237 238 239


if __name__ == "__main__":
    args = parse_args()
240 241
    # check whether the installed paddle is compiled with GPU
    check_cuda(args.use_gpu)
242
    check_version()
243
    logger.info(args)
D
dengkaipeng 已提交
244

245 246 247
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

D
dengkaipeng 已提交
248
    train(args)