eval.py 4.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
import os
import time
import yaml
import logging
import argparse
import numpy as np
from pprint import pprint
from attrdict import AttrDict

import paddle

from reader import get_lm_vocab, get_lm_data_loader
from mem_transformer import MemTransformerLM

FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config",
        default="./configs/enwik8.yaml",
        type=str,
        help="Path of the config file. ")
    args = parser.parse_args()
    return args


def do_eval(args):
    assert args.ext_len >= 0, 'Extended context length must be no less than 0'

    def _evaluate(loader):
        total_len, total_loss = 0, 0.

        eval_mems = tuple()
        for i, (src, target, seq_len) in enumerate(loader):
            if args.max_eval_steps > 0 and i >= args.max_eval_steps:
                break
            ret = mem_transformer(src, target, *eval_mems)
            loss, eval_mems = ret[0], ret[1:]
            seq_len = seq_len.numpy()
            eval_cur_loss = seq_len * loss.numpy()
            total_loss += eval_cur_loss
            total_len += seq_len
        return total_loss / total_len

    def _logger(loss):
        if args.dataset in ['enwik8', 'text8']:
            logger_info = "loss: %f, bpc: %f" % \
                          (loss, loss / np.log(2))
        else:
            logger_info = "loss: %f, ppl: %.2f" % \
                          (loss, np.exp(loss))
        return logger_info

L
liu zhengxi 已提交
58 59 60
    if not args.use_gpu:
        paddle.set_device("cpu")

61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
    vocab = get_lm_vocab(args)
    eval_loader = get_lm_data_loader(args, vocab, "valid")
    test_loader = get_lm_data_loader(args, vocab, "test")

    cutoffs, tie_projs = [], [False]
    if args.adaptive:
        assert args.dataset in ['wt103', 'lm1b']
        if args.dataset == 'wt103':
            cutoffs = [20000, 40000, 200000]
            tie_projs += [True] * len(cutoffs)
        elif args.dataset == 'lm1b':
            cutoffs = [60000, 100000, 640000]
            tie_projs += [False] * len(cutoffs)

    mem_transformer = MemTransformerLM(
        args.ntokens,
        args.n_layer,
        args.n_head,
        args.d_model,
        args.d_head,
        args.d_inner_hid,
        args.dropout,
        args.attn_dropout,
        tie_weight=args.tie_weight,
        d_embed=args.d_model,
        div_val=args.div_val,
        tie_projs=tie_projs,
        normalize_before=args.normalize_before,
        tgt_len=args.tgt_len,
        ext_len=args.ext_len,
        mem_len=args.mem_len,
        cutoffs=cutoffs,
        same_length=args.same_length,
        attn_type=args.attn_type,
        clamp_len=args.clamp_len,
        sample_softmax=args.sample_softmax)

    assert args.init_from_params, (
        "Please set init_from_params to load the infer model.")

    model_dict = paddle.load(
        os.path.join(args.init_from_params, "mem_transformer.pdparams"))
    mem_transformer.load_dict(model_dict)

    logger.info(
        "Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}".
107
        format(args.eval_batch_size, args.tgt_len, args.ext_len, args.mem_len,
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
               args.clamp_len))

    mem_transformer.reset_length(args.tgt_len, args.ext_len, args.mem_len)

    test_loss = None
    valid_loss = None
    if args.mode == 'all':
        test_loss = _evaluate(test_loader)
        valid_loss = _evaluate(eval_loader)
    elif args.mode == 'valid':
        valid_loss = _evaluate(eval_loader)
    elif args.mode == 'test':
        test_loss = _evaluate(test_loader)

    logger_info = ''
    if valid_loss is not None:
124
        logger_info = logger_info + _logger(valid_loss) + " | "
125
    if test_loss is not None:
126
        logger_info = logger_info + _logger(test_loss) + " | "
127 128 129 130 131 132 133 134 135 136 137
    logger.info(logger_info)


if __name__ == "__main__":
    ARGS = parse_args()
    yaml_file = ARGS.config
    with open(yaml_file, 'rt') as f:
        args = AttrDict(yaml.safe_load(f))
        pprint(args)

    do_eval(args)