train.py 8.7 KB
Newer Older
D
dengkaipeng 已提交
1
#  Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import os
import sys
import time
import argparse
import logging
import numpy as np
import paddle.fluid as fluid

from tools.train_utils import train_with_pyreader, train_without_pyreader
import models
25 26 27
from config import *
from datareader import get_reader
from metrics import get_metrics
28

D
dengkaipeng 已提交
29
logging.root.handlers = []
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser("Paddle Video train script")
    parser.add_argument(
        '--model-name',
        type=str,
        default='AttentionCluster',
        help='name of model to train.')
    parser.add_argument(
        '--config',
        type=str,
        default='configs/attention_cluster.txt',
        help='path to config file of model')
    parser.add_argument(
        '--batch-size',
        type=int,
        default=None,
        help='training batch size. None to use config file setting.')
    parser.add_argument(
        '--learning-rate',
        type=float,
        default=None,
        help='learning rate use for training. None to use config file setting.')
    parser.add_argument(
        '--pretrain',
        type=str,
        default=None,
        help='path to pretrain weights. None to use default weights path in  ~/.paddle/weights.'
    )
D
dengkaipeng 已提交
63
    parser.add_argument(
64
        '--resume',
D
dengkaipeng 已提交
65 66
        type=str,
        default=None,
D
dengkaipeng 已提交
67 68
        help='path to resume training based on previous checkpoints. '
             'None for not resuming any checkpoints.'
D
dengkaipeng 已提交
69
    )
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
    parser.add_argument(
        '--use-gpu', type=bool, default=True, help='default use gpu.')
    parser.add_argument(
        '--no-use-pyreader',
        action='store_true',
        default=False,
        help='whether to use pyreader')
    parser.add_argument(
        '--no-memory-optimize',
        action='store_true',
        default=False,
        help='whether to use memory optimize in train')
    parser.add_argument(
        '--epoch-num',
        type=int,
        default=0,
        help='epoch number, 0 for read from config file')
    parser.add_argument(
        '--valid-interval',
        type=int,
        default=1,
        help='validation epoch interval, 0 for no validation.')
    parser.add_argument(
        '--save-dir',
        type=str,
        default='checkpoints',
        help='directory name to save train snapshoot')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        help='mini-batch interval to log.')
    args = parser.parse_args()
    return args


D
dengkaipeng 已提交
106 107 108 109 110
def train(args):
    # parse config
    config = parse_config(args.config)
    train_config = merge_configs(config, 'train', vars(args))
    valid_config = merge_configs(config, 'valid', vars(args))
111 112
    train_model = models.get_model(args.model_name, train_config, mode='train')
    valid_model = models.get_model(args.model_name, valid_config, mode='valid')
D
dengkaipeng 已提交
113 114

    # build model
D
dengkaipeng 已提交
115
    startup = fluid.Program()
116
    train_prog = fluid.Program()
D
dengkaipeng 已提交
117
    with fluid.program_guard(train_prog, startup):
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
        with fluid.unique_name.guard():
            train_model.build_input(not args.no_use_pyreader)
            train_model.build_model()
            # for the input, has the form [data1, data2,..., label], so train_feeds[-1] is label
            train_feeds = train_model.feeds()
            train_feeds[-1].persistable = True
            # for the output of classification model, has the form [pred]
            train_outputs = train_model.outputs()
            for output in train_outputs:
                output.persistable = True
            train_loss = train_model.loss()
            train_loss.persistable = True
            # outputs, loss, label should be fetched, so set persistable to be true
            optimizer = train_model.optimizer()
            optimizer.minimize(train_loss)
            train_pyreader = train_model.pyreader()

    if not args.no_memory_optimize:
        fluid.memory_optimize(train_prog)

    valid_prog = fluid.Program()
D
dengkaipeng 已提交
139
    with fluid.program_guard(valid_prog, startup):
140 141 142 143 144 145 146 147 148 149
        with fluid.unique_name.guard():
            valid_model.build_input(not args.no_use_pyreader)
            valid_model.build_model()
            valid_feeds = valid_model.feeds()
            valid_outputs = valid_model.outputs()
            valid_loss = valid_model.loss()
            valid_pyreader = valid_model.pyreader()

    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
D
dengkaipeng 已提交
150
    exe.run(startup)
151

152
    if args.resume:
D
dengkaipeng 已提交
153
        # if resume weights is given, load resume weights directly
154 155
        assert os.path.exists(args.resume), \
                "Given resume weight dir {} not exist.".format(args.resume)
D
dengkaipeng 已提交
156 157 158
        def if_exist(var):
            return os.path.exists(os.path.join(args.resume, var.name))
        fluid.io.load_vars(exe, args.resume, predicate=if_exist, main_program=train_prog)
D
dengkaipeng 已提交
159 160 161 162 163 164 165 166
    else:
        # if not in resume mode, load pretrain weights
        if args.pretrain:
            assert os.path.exists(args.pretrain), \
                    "Given pretrain weight dir {} not exist.".format(args.pretrain)
        pretrain = args.pretrain or train_model.get_pretrain_weights()
        if pretrain:
            train_model.load_pretrain_params(exe, pretrain, train_prog, place)
167

D
dengkaipeng 已提交
168 169 170 171 172 173 174 175
    train_exe = fluid.ParallelExecutor(
        use_cuda=args.use_gpu,
        loss_name=train_loss.name,
        main_program=train_prog)
    valid_exe = fluid.ParallelExecutor(
        use_cuda=args.use_gpu,
        share_vars_from=train_exe,
        main_program=valid_prog)
176

D
dengkaipeng 已提交
177
    # get reader
178 179 180 181 182 183 184
    bs_denominator = 1
    if (not args.no_use_pyreader) and args.use_gpu:
        bs_denominator = train_config.TRAIN.num_gpus
    train_config.TRAIN.batch_size = int(train_config.TRAIN.batch_size /
                                        bs_denominator)
    valid_config.VALID.batch_size = int(valid_config.VALID.batch_size /
                                        bs_denominator)
185 186
    train_reader = get_reader(args.model_name.upper(), 'train', train_config)
    valid_reader = get_reader(args.model_name.upper(), 'valid', valid_config)
D
dengkaipeng 已提交
187 188

    # get metrics 
189 190
    train_metrics = get_metrics(args.model_name.upper(), 'train', train_config)
    valid_metrics = get_metrics(args.model_name.upper(), 'valid', valid_config)
D
dengkaipeng 已提交
191

192 193 194 195 196 197 198 199 200 201
    train_fetch_list = [train_loss.name] + [x.name for x in train_outputs
                                            ] + [train_feeds[-1].name]
    valid_fetch_list = [valid_loss.name] + [x.name for x in valid_outputs
                                            ] + [valid_feeds[-1].name]

    epochs = args.epoch_num or train_model.epoch_num()

    if args.no_use_pyreader:
        train_feeder = fluid.DataFeeder(place=place, feed_list=train_feeds)
        valid_feeder = fluid.DataFeeder(place=place, feed_list=valid_feeds)
D
dengkaipeng 已提交
202 203 204 205 206
        train_without_pyreader(exe, train_prog, train_exe, train_reader, train_feeder,
                               train_fetch_list, train_metrics, epochs = epochs,
                               log_interval = args.log_interval, valid_interval = args.valid_interval,
                               save_dir = args.save_dir, save_model_name = args.model_name,
                               test_exe = valid_exe, test_reader = valid_reader, test_feeder = valid_feeder,
207 208 209 210
                               test_fetch_list = valid_fetch_list, test_metrics = valid_metrics)
    else:
        train_pyreader.decorate_paddle_reader(train_reader)
        valid_pyreader.decorate_paddle_reader(valid_reader)
D
dengkaipeng 已提交
211 212 213 214 215
        train_with_pyreader(exe, train_prog, train_exe, train_pyreader, train_fetch_list, train_metrics,
                            epochs = epochs, log_interval = args.log_interval,
                            valid_interval = args.valid_interval,
                            save_dir = args.save_dir, save_model_name = args.model_name,
                            test_exe = valid_exe, test_pyreader = valid_pyreader,
216 217 218 219 220 221
                            test_fetch_list = valid_fetch_list, test_metrics = valid_metrics)


if __name__ == "__main__":
    args = parse_args()
    logger.info(args)
D
dengkaipeng 已提交
222

223 224 225
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

D
dengkaipeng 已提交
226
    train(args)