train.py 8.6 KB
Newer Older
D
dengkaipeng 已提交
1
#  Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import os
import sys
import time
import argparse
19
import ast
20 21 22 23
import logging
import numpy as np
import paddle.fluid as fluid

24
from utils.train_utils import train_with_pyreader
25
import models
26 27
from utils.config_utils import *
from reader import get_reader
28
from metrics import get_metrics
29
from utils.utility import check_cuda
30

D
dengkaipeng 已提交
31
logging.root.handlers = []
32 33 34 35 36 37 38 39
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser("Paddle Video train script")
    parser.add_argument(
S
SunGaofeng 已提交
40
        '--model_name',
41 42 43 44 45 46 47 48 49
        type=str,
        default='AttentionCluster',
        help='name of model to train.')
    parser.add_argument(
        '--config',
        type=str,
        default='configs/attention_cluster.txt',
        help='path to config file of model')
    parser.add_argument(
S
SunGaofeng 已提交
50
        '--batch_size',
51 52 53 54
        type=int,
        default=None,
        help='training batch size. None to use config file setting.')
    parser.add_argument(
S
SunGaofeng 已提交
55
        '--learning_rate',
56 57 58 59 60 61 62 63 64
        type=float,
        default=None,
        help='learning rate use for training. None to use config file setting.')
    parser.add_argument(
        '--pretrain',
        type=str,
        default=None,
        help='path to pretrain weights. None to use default weights path in  ~/.paddle/weights.'
    )
D
dengkaipeng 已提交
65
    parser.add_argument(
66
        '--resume',
D
dengkaipeng 已提交
67 68
        type=str,
        default=None,
D
dengkaipeng 已提交
69
        help='path to resume training based on previous checkpoints. '
S
SunGaofeng 已提交
70
        'None for not resuming any checkpoints.')
71
    parser.add_argument(
72 73 74 75
        '--use_gpu',
        type=ast.literal_eval,
        default=True,
        help='default use gpu.')
76
    parser.add_argument(
S
SunGaofeng 已提交
77
        '--no_memory_optimize',
78 79 80 81
        action='store_true',
        default=False,
        help='whether to use memory optimize in train')
    parser.add_argument(
82
        '--epoch',
83
        type=int,
84
        default=None,
85 86
        help='epoch number, 0 for read from config file')
    parser.add_argument(
S
SunGaofeng 已提交
87
        '--valid_interval',
88 89 90 91
        type=int,
        default=1,
        help='validation epoch interval, 0 for no validation.')
    parser.add_argument(
S
SunGaofeng 已提交
92
        '--save_dir',
93
        type=str,
94
        default=os.path.join('data', 'checkpoints'),
95 96
        help='directory name to save train snapshoot')
    parser.add_argument(
S
SunGaofeng 已提交
97
        '--log_interval',
98 99 100
        type=int,
        default=10,
        help='mini-batch interval to log.')
X
xiegegege 已提交
101
    parser.add_argument(
102
        '--fix_random_seed',
103
        type=ast.literal_eval,
X
xiegegege 已提交
104 105
        default=False,
        help='If set True, enable continuous evaluation job.')
106 107 108 109
    args = parser.parse_args()
    return args


D
dengkaipeng 已提交
110 111 112 113 114
def train(args):
    # parse config
    config = parse_config(args.config)
    train_config = merge_configs(config, 'train', vars(args))
    valid_config = merge_configs(config, 'valid', vars(args))
D
dengkaipeng 已提交
115
    print_configs(train_config, 'Train')
116 117
    train_model = models.get_model(args.model_name, train_config, mode='train')
    valid_model = models.get_model(args.model_name, valid_config, mode='valid')
D
dengkaipeng 已提交
118 119

    # build model
D
dengkaipeng 已提交
120
    startup = fluid.Program()
121
    train_prog = fluid.Program()
122
    if args.fix_random_seed:
X
xiegegege 已提交
123 124
        startup.random_seed = 1000
        train_prog.random_seed = 1000
D
dengkaipeng 已提交
125
    with fluid.program_guard(train_prog, startup):
126
        with fluid.unique_name.guard():
127
            train_model.build_input(use_pyreader=True)
128 129 130
            train_model.build_model()
            # for the input, has the form [data1, data2,..., label], so train_feeds[-1] is label
            train_feeds = train_model.feeds()
131 132 133 134
            train_fetch_list = train_model.fetches()
            train_loss = train_fetch_list[0]
            for item in train_fetch_list:
                item.persistable = True
135 136 137 138 139
            optimizer = train_model.optimizer()
            optimizer.minimize(train_loss)
            train_pyreader = train_model.pyreader()

    valid_prog = fluid.Program()
D
dengkaipeng 已提交
140
    with fluid.program_guard(valid_prog, startup):
141
        with fluid.unique_name.guard():
142
            valid_model.build_input(use_pyreader=True)
143 144
            valid_model.build_model()
            valid_feeds = valid_model.feeds()
145
            valid_fetch_list = valid_model.fetches()
146
            valid_pyreader = valid_model.pyreader()
147 148
            for item in valid_fetch_list:
                item.persistable = True
149 150 151

    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
D
dengkaipeng 已提交
152
    exe.run(startup)
153

154
    if args.resume:
D
dengkaipeng 已提交
155
        # if resume weights is given, load resume weights directly
156 157
        assert os.path.exists(args.resume), \
                "Given resume weight dir {} not exist.".format(args.resume)
S
SunGaofeng 已提交
158

159 160
        fluid.io.load_persistables(
            exe, '', main_program=train_prog, filename=args.resume)
D
dengkaipeng 已提交
161 162 163 164 165 166 167 168
    else:
        # if not in resume mode, load pretrain weights
        if args.pretrain:
            assert os.path.exists(args.pretrain), \
                    "Given pretrain weight dir {} not exist.".format(args.pretrain)
        pretrain = args.pretrain or train_model.get_pretrain_weights()
        if pretrain:
            train_model.load_pretrain_params(exe, pretrain, train_prog, place)
169

170 171
    build_strategy = fluid.BuildStrategy()
    build_strategy.enable_inplace = True
172 173
    if args.model_name in ['CTCN']:
        build_strategy.enable_sequential_execution = True
174 175
    #build_strategy.memory_optimize = True

176 177 178 179 180 181
    compiled_train_prog = fluid.compiler.CompiledProgram(
        train_prog).with_data_parallel(
            loss_name=train_loss.name, build_strategy=build_strategy)
    compiled_valid_prog = fluid.compiler.CompiledProgram(
        valid_prog).with_data_parallel(
            share_vars_from=compiled_train_prog, build_strategy=build_strategy)
182

D
dengkaipeng 已提交
183
    # get reader
184
    bs_denominator = 1
185 186 187 188 189 190 191 192 193 194 195 196 197
    if args.use_gpu:
        # check number of GPUs
        gpus = os.getenv("CUDA_VISIBLE_DEVICES", "")
        if gpus == "":
            pass
        else:
            gpus = gpus.split(",")
            num_gpus = len(gpus)
            assert num_gpus == train_config.TRAIN.num_gpus, \
                   "num_gpus({}) set by CUDA_VISIBLE_DEVICES" \
                   "shoud be the same as that" \
                   "set in {}({})".format(
                   num_gpus, args.config, train_config.TRAIN.num_gpus)
198
        bs_denominator = train_config.TRAIN.num_gpus
199

200 201 202 203
    train_config.TRAIN.batch_size = int(train_config.TRAIN.batch_size /
                                        bs_denominator)
    valid_config.VALID.batch_size = int(valid_config.VALID.batch_size /
                                        bs_denominator)
204 205
    train_reader = get_reader(args.model_name.upper(), 'train', train_config)
    valid_reader = get_reader(args.model_name.upper(), 'valid', valid_config)
D
dengkaipeng 已提交
206 207

    # get metrics 
208 209
    train_metrics = get_metrics(args.model_name.upper(), 'train', train_config)
    valid_metrics = get_metrics(args.model_name.upper(), 'valid', valid_config)
D
dengkaipeng 已提交
210

211
    epochs = args.epoch or train_model.epoch_num()
212

213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
    exe_places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
    train_pyreader.decorate_sample_list_generator(
        train_reader, places=exe_places)
    valid_pyreader.decorate_sample_list_generator(
        valid_reader, places=exe_places)

    train_with_pyreader(
        exe,
        train_prog,
        compiled_train_prog,  #train_exe,
        train_pyreader,
        train_fetch_list,
        train_metrics,
        epochs=epochs,
        log_interval=args.log_interval,
        valid_interval=args.valid_interval,
        save_dir=args.save_dir,
        save_model_name=args.model_name,
        fix_random_seed=args.fix_random_seed,
        compiled_test_prog=compiled_valid_prog,  #test_exe=valid_exe,
        test_pyreader=valid_pyreader,
        test_fetch_list=valid_fetch_list,
        test_metrics=valid_metrics)
236 237 238 239


if __name__ == "__main__":
    args = parse_args()
240 241
    # check whether the installed paddle is compiled with GPU
    check_cuda(args.use_gpu)
242
    logger.info(args)
D
dengkaipeng 已提交
243

244 245 246
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

D
dengkaipeng 已提交
247
    train(args)