train.py 7.9 KB
Newer Older
W
whs 已提交
1 2 3
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
W
whs 已提交
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
import data_reader
import os
import random
import sys
import paddle
import argparse
import functools
import paddle.fluid as fluid
import numpy as np
from paddle.fluid import core
from trainer import *
from scipy.misc import imsave
import paddle.fluid.profiler as profiler
from utility import add_arguments, print_arguments, ImagePool

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size',        int,   1,          "Minibatch size.")
add_arg('epoch',             int,   2,        "The number of epoched to be trained.")
add_arg('output',            str,   "./output_1", "The directory the model and the test result to be saved to.")
add_arg('init_model',        str,   None,       "The init model file of directory.")
add_arg('save_checkpoints',  bool,  True,       "Whether to save checkpoints.")
add_arg('run_test',          bool,  True,       "Whether to run test.")
add_arg('use_gpu',           bool,  True,       "Whether to use GPU to train.")
add_arg('profile',           bool,  False,       "Whether to profile.")
# yapf: enable


def train(args):
    data_shape = [-1] + data_reader.image_shape()
    max_images_num = data_reader.max_images_num()

    input_A = fluid.layers.data(
        name='input_A', shape=data_shape, dtype='float32')
    input_B = fluid.layers.data(
        name='input_B', shape=data_shape, dtype='float32')
    fake_pool_A = fluid.layers.data(
        name='fake_pool_A', shape=data_shape, dtype='float32')
    fake_pool_B = fluid.layers.data(
        name='fake_pool_B', shape=data_shape, dtype='float32')

    g_A_trainer = GATrainer(input_A, input_B)
    g_B_trainer = GBTrainer(input_A, input_B)
    d_A_trainer = DATrainer(input_A, fake_pool_A)
    d_B_trainer = DBTrainer(input_B, fake_pool_B)

    # prepare environment
    place = fluid.CPUPlace()
    if args.use_gpu:
        place = fluid.CUDAPlace(0)
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
    A_pool = ImagePool()
    B_pool = ImagePool()

    A_reader = paddle.batch(data_reader.a_reader(), args.batch_size)()
    B_reader = paddle.batch(data_reader.b_reader(), args.batch_size)()

    A_test_reader = data_reader.a_test_reader()
    B_test_reader = data_reader.b_test_reader()

    def test(epoch):
        out_path = args.output + "/test"
        if not os.path.exists(out_path):
            os.makedirs(out_path)
        i = 0
W
whs 已提交
71
        for data_A, data_B in zip(A_test_reader(), B_test_reader()):
W
whs 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
            A_name = data_A[1]
            B_name = data_B[1]
            tensor_A = core.LoDTensor()
            tensor_B = core.LoDTensor()
            tensor_A.set(data_A[0], place)
            tensor_B.set(data_B[0], place)
            fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = exe.run(
                g_A_trainer.infer_program,
                fetch_list=[
                    g_A_trainer.fake_A, g_A_trainer.fake_B, g_A_trainer.cyc_A,
                    g_A_trainer.cyc_B
                ],
                feed={"input_A": tensor_A,
                      "input_B": tensor_B})
            fake_A_temp = np.squeeze(fake_A_temp[0]).transpose([1, 2, 0])
            fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0])
            cyc_A_temp = np.squeeze(cyc_A_temp[0]).transpose([1, 2, 0])
            cyc_B_temp = np.squeeze(cyc_B_temp[0]).transpose([1, 2, 0])
            input_A_temp = np.squeeze(data_A[0]).transpose([1, 2, 0])
            input_B_temp = np.squeeze(data_B[0]).transpose([1, 2, 0])

            imsave(out_path + "/fakeB_" + str(epoch) + "_" + A_name, (
                (fake_B_temp + 1) * 127.5).astype(np.uint8))
            imsave(out_path + "/fakeA_" + str(epoch) + "_" + B_name, (
                (fake_A_temp + 1) * 127.5).astype(np.uint8))
            imsave(out_path + "/cycA_" + str(epoch) + "_" + A_name, (
                (cyc_A_temp + 1) * 127.5).astype(np.uint8))
            imsave(out_path + "/cycB_" + str(epoch) + "_" + B_name, (
                (cyc_B_temp + 1) * 127.5).astype(np.uint8))
            imsave(out_path + "/inputA_" + str(epoch) + "_" + A_name, (
                (input_A_temp + 1) * 127.5).astype(np.uint8))
            imsave(out_path + "/inputB_" + str(epoch) + "_" + B_name, (
                (input_B_temp + 1) * 127.5).astype(np.uint8))
            i += 1

    def checkpoints(epoch):
        out_path = args.output + "/checkpoints/" + str(epoch)
        if not os.path.exists(out_path):
            os.makedirs(out_path)
        fluid.io.save_persistables(
            exe, out_path + "/g_a", main_program=g_A_trainer.program)
        fluid.io.save_persistables(
            exe, out_path + "/g_b", main_program=g_B_trainer.program)
        fluid.io.save_persistables(
            exe, out_path + "/d_a", main_program=d_A_trainer.program)
        fluid.io.save_persistables(
            exe, out_path + "/d_b", main_program=d_B_trainer.program)
W
whs 已提交
119
        print("saved checkpoint to {}".format(out_path))
W
whs 已提交
120 121 122 123 124 125 126 127 128 129 130 131 132
        sys.stdout.flush()

    def init_model():
        assert os.path.exists(
            args.init_model), "[%s] cann't be found." % args.init_mode
        fluid.io.load_persistables(
            exe, args.init_model + "/g_a", main_program=g_A_trainer.program)
        fluid.io.load_persistables(
            exe, args.init_model + "/g_b", main_program=g_B_trainer.program)
        fluid.io.load_persistables(
            exe, args.init_model + "/d_a", main_program=d_A_trainer.program)
        fluid.io.load_persistables(
            exe, args.init_model + "/d_b", main_program=d_B_trainer.program)
W
whs 已提交
133
        print("Load model from {}".format(args.init_model))
W
whs 已提交
134 135 136 137 138 139 140

    if args.init_model:
        init_model()

    for epoch in range(args.epoch):
        batch_id = 0
        for i in range(max_images_num):
W
whs 已提交
141 142
            data_A = next(A_reader)
            data_B = next(B_reader)
W
whs 已提交
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
            tensor_A = core.LoDTensor()
            tensor_B = core.LoDTensor()
            tensor_A.set(data_A, place)
            tensor_B.set(data_B, place)
            # optimize the g_A network
            g_A_loss, fake_B_tmp = exe.run(
                g_A_trainer.program,
                fetch_list=[g_A_trainer.g_loss_A, g_A_trainer.fake_B],
                feed={"input_A": tensor_A,
                      "input_B": tensor_B})

            fake_pool_B = B_pool.pool_image(fake_B_tmp)

            # optimize the d_B network
            d_B_loss = exe.run(
                d_B_trainer.program,
                fetch_list=[d_B_trainer.d_loss_B],
                feed={"input_B": tensor_B,
                      "fake_pool_B": fake_pool_B})

            # optimize the g_B network
            g_B_loss, fake_A_tmp = exe.run(
                g_B_trainer.program,
                fetch_list=[g_B_trainer.g_loss_B, g_B_trainer.fake_A],
                feed={"input_A": tensor_A,
                      "input_B": tensor_B})

            fake_pool_A = A_pool.pool_image(fake_A_tmp)

            # optimize the d_A network
            d_A_loss = exe.run(
                d_A_trainer.program,
                fetch_list=[d_A_trainer.d_loss_A],
                feed={"input_A": tensor_A,
                      "fake_pool_A": fake_pool_A})

W
whs 已提交
179
            print("epoch{}; batch{}; g_A_loss: {}; d_B_loss: {}; g_B_loss: {}; d_A_loss: {};".format(
W
whs 已提交
180
                epoch, batch_id, g_A_loss[0], d_B_loss[0], g_B_loss[0],
W
whs 已提交
181
                d_A_loss[0]))
W
whs 已提交
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
            sys.stdout.flush()
            batch_id += 1

        if args.run_test:
            test(epoch)
        if args.save_checkpoints:
            checkpoints(epoch)


if __name__ == "__main__":
    args = parser.parse_args()
    print_arguments(args)
    if args.profile:
        if args.use_gpu:
            with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
                train(args)
        else:
            with profiler.profiler("CPU", sorted_key='total') as cpuprof:
                train(args)
    else:
        train(args)