train.py 6.7 KB
Newer Older
X
xiaoting 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import sys
import paddle
import argparse
import functools
import time
import numpy as np
from scipy.misc import imsave
import paddle.fluid as fluid
import data_reader
from utility import add_arguments, print_arguments, ImagePool
from trainer import *
from paddle.fluid.dygraph.base import to_variable
import six
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size',        int,   1,          "Minibatch size.")
add_arg('epoch',             int,   200,        "The number of epoched to be trained.")
add_arg('output',            str,   "./output_0", "The directory the model and the test result to be saved to.")
add_arg('init_model',        str,   None,       "The init model file of directory.")
add_arg('save_checkpoints',  bool,  True,       "Whether to save checkpoints.")
# yapf: enable

lambda_A = 10.0
lambda_B = 10.0
lambda_identity = 0.5
tep_per_epoch = 2974

def optimizer_setting():
    lr=0.0002
    optimizer = fluid.optimizer.Adam(
        learning_rate=fluid.layers.piecewise_decay(
            boundaries=[
                100 * step_per_epoch, 120 * step_per_epoch,
                140 * step_per_epoch, 160 * step_per_epoch,
                180 * step_per_epoch
            ],
            values=[
                lr , lr * 0.8, lr * 0.6, lr * 0.4, lr * 0.2, lr * 0.1
            ]),
        beta1=0.5)    
    return optimizer
def train(args):
    with fluid.dygraph.guard():
        max_images_num = data_reader.max_images_num()
        shuffle = True
        data_shape = [-1] + data_reader.image_shape()
        print(data_shape)

        A_pool = ImagePool()
        B_pool = ImagePool()

        A_reader = paddle.batch(
            data_reader.a_reader(shuffle=shuffle), args.batch_size)()
        B_reader = paddle.batch(
            data_reader.b_reader(shuffle=shuffle), args.batch_size)()
        A_test_reader = data_reader.a_test_reader()
        B_test_reader = data_reader.b_test_reader()

        cycle_gan = Cycle_Gan("cycle_gan",istrain=True)

        losses = [[], []]
        t_time = 0
        optimizer1 = optimizer_setting()
        optimizer2 = optimizer_setting()
        optimizer3 = optimizer_setting()

        for epoch in range(args.epoch):
            batch_id = 0
            for i in range(max_images_num):

                data_A = next(A_reader)
                data_B = next(B_reader)

                s_time = time.time()
                data_A = np.array([data_A[0].reshape(3,256,256)]).astype("float32")
                data_B = np.array([data_B[0].reshape(3,256,256)]).astype("float32")
                data_A = to_variable(data_A)
                data_B = to_variable(data_B)

                # optimize the g_A network
                fake_A,fake_B,cyc_A,cyc_B,g_A_loss,g_B_loss,idt_loss_A,idt_loss_B,cyc_A_loss,cyc_B_loss,g_loss = cycle_gan(data_A,data_B,True,False,False)

                g_loss_out = g_loss.numpy()

                g_loss.backward()
                vars_G = []
                for param in cycle_gan.parameters():
                    if param.name[:52]=="cycle_gan/Cycle_Gan_0/build_generator_resnet_9blocks":    
                        vars_G.append(param)

                optimizer1.minimize(g_loss,parameter_list=vars_G)                
                cycle_gan.clear_gradients()


                fake_pool_B = B_pool.pool_image(fake_B).numpy()
                fake_pool_B = np.array([fake_pool_B[0].reshape(3,256,256)]).astype("float32")
                fake_pool_B = to_variable(fake_pool_B)

                fake_pool_A = A_pool.pool_image(fake_A).numpy()
                fake_pool_A = np.array([fake_pool_A[0].reshape(3,256,256)]).astype("float32")
                fake_pool_A = to_variable(fake_pool_A)

                # optimize the d_A network
                rec_B, fake_pool_rec_B = cycle_gan(data_B,fake_pool_B,False,True,False)
                d_loss_A = (fluid.layers.square(fake_pool_rec_B) +
                    fluid.layers.square(rec_B - 1)) / 2.0
                d_loss_A = fluid.layers.reduce_mean(d_loss_A)

                d_loss_A.backward()
                vars_da = []
                for param in cycle_gan.parameters():
                    if param.name[:47]=="cycle_gan/Cycle_Gan_0/build_gen_discriminator_0":
                        vars_da.append(param)
                optimizer2.minimize(d_loss_A,parameter_list=vars_da)
                cycle_gan.clear_gradients()

                # optimize the d_B network

                rec_A, fake_pool_rec_A = cycle_gan(data_A,fake_pool_A,False,False,True)
                d_loss_B = (fluid.layers.square(fake_pool_rec_A) +
                    fluid.layers.square(rec_A - 1)) / 2.0
                d_loss_B = fluid.layers.reduce_mean(d_loss_B)

                d_loss_B.backward()
                vars_db = []
                for param in cycle_gan.parameters():
                    if param.name[:47]=="cycle_gan/Cycle_Gan_0/build_gen_discriminator_1":
                        vars_db.append(param)
                optimizer3.minimize(d_loss_B,parameter_list=vars_db)

                cycle_gan.clear_gradients()

                batch_time = time.time() - s_time
                t_time += batch_time
                print(
                    "epoch{}; batch{}; g_loss:{}; d_A_loss: {}; d_B_loss:{} ; \n g_A_loss: {}; g_A_cyc_loss: {}; g_A_idt_loss: {}; g_B_loss: {}; g_B_cyc_loss:  {}; g_B_idt_loss: {};Batch_time_cost: {:.2f}".format(epoch, batch_id,g_loss_out[0],d_loss_A.numpy()[0], d_loss_B.numpy()[0],g_A_loss.numpy()[0],cyc_A_loss.numpy()[0], idt_loss_A.numpy()[0],  g_B_loss.numpy()[0],cyc_B_loss.numpy()[0],idt_loss_B.numpy()[0], batch_time))
                with open('logging_train.txt', 'a') as log_file:
                    now = time.strftime("%c")
                    log_file.write(
                    "time: {}; epoch{}; batch{}; d_A_loss: {}; g_A_loss: {}; \
                    g_A_cyc_loss: {}; g_A_idt_loss: {}; d_B_loss: {}; \
                    g_B_loss: {}; g_B_cyc_loss: {}; g_B_idt_loss: {}; \
                    Batch_time_cost: {:.2f}\n".format(now, epoch, \
                        batch_id, d_loss_A[0], g_A_loss[ 0], cyc_A_loss[0], \
                        idt_loss_A[0], d_loss_B[0], g_A_loss[0], \
                        cyc_B_loss[0], idt_loss_B[0], batch_time))
                losses[0].append(g_A_loss[0])
                losses[1].append(d_loss_A[0])
                sys.stdout.flush()
                batch_id += 1

            if args.save_checkpoints:
                fluid.dygraph.save_persistables(cycle_gan.state_dict(),args.output+"/checkpoints/{}".format(epoch))

if __name__ == "__main__":
    args = parser.parse_args()
    print_arguments(args)
    train(args)