train.py 9.9 KB
Newer Older
X
xiaoting 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15 16
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
D
Dun 已提交
17
import os
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38


def set_paddle_flags(flags):
    for key, value in flags.items():
        if os.environ.get(key, None) is None:
            os.environ[key] = str(value)


# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
# not take any effect. 
set_paddle_flags({
    'FLAGS_eager_delete_tensor_gb': 0,  # enable GC 
    # You can omit the following settings, because the default
    # value of FLAGS_memory_fraction_of_eager_deletion is 1,
    # and default value of FLAGS_fast_eager_deletion_mode is 1 
    'FLAGS_memory_fraction_of_eager_deletion': 1,
    'FLAGS_fast_eager_deletion_mode': 1,
    # Setting the default used gpu memory
    'FLAGS_fraction_of_gpu_memory_to_use': 0.98
})
D
Dun 已提交
39 40 41

import paddle
import paddle.fluid as fluid
42
from paddle.fluid import profiler
D
Dun 已提交
43 44 45 46 47
import numpy as np
import argparse
from reader import CityscapeDataset
import reader
import models
C
ccmeteorljh 已提交
48
import time
D
Dun 已提交
49 50 51
import contextlib
import paddle.fluid.profiler as profiler
import utility
D
Dun 已提交
52

D
Dun 已提交
53 54 55 56
parser = argparse.ArgumentParser()
add_arg = lambda *args: utility.add_arguments(*args, argparser=parser)

# yapf: disable
57
add_arg('batch_size',           int,    4,      "The number of images in each batch during training.")
D
Dun 已提交
58
add_arg('train_crop_size',      int,    769,    "Image crop size during training.")
59 60
add_arg('base_lr',              float,  0.001,  "The base learning rate for model training.")
add_arg('total_step',           int,    500000, "Number of the training step.")
D
Dun 已提交
61 62 63 64 65 66 67 68
add_arg('init_weights_path',    str,    None,   "Path of the initial weights in paddlepaddle format.")
add_arg('save_weights_path',    str,    None,   "Path of the saved weights during training.")
add_arg('dataset_path',         str,    None,   "Cityscape dataset path.")
add_arg('parallel',             bool,   True,   "using ParallelExecutor.")
add_arg('use_gpu',              bool,   True,   "Whether use GPU or CPU.")
add_arg('num_classes',          int,    19,     "Number of classes.")
add_arg('load_logit_layer',     bool,   True,   "Load last logit fc layer or not. If you are training with different number of classes, you should set to False.")
add_arg('memory_optimize',      bool,   True,   "Using memory optimizer.")
D
Dun 已提交
69
add_arg('norm_type',            str,    'bn',   "Normalization type, should be 'bn' or 'gn'.")
D
Dun 已提交
70 71
add_arg('profile',              bool,    False, "Enable profiler.")
add_arg('use_py_reader',        bool,    True,  "Use py reader.")
72
add_arg('use_multiprocessing',  bool,    False, "Use multiprocessing.")
73
add_arg("num_workers",          int,     8,     "The number of python processes used to read and preprocess data.")
74 75
# NOTE: args for profiler, used for benchmark
add_arg("profiler_path",        str,     '/tmp/profile_file2',  "the profiler output file path. (used for benchmark)")
D
Dun 已提交
76 77 78
parser.add_argument(
    '--enable_ce',
    action='store_true',
79
    help='If set, run the task with continuous evaluation logs. Users can ignore this agument.')
D
Dun 已提交
80 81 82 83 84
#yapf: enable

@contextlib.contextmanager
def profile_context(profile=True):
    if profile:
85
        with profiler.profiler('All', 'total', args.profiler_path):
D
Dun 已提交
86 87 88
            yield
    else:
        yield
D
Dun 已提交
89 90

def load_model():
D
Dun 已提交
91 92 93 94 95 96 97
    if os.path.isdir(args.init_weights_path):
        load_vars = [
            x for x in tp.list_vars()
            if isinstance(x, fluid.framework.Parameter) and x.name.find('logit') ==
            -1
        ]
        if args.load_logit_layer:
D
Dun 已提交
98 99 100
            fluid.io.load_params(
                exe, dirname=args.init_weights_path, main_program=tp)
        else:
D
Dun 已提交
101
            fluid.io.load_vars(exe, dirname=args.init_weights_path, vars=load_vars)
D
Dun 已提交
102
    else:
D
Dun 已提交
103 104 105 106 107 108
        fluid.io.load_params(
            exe,
            dirname="",
            filename=args.init_weights_path,
            main_program=tp)

D
Dun 已提交
109 110 111


def save_model():
D
Dun 已提交
112 113 114
    assert not os.path.isfile(args.save_weights_path)
    fluid.io.save_params(
        exe, dirname=args.save_weights_path, main_program=tp)
D
Dun 已提交
115 116 117


def loss(logit, label):
D
Dun 已提交
118 119 120 121
    label_nignore = fluid.layers.less_than(
        label.astype('float32'),
        fluid.layers.assign(np.array([num_classes], 'float32')),
        force_cpu=False).astype('float32')
D
Dun 已提交
122 123 124 125 126
    logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
    logit = fluid.layers.reshape(logit, [-1, num_classes])
    label = fluid.layers.reshape(label, [-1, 1])
    label = fluid.layers.cast(label, 'int64')
    label_nignore = fluid.layers.reshape(label_nignore, [-1, 1])
127 128
    logit = fluid.layers.softmax(logit, use_cudnn=False)
    loss = fluid.layers.cross_entropy(logit, label, ignore_index=255)
D
Dun 已提交
129 130
    label_nignore.stop_gradient = True
    label.stop_gradient = True
D
Dun 已提交
131 132 133 134
    return loss, label_nignore


args = parser.parse_args()
D
Dun 已提交
135
utility.print_arguments(args)
136
utility.check_gpu(args.use_gpu)
D
Dun 已提交
137 138 139 140

models.clean()
models.bn_momentum = 0.9997
models.dropout_keep_prop = 0.9
D
Dun 已提交
141
models.label_number = args.num_classes
D
Dun 已提交
142
models.default_norm_type = args.norm_type
D
Dun 已提交
143 144 145 146
deeplabv3p = models.deeplabv3p

sp = fluid.Program()
tp = fluid.Program()
Z
add ce  
zhengya01 已提交
147 148 149 150 151 152 153

# only for ce
if args.enable_ce:
    SEED = 102
    sp.random_seed = SEED
    tp.random_seed = SEED

D
Dun 已提交
154 155 156 157 158
crop_size = args.train_crop_size
batch_size = args.batch_size
image_shape = [crop_size, crop_size]
reader.default_config['crop_size'] = crop_size
reader.default_config['shuffle'] = True
D
Dun 已提交
159
num_classes = args.num_classes
D
Dun 已提交
160 161 162 163 164 165
weight_decay = 0.00004

base_lr = args.base_lr
total_step = args.total_step

with fluid.program_guard(tp, sp):
D
Dun 已提交
166
    if args.use_py_reader:
167
        batch_size_each = batch_size // utility.get_device_count()
D
Dun 已提交
168 169 170 171 172 173 174 175
        py_reader = fluid.layers.py_reader(capacity=64,
                                        shapes=[[batch_size_each, 3] + image_shape, [batch_size_each] + image_shape],
                                        dtypes=['float32', 'int32'])
        img, label = fluid.layers.read_file(py_reader)
    else:
        img = fluid.layers.data(
            name='img', shape=[3] + image_shape, dtype='float32')
        label = fluid.layers.data(name='label', shape=image_shape, dtype='int32')
D
Dun 已提交
176 177 178 179 180 181 182 183 184 185
    logit = deeplabv3p(img)
    pred = fluid.layers.argmax(logit, axis=1).astype('int32')
    loss, mask = loss(logit, label)
    lr = fluid.layers.polynomial_decay(
        base_lr, total_step, end_learning_rate=0, power=0.9)
    area = fluid.layers.elementwise_max(
        fluid.layers.reduce_mean(mask),
        fluid.layers.assign(np.array(
            [0.1], dtype=np.float32)))
    loss_mean = fluid.layers.reduce_mean(loss) / area
B
Bai Yifan 已提交
186
    loss_mean.persistable = True
D
Dun 已提交
187 188 189 190 191

    opt = fluid.optimizer.Momentum(
        lr,
        momentum=0.9,
        regularization=fluid.regularizer.L2DecayRegularizer(
D
Dun 已提交
192 193 194 195 196 197 198 199
            regularization_coeff=weight_decay))
    optimize_ops, params_grads = opt.minimize(loss_mean, startup_program=sp)
    # ir memory optimizer has some issues, we need to seed grad persistable to
    # avoid this issue
    for p,g in params_grads: g.persistable = True


exec_strategy = fluid.ExecutionStrategy()
200
exec_strategy.num_threads = utility.get_device_count()
D
Dun 已提交
201 202 203 204 205
exec_strategy.num_iteration_per_drop_scope = 100
build_strategy = fluid.BuildStrategy()
if args.memory_optimize:
    build_strategy.fuse_relu_depthwise_conv = True
    build_strategy.enable_inplace = True
D
Dun 已提交
206 207 208 209 210 211 212 213

place = fluid.CPUPlace()
if args.use_gpu:
    place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(sp)

if args.init_weights_path:
214
    print("load from:", args.init_weights_path)
D
Dun 已提交
215 216
    load_model()

D
Dun 已提交
217
dataset = reader.CityscapeDataset(args.dataset_path, 'train')
D
Dun 已提交
218 219

if args.parallel:
D
Dun 已提交
220 221 222 223 224
    binary = fluid.compiler.CompiledProgram(tp).with_data_parallel(
        loss_name=loss_mean.name,
        build_strategy=build_strategy,
        exec_strategy=exec_strategy)
else:
225
    binary = fluid.compiler.CompiledProgram(tp)
D
Dun 已提交
226 227

if args.use_py_reader:
228
    assert(batch_size % utility.get_device_count() == 0)
D
Dun 已提交
229 230
    def data_gen():
        batches = dataset.get_batch_generator(
231 232
            batch_size // utility.get_device_count(),
            total_step * utility.get_device_count(),
233
            use_multiprocessing=args.use_multiprocessing, num_workers=args.num_workers)
D
Dun 已提交
234
        for b in batches:
235
            yield b[0], b[1]
D
Dun 已提交
236 237 238
    py_reader.decorate_tensor_provider(data_gen)
    py_reader.start()
else:
239
    batches = dataset.get_batch_generator(batch_size, total_step, use_multiprocessing=True, num_workers=args.num_workers)
Z
add ce  
zhengya01 已提交
240 241 242 243
total_time = 0.0
epoch_idx = 0
train_loss = 0

D
Dun 已提交
244 245 246 247 248
with profile_context(args.profile):
    for i in range(total_step):
        epoch_idx += 1
        begin_time = time.time()
        if not args.use_py_reader:
249
            imgs, labels, names = next(batches)
D
Dun 已提交
250 251 252 253 254 255 256 257
            train_loss, = exe.run(binary,
                             feed={'img': imgs,
                                   'label': labels}, fetch_list=[loss_mean])
        else:
            train_loss, = exe.run(binary, fetch_list=[loss_mean])
        train_loss = np.mean(train_loss)
        end_time = time.time()
        total_time += end_time - begin_time
258
        
D
Dun 已提交
259 260 261
        if i % 100 == 0:
            print("Model is saved to", args.save_weights_path)
            save_model()
262 263
        print("step {:d}, loss: {:.6f}, step_time_cost: {:.3f} s".format(
            i, train_loss, end_time - begin_time))
D
Dun 已提交
264 265 266

print("Training done. Model is saved to", args.save_weights_path)
save_model()
Z
add ce  
zhengya01 已提交
267 268

if args.enable_ce:
269
    gpu_num = utility.get_device_count()
Z
add ce  
zhengya01 已提交
270
    print("kpis\teach_pass_duration_card%s\t%s" %
D
Dun 已提交
271 272
          (gpu_num, total_time / epoch_idx))
    print("kpis\ttrain_loss_card%s\t%s" % (gpu_num, train_loss))