test_positive_negative_pair_op.py 4.2 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16 17
import unittest
import itertools
import numpy as np
18
from op_test import OpTest
19 20


21
def py_pnpair_op(score, label, query, column=-1, weight=None):
22 23
    # group by query id
    predictions = {}
24 25 26 27 28
    batch_size = label.shape[0]
    if weight is None:
        weight = np.ones(shape=(batch_size, 1)).astype('float32')
    for s, l, q, w in zip(score, label, query, weight):
        s, l, q, w = s[column], l[0], q[0], w[0]
29 30
        if q not in predictions:
            predictions[q] = []
31
        predictions[q].append((s, l, w))
32 33 34

    # accumulate statistics
    pos, neg, neu = 0, 0, 0
35
    for _, ranks in list(predictions.items()):
36
        for e1, e2 in itertools.combinations(ranks, 2):
37 38
            s1, s2, l1, l2, w1, w2 = e1[0], e2[0], e1[1], e2[1], e1[2], e2[2]
            w = (w1 + w2) * 0.5
39 40 41
            if l1 == l2:
                continue
            if s1 == s2:
42
                neu += w
43
            elif (s1 - s2) * (l1 - l2) > 0:
44
                pos += w
45
            else:
46
                neg += w
47 48 49 50 51 52 53 54 55 56 57 58 59 60

    return np.array(pos).astype('float32'), np.array(neg).astype(
        'float32'), np.array(neu).astype('float32')


class TestPositiveNegativePairOp(OpTest):
    def setUp(self):
        self.op_type = 'positive_negative_pair'
        batch_size = 20
        max_query_id = 5
        score = np.random.normal(size=(batch_size, 1)).astype('float32')
        label = np.random.normal(size=(batch_size, 1)).astype('float32')
        query = np.array(
            [np.random.randint(max_query_id) for i in range(batch_size)])
Z
zhouxiao-coder 已提交
61
        query = np.reshape(query, newshape=(batch_size, 1)).astype('int64')
62 63

        pos, neg, neu = py_pnpair_op(score, label, query)
64 65
        self.inputs = {'Score': score, 'Label': label, 'QueryID': query}
        self.attrs = {'column': -1}
66 67 68 69 70 71 72 73 74 75
        self.outputs = {
            'PositivePair': pos,
            'NegativePair': neg,
            'NeutralPair': neu
        }

    def test_check_output(self):
        self.check_output()


76 77 78 79 80 81
class TestPositiveNegativePairOpAccumulateWeight(OpTest):
    def setUp(self):
        self.op_type = 'positive_negative_pair'
        batch_size = 20
        max_query_id = 5
        max_random_num = 2 << 15
Z
zhouxiao-coder 已提交
82
        score_dim = 2
83 84 85 86 87
        score = np.random.normal(size=(batch_size, 2)).astype('float32')
        label = np.random.normal(size=(batch_size, 1)).astype('float32')
        weight = np.random.normal(size=(batch_size, 1)).astype('float32')
        query = np.array(
            [np.random.randint(max_query_id) for i in range(batch_size)])
Z
zhouxiao-coder 已提交
88
        query = np.reshape(query, newshape=(batch_size, 1)).astype('int64')
89 90 91 92 93 94
        acc_pos = np.reshape(
            np.random.randint(max_random_num), newshape=(1)).astype('float32')
        acc_neg = np.reshape(
            np.random.randint(max_random_num), newshape=(1)).astype('float32')
        acc_neu = np.reshape(
            np.random.randint(max_random_num), newshape=(1)).astype('float32')
Z
zhouxiao-coder 已提交
95
        column = np.random.randint(score_dim)
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118

        pos, neg, neu = py_pnpair_op(
            score, label, query, column=column, weight=weight)
        self.inputs = {
            'Score': score,
            'Label': label,
            'QueryID': query,
            'AccumulatePositivePair': acc_pos,
            'AccumulateNegativePair': acc_neg,
            'AccumulateNeutralPair': acc_neu,
            'Weight': weight
        }
        self.attrs = {'column': column}
        self.outputs = {
            'PositivePair': pos + acc_pos,
            'NegativePair': neg + acc_neg,
            'NeutralPair': neu + acc_neu
        }

    def test_check_output(self):
        self.check_output()


119 120
if __name__ == '__main__':
    unittest.main()