train.py 8.9 KB
Newer Older
D
dengkaipeng 已提交
1
#  Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import os
import sys
import time
import argparse
import logging
import numpy as np
import paddle.fluid as fluid

from tools.train_utils import train_with_pyreader, train_without_pyreader
import models
25 26 27
from config import *
from datareader import get_reader
from metrics import get_metrics
28

D
dengkaipeng 已提交
29
logging.root.handlers = []
30 31 32 33 34 35 36 37
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser("Paddle Video train script")
    parser.add_argument(
S
SunGaofeng 已提交
38
        '--model_name',
39 40 41 42 43 44 45 46 47
        type=str,
        default='AttentionCluster',
        help='name of model to train.')
    parser.add_argument(
        '--config',
        type=str,
        default='configs/attention_cluster.txt',
        help='path to config file of model')
    parser.add_argument(
S
SunGaofeng 已提交
48
        '--batch_size',
49 50 51 52
        type=int,
        default=None,
        help='training batch size. None to use config file setting.')
    parser.add_argument(
S
SunGaofeng 已提交
53
        '--learning_rate',
54 55 56 57 58 59 60 61 62
        type=float,
        default=None,
        help='learning rate use for training. None to use config file setting.')
    parser.add_argument(
        '--pretrain',
        type=str,
        default=None,
        help='path to pretrain weights. None to use default weights path in  ~/.paddle/weights.'
    )
D
dengkaipeng 已提交
63
    parser.add_argument(
64
        '--resume',
D
dengkaipeng 已提交
65 66
        type=str,
        default=None,
D
dengkaipeng 已提交
67
        help='path to resume training based on previous checkpoints. '
S
SunGaofeng 已提交
68
        'None for not resuming any checkpoints.')
69
    parser.add_argument(
S
SunGaofeng 已提交
70
        '--use_gpu', type=bool, default=True, help='default use gpu.')
71
    parser.add_argument(
S
SunGaofeng 已提交
72
        '--no_use_pyreader',
73 74 75 76
        action='store_true',
        default=False,
        help='whether to use pyreader')
    parser.add_argument(
S
SunGaofeng 已提交
77
        '--no_memory_optimize',
78 79 80 81
        action='store_true',
        default=False,
        help='whether to use memory optimize in train')
    parser.add_argument(
S
SunGaofeng 已提交
82
        '--epoch_num',
83 84 85 86
        type=int,
        default=0,
        help='epoch number, 0 for read from config file')
    parser.add_argument(
S
SunGaofeng 已提交
87
        '--valid_interval',
88 89 90 91
        type=int,
        default=1,
        help='validation epoch interval, 0 for no validation.')
    parser.add_argument(
S
SunGaofeng 已提交
92
        '--save_dir',
93 94 95 96
        type=str,
        default='checkpoints',
        help='directory name to save train snapshoot')
    parser.add_argument(
S
SunGaofeng 已提交
97
        '--log_interval',
98 99 100 101 102 103 104
        type=int,
        default=10,
        help='mini-batch interval to log.')
    args = parser.parse_args()
    return args


D
dengkaipeng 已提交
105 106 107 108 109
def train(args):
    # parse config
    config = parse_config(args.config)
    train_config = merge_configs(config, 'train', vars(args))
    valid_config = merge_configs(config, 'valid', vars(args))
S
SunGaofeng 已提交
110 111
    logger.info("############### train config ###############")
    print_configs(train_config)
112 113
    train_model = models.get_model(args.model_name, train_config, mode='train')
    valid_model = models.get_model(args.model_name, valid_config, mode='valid')
D
dengkaipeng 已提交
114 115

    # build model
D
dengkaipeng 已提交
116
    startup = fluid.Program()
117
    train_prog = fluid.Program()
D
dengkaipeng 已提交
118
    with fluid.program_guard(train_prog, startup):
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
        with fluid.unique_name.guard():
            train_model.build_input(not args.no_use_pyreader)
            train_model.build_model()
            # for the input, has the form [data1, data2,..., label], so train_feeds[-1] is label
            train_feeds = train_model.feeds()
            train_feeds[-1].persistable = True
            # for the output of classification model, has the form [pred]
            train_outputs = train_model.outputs()
            for output in train_outputs:
                output.persistable = True
            train_loss = train_model.loss()
            train_loss.persistable = True
            # outputs, loss, label should be fetched, so set persistable to be true
            optimizer = train_model.optimizer()
            optimizer.minimize(train_loss)
            train_pyreader = train_model.pyreader()

    if not args.no_memory_optimize:
        fluid.memory_optimize(train_prog)

    valid_prog = fluid.Program()
D
dengkaipeng 已提交
140
    with fluid.program_guard(valid_prog, startup):
141 142 143 144 145 146 147 148 149 150
        with fluid.unique_name.guard():
            valid_model.build_input(not args.no_use_pyreader)
            valid_model.build_model()
            valid_feeds = valid_model.feeds()
            valid_outputs = valid_model.outputs()
            valid_loss = valid_model.loss()
            valid_pyreader = valid_model.pyreader()

    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
D
dengkaipeng 已提交
151
    exe.run(startup)
152

153
    if args.resume:
D
dengkaipeng 已提交
154
        # if resume weights is given, load resume weights directly
155 156
        assert os.path.exists(args.resume), \
                "Given resume weight dir {} not exist.".format(args.resume)
S
SunGaofeng 已提交
157

D
dengkaipeng 已提交
158 159
        def if_exist(var):
            return os.path.exists(os.path.join(args.resume, var.name))
S
SunGaofeng 已提交
160 161 162

        fluid.io.load_vars(
            exe, args.resume, predicate=if_exist, main_program=train_prog)
D
dengkaipeng 已提交
163 164 165 166 167 168 169 170
    else:
        # if not in resume mode, load pretrain weights
        if args.pretrain:
            assert os.path.exists(args.pretrain), \
                    "Given pretrain weight dir {} not exist.".format(args.pretrain)
        pretrain = args.pretrain or train_model.get_pretrain_weights()
        if pretrain:
            train_model.load_pretrain_params(exe, pretrain, train_prog, place)
171

D
dengkaipeng 已提交
172 173 174 175 176 177 178 179
    train_exe = fluid.ParallelExecutor(
        use_cuda=args.use_gpu,
        loss_name=train_loss.name,
        main_program=train_prog)
    valid_exe = fluid.ParallelExecutor(
        use_cuda=args.use_gpu,
        share_vars_from=train_exe,
        main_program=valid_prog)
180

D
dengkaipeng 已提交
181
    # get reader
182 183 184 185 186 187 188
    bs_denominator = 1
    if (not args.no_use_pyreader) and args.use_gpu:
        bs_denominator = train_config.TRAIN.num_gpus
    train_config.TRAIN.batch_size = int(train_config.TRAIN.batch_size /
                                        bs_denominator)
    valid_config.VALID.batch_size = int(valid_config.VALID.batch_size /
                                        bs_denominator)
189 190
    train_reader = get_reader(args.model_name.upper(), 'train', train_config)
    valid_reader = get_reader(args.model_name.upper(), 'valid', valid_config)
D
dengkaipeng 已提交
191 192

    # get metrics 
193 194
    train_metrics = get_metrics(args.model_name.upper(), 'train', train_config)
    valid_metrics = get_metrics(args.model_name.upper(), 'valid', valid_config)
D
dengkaipeng 已提交
195

196 197 198 199 200 201 202 203 204 205
    train_fetch_list = [train_loss.name] + [x.name for x in train_outputs
                                            ] + [train_feeds[-1].name]
    valid_fetch_list = [valid_loss.name] + [x.name for x in valid_outputs
                                            ] + [valid_feeds[-1].name]

    epochs = args.epoch_num or train_model.epoch_num()

    if args.no_use_pyreader:
        train_feeder = fluid.DataFeeder(place=place, feed_list=train_feeds)
        valid_feeder = fluid.DataFeeder(place=place, feed_list=valid_feeds)
S
SunGaofeng 已提交
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
        train_without_pyreader(
            exe,
            train_prog,
            train_exe,
            train_reader,
            train_feeder,
            train_fetch_list,
            train_metrics,
            epochs=epochs,
            log_interval=args.log_interval,
            valid_interval=args.valid_interval,
            save_dir=args.save_dir,
            save_model_name=args.model_name,
            test_exe=valid_exe,
            test_reader=valid_reader,
            test_feeder=valid_feeder,
            test_fetch_list=valid_fetch_list,
            test_metrics=valid_metrics)
224 225 226
    else:
        train_pyreader.decorate_paddle_reader(train_reader)
        valid_pyreader.decorate_paddle_reader(valid_reader)
S
SunGaofeng 已提交
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
        train_with_pyreader(
            exe,
            train_prog,
            train_exe,
            train_pyreader,
            train_fetch_list,
            train_metrics,
            epochs=epochs,
            log_interval=args.log_interval,
            valid_interval=args.valid_interval,
            save_dir=args.save_dir,
            save_model_name=args.model_name,
            test_exe=valid_exe,
            test_pyreader=valid_pyreader,
            test_fetch_list=valid_fetch_list,
            test_metrics=valid_metrics)
243 244 245 246 247


if __name__ == "__main__":
    args = parse_args()
    logger.info(args)
D
dengkaipeng 已提交
248

249 250 251
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

D
dengkaipeng 已提交
252
    train(args)