train.py 8.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

X
xiaoting 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import sys
import paddle
import argparse
import functools
import time
import numpy as np
from scipy.misc import imsave
import paddle.fluid as fluid
import data_reader
from utility import add_arguments, print_arguments, ImagePool
from trainer import *
from paddle.fluid.dygraph.base import to_variable
import six
parser = argparse.ArgumentParser(description=__doc__)
D
Divano 已提交
34
parser.add_argument("--ce", action="store_true", help="run ce")
X
xiaoting 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size',        int,   1,          "Minibatch size.")
add_arg('epoch',             int,   200,        "The number of epoched to be trained.")
add_arg('output',            str,   "./output_0", "The directory the model and the test result to be saved to.")
add_arg('init_model',        str,   None,       "The init model file of directory.")
add_arg('save_checkpoints',  bool,  True,       "Whether to save checkpoints.")
# yapf: enable

lambda_A = 10.0
lambda_B = 10.0
lambda_identity = 0.5
tep_per_epoch = 2974

49

50
def optimizer_setting(parameters):
51
    lr = 0.0002
X
xiaoting 已提交
52 53 54 55
    optimizer = fluid.optimizer.Adam(
        learning_rate=fluid.layers.piecewise_decay(
            boundaries=[
                100 * step_per_epoch, 120 * step_per_epoch,
56
                140 * step_per_epoch, 160 * step_per_epoch, 180 * step_per_epoch
X
xiaoting 已提交
57
            ],
58
            values=[lr, lr * 0.8, lr * 0.6, lr * 0.4, lr * 0.2, lr * 0.1]),
59
        parameter_list=parameters,
60
        beta1=0.5)
X
xiaoting 已提交
61
    return optimizer
62 63


X
xiaoting 已提交
64 65 66 67 68 69
def train(args):
    with fluid.dygraph.guard():
        max_images_num = data_reader.max_images_num()
        shuffle = True
        data_shape = [-1] + data_reader.image_shape()
        print(data_shape)
D
Divano 已提交
70 71 72 73 74 75 76 77
        if args.ce:
            print("ce mode")
            seed = 33
            random.seed(seed)
            np.random.seed(seed)
            fluid.default_startup_program().random_seed = seed
            fluid.default_main_program().random_seed = seed
            shuffle = False
X
xiaoting 已提交
78 79 80 81 82 83 84 85 86 87

        A_pool = ImagePool()
        B_pool = ImagePool()
        A_reader = paddle.batch(
            data_reader.a_reader(shuffle=shuffle), args.batch_size)()
        B_reader = paddle.batch(
            data_reader.b_reader(shuffle=shuffle), args.batch_size)()
        A_test_reader = data_reader.a_test_reader()
        B_test_reader = data_reader.b_test_reader()

88
        cycle_gan = Cycle_Gan(input_channel=data_shape[1], istrain=True)
X
xiaoting 已提交
89 90 91

        losses = [[], []]
        t_time = 0
92 93 94 95 96 97 98 99

        vars_G = cycle_gan.build_generator_resnet_9blocks_a.parameters() + cycle_gan.build_generator_resnet_9blocks_b.parameters()
        vars_da = cycle_gan.build_gen_discriminator_a.parameters()
        vars_db = cycle_gan.build_gen_discriminator_b.parameters()

        optimizer1 = optimizer_setting(vars_G)
        optimizer2 = optimizer_setting(vars_da)
        optimizer3 = optimizer_setting(vars_db)
X
xiaoting 已提交
100 101 102 103 104 105 106 107 108

        for epoch in range(args.epoch):
            batch_id = 0
            for i in range(max_images_num):

                data_A = next(A_reader)
                data_B = next(B_reader)

                s_time = time.time()
109 110 111 112
                data_A = np.array(
                    [data_A[0].reshape(3, 256, 256)]).astype("float32")
                data_B = np.array(
                    [data_B[0].reshape(3, 256, 256)]).astype("float32")
X
xiaoting 已提交
113 114 115 116
                data_A = to_variable(data_A)
                data_B = to_variable(data_B)

                # optimize the g_A network
117 118
                fake_A, fake_B, cyc_A, cyc_B, g_A_loss, g_B_loss, idt_loss_A, idt_loss_B, cyc_A_loss, cyc_B_loss, g_loss = cycle_gan(
                    data_A, data_B, True, False, False)
X
xiaoting 已提交
119 120 121 122 123

                g_loss_out = g_loss.numpy()

                g_loss.backward()

124
                optimizer1.minimize(g_loss)
X
xiaoting 已提交
125 126 127
                cycle_gan.clear_gradients()

                fake_pool_B = B_pool.pool_image(fake_B).numpy()
128 129
                fake_pool_B = np.array(
                    [fake_pool_B[0].reshape(3, 256, 256)]).astype("float32")
X
xiaoting 已提交
130 131 132
                fake_pool_B = to_variable(fake_pool_B)

                fake_pool_A = A_pool.pool_image(fake_A).numpy()
133 134
                fake_pool_A = np.array(
                    [fake_pool_A[0].reshape(3, 256, 256)]).astype("float32")
X
xiaoting 已提交
135 136 137
                fake_pool_A = to_variable(fake_pool_A)

                # optimize the d_A network
138 139
                rec_B, fake_pool_rec_B = cycle_gan(data_B, fake_pool_B, False,
                                                   True, False)
X
xiaoting 已提交
140
                d_loss_A = (fluid.layers.square(fake_pool_rec_B) +
141
                            fluid.layers.square(rec_B - 1)) / 2.0
X
xiaoting 已提交
142 143 144
                d_loss_A = fluid.layers.reduce_mean(d_loss_A)

                d_loss_A.backward()
145
                optimizer2.minimize(d_loss_A)
X
xiaoting 已提交
146 147 148 149
                cycle_gan.clear_gradients()

                # optimize the d_B network

150 151
                rec_A, fake_pool_rec_A = cycle_gan(data_A, fake_pool_A, False,
                                                   False, True)
X
xiaoting 已提交
152
                d_loss_B = (fluid.layers.square(fake_pool_rec_A) +
153
                            fluid.layers.square(rec_A - 1)) / 2.0
X
xiaoting 已提交
154 155 156
                d_loss_B = fluid.layers.reduce_mean(d_loss_B)

                d_loss_B.backward()
157
                optimizer3.minimize(d_loss_B)
X
xiaoting 已提交
158 159 160 161 162 163

                cycle_gan.clear_gradients()

                batch_time = time.time() - s_time
                t_time += batch_time
                print(
164 165 166 167 168 169 170 171 172 173
                    "epoch{}; batch{}; g_loss:{}; d_A_loss: {}; d_B_loss:{} ; \n g_A_loss: {}; g_A_cyc_loss: {}; g_A_idt_loss: {}; g_B_loss: {}; g_B_cyc_loss:  {}; g_B_idt_loss: {};Batch_time_cost: {}".
                    format(epoch, batch_id, g_loss_out[0],
                           d_loss_A.numpy()[0],
                           d_loss_B.numpy()[0],
                           g_A_loss.numpy()[0],
                           cyc_A_loss.numpy()[0],
                           idt_loss_A.numpy()[0],
                           g_B_loss.numpy()[0],
                           cyc_B_loss.numpy()[0],
                           idt_loss_B.numpy()[0], batch_time))
X
xiaoting 已提交
174 175 176 177 178 179
                with open('logging_train.txt', 'a') as log_file:
                    now = time.strftime("%c")
                    log_file.write(
                    "time: {}; epoch{}; batch{}; d_A_loss: {}; g_A_loss: {}; \
                    g_A_cyc_loss: {}; g_A_idt_loss: {}; d_B_loss: {}; \
                    g_B_loss: {}; g_B_cyc_loss: {}; g_B_idt_loss: {}; \
180 181
                    Batch_time_cost: {}\n"
                                          .format(now, epoch, \
X
xiaoting 已提交
182 183 184 185 186 187 188
                        batch_id, d_loss_A[0], g_A_loss[ 0], cyc_A_loss[0], \
                        idt_loss_A[0], d_loss_B[0], g_A_loss[0], \
                        cyc_B_loss[0], idt_loss_B[0], batch_time))
                losses[0].append(g_A_loss[0])
                losses[1].append(d_loss_A[0])
                sys.stdout.flush()
                batch_id += 1
D
Divano 已提交
189 190 191 192 193 194 195
                if args.ce and batch_id == 500:
                    print("kpis\tg_loss\t%0.3f" % g_loss_out[0])
                    print("kpis\tg_A_loss\t%0.3f" % g_A_loss.numpy()[0])
                    print("kpis\tg_B_loss\t%0.3f" % g_B_loss.numpy()[0])
                    print("kpis\td_A_loss\t%0.3f" % d_loss_A.numpy()[0])
                    print("kpis\td_B_loss\t%0.3f" % d_loss_B.numpy()[0])
                    break
X
xiaoting 已提交
196 197

            if args.save_checkpoints:
198
                fluid.save_dygraph(
199 200 201
                    cycle_gan.state_dict(),
                    args.output + "/checkpoints/{}".format(epoch))

X
xiaoting 已提交
202 203 204 205 206

if __name__ == "__main__":
    args = parser.parse_args()
    print_arguments(args)
    train(args)