train_utils.py 5.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
import os
import time
import numpy as np
import paddle
import paddle.fluid as fluid
import logging
import shutil

logger = logging.getLogger(__name__)


def test_without_pyreader(test_exe,
                          test_reader,
                          test_feeder,
                          test_fetch_list,
                          test_metrics,
                          log_interval=0):
    test_metrics.reset()
    for test_iter, data in enumerate(test_reader()):
        test_outs = test_exe.run(test_fetch_list, feed=test_feeder.feed(data))
        loss = np.array(test_outs[0])
        pred = np.array(test_outs[1])
        label = np.array(test_outs[-1])
        test_metrics.accumulate(loss, pred, label)
        if log_interval > 0 and test_iter % log_interval == 0:
            test_metrics.calculate_and_log_out(loss, pred, label, \
S
SunGaofeng 已提交
27
                  info = '[TEST] test_iter {} '.format(test_iter))
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
    test_metrics.finalize_and_log_out("[TEST] Finish")


def test_with_pyreader(test_exe,
                       test_pyreader,
                       test_fetch_list,
                       test_metrics,
                       log_interval=0):
    if not test_pyreader:
        logger.error("[TEST] get pyreader failed.")
    test_pyreader.start()
    test_metrics.reset()
    test_iter = 0
    try:
        while True:
            test_outs = test_exe.run(fetch_list=test_fetch_list)
            loss = np.array(test_outs[0])
            pred = np.array(test_outs[1])
            label = np.array(test_outs[-1])
            test_metrics.accumulate(loss, pred, label)
            if log_interval > 0 and test_iter % log_interval == 0:
                test_metrics.calculate_and_log_out(loss, pred, label, \
S
SunGaofeng 已提交
50
                  info = '[TEST] test_iter {} '.format(test_iter))
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
            test_iter += 1
    except fluid.core.EOFException:
        test_metrics.finalize_and_log_out("[TEST] Finish")
    finally:
        test_pyreader.reset()


def train_without_pyreader(exe, train_prog, train_exe, train_reader, train_feeder, \
                           train_fetch_list, train_metrics, epochs = 10, \
                           log_interval = 0, valid_interval = 0, save_dir = './', \
                           save_model_name = 'model', test_exe = None, test_reader = None, \
                           test_feeder = None, test_fetch_list = None, test_metrics = None):
    for epoch in range(epochs):
        epoch_periods = []
        for train_iter, data in enumerate(train_reader()):
            cur_time = time.time()
            train_outs = train_exe.run(train_fetch_list,
                                       feed=train_feeder.feed(data))
            period = time.time() - cur_time
            epoch_periods.append(period)
            loss = np.array(train_outs[0])
            pred = np.array(train_outs[1])
            label = np.array(train_outs[-1])
            if log_interval > 0 and (train_iter % log_interval == 0):
                # eval here
                train_metrics.calculate_and_log_out(loss, pred, label, \
                       info = '[TRAIN] Epoch {}, iter {} '.format(epoch, train_iter))
            train_iter += 1
        logger.info('[TRAIN] Epoch {} training finished, average time: {}'.
                    format(epoch, np.mean(epoch_periods)))
        save_model(exe, train_prog, save_dir, save_model_name,
                   "_epoch{}".format(epoch))
D
dengkaipeng 已提交
83
        if test_exe and valid_interval > 0 and (epoch + 1) % valid_interval == 0:
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
            test_without_pyreader(test_exe, test_reader, test_feeder,
                                  test_fetch_list, test_metrics, log_interval)



def train_with_pyreader(exe, train_prog, train_exe, train_pyreader, \
                        train_fetch_list, train_metrics, epochs = 10, \
                        log_interval = 0, valid_interval = 0, \
                        save_dir = './', save_model_name = 'model', \
                        test_exe = None, test_pyreader = None, \
                        test_fetch_list = None, test_metrics = None):
    if not train_pyreader:
        logger.error("[TRAIN] get pyreader failed.")
    for epoch in range(epochs):
        train_pyreader.start()
        train_metrics.reset()
        try:
            train_iter = 0
            epoch_periods = []
            while True:
                cur_time = time.time()
                train_outs = train_exe.run(fetch_list=train_fetch_list)
                period = time.time() - cur_time
                epoch_periods.append(period)
                loss = np.array(train_outs[0])
                pred = np.array(train_outs[1])
                label = np.array(train_outs[-1])
                if log_interval > 0 and (train_iter % log_interval == 0):
                    # eval here
                    train_metrics.calculate_and_log_out(loss, pred, label, \
                                info = '[TRAIN] Epoch {}, iter {} '.format(epoch, train_iter))
                train_iter += 1
        except fluid.core.EOFException:
            # eval here
            logger.info('[TRAIN] Epoch {} training finished, average time: {}'.
                        format(epoch, np.mean(epoch_periods)))
            save_model(exe, train_prog, save_dir, save_model_name,
                       "_epoch{}".format(epoch))
D
dengkaipeng 已提交
122
            if test_exe and valid_interval > 0 and (epoch + 1) % valid_interval == 0:
123 124 125 126 127 128 129 130 131 132 133 134
                test_with_pyreader(test_exe, test_pyreader, test_fetch_list,
                                   test_metrics, log_interval)
        finally:
            epoch_period = []
            train_pyreader.reset()


def save_model(exe, program, save_dir, model_name, postfix=None):
    model_path = os.path.join(save_dir, model_name + postfix)
    if os.path.isdir(model_path):
        shutil.rmtree(model_path)
    fluid.io.save_persistables(exe, model_path, main_program=program)