train.py 9.4 KB
Newer Older
W
whs 已提交
1 2 3
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
W
whs 已提交
4 5 6 7 8 9
import os
import random
import sys
import paddle
import argparse
import functools
W
whs 已提交
10
import time
W
whs 已提交
11 12
import numpy as np
from scipy.misc import imsave
W
whs 已提交
13
import paddle.fluid as fluid
W
whs 已提交
14
import paddle.fluid.profiler as profiler
W
whs 已提交
15 16
from paddle.fluid import core
import data_reader
W
whs 已提交
17
from utility import add_arguments, print_arguments, ImagePool
W
whs 已提交
18 19
from trainer import *

W
whs 已提交
20 21 22 23 24
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size',        int,   1,          "Minibatch size.")
add_arg('epoch',             int,   2,        "The number of epoched to be trained.")
W
whs 已提交
25
add_arg('output',            str,   "./output_0", "The directory the model and the test result to be saved to.")
W
whs 已提交
26 27 28 29 30
add_arg('init_model',        str,   None,       "The init model file of directory.")
add_arg('save_checkpoints',  bool,  True,       "Whether to save checkpoints.")
add_arg('run_test',          bool,  True,       "Whether to run test.")
add_arg('use_gpu',           bool,  True,       "Whether to use GPU to train.")
add_arg('profile',           bool,  False,       "Whether to profile.")
W
whs 已提交
31
add_arg('run_ce',            bool,  False,       "Whether to run for model ce.")
W
whs 已提交
32 33 34 35
# yapf: enable


def train(args):
W
whs 已提交
36

W
whs 已提交
37
    max_images_num = data_reader.max_images_num()
W
whs 已提交
38
    shuffle = True
W
whs 已提交
39 40 41 42 43 44
    if args.run_ce:
        np.random.seed(10)
        fluid.default_startup_program().random_seed = 90
        max_images_num = 1
        shuffle = False
    data_shape = [-1] + data_reader.image_shape()
W
whs 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67

    input_A = fluid.layers.data(
        name='input_A', shape=data_shape, dtype='float32')
    input_B = fluid.layers.data(
        name='input_B', shape=data_shape, dtype='float32')
    fake_pool_A = fluid.layers.data(
        name='fake_pool_A', shape=data_shape, dtype='float32')
    fake_pool_B = fluid.layers.data(
        name='fake_pool_B', shape=data_shape, dtype='float32')

    g_A_trainer = GATrainer(input_A, input_B)
    g_B_trainer = GBTrainer(input_A, input_B)
    d_A_trainer = DATrainer(input_A, fake_pool_A)
    d_B_trainer = DBTrainer(input_B, fake_pool_B)

    # prepare environment
    place = fluid.CPUPlace()
    if args.use_gpu:
        place = fluid.CUDAPlace(0)
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
    A_pool = ImagePool()
    B_pool = ImagePool()
W
whs 已提交
68 69 70 71 72

    A_reader = paddle.batch(
        data_reader.a_reader(shuffle=shuffle), args.batch_size)()
    B_reader = paddle.batch(
        data_reader.b_reader(shuffle=shuffle), args.batch_size)()
W
whs 已提交
73 74 75
    if not args.run_ce:
        A_test_reader = data_reader.a_test_reader()
        B_test_reader = data_reader.b_test_reader()
W
whs 已提交
76 77 78 79 80 81

    def test(epoch):
        out_path = args.output + "/test"
        if not os.path.exists(out_path):
            os.makedirs(out_path)
        i = 0
W
whs 已提交
82
        for data_A, data_B in zip(A_test_reader(), B_test_reader()):
W
whs 已提交
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
            A_name = data_A[1]
            B_name = data_B[1]
            tensor_A = core.LoDTensor()
            tensor_B = core.LoDTensor()
            tensor_A.set(data_A[0], place)
            tensor_B.set(data_B[0], place)
            fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = exe.run(
                g_A_trainer.infer_program,
                fetch_list=[
                    g_A_trainer.fake_A, g_A_trainer.fake_B, g_A_trainer.cyc_A,
                    g_A_trainer.cyc_B
                ],
                feed={"input_A": tensor_A,
                      "input_B": tensor_B})
            fake_A_temp = np.squeeze(fake_A_temp[0]).transpose([1, 2, 0])
            fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0])
            cyc_A_temp = np.squeeze(cyc_A_temp[0]).transpose([1, 2, 0])
            cyc_B_temp = np.squeeze(cyc_B_temp[0]).transpose([1, 2, 0])
            input_A_temp = np.squeeze(data_A[0]).transpose([1, 2, 0])
            input_B_temp = np.squeeze(data_B[0]).transpose([1, 2, 0])

            imsave(out_path + "/fakeB_" + str(epoch) + "_" + A_name, (
                (fake_B_temp + 1) * 127.5).astype(np.uint8))
            imsave(out_path + "/fakeA_" + str(epoch) + "_" + B_name, (
                (fake_A_temp + 1) * 127.5).astype(np.uint8))
            imsave(out_path + "/cycA_" + str(epoch) + "_" + A_name, (
                (cyc_A_temp + 1) * 127.5).astype(np.uint8))
            imsave(out_path + "/cycB_" + str(epoch) + "_" + B_name, (
                (cyc_B_temp + 1) * 127.5).astype(np.uint8))
            imsave(out_path + "/inputA_" + str(epoch) + "_" + A_name, (
                (input_A_temp + 1) * 127.5).astype(np.uint8))
            imsave(out_path + "/inputB_" + str(epoch) + "_" + B_name, (
                (input_B_temp + 1) * 127.5).astype(np.uint8))
            i += 1

    def checkpoints(epoch):
        out_path = args.output + "/checkpoints/" + str(epoch)
        if not os.path.exists(out_path):
            os.makedirs(out_path)
        fluid.io.save_persistables(
W
whs 已提交
123
            exe, out_path + "/g_a", main_program=g_A_trainer.program)
W
whs 已提交
124
        fluid.io.save_persistables(
W
whs 已提交
125
            exe, out_path + "/g_b", main_program=g_B_trainer.program)
W
whs 已提交
126
        fluid.io.save_persistables(
W
whs 已提交
127
            exe, out_path + "/d_a", main_program=d_A_trainer.program)
W
whs 已提交
128
        fluid.io.save_persistables(
W
whs 已提交
129
            exe, out_path + "/d_b", main_program=d_B_trainer.program)
W
whs 已提交
130
        print("saved checkpoint to {}".format(out_path))
W
whs 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143
        sys.stdout.flush()

    def init_model():
        assert os.path.exists(
            args.init_model), "[%s] cann't be found." % args.init_mode
        fluid.io.load_persistables(
            exe, args.init_model + "/g_a", main_program=g_A_trainer.program)
        fluid.io.load_persistables(
            exe, args.init_model + "/g_b", main_program=g_B_trainer.program)
        fluid.io.load_persistables(
            exe, args.init_model + "/d_a", main_program=d_A_trainer.program)
        fluid.io.load_persistables(
            exe, args.init_model + "/d_b", main_program=d_B_trainer.program)
W
whs 已提交
144
        print("Load model from {}".format(args.init_model))
W
whs 已提交
145 146 147

    if args.init_model:
        init_model()
W
whs 已提交
148
    losses = [[], []]
W
whs 已提交
149
    t_time = 0
W
whs 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162

    g_A_trainer_program = fluid.CompiledProgram(
        g_A_trainer.program).with_data_parallel(
            loss_name=g_A_trainer.g_loss_A.name)
    g_B_trainer_program = fluid.CompiledProgram(
        g_B_trainer.program).with_data_parallel(
            loss_name=g_B_trainer.g_loss_B.name)
    d_B_trainer_program = fluid.CompiledProgram(
        d_B_trainer.program).with_data_parallel(
            loss_name=d_B_trainer.d_loss_B.name)
    d_A_trainer_program = fluid.CompiledProgram(
        d_A_trainer.program).with_data_parallel(
            loss_name=d_A_trainer.d_loss_A.name)
W
whs 已提交
163 164 165
    for epoch in range(args.epoch):
        batch_id = 0
        for i in range(max_images_num):
W
whs 已提交
166 167
            data_A = next(A_reader)
            data_B = next(B_reader)
W
whs 已提交
168 169 170 171
            tensor_A = core.LoDTensor()
            tensor_B = core.LoDTensor()
            tensor_A.set(data_A, place)
            tensor_B.set(data_B, place)
W
whs 已提交
172
            s_time = time.time()
W
whs 已提交
173 174
            # optimize the g_A network
            g_A_loss, fake_B_tmp = exe.run(
W
whs 已提交
175
                g_A_trainer_program,
W
whs 已提交
176 177 178 179 180 181 182 183
                fetch_list=[g_A_trainer.g_loss_A, g_A_trainer.fake_B],
                feed={"input_A": tensor_A,
                      "input_B": tensor_B})

            fake_pool_B = B_pool.pool_image(fake_B_tmp)

            # optimize the d_B network
            d_B_loss = exe.run(
W
whs 已提交
184
                d_B_trainer_program,
W
whs 已提交
185 186
                fetch_list=[d_B_trainer.d_loss_B],
                feed={"input_B": tensor_B,
W
whs 已提交
187
                      "fake_pool_B": fake_pool_B})[0]
W
whs 已提交
188 189 190

            # optimize the g_B network
            g_B_loss, fake_A_tmp = exe.run(
W
whs 已提交
191
                g_B_trainer_program,
W
whs 已提交
192 193 194 195 196 197 198 199
                fetch_list=[g_B_trainer.g_loss_B, g_B_trainer.fake_A],
                feed={"input_A": tensor_A,
                      "input_B": tensor_B})

            fake_pool_A = A_pool.pool_image(fake_A_tmp)

            # optimize the d_A network
            d_A_loss = exe.run(
W
whs 已提交
200
                d_A_trainer_program,
W
whs 已提交
201 202
                fetch_list=[d_A_trainer.d_loss_A],
                feed={"input_A": tensor_A,
W
whs 已提交
203
                      "fake_pool_A": fake_pool_A})[0]
204 205
            batch_time = time.time() - s_time
            t_time += batch_time
W
whs 已提交
206 207 208 209
            print(
                "epoch{}; batch{}; g_A_loss: {}; d_B_loss: {}; g_B_loss: {}; d_A_loss: {}; "
                "Batch_time_cost: {:.2f}".format(epoch, batch_id, g_A_loss[
                    0], d_B_loss[0], g_B_loss[0], d_A_loss[0], batch_time))
W
whs 已提交
210 211
            losses[0].append(g_A_loss[0])
            losses[1].append(d_A_loss[0])
W
whs 已提交
212 213 214
            sys.stdout.flush()
            batch_id += 1

W
whs 已提交
215
        if args.run_test and not args.run_ce:
W
whs 已提交
216
            test(epoch)
W
whs 已提交
217
        if args.save_checkpoints and not args.run_ce:
W
whs 已提交
218
            checkpoints(epoch)
W
whs 已提交
219 220 221 222
    if args.run_ce:
        print("kpis,g_train_cost,{}".format(np.mean(losses[0])))
        print("kpis,d_train_cost,{}".format(np.mean(losses[1])))
        print("kpis,duration,{}".format(t_time / args.epoch))
W
whs 已提交
223 224 225 226 227 228 229 230 231 232 233 234 235 236


if __name__ == "__main__":
    args = parser.parse_args()
    print_arguments(args)
    if args.profile:
        if args.use_gpu:
            with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
                train(args)
        else:
            with profiler.profiler("CPU", sorted_key='total') as cpuprof:
                train(args)
    else:
        train(args)