train.py 8.6 KB
Newer Older
D
dengkaipeng 已提交
1
#  Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import os
import sys
import time
import argparse
19
import ast
20 21 22 23
import logging
import numpy as np
import paddle.fluid as fluid

24
from utils.train_utils import train_with_pyreader
25
import models
26 27
from utils.config_utils import *
from reader import get_reader
28
from metrics import get_metrics
29
from utils.utility import check_cuda
30

D
dengkaipeng 已提交
31
logging.root.handlers = []
32 33 34 35 36 37 38 39
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser("Paddle Video train script")
    parser.add_argument(
S
SunGaofeng 已提交
40
        '--model_name',
41 42 43 44 45 46 47 48 49
        type=str,
        default='AttentionCluster',
        help='name of model to train.')
    parser.add_argument(
        '--config',
        type=str,
        default='configs/attention_cluster.txt',
        help='path to config file of model')
    parser.add_argument(
S
SunGaofeng 已提交
50
        '--batch_size',
51 52 53 54
        type=int,
        default=None,
        help='training batch size. None to use config file setting.')
    parser.add_argument(
S
SunGaofeng 已提交
55
        '--learning_rate',
56 57 58 59 60 61 62 63 64
        type=float,
        default=None,
        help='learning rate use for training. None to use config file setting.')
    parser.add_argument(
        '--pretrain',
        type=str,
        default=None,
        help='path to pretrain weights. None to use default weights path in  ~/.paddle/weights.'
    )
D
dengkaipeng 已提交
65
    parser.add_argument(
66
        '--resume',
D
dengkaipeng 已提交
67 68
        type=str,
        default=None,
D
dengkaipeng 已提交
69
        help='path to resume training based on previous checkpoints. '
S
SunGaofeng 已提交
70
        'None for not resuming any checkpoints.')
71
    parser.add_argument(
72 73 74 75
        '--use_gpu',
        type=ast.literal_eval,
        default=True,
        help='default use gpu.')
76
    parser.add_argument(
S
SunGaofeng 已提交
77
        '--no_memory_optimize',
78 79 80 81
        action='store_true',
        default=False,
        help='whether to use memory optimize in train')
    parser.add_argument(
82
        '--epoch',
83
        type=int,
84
        default=None,
85 86
        help='epoch number, 0 for read from config file')
    parser.add_argument(
S
SunGaofeng 已提交
87
        '--valid_interval',
88 89 90 91
        type=int,
        default=1,
        help='validation epoch interval, 0 for no validation.')
    parser.add_argument(
S
SunGaofeng 已提交
92
        '--save_dir',
93
        type=str,
94
        default=os.path.join('data', 'checkpoints'),
95 96
        help='directory name to save train snapshoot')
    parser.add_argument(
S
SunGaofeng 已提交
97
        '--log_interval',
98 99 100
        type=int,
        default=10,
        help='mini-batch interval to log.')
X
xiegegege 已提交
101
    parser.add_argument(
102
        '--fix_random_seed',
103
        type=ast.literal_eval,
X
xiegegege 已提交
104 105
        default=False,
        help='If set True, enable continuous evaluation job.')
106 107 108 109
    args = parser.parse_args()
    return args


D
dengkaipeng 已提交
110 111 112 113 114
def train(args):
    # parse config
    config = parse_config(args.config)
    train_config = merge_configs(config, 'train', vars(args))
    valid_config = merge_configs(config, 'valid', vars(args))
D
dengkaipeng 已提交
115
    print_configs(train_config, 'Train')
116 117
    train_model = models.get_model(args.model_name, train_config, mode='train')
    valid_model = models.get_model(args.model_name, valid_config, mode='valid')
D
dengkaipeng 已提交
118 119

    # build model
D
dengkaipeng 已提交
120
    startup = fluid.Program()
121
    train_prog = fluid.Program()
122
    if args.fix_random_seed:
X
xiegegege 已提交
123 124
        startup.random_seed = 1000
        train_prog.random_seed = 1000
D
dengkaipeng 已提交
125
    with fluid.program_guard(train_prog, startup):
126
        with fluid.unique_name.guard():
127
            train_model.build_input(use_pyreader=True)
128 129 130
            train_model.build_model()
            # for the input, has the form [data1, data2,..., label], so train_feeds[-1] is label
            train_feeds = train_model.feeds()
131 132 133 134
            train_fetch_list = train_model.fetches()
            train_loss = train_fetch_list[0]
            for item in train_fetch_list:
                item.persistable = True
135 136
            optimizer = train_model.optimizer()
            optimizer.minimize(train_loss)
137
            train_pyreader = train_model.pyreader()
138 139

    valid_prog = fluid.Program()
D
dengkaipeng 已提交
140
    with fluid.program_guard(valid_prog, startup):
141
        with fluid.unique_name.guard():
142
            valid_model.build_input(use_pyreader=True)
143 144
            valid_model.build_model()
            valid_feeds = valid_model.feeds()
145
            valid_fetch_list = valid_model.fetches()
146
            valid_pyreader = valid_model.pyreader()
147 148
            for item in valid_fetch_list:
                item.persistable = True
149 150 151

    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
D
dengkaipeng 已提交
152
    exe.run(startup)
153

154
    if args.resume:
D
dengkaipeng 已提交
155
        # if resume weights is given, load resume weights directly
156 157
        assert os.path.exists(args.resume), \
                "Given resume weight dir {} not exist.".format(args.resume)
S
SunGaofeng 已提交
158

159 160
        fluid.io.load_persistables(
            exe, '', main_program=train_prog, filename=args.resume)
D
dengkaipeng 已提交
161 162 163 164 165 166 167 168
    else:
        # if not in resume mode, load pretrain weights
        if args.pretrain:
            assert os.path.exists(args.pretrain), \
                    "Given pretrain weight dir {} not exist.".format(args.pretrain)
        pretrain = args.pretrain or train_model.get_pretrain_weights()
        if pretrain:
            train_model.load_pretrain_params(exe, pretrain, train_prog, place)
169

170 171
    build_strategy = fluid.BuildStrategy()
    build_strategy.enable_inplace = True
172 173
    if args.model_name in ['CTCN']:
        build_strategy.enable_sequential_execution = True
174

175 176 177 178 179 180
    compiled_train_prog = fluid.compiler.CompiledProgram(
        train_prog).with_data_parallel(
            loss_name=train_loss.name, build_strategy=build_strategy)
    compiled_valid_prog = fluid.compiler.CompiledProgram(
        valid_prog).with_data_parallel(
            share_vars_from=compiled_train_prog, build_strategy=build_strategy)
181

D
dengkaipeng 已提交
182
    # get reader
183
    bs_denominator = 1
184 185 186 187 188 189 190 191 192
    if args.use_gpu:
        # check number of GPUs
        gpus = os.getenv("CUDA_VISIBLE_DEVICES", "")
        if gpus == "":
            pass
        else:
            gpus = gpus.split(",")
            num_gpus = len(gpus)
            assert num_gpus == train_config.TRAIN.num_gpus, \
193 194
                   "num_gpus({}) set by CUDA_VISIBLE_DEVICES" \
                   "shoud be the same as that" \
195 196
                   "set in {}({})".format(
                   num_gpus, args.config, train_config.TRAIN.num_gpus)
197
        bs_denominator = train_config.TRAIN.num_gpus
198

199 200 201 202
    train_config.TRAIN.batch_size = int(train_config.TRAIN.batch_size /
                                        bs_denominator)
    valid_config.VALID.batch_size = int(valid_config.VALID.batch_size /
                                        bs_denominator)
203 204
    train_reader = get_reader(args.model_name.upper(), 'train', train_config)
    valid_reader = get_reader(args.model_name.upper(), 'valid', valid_config)
D
dengkaipeng 已提交
205 206

    # get metrics 
207 208
    train_metrics = get_metrics(args.model_name.upper(), 'train', train_config)
    valid_metrics = get_metrics(args.model_name.upper(), 'valid', valid_config)
D
dengkaipeng 已提交
209

210
    epochs = args.epoch or train_model.epoch_num()
211

212
    exe_places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
213
    train_pyreader.decorate_sample_list_generator(
214
        train_reader, places=exe_places)
215
    valid_pyreader.decorate_sample_list_generator(
216 217
        valid_reader, places=exe_places)

218
    train_with_pyreader(
219 220 221
        exe,
        train_prog,
        compiled_train_prog,  #train_exe,
222
        train_pyreader,
223 224 225 226 227 228 229 230 231
        train_fetch_list,
        train_metrics,
        epochs=epochs,
        log_interval=args.log_interval,
        valid_interval=args.valid_interval,
        save_dir=args.save_dir,
        save_model_name=args.model_name,
        fix_random_seed=args.fix_random_seed,
        compiled_test_prog=compiled_valid_prog,  #test_exe=valid_exe,
232
        test_pyreader=valid_pyreader,
233 234
        test_fetch_list=valid_fetch_list,
        test_metrics=valid_metrics)
235 236 237 238


if __name__ == "__main__":
    args = parse_args()
239 240
    # check whether the installed paddle is compiled with GPU
    check_cuda(args.use_gpu)
241
    logger.info(args)
D
dengkaipeng 已提交
242

243 244 245
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

D
dengkaipeng 已提交
246
    train(args)