test_nce.py 3.9 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

W
wanghaoshuang 已提交
17 18
import unittest
import numpy as np
19
from op_test import OpTest
W
wanghaoshuang 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36


def nce(input, weight, bias, sample_weight, labels, num_classes,
        num_sample_class):
    samples = []
    sample_labels = []
    batch_size = input.shape[0]
    num_true_class = labels.shape[1]
    for i in range(batch_size):
        w = 1 if sample_weight is None else sample_weight[i]
        for label in labels[i]:
            samples.append((i, label, True, w))
            sample_labels.append(label)
        for num in range(num_sample_class):
            samples.append((i, num, False, w))
            sample_labels.append(num)
    # forward bias
W
wanghaoshuang 已提交
37
    sample_out = np.zeros(len(samples)).astype(np.float32)
W
wanghaoshuang 已提交
38 39
    if bias is not None:
        for i in range(len(samples)):
W
wanghaoshuang 已提交
40
            sample_out[i] = bias[samples[i][1]]
W
wanghaoshuang 已提交
41 42
    # forward weight
    for i in range(len(samples)):
W
wanghaoshuang 已提交
43
        sample_out[i] += np.dot(input[samples[i][0]], weight[samples[i][1]])
W
wanghaoshuang 已提交
44 45

    # forward activation
W
wanghaoshuang 已提交
46
    sample_out = 1.0 / (1.0 + np.exp(-sample_out))
W
wanghaoshuang 已提交
47 48 49 50
    # forward cost
    out = np.zeros(batch_size).astype(np.float32)
    b = 1.0 / num_classes * num_sample_class
    for i in range(len(samples)):
W
wanghaoshuang 已提交
51
        o = sample_out[i]
W
wanghaoshuang 已提交
52 53
        cost = -np.log(o / (o + b)) if samples[i][2] else -np.log(b / (o + b))
        out[samples[i][0]] += cost * samples[i][3]
W
wanghaoshuang 已提交
54
    return (out[:, np.newaxis], np.array(sample_out).reshape(
W
wanghaoshuang 已提交
55
        batch_size, num_sample_class + num_true_class),
W
wanghaoshuang 已提交
56 57 58 59 60 61
            np.array(sample_labels).reshape(batch_size,
                                            num_sample_class + num_true_class))


class TestNCE(OpTest):
    def generate_data(self, dim, batch_size, num_classes, num_true_class,
W
wanghaoshuang 已提交
62
                      num_neg_samples):
W
wanghaoshuang 已提交
63 64 65 66 67 68
        input = np.random.randn(batch_size, dim).astype(np.float32)
        weight = np.random.randn(num_classes, dim).astype(np.float32)
        bias = np.random.randn(num_classes).astype(np.float32)
        sample_weight = np.random.randn(batch_size).astype(np.float32)
        labels = np.random.randint(0, num_classes, (batch_size, num_true_class))
        self.attrs = {
W
wanghaoshuang 已提交
69 70
            'num_total_classes': num_classes,
            'num_neg_samples': num_neg_samples,
71
            'custom_neg_classes': list(range(num_neg_samples))
W
wanghaoshuang 已提交
72 73
        }
        self.inputs = {
W
wanghaoshuang 已提交
74
            'Input': input,
W
wanghaoshuang 已提交
75
            'Label': labels,
W
wanghaoshuang 已提交
76 77
            'Weight': weight,
            'Bias': bias,
W
wanghaoshuang 已提交
78 79 80 81 82 83 84
            'SampleWeight': sample_weight
        }

    def set_data(self):
        self.generate_data(5, 5, 4, 1, 2)

    def compute(self):
W
wanghaoshuang 已提交
85 86
        out = nce(self.inputs['Input'], self.inputs['Weight'],
                  self.inputs['Bias'], self.inputs['SampleWeight'],
W
wanghaoshuang 已提交
87 88
                  self.inputs['Label'], self.attrs['num_total_classes'],
                  self.attrs['num_neg_samples'])
W
wanghaoshuang 已提交
89
        self.outputs = {
W
wanghaoshuang 已提交
90
            'Cost': out[0],
W
wanghaoshuang 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103
            'SampleLogits': out[1],
            'SampleLabels': out[2]
        }

    def setUp(self):
        self.op_type = 'nce'
        self.set_data()
        self.compute()

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
W
wanghaoshuang 已提交
104 105
        self.check_grad(
            ["Input", "Weight", "Bias"], "Cost", max_relative_error=0.02)
W
wanghaoshuang 已提交
106 107 108 109 110 111 112 113 114


class TestNCECase1(TestNCE):
    def set_data(self):
        self.generate_data(10, 20, 10, 2, 5)


if __name__ == '__main__':
    unittest.main()