test_reader.py 3.7 KB
Newer Older
P
Peng Li 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
import unittest
import os
import itertools
import math
import logging

# set up python path
topdir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
import sys
sys.path += [topdir, os.path.join(topdir, "data", "evaluation")]

import reader
import utils

formatter = logging.Formatter(
    "[%(levelname)s %(asctime)s.%(msecs)d %(filename)s:%(lineno)d] %(message)s",
    datefmt='%Y-%m-%d %I:%M:%S')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
utils.logger.addHandler(ch)

P
Peng Li 已提交
22

P
Peng Li 已提交
23 24 25
class Vocab(object):
    @property
    def data(self):
P
Peng Li 已提交
26 27
        word_dict_path = os.path.join(topdir, "data", "embedding",
                                      "wordvecs.vcb")
P
Peng Li 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
        return utils.load_dict(word_dict_path)


class NegativeSampleRatioTest(unittest.TestCase):
    def check_ratio(self, negative_sample_ratio):
        for keep_first_b in [True, False]:
            settings = reader.Settings(
                vocab=Vocab().data,
                is_training=True,
                label_schema="BIO2",
                negative_sample_ratio=negative_sample_ratio,
                hit_ans_negative_sample_ratio=0.25,
                keep_first_b=keep_first_b)

            filename = os.path.join(topdir, "test", "trn_data.gz")
            data_stream = reader.create_reader(filename, settings)
            total, negative_num = 5000, 0
            for _, d in itertools.izip(xrange(total), data_stream()):
                labels = d[reader.LABELS]
                if labels.count(0) == 0:
                    negative_num += 1

            ratio = negative_num / float(total)
            self.assertLessEqual(math.fabs(ratio - negative_sample_ratio), 0.01)

    def runTest(self):
        for ratio in [1., 0.25, 0.]:
            self.check_ratio(ratio)
P
Peng Li 已提交
56

P
Peng Li 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106

class KeepFirstBTest(unittest.TestCase):
    def runTest(self):
        for keep_first_b in [True, False]:
            for label_schema in ["BIO", "BIO2"]:
                settings = reader.Settings(
                    vocab=Vocab().data,
                    is_training=True,
                    label_schema=label_schema,
                    negative_sample_ratio=0.2,
                    hit_ans_negative_sample_ratio=0.25,
                    keep_first_b=keep_first_b)

                filename = os.path.join(topdir, "test", "trn_data.gz")
                data_stream = reader.create_reader(filename, settings)
                total, at_least_one, one = 1000, 0, 0
                for _, d in itertools.izip(xrange(total), data_stream()):
                    labels = d[reader.LABELS]
                    b_num = labels.count(0)
                    if b_num >= 1:
                        at_least_one += 1
                    if b_num == 1:
                        one += 1

                self.assertLess(at_least_one, total)
                if keep_first_b:
                    self.assertEqual(one, at_least_one)
                else:
                    self.assertLess(one, at_least_one)


class DictTest(unittest.TestCase):
    def runTest(self):
        settings = reader.Settings(
            vocab=Vocab().data,
            is_training=True,
            label_schema="BIO2",
            negative_sample_ratio=0.2,
            hit_ans_negative_sample_ratio=0.25,
            keep_first_b=True)

        filename = os.path.join(topdir, "test", "trn_data.gz")
        data_stream = reader.create_reader(filename, settings)
        q_uniq_ids, e_uniq_ids = set(), set()
        for _, d in itertools.izip(xrange(1000), data_stream()):
            q_uniq_ids.update(d[reader.Q_IDS])
            e_uniq_ids.update(d[reader.E_IDS])

        self.assertGreater(len(q_uniq_ids), 50)
        self.assertGreater(len(e_uniq_ids), 50)
P
Peng Li 已提交
107

P
Peng Li 已提交
108 109 110

if __name__ == '__main__':
    unittest.main()