test_nce.py 3.3 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
import unittest
import numpy as np
from op_test import OpTest


def nce(input, weight, bias, sample_weight, labels, num_classes,
        num_sample_class):
    samples = []
    sample_labels = []
    batch_size = input.shape[0]
    num_true_class = labels.shape[1]
    for i in range(batch_size):
        w = 1 if sample_weight is None else sample_weight[i]
        for label in labels[i]:
            samples.append((i, label, True, w))
            sample_labels.append(label)
        for num in range(num_sample_class):
            samples.append((i, num, False, w))
            sample_labels.append(num)
    # forward bias
W
wanghaoshuang 已提交
21
    sample_out = np.zeros(len(samples)).astype(np.float32)
W
wanghaoshuang 已提交
22 23
    if bias is not None:
        for i in range(len(samples)):
W
wanghaoshuang 已提交
24
            sample_out[i] = bias[samples[i][1]]
W
wanghaoshuang 已提交
25 26
    # forward weight
    for i in range(len(samples)):
W
wanghaoshuang 已提交
27
        sample_out[i] += np.dot(input[samples[i][0]], weight[samples[i][1]])
W
wanghaoshuang 已提交
28 29

    # forward activation
W
wanghaoshuang 已提交
30
    sample_out = 1.0 / (1.0 + np.exp(-sample_out))
W
wanghaoshuang 已提交
31 32 33 34
    # forward cost
    out = np.zeros(batch_size).astype(np.float32)
    b = 1.0 / num_classes * num_sample_class
    for i in range(len(samples)):
W
wanghaoshuang 已提交
35
        o = sample_out[i]
W
wanghaoshuang 已提交
36 37
        cost = -np.log(o / (o + b)) if samples[i][2] else -np.log(b / (o + b))
        out[samples[i][0]] += cost * samples[i][3]
W
wanghaoshuang 已提交
38 39
    return (out, np.array(sample_out).reshape(
        batch_size, num_sample_class + num_true_class),
W
wanghaoshuang 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
            np.array(sample_labels).reshape(batch_size,
                                            num_sample_class + num_true_class))


class TestNCE(OpTest):
    def generate_data(self, dim, batch_size, num_classes, num_true_class,
                      num_sampled_classes):
        input = np.random.randn(batch_size, dim).astype(np.float32)
        weight = np.random.randn(num_classes, dim).astype(np.float32)
        bias = np.random.randn(num_classes).astype(np.float32)
        sample_weight = np.random.randn(batch_size).astype(np.float32)
        labels = np.random.randint(0, num_classes, (batch_size, num_true_class))
        self.attrs = {
            'num_classes': num_classes,
            'num_sampled_classes': num_sampled_classes,
            'sampled_labels': range(num_sampled_classes)
        }
        self.inputs = {
W
wanghaoshuang 已提交
58
            'Input': input,
W
wanghaoshuang 已提交
59
            'Label': labels,
W
wanghaoshuang 已提交
60 61
            'Weight': weight,
            'Bias': bias,
W
wanghaoshuang 已提交
62 63 64 65 66 67 68
            'SampleWeight': sample_weight
        }

    def set_data(self):
        self.generate_data(5, 5, 4, 1, 2)

    def compute(self):
W
wanghaoshuang 已提交
69 70 71 72
        out = nce(self.inputs['Input'], self.inputs['Weight'],
                  self.inputs['Bias'], self.inputs['SampleWeight'],
                  self.inputs['Label'], self.attrs['num_classes'],
                  self.attrs['num_sampled_classes'])
W
wanghaoshuang 已提交
73
        self.outputs = {
W
wanghaoshuang 已提交
74
            'Cost': out[0],
W
wanghaoshuang 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87
            'SampleLogits': out[1],
            'SampleLabels': out[2]
        }

    def setUp(self):
        self.op_type = 'nce'
        self.set_data()
        self.compute()

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
W
wanghaoshuang 已提交
88 89
        self.check_grad(
            ["Input", "Weight", "Bias"], "Cost", max_relative_error=0.02)
W
wanghaoshuang 已提交
90 91 92 93 94 95 96 97 98


class TestNCECase1(TestNCE):
    def set_data(self):
        self.generate_data(10, 20, 10, 2, 5)


if __name__ == '__main__':
    unittest.main()