train.py 8.5 KB
Newer Older
D
dengkaipeng 已提交
1
#  Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import os
import sys
import time
import argparse
19
import ast
20 21 22 23
import logging
import numpy as np
import paddle.fluid as fluid

24
from utils.train_utils import train_with_pyreader
25
import models
26 27
from utils.config_utils import *
from reader import get_reader
28
from metrics import get_metrics
29
from utils.utility import check_cuda
30

D
dengkaipeng 已提交
31
logging.root.handlers = []
32 33 34 35 36 37 38 39
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser("Paddle Video train script")
    parser.add_argument(
S
SunGaofeng 已提交
40
        '--model_name',
41 42 43 44 45 46 47 48 49
        type=str,
        default='AttentionCluster',
        help='name of model to train.')
    parser.add_argument(
        '--config',
        type=str,
        default='configs/attention_cluster.txt',
        help='path to config file of model')
    parser.add_argument(
S
SunGaofeng 已提交
50
        '--batch_size',
51 52 53 54
        type=int,
        default=None,
        help='training batch size. None to use config file setting.')
    parser.add_argument(
S
SunGaofeng 已提交
55
        '--learning_rate',
56 57 58 59 60 61 62 63 64
        type=float,
        default=None,
        help='learning rate use for training. None to use config file setting.')
    parser.add_argument(
        '--pretrain',
        type=str,
        default=None,
        help='path to pretrain weights. None to use default weights path in  ~/.paddle/weights.'
    )
D
dengkaipeng 已提交
65
    parser.add_argument(
66
        '--resume',
D
dengkaipeng 已提交
67 68
        type=str,
        default=None,
D
dengkaipeng 已提交
69
        help='path to resume training based on previous checkpoints. '
S
SunGaofeng 已提交
70
        'None for not resuming any checkpoints.')
71
    parser.add_argument(
72 73 74 75
        '--use_gpu',
        type=ast.literal_eval,
        default=True,
        help='default use gpu.')
76
    parser.add_argument(
S
SunGaofeng 已提交
77
        '--no_memory_optimize',
78 79 80 81
        action='store_true',
        default=False,
        help='whether to use memory optimize in train')
    parser.add_argument(
82
        '--epoch',
83
        type=int,
84
        default=None,
85 86
        help='epoch number, 0 for read from config file')
    parser.add_argument(
S
SunGaofeng 已提交
87
        '--valid_interval',
88 89 90 91
        type=int,
        default=1,
        help='validation epoch interval, 0 for no validation.')
    parser.add_argument(
S
SunGaofeng 已提交
92
        '--save_dir',
93
        type=str,
94
        default=os.path.join('data', 'checkpoints'),
95 96
        help='directory name to save train snapshoot')
    parser.add_argument(
S
SunGaofeng 已提交
97
        '--log_interval',
98 99 100
        type=int,
        default=10,
        help='mini-batch interval to log.')
X
xiegegege 已提交
101
    parser.add_argument(
102
        '--fix_random_seed',
103
        type=ast.literal_eval,
X
xiegegege 已提交
104 105
        default=False,
        help='If set True, enable continuous evaluation job.')
106 107 108 109
    args = parser.parse_args()
    return args


D
dengkaipeng 已提交
110 111 112 113 114
def train(args):
    # parse config
    config = parse_config(args.config)
    train_config = merge_configs(config, 'train', vars(args))
    valid_config = merge_configs(config, 'valid', vars(args))
D
dengkaipeng 已提交
115
    print_configs(train_config, 'Train')
116 117
    train_model = models.get_model(args.model_name, train_config, mode='train')
    valid_model = models.get_model(args.model_name, valid_config, mode='valid')
D
dengkaipeng 已提交
118 119

    # build model
D
dengkaipeng 已提交
120
    startup = fluid.Program()
121
    train_prog = fluid.Program()
122
    if args.fix_random_seed:
X
xiegegege 已提交
123 124
        startup.random_seed = 1000
        train_prog.random_seed = 1000
D
dengkaipeng 已提交
125
    with fluid.program_guard(train_prog, startup):
126
        with fluid.unique_name.guard():
127
            train_model.build_input(use_pyreader=True)
128 129 130
            train_model.build_model()
            # for the input, has the form [data1, data2,..., label], so train_feeds[-1] is label
            train_feeds = train_model.feeds()
131 132 133 134
            train_fetch_list = train_model.fetches()
            train_loss = train_fetch_list[0]
            for item in train_fetch_list:
                item.persistable = True
135 136 137 138 139
            optimizer = train_model.optimizer()
            optimizer.minimize(train_loss)
            train_pyreader = train_model.pyreader()

    valid_prog = fluid.Program()
D
dengkaipeng 已提交
140
    with fluid.program_guard(valid_prog, startup):
141
        with fluid.unique_name.guard():
142
            valid_model.build_input(use_pyreader=True)
143 144
            valid_model.build_model()
            valid_feeds = valid_model.feeds()
145
            valid_fetch_list = valid_model.fetches()
146
            valid_pyreader = valid_model.pyreader()
147 148
            for item in valid_fetch_list:
                item.persistable = True
149 150 151

    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
D
dengkaipeng 已提交
152
    exe.run(startup)
153

154
    if args.resume:
D
dengkaipeng 已提交
155
        # if resume weights is given, load resume weights directly
156 157
        assert os.path.exists(args.resume), \
                "Given resume weight dir {} not exist.".format(args.resume)
S
SunGaofeng 已提交
158

159 160
        fluid.io.load_persistables(
            exe, '', main_program=train_prog, filename=args.resume)
D
dengkaipeng 已提交
161 162 163 164 165 166 167 168
    else:
        # if not in resume mode, load pretrain weights
        if args.pretrain:
            assert os.path.exists(args.pretrain), \
                    "Given pretrain weight dir {} not exist.".format(args.pretrain)
        pretrain = args.pretrain or train_model.get_pretrain_weights()
        if pretrain:
            train_model.load_pretrain_params(exe, pretrain, train_prog, place)
169

170
    build_strategy = fluid.BuildStrategy()
171 172
    if args.model_name in ['CTCN']:
        build_strategy.enable_sequential_execution = True
173

174 175 176 177 178 179
    compiled_train_prog = fluid.compiler.CompiledProgram(
        train_prog).with_data_parallel(
            loss_name=train_loss.name, build_strategy=build_strategy)
    compiled_valid_prog = fluid.compiler.CompiledProgram(
        valid_prog).with_data_parallel(
            share_vars_from=compiled_train_prog, build_strategy=build_strategy)
180

D
dengkaipeng 已提交
181
    # get reader
182
    bs_denominator = 1
183 184 185 186 187 188 189 190 191 192 193 194 195
    if args.use_gpu:
        # check number of GPUs
        gpus = os.getenv("CUDA_VISIBLE_DEVICES", "")
        if gpus == "":
            pass
        else:
            gpus = gpus.split(",")
            num_gpus = len(gpus)
            assert num_gpus == train_config.TRAIN.num_gpus, \
                   "num_gpus({}) set by CUDA_VISIBLE_DEVICES" \
                   "shoud be the same as that" \
                   "set in {}({})".format(
                   num_gpus, args.config, train_config.TRAIN.num_gpus)
196
        bs_denominator = train_config.TRAIN.num_gpus
197

198 199 200 201
    train_config.TRAIN.batch_size = int(train_config.TRAIN.batch_size /
                                        bs_denominator)
    valid_config.VALID.batch_size = int(valid_config.VALID.batch_size /
                                        bs_denominator)
202 203
    train_reader = get_reader(args.model_name.upper(), 'train', train_config)
    valid_reader = get_reader(args.model_name.upper(), 'valid', valid_config)
D
dengkaipeng 已提交
204 205

    # get metrics 
206 207
    train_metrics = get_metrics(args.model_name.upper(), 'train', train_config)
    valid_metrics = get_metrics(args.model_name.upper(), 'valid', valid_config)
D
dengkaipeng 已提交
208

209
    epochs = args.epoch or train_model.epoch_num()
210

211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
    exe_places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
    train_pyreader.decorate_sample_list_generator(
        train_reader, places=exe_places)
    valid_pyreader.decorate_sample_list_generator(
        valid_reader, places=exe_places)

    train_with_pyreader(
        exe,
        train_prog,
        compiled_train_prog,  #train_exe,
        train_pyreader,
        train_fetch_list,
        train_metrics,
        epochs=epochs,
        log_interval=args.log_interval,
        valid_interval=args.valid_interval,
        save_dir=args.save_dir,
        save_model_name=args.model_name,
        fix_random_seed=args.fix_random_seed,
        compiled_test_prog=compiled_valid_prog,  #test_exe=valid_exe,
        test_pyreader=valid_pyreader,
        test_fetch_list=valid_fetch_list,
        test_metrics=valid_metrics)
234 235 236 237


if __name__ == "__main__":
    args = parse_args()
238 239
    # check whether the installed paddle is compiled with GPU
    check_cuda(args.use_gpu)
240
    logger.info(args)
D
dengkaipeng 已提交
241

242 243 244
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

D
dengkaipeng 已提交
245
    train(args)