train.py 4.9 KB
Newer Older
Q
qingqing01 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Q
qingqing01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import random
import argparse
import contextlib
import time

import paddle
import paddle.fluid as fluid
Q
qingqing01 已提交
27
from paddle.static import InputSpec as Input
Q
qingqing01 已提交
28

29
from check import check_gpu, check_version
Q
qingqing01 已提交
30
from cyclegan import Generator, Discriminator, GeneratorCombine, GLoss, DLoss
31
import data as data
Q
qingqing01 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50

step_per_epoch = 2974


def opt(parameters):
    lr_base = 0.0002
    bounds = [100, 120, 140, 160, 180]
    lr = [1., 0.8, 0.6, 0.4, 0.2, 0.1]
    bounds = [i * step_per_epoch for i in bounds]
    lr = [i * lr_base for i in lr]
    optimizer = fluid.optimizer.Adam(
        learning_rate=fluid.layers.piecewise_decay(
            boundaries=bounds, values=lr),
        parameter_list=parameters,
        beta1=0.5)
    return optimizer


def main():
Q
qingqing01 已提交
51
    place = paddle.set_device(FLAGS.device)
Q
qingqing01 已提交
52 53
    fluid.enable_dygraph(place) if FLAGS.dynamic else None

Q
qingqing01 已提交
54 55 56 57 58 59
    im_shape = [None, 3, 256, 256]
    input_A = Input(im_shape, 'float32', 'input_A')
    input_B = Input(im_shape, 'float32', 'input_B')
    fake_A = Input(im_shape, 'float32', 'fake_A')
    fake_B = Input(im_shape, 'float32', 'fake_B')

Q
qingqing01 已提交
60 61 62 63 64 65
    # Generators
    g_AB = Generator()
    g_BA = Generator()
    d_A = Discriminator()
    d_B = Discriminator()

Q
qingqing01 已提交
66 67 68 69 70 71 72 73
    g = paddle.Model(
        GeneratorCombine(g_AB, g_BA, d_A, d_B), inputs=[input_A, input_B])
    g_AB = paddle.Model(g_AB, [input_A])
    g_BA = paddle.Model(g_BA, [input_B])

    # Discriminators
    d_A = paddle.Model(d_A, [input_B, fake_B])
    d_B = paddle.Model(d_B, [input_A, fake_A])
Q
qingqing01 已提交
74 75 76 77 78 79 80 81 82

    da_params = d_A.parameters()
    db_params = d_B.parameters()
    g_params = g_AB.parameters() + g_BA.parameters()

    da_optimizer = opt(da_params)
    db_optimizer = opt(db_params)
    g_optimizer = opt(g_params)

Q
qingqing01 已提交
83 84
    g_AB.prepare()
    g_BA.prepare()
Q
qingqing01 已提交
85

Q
qingqing01 已提交
86 87 88
    g.prepare(g_optimizer, GLoss())
    d_A.prepare(da_optimizer, DLoss())
    d_B.prepare(db_optimizer, DLoss())
Q
qingqing01 已提交
89

Q
qingqing01 已提交
90 91 92
    if FLAGS.resume:
        g.load(FLAGS.resume)

D
dengkaipeng 已提交
93
    loader_A = paddle.io.DataLoader(
Q
qingqing01 已提交
94 95 96 97 98
        data.DataA(),
        places=place,
        shuffle=True,
        return_list=True,
        batch_size=FLAGS.batch_size)
D
dengkaipeng 已提交
99
    loader_B = paddle.io.DataLoader(
Q
qingqing01 已提交
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
        data.DataB(),
        places=place,
        shuffle=True,
        return_list=True,
        batch_size=FLAGS.batch_size)

    A_pool = data.ImagePool()
    B_pool = data.ImagePool()

    for epoch in range(FLAGS.epoch):
        for i, (data_A, data_B) in enumerate(zip(loader_A, loader_B)):
            data_A = data_A[0][0] if not FLAGS.dynamic else data_A[0]
            data_B = data_B[0][0] if not FLAGS.dynamic else data_B[0]
            start = time.time()

115 116 117
            fake_B = g_AB.test_batch(data_A)[0]
            fake_A = g_BA.test_batch(data_B)[0]
            g_loss = g.train_batch([data_A, data_B])[0]
Q
qingqing01 已提交
118
            fake_pb = B_pool.get(fake_B)
119
            da_loss = d_A.train_batch([data_B, fake_pb])[0]
Q
qingqing01 已提交
120 121

            fake_pa = A_pool.get(fake_A)
122
            db_loss = d_B.train_batch([data_A, fake_pa])[0]
Q
qingqing01 已提交
123 124 125

            t = time.time() - start
            if i % 20 == 0:
Q
qingqing01 已提交
126
                print("epoch: {} | step: {:3d} | g_loss: {:.4f} | " \
Q
qingqing01 已提交
127 128 129 130 131 132 133 134
                      "da_loss: {:.4f} | db_loss: {:.4f} | s/step {:.4f}".
                      format(epoch, i, g_loss[0], da_loss[0], db_loss[0], t))
        g.save('{}/{}'.format(FLAGS.checkpoint_path, epoch))


if __name__ == "__main__":
    parser = argparse.ArgumentParser("CycleGAN Training on Cityscapes")
    parser.add_argument(
135
        "-d", "--dynamic", action='store_true', help="Enable dygraph mode")
Q
qingqing01 已提交
136
    parser.add_argument(
Q
qingqing01 已提交
137 138 139 140 141
        "-p",
        "--device",
        type=str,
        default='gpu',
        help="device to use, gpu or cpu")
Q
qingqing01 已提交
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
    parser.add_argument(
        "-e", "--epoch", default=200, type=int, help="Epoch number")
    parser.add_argument(
        "-b", "--batch_size", default=1, type=int, help="batch size")
    parser.add_argument(
        "-o",
        "--checkpoint_path",
        type=str,
        default='checkpoint',
        help="path to save checkpoint")
    parser.add_argument(
        "-r",
        "--resume",
        default=None,
        type=str,
        help="checkpoint path to resume")
    FLAGS = parser.parse_args()
Q
qingqing01 已提交
159
    print(FLAGS)
Q
qingqing01 已提交
160 161
    check_gpu(str.lower(FLAGS.device) == 'gpu')
    check_version()
Q
qingqing01 已提交
162
    main()