train.py 8.8 KB
Newer Older
D
dengkaipeng 已提交
1
#  Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import os
import sys
import time
import argparse
19
import ast
20 21 22 23
import logging
import numpy as np
import paddle.fluid as fluid

24
from utils.train_utils import train_with_dataloader
25
import models
26 27
from utils.config_utils import *
from reader import get_reader
28
from metrics import get_metrics
29
from utils.utility import check_cuda
30
from utils.utility import check_version
31

D
dengkaipeng 已提交
32
logging.root.handlers = []
33 34 35 36 37 38 39 40
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser("Paddle Video train script")
    parser.add_argument(
S
SunGaofeng 已提交
41
        '--model_name',
42 43 44 45 46 47 48 49 50
        type=str,
        default='AttentionCluster',
        help='name of model to train.')
    parser.add_argument(
        '--config',
        type=str,
        default='configs/attention_cluster.txt',
        help='path to config file of model')
    parser.add_argument(
S
SunGaofeng 已提交
51
        '--batch_size',
52 53 54 55
        type=int,
        default=None,
        help='training batch size. None to use config file setting.')
    parser.add_argument(
S
SunGaofeng 已提交
56
        '--learning_rate',
57 58 59 60 61 62 63 64 65
        type=float,
        default=None,
        help='learning rate use for training. None to use config file setting.')
    parser.add_argument(
        '--pretrain',
        type=str,
        default=None,
        help='path to pretrain weights. None to use default weights path in  ~/.paddle/weights.'
    )
D
dengkaipeng 已提交
66
    parser.add_argument(
67
        '--resume',
D
dengkaipeng 已提交
68 69
        type=str,
        default=None,
D
dengkaipeng 已提交
70
        help='path to resume training based on previous checkpoints. '
S
SunGaofeng 已提交
71
        'None for not resuming any checkpoints.')
72
    parser.add_argument(
73 74 75 76
        '--use_gpu',
        type=ast.literal_eval,
        default=True,
        help='default use gpu.')
77
    parser.add_argument(
S
SunGaofeng 已提交
78
        '--no_memory_optimize',
79 80 81 82
        action='store_true',
        default=False,
        help='whether to use memory optimize in train')
    parser.add_argument(
83
        '--epoch',
84
        type=int,
85
        default=None,
86 87
        help='epoch number, 0 for read from config file')
    parser.add_argument(
S
SunGaofeng 已提交
88
        '--valid_interval',
89 90 91 92
        type=int,
        default=1,
        help='validation epoch interval, 0 for no validation.')
    parser.add_argument(
S
SunGaofeng 已提交
93
        '--save_dir',
94
        type=str,
95
        default=os.path.join('data', 'checkpoints'),
96 97
        help='directory name to save train snapshoot')
    parser.add_argument(
S
SunGaofeng 已提交
98
        '--log_interval',
99 100 101
        type=int,
        default=10,
        help='mini-batch interval to log.')
X
xiegegege 已提交
102
    parser.add_argument(
103
        '--fix_random_seed',
104
        type=ast.literal_eval,
X
xiegegege 已提交
105 106
        default=False,
        help='If set True, enable continuous evaluation job.')
107 108 109 110
    args = parser.parse_args()
    return args


D
dengkaipeng 已提交
111 112 113 114 115
def train(args):
    # parse config
    config = parse_config(args.config)
    train_config = merge_configs(config, 'train', vars(args))
    valid_config = merge_configs(config, 'valid', vars(args))
D
dengkaipeng 已提交
116
    print_configs(train_config, 'Train')
117 118
    train_model = models.get_model(args.model_name, train_config, mode='train')
    valid_model = models.get_model(args.model_name, valid_config, mode='valid')
D
dengkaipeng 已提交
119 120

    # build model
D
dengkaipeng 已提交
121
    startup = fluid.Program()
122
    train_prog = fluid.Program()
123
    if args.fix_random_seed:
X
xiegegege 已提交
124 125
        startup.random_seed = 1000
        train_prog.random_seed = 1000
D
dengkaipeng 已提交
126
    with fluid.program_guard(train_prog, startup):
127
        with fluid.unique_name.guard():
128
            train_model.build_input(use_dataloader=True)
129 130 131
            train_model.build_model()
            # for the input, has the form [data1, data2,..., label], so train_feeds[-1] is label
            train_feeds = train_model.feeds()
132 133 134 135
            train_fetch_list = train_model.fetches()
            train_loss = train_fetch_list[0]
            for item in train_fetch_list:
                item.persistable = True
136 137
            optimizer = train_model.optimizer()
            optimizer.minimize(train_loss)
138
            train_dataloader = train_model.dataloader()
139 140

    valid_prog = fluid.Program()
D
dengkaipeng 已提交
141
    with fluid.program_guard(valid_prog, startup):
142
        with fluid.unique_name.guard():
143
            valid_model.build_input(use_dataloader=True)
144 145
            valid_model.build_model()
            valid_feeds = valid_model.feeds()
146
            valid_fetch_list = valid_model.fetches()
147
            valid_dataloader = valid_model.dataloader()
148 149
            for item in valid_fetch_list:
                item.persistable = True
150 151 152

    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
D
dengkaipeng 已提交
153
    exe.run(startup)
154

155
    if args.resume:
D
dengkaipeng 已提交
156
        # if resume weights is given, load resume weights directly
157 158
        assert os.path.exists(args.resume), \
                "Given resume weight dir {} not exist.".format(args.resume)
S
SunGaofeng 已提交
159

160 161
        fluid.io.load_persistables(
            exe, '', main_program=train_prog, filename=args.resume)
D
dengkaipeng 已提交
162 163 164 165 166 167 168 169
    else:
        # if not in resume mode, load pretrain weights
        if args.pretrain:
            assert os.path.exists(args.pretrain), \
                    "Given pretrain weight dir {} not exist.".format(args.pretrain)
        pretrain = args.pretrain or train_model.get_pretrain_weights()
        if pretrain:
            train_model.load_pretrain_params(exe, pretrain, train_prog, place)
170

171 172
    build_strategy = fluid.BuildStrategy()
    build_strategy.enable_inplace = True
173 174
    if args.model_name in ['CTCN']:
        build_strategy.enable_sequential_execution = True
175

H
huangjun12 已提交
176 177
    exec_strategy = fluid.ExecutionStrategy()

178 179
    compiled_train_prog = fluid.compiler.CompiledProgram(
        train_prog).with_data_parallel(
H
huangjun12 已提交
180 181 182
            loss_name=train_loss.name,
            build_strategy=build_strategy,
            exec_strategy=exec_strategy)
183 184
    compiled_valid_prog = fluid.compiler.CompiledProgram(
        valid_prog).with_data_parallel(
H
huangjun12 已提交
185 186 187
            share_vars_from=compiled_train_prog,
            build_strategy=build_strategy,
            exec_strategy=exec_strategy)
188

D
dengkaipeng 已提交
189
    # get reader
190
    bs_denominator = 1
191 192 193 194 195 196 197 198 199
    if args.use_gpu:
        # check number of GPUs
        gpus = os.getenv("CUDA_VISIBLE_DEVICES", "")
        if gpus == "":
            pass
        else:
            gpus = gpus.split(",")
            num_gpus = len(gpus)
            assert num_gpus == train_config.TRAIN.num_gpus, \
200 201
                   "num_gpus({}) set by CUDA_VISIBLE_DEVICES " \
                   "shoud be the same as that " \
202 203
                   "set in {}({})".format(
                   num_gpus, args.config, train_config.TRAIN.num_gpus)
204
        bs_denominator = train_config.TRAIN.num_gpus
205

206 207 208 209
    train_config.TRAIN.batch_size = int(train_config.TRAIN.batch_size /
                                        bs_denominator)
    valid_config.VALID.batch_size = int(valid_config.VALID.batch_size /
                                        bs_denominator)
210 211
    train_reader = get_reader(args.model_name.upper(), 'train', train_config)
    valid_reader = get_reader(args.model_name.upper(), 'valid', valid_config)
D
dengkaipeng 已提交
212 213

    # get metrics 
214 215
    train_metrics = get_metrics(args.model_name.upper(), 'train', train_config)
    valid_metrics = get_metrics(args.model_name.upper(), 'valid', valid_config)
D
dengkaipeng 已提交
216

217
    epochs = args.epoch or train_model.epoch_num()
218

219
    exe_places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
220
    train_dataloader.set_sample_list_generator(
221
        train_reader, places=exe_places)
222
    valid_dataloader.set_sample_list_generator(
223 224
        valid_reader, places=exe_places)

225
    train_with_dataloader(
226 227 228
        exe,
        train_prog,
        compiled_train_prog,  #train_exe,
229
        train_dataloader,
230 231 232 233 234 235 236 237 238
        train_fetch_list,
        train_metrics,
        epochs=epochs,
        log_interval=args.log_interval,
        valid_interval=args.valid_interval,
        save_dir=args.save_dir,
        save_model_name=args.model_name,
        fix_random_seed=args.fix_random_seed,
        compiled_test_prog=compiled_valid_prog,  #test_exe=valid_exe,
239
        test_dataloader=valid_dataloader,
240 241
        test_fetch_list=valid_fetch_list,
        test_metrics=valid_metrics)
242 243 244 245


if __name__ == "__main__":
    args = parse_args()
246 247
    # check whether the installed paddle is compiled with GPU
    check_cuda(args.use_gpu)
248
    check_version()
249
    logger.info(args)
D
dengkaipeng 已提交
250

251 252 253
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

D
dengkaipeng 已提交
254
    train(args)