test_positive_negative_pair_op.py 4.4 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import itertools
16 17
import unittest

18
import numpy as np
19
from eager_op_test import OpTest
20 21


22
def py_pnpair_op(score, label, query, column=-1, weight=None):
23 24
    # group by query id
    predictions = {}
25 26 27 28 29
    batch_size = label.shape[0]
    if weight is None:
        weight = np.ones(shape=(batch_size, 1)).astype('float32')
    for s, l, q, w in zip(score, label, query, weight):
        s, l, q, w = s[column], l[0], q[0], w[0]
30 31
        if q not in predictions:
            predictions[q] = []
32
        predictions[q].append((s, l, w))
33 34 35

    # accumulate statistics
    pos, neg, neu = 0, 0, 0
36
    for _, ranks in predictions.items():
37
        for e1, e2 in itertools.combinations(ranks, 2):
38 39
            s1, s2, l1, l2, w1, w2 = e1[0], e2[0], e1[1], e2[1], e1[2], e2[2]
            w = (w1 + w2) * 0.5
40 41 42
            if l1 == l2:
                continue
            if s1 == s2:
43
                neu += w
44
            elif (s1 - s2) * (l1 - l2) > 0:
45
                pos += w
46
            else:
47
                neg += w
48

49
    return (
50 51 52
        np.array([pos]).astype('float32'),
        np.array([neg]).astype('float32'),
        np.array([neu]).astype('float32'),
53
    )
54 55 56 57 58 59 60 61 62 63


class TestPositiveNegativePairOp(OpTest):
    def setUp(self):
        self.op_type = 'positive_negative_pair'
        batch_size = 20
        max_query_id = 5
        score = np.random.normal(size=(batch_size, 1)).astype('float32')
        label = np.random.normal(size=(batch_size, 1)).astype('float32')
        query = np.array(
64 65
            [np.random.randint(max_query_id) for i in range(batch_size)]
        )
Z
zhouxiao-coder 已提交
66
        query = np.reshape(query, newshape=(batch_size, 1)).astype('int64')
67 68

        pos, neg, neu = py_pnpair_op(score, label, query)
69 70
        self.inputs = {'Score': score, 'Label': label, 'QueryID': query}
        self.attrs = {'column': -1}
71 72 73
        self.outputs = {
            'PositivePair': pos,
            'NegativePair': neg,
74
            'NeutralPair': neu,
75 76 77
        }

    def test_check_output(self):
78 79
        # NODE(yjjiang11): This op will be deprecated.
        self.check_output(check_dygraph=False)
80 81


82 83 84 85 86 87
class TestPositiveNegativePairOpAccumulateWeight(OpTest):
    def setUp(self):
        self.op_type = 'positive_negative_pair'
        batch_size = 20
        max_query_id = 5
        max_random_num = 2 << 15
Z
zhouxiao-coder 已提交
88
        score_dim = 2
89 90 91 92
        score = np.random.normal(size=(batch_size, 2)).astype('float32')
        label = np.random.normal(size=(batch_size, 1)).astype('float32')
        weight = np.random.normal(size=(batch_size, 1)).astype('float32')
        query = np.array(
93 94
            [np.random.randint(max_query_id) for i in range(batch_size)]
        )
Z
zhouxiao-coder 已提交
95
        query = np.reshape(query, newshape=(batch_size, 1)).astype('int64')
96 97 98 99 100 101 102 103 104
        acc_pos = np.reshape(
            np.random.randint(max_random_num), newshape=(1)
        ).astype('float32')
        acc_neg = np.reshape(
            np.random.randint(max_random_num), newshape=(1)
        ).astype('float32')
        acc_neu = np.reshape(
            np.random.randint(max_random_num), newshape=(1)
        ).astype('float32')
Z
zhouxiao-coder 已提交
105
        column = np.random.randint(score_dim)
106

107 108 109
        pos, neg, neu = py_pnpair_op(
            score, label, query, column=column, weight=weight
        )
110 111 112 113 114 115 116
        self.inputs = {
            'Score': score,
            'Label': label,
            'QueryID': query,
            'AccumulatePositivePair': acc_pos,
            'AccumulateNegativePair': acc_neg,
            'AccumulateNeutralPair': acc_neu,
117
            'Weight': weight,
118 119 120 121 122
        }
        self.attrs = {'column': column}
        self.outputs = {
            'PositivePair': pos + acc_pos,
            'NegativePair': neg + acc_neg,
123
            'NeutralPair': neu + acc_neu,
124 125 126
        }

    def test_check_output(self):
127
        self.check_output(check_dygraph=False)
128 129


130 131
if __name__ == '__main__':
    unittest.main()