eval.py 4.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
import os
import time
import yaml
import logging
import argparse
import numpy as np
from pprint import pprint
from attrdict import AttrDict

import paddle

from reader import get_lm_vocab, get_lm_data_loader
from mem_transformer import MemTransformerLM

FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config",
        default="./configs/enwik8.yaml",
        type=str,
        help="Path of the config file. ")
    args = parser.parse_args()
    return args


def do_eval(args):
    assert args.ext_len >= 0, 'Extended context length must be no less than 0'

    def _evaluate(loader):
        total_len, total_loss = 0, 0.

        eval_mems = tuple()
        for i, (src, target, seq_len) in enumerate(loader):
            if args.max_eval_steps > 0 and i >= args.max_eval_steps:
                break
            ret = mem_transformer(src, target, *eval_mems)
            loss, eval_mems = ret[0], ret[1:]
            seq_len = seq_len.numpy()
            eval_cur_loss = seq_len * loss.numpy()
            total_loss += eval_cur_loss
            total_len += seq_len
        return total_loss / total_len

    def _logger(loss):
        if args.dataset in ['enwik8', 'text8']:
            logger_info = "loss: %f, bpc: %f" % \
                          (loss, loss / np.log(2))
        else:
            logger_info = "loss: %f, ppl: %.2f" % \
                          (loss, np.exp(loss))
        return logger_info

    vocab = get_lm_vocab(args)
    eval_loader = get_lm_data_loader(args, vocab, "valid")
    test_loader = get_lm_data_loader(args, vocab, "test")

    cutoffs, tie_projs = [], [False]
    if args.adaptive:
        assert args.dataset in ['wt103', 'lm1b']
        if args.dataset == 'wt103':
            cutoffs = [20000, 40000, 200000]
            tie_projs += [True] * len(cutoffs)
        elif args.dataset == 'lm1b':
            cutoffs = [60000, 100000, 640000]
            tie_projs += [False] * len(cutoffs)

    mem_transformer = MemTransformerLM(
        args.ntokens,
        args.n_layer,
        args.n_head,
        args.d_model,
        args.d_head,
        args.d_inner_hid,
        args.dropout,
        args.attn_dropout,
        tie_weight=args.tie_weight,
        d_embed=args.d_model,
        div_val=args.div_val,
        tie_projs=tie_projs,
        normalize_before=args.normalize_before,
        tgt_len=args.tgt_len,
        ext_len=args.ext_len,
        mem_len=args.mem_len,
        cutoffs=cutoffs,
        same_length=args.same_length,
        attn_type=args.attn_type,
        clamp_len=args.clamp_len,
        sample_softmax=args.sample_softmax)

    assert args.init_from_params, (
        "Please set init_from_params to load the infer model.")

    model_dict = paddle.load(
        os.path.join(args.init_from_params, "mem_transformer.pdparams"))
    mem_transformer.load_dict(model_dict)

    logger.info(
        "Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}".
        format(args.batch_size, args.tgt_len, args.ext_len, args.mem_len,
               args.clamp_len))

    mem_transformer.reset_length(args.tgt_len, args.ext_len, args.mem_len)

    test_loss = None
    valid_loss = None
    if args.mode == 'all':
        test_loss = _evaluate(test_loader)
        valid_loss = _evaluate(eval_loader)
    elif args.mode == 'valid':
        valid_loss = _evaluate(eval_loader)
    elif args.mode == 'test':
        test_loss = _evaluate(test_loader)

    logger_info = ''
    if valid_loss is not None:
121
        logger_info = logger_info + _logger(valid_loss) + " | "
122
    if test_loss is not None:
123
        logger_info = logger_info + _logger(test_loss) + " | "
124 125 126 127 128 129 130 131 132 133 134
    logger.info(logger_info)


if __name__ == "__main__":
    ARGS = parse_args()
    yaml_file = ARGS.config
    with open(yaml_file, 'rt') as f:
        args = AttrDict(yaml.safe_load(f))
        pprint(args)

    do_eval(args)