test_nce.py 3.9 KB
Newer Older
D
dzhwinter 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
W
wanghaoshuang 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
import unittest
import numpy as np
from op_test import OpTest


def nce(input, weight, bias, sample_weight, labels, num_classes,
        num_sample_class):
    samples = []
    sample_labels = []
    batch_size = input.shape[0]
    num_true_class = labels.shape[1]
    for i in range(batch_size):
        w = 1 if sample_weight is None else sample_weight[i]
        for label in labels[i]:
            samples.append((i, label, True, w))
            sample_labels.append(label)
        for num in range(num_sample_class):
            samples.append((i, num, False, w))
            sample_labels.append(num)
    # forward bias
W
wanghaoshuang 已提交
34
    sample_out = np.zeros(len(samples)).astype(np.float32)
W
wanghaoshuang 已提交
35 36
    if bias is not None:
        for i in range(len(samples)):
W
wanghaoshuang 已提交
37
            sample_out[i] = bias[samples[i][1]]
W
wanghaoshuang 已提交
38 39
    # forward weight
    for i in range(len(samples)):
W
wanghaoshuang 已提交
40
        sample_out[i] += np.dot(input[samples[i][0]], weight[samples[i][1]])
W
wanghaoshuang 已提交
41 42

    # forward activation
W
wanghaoshuang 已提交
43
    sample_out = 1.0 / (1.0 + np.exp(-sample_out))
W
wanghaoshuang 已提交
44 45 46 47
    # forward cost
    out = np.zeros(batch_size).astype(np.float32)
    b = 1.0 / num_classes * num_sample_class
    for i in range(len(samples)):
W
wanghaoshuang 已提交
48
        o = sample_out[i]
W
wanghaoshuang 已提交
49 50
        cost = -np.log(o / (o + b)) if samples[i][2] else -np.log(b / (o + b))
        out[samples[i][0]] += cost * samples[i][3]
W
wanghaoshuang 已提交
51
    return (out[:, np.newaxis], np.array(sample_out).reshape(
W
wanghaoshuang 已提交
52
        batch_size, num_sample_class + num_true_class),
W
wanghaoshuang 已提交
53 54 55 56 57 58
            np.array(sample_labels).reshape(batch_size,
                                            num_sample_class + num_true_class))


class TestNCE(OpTest):
    def generate_data(self, dim, batch_size, num_classes, num_true_class,
W
wanghaoshuang 已提交
59
                      num_neg_samples):
W
wanghaoshuang 已提交
60 61 62 63 64 65
        input = np.random.randn(batch_size, dim).astype(np.float32)
        weight = np.random.randn(num_classes, dim).astype(np.float32)
        bias = np.random.randn(num_classes).astype(np.float32)
        sample_weight = np.random.randn(batch_size).astype(np.float32)
        labels = np.random.randint(0, num_classes, (batch_size, num_true_class))
        self.attrs = {
W
wanghaoshuang 已提交
66 67 68
            'num_total_classes': num_classes,
            'num_neg_samples': num_neg_samples,
            'custom_neg_classes': range(num_neg_samples)
W
wanghaoshuang 已提交
69 70
        }
        self.inputs = {
W
wanghaoshuang 已提交
71
            'Input': input,
W
wanghaoshuang 已提交
72
            'Label': labels,
W
wanghaoshuang 已提交
73 74
            'Weight': weight,
            'Bias': bias,
W
wanghaoshuang 已提交
75 76 77 78 79 80 81
            'SampleWeight': sample_weight
        }

    def set_data(self):
        self.generate_data(5, 5, 4, 1, 2)

    def compute(self):
W
wanghaoshuang 已提交
82 83
        out = nce(self.inputs['Input'], self.inputs['Weight'],
                  self.inputs['Bias'], self.inputs['SampleWeight'],
W
wanghaoshuang 已提交
84 85
                  self.inputs['Label'], self.attrs['num_total_classes'],
                  self.attrs['num_neg_samples'])
W
wanghaoshuang 已提交
86
        self.outputs = {
W
wanghaoshuang 已提交
87
            'Cost': out[0],
W
wanghaoshuang 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100
            'SampleLogits': out[1],
            'SampleLabels': out[2]
        }

    def setUp(self):
        self.op_type = 'nce'
        self.set_data()
        self.compute()

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
W
wanghaoshuang 已提交
101 102
        self.check_grad(
            ["Input", "Weight", "Bias"], "Cost", max_relative_error=0.02)
W
wanghaoshuang 已提交
103 104 105 106 107 108 109 110 111


class TestNCECase1(TestNCE):
    def set_data(self):
        self.generate_data(10, 20, 10, 2, 5)


if __name__ == '__main__':
    unittest.main()