diff --git a/paddle/operators/positive_negative_pair_op.cc b/paddle/operators/positive_negative_pair_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..afbb63cc6053e2968750c163f9ac194d0f5b93e0 --- /dev/null +++ b/paddle/operators/positive_negative_pair_op.cc @@ -0,0 +1,177 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/positive_negative_pair_op.h" + +namespace paddle { +namespace operators { + +class PositiveNegativePairOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput("Score"), + "Input(Score) of PositiveNegativePairOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("Label"), + "Input(Label) of PositiveNegativePairOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("QueryID"), + "Input(QueryID) of PositiveNegativePairOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("PositivePair"), + "Output(PositivePair) of PositiveNegativePairOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("NegativePair"), + "Output(NegativePair) of PositiveNegativePairOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("NeutralPair"), + "Output(NeutralPair) of PositiveNegativePairOp should not be null."); + auto scalar_dim = framework::make_ddim({1}); + if (ctx->HasInput("AccumulatePositivePair") || + ctx->HasInput("AccumulateNegativePair") || + ctx->HasInput("AccumulateNeutralPair")) { + PADDLE_ENFORCE(ctx->HasInput("AccumulatePositivePair") && + ctx->HasInput("AccumulateNegativePair") && + ctx->HasInput("AccumulateNeutralPair"), + "All optional inputs(AccumulatePositivePair, " + "AccumulateNegativePair, AccumulateNeutralPair) of " + "PositiveNegativePairOp are required if one of them is " + "specified."); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("AccumulatePositivePair"), scalar_dim, + "Shape of AccumulatePositivePair should be {1}."); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("AccumulateNegativePair"), scalar_dim, + "Shape of AccumulateNegativePair should be {1}."); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("AccumulateNeutralPair"), scalar_dim, + "Shape of AccumulateNeutralPair should be {1}."); + } + + auto score_dim = ctx->GetInputDim("Score"); + auto label_dim = ctx->GetInputDim("Label"); + auto query_dim = ctx->GetInputDim("QueryID"); + PADDLE_ENFORCE_EQ(score_dim.size(), 2, "Score should be a 2-D tensor."); + PADDLE_ENFORCE_EQ(label_dim.size(), 2, "Label should be a 2-D tensor."); + PADDLE_ENFORCE_EQ( + label_dim[0], score_dim[0], + "Tensor Score and Label should have the same height (batch size)."); + PADDLE_ENFORCE_EQ(label_dim[1], 1, + "The width of Label should be 1, i.e. each item should " + "have a scalar label."); + PADDLE_ENFORCE(query_dim == label_dim, + "QueryID should have the same shape as Label."); + if (ctx->HasInput("Weight")) { + PADDLE_ENFORCE(ctx->GetInputDim("Weight") == label_dim, + "Weight should have the same shape as Label."); + } + int column = ctx->Attrs().Get("column"); + auto depth = score_dim[1]; + PADDLE_ENFORCE(column < depth && column >= -depth, + "Attribute column should be in the range of [-%l, %l)", + depth, depth); + + ctx->SetOutputDim("PositivePair", scalar_dim); + ctx->SetOutputDim("NegativePair", scalar_dim); + ctx->SetOutputDim("NeutralPair", scalar_dim); + } + + protected: + framework::DataType IndicateDataType( + const framework::ExecutionContext &ctx) const override { + return framework::ToDataType(ctx.Input("Score")->type()); + } +}; + +class PositiveNegativePairOpMaker : public framework::OpProtoAndCheckerMaker { + public: + PositiveNegativePairOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Score", + "(Tensor, float) Model Score on an item (with " + "respect to QueryID). It's a 2-D tensor with shape [batch_size, " + "depth], where the column specified by the attribute \"column\" " + "is used as item score."); + AddInput("Label", + "(Tensor, float) Label of an item (with repsect to " + "QueryId). It's a 2-D tensor with shape [batch_size, 1]."); + AddInput("QueryID", + "(Tensor, int64) Query ID that indicates the context. Its shape " + "should be the same as Label."); + AddInput( + "AccumulatePositivePair", + "(float) Optional. The accumulated number of positive pairs over a " + "stream of data. If provided, the output PositivePair will be " + "initialized with this number rather than 0. it won't be modified " + "in place.") + .AsDispensable(); + AddInput( + "AccumulateNegativePair", + "(float) Optional. The accumulated number of negative pairs over a " + "stream of data. If provided, the output NegativePair will be " + "initialized with this number rather than 0. it won't be modified " + "in place.") + .AsDispensable(); + AddInput("AccumulateNeutralPair", + "(float) Optional. The accumulated number of neutral pairs over a " + "stream of data. If provided, the output NeutralPair will be " + "initialized with this number rather than 0. it won't be modified " + "in place.") + .AsDispensable(); + AddInput("Weight", + "(float) Optional. Weight of current item. If specified, its " + "shape should be the same as Label, and the meaning of the output " + "changes from numbers of pairs to the total sum of pairs' " + "weights. Weight of a pair of items is the average of their " + "weights.") + .AsDispensable(); + AddOutput("PositivePair", + "(float) Number of positive pairs, i.e. the pairs of " + "items that are ranked correctly."); + AddOutput("NegativePair", + "(float) Number of negative pairs, i.e. the pairs of " + "items that are ranked incorrectly."); + AddOutput("NeutralPair", + "(float) Number of neutral pairs, i.e. the pairs of items " + "that have the same score.") + .AsDispensable(); + AddAttr( + "column", + "(int, default -1) The column position of Score used to rank items in " + "descending order. It must be in the range of [-rank(Score), " + "rank(Score)). " + "If `dim < 0`, the dim to reduce is `rank + dim`. " + "Noting that reducing on the first dim will make the LoD info lost.") + .SetDefault(0); + AddComment(R"DOC( + PositiveNegativePairOp can be used to evaluate Learning To Rank(LTR) + model performance. + Within some context, e.g. the "query", a LTR model generates scores + for a list of items, which gives a partial order of the items. + PositiveNegativePairOp takes a list of reference rank order + (Input("Label")) and the model generated scores (Input(Score)) as + inputs and counts the pairs that ranked correctly and incorrectly. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(positive_negative_pair, + ops::PositiveNegativePairOp, + ops::PositiveNegativePairOpMaker); +REGISTER_OP_CPU_KERNEL( + positive_negative_pair, + ops::PositiveNegativePairKernel, + ops::PositiveNegativePairKernel); diff --git a/paddle/operators/positive_negative_pair_op.h b/paddle/operators/positive_negative_pair_op.h new file mode 100644 index 0000000000000000000000000000000000000000..2efd3777e04c17b27c07bccde524de5785af35fe --- /dev/null +++ b/paddle/operators/positive_negative_pair_op.h @@ -0,0 +1,114 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/utils/Logging.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +class PositiveNegativePairKernel : public framework::OpKernel { + public: + struct PredictionResult { + PredictionResult(T score, T label, T weight) + : score(score), label(label), weight(weight) {} + T score; + T label; + T weight; + }; + + void Compute(const framework::ExecutionContext& context) const override { + auto score_t = context.Input("Score"); + auto label_t = context.Input("Label"); + auto query_t = context.Input("QueryID"); + auto acc_positive_t = context.Input("AccumulatePositivePair"); + auto acc_negative_t = context.Input("AccumulateNegativePair"); + auto acc_neutral_t = context.Input("AccumulateNeutralPair"); + auto positive_t = context.Output("PositivePair"); + auto negative_t = context.Output("NegativePair"); + auto neutral_t = context.Output("NeutralPair"); + auto weight_t = context.Input("Weight"); + + auto score = score_t->data(); + auto label = label_t->data(); + auto query = query_t->data(); + const T* weight = nullptr; + if (weight_t != nullptr) { + weight = weight_t->data(); + } + T* positive = positive_t->mutable_data(context.GetPlace()); + T* negative = negative_t->mutable_data(context.GetPlace()); + T* neutral = neutral_t->mutable_data(context.GetPlace()); + + auto score_dim = score_t->dims(); + auto batch_size = score_dim[0]; + auto width = score_dim[1]; + auto column = context.Attr("column"); + if (column < 0) { + column += width; + } + + // construct document instances for each query: Query => List[, ...] + std::unordered_map> predictions; + for (auto i = 0; i < batch_size; ++i) { + if (predictions.find(query[i]) == predictions.end()) { + predictions.emplace( + std::make_pair(query[i], std::vector())); + } + predictions[query[i]].emplace_back(score[i * width + column], label[i], + weight_t != nullptr ? weight[i] : 1.0); + } + + // for each query, accumulate pair counts + T pos = 0, neg = 0, neu = 0; + if (acc_positive_t != nullptr && acc_negative_t != nullptr && + acc_neutral_t != nullptr) { + pos = acc_positive_t->data()[0]; + neg = acc_negative_t->data()[0]; + neu = acc_neutral_t->data()[0]; + } + auto evaluate_one_list = [&pos, &neg, + &neu](std::vector vec) { + for (auto ite1 = vec.begin(); ite1 != vec.end(); ++ite1) { + for (auto ite2 = ite1 + 1; ite2 != vec.end(); ++ite2) { + if (ite1->label == ite2->label) { // labels are equal, ignore. + continue; + } + T w = (ite1->weight + ite2->weight) * 0.5; + if (ite1->score == ite2->score) { + neu += w; + } + (ite1->score - ite2->score) * (ite1->label - ite2->label) > 0.0 + ? pos += w + : neg += w; + } + } + }; + for (auto prediction : predictions) { + evaluate_one_list(prediction.second); + } + *positive = pos; + *negative = neg; + *neutral = neu; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_positive_negative_pair_op.py b/python/paddle/v2/framework/tests/test_positive_negative_pair_op.py new file mode 100644 index 0000000000000000000000000000000000000000..f6a6c428a26dece01fe2958991edd3edf3a8266e --- /dev/null +++ b/python/paddle/v2/framework/tests/test_positive_negative_pair_op.py @@ -0,0 +1,106 @@ +import unittest +import itertools +import numpy as np +from op_test import OpTest + + +def py_pnpair_op(score, label, query, column=-1, weight=None): + # group by query id + predictions = {} + batch_size = label.shape[0] + if weight is None: + weight = np.ones(shape=(batch_size, 1)).astype('float32') + for s, l, q, w in zip(score, label, query, weight): + s, l, q, w = s[column], l[0], q[0], w[0] + if q not in predictions: + predictions[q] = [] + predictions[q].append((s, l, w)) + + # accumulate statistics + pos, neg, neu = 0, 0, 0 + for _, ranks in predictions.items(): + for e1, e2 in itertools.combinations(ranks, 2): + s1, s2, l1, l2, w1, w2 = e1[0], e2[0], e1[1], e2[1], e1[2], e2[2] + w = (w1 + w2) * 0.5 + if l1 == l2: + continue + if s1 == s2: + neu += w + elif (s1 - s2) * (l1 - l2) > 0: + pos += w + else: + neg += w + + return np.array(pos).astype('float32'), np.array(neg).astype( + 'float32'), np.array(neu).astype('float32') + + +class TestPositiveNegativePairOp(OpTest): + def setUp(self): + self.op_type = 'positive_negative_pair' + batch_size = 20 + max_query_id = 5 + score = np.random.normal(size=(batch_size, 1)).astype('float32') + label = np.random.normal(size=(batch_size, 1)).astype('float32') + query = np.array( + [np.random.randint(max_query_id) for i in range(batch_size)]) + query = np.reshape(query, newshape=(batch_size, 1)).astype('int64') + + pos, neg, neu = py_pnpair_op(score, label, query) + self.inputs = {'Score': score, 'Label': label, 'QueryID': query} + self.attrs = {'column': -1} + self.outputs = { + 'PositivePair': pos, + 'NegativePair': neg, + 'NeutralPair': neu + } + + def test_check_output(self): + self.check_output() + + +class TestPositiveNegativePairOpAccumulateWeight(OpTest): + def setUp(self): + self.op_type = 'positive_negative_pair' + batch_size = 20 + max_query_id = 5 + max_random_num = 2 << 15 + score_dim = 2 + score = np.random.normal(size=(batch_size, 2)).astype('float32') + label = np.random.normal(size=(batch_size, 1)).astype('float32') + weight = np.random.normal(size=(batch_size, 1)).astype('float32') + query = np.array( + [np.random.randint(max_query_id) for i in range(batch_size)]) + query = np.reshape(query, newshape=(batch_size, 1)).astype('int64') + acc_pos = np.reshape( + np.random.randint(max_random_num), newshape=(1)).astype('float32') + acc_neg = np.reshape( + np.random.randint(max_random_num), newshape=(1)).astype('float32') + acc_neu = np.reshape( + np.random.randint(max_random_num), newshape=(1)).astype('float32') + column = np.random.randint(score_dim) + + pos, neg, neu = py_pnpair_op( + score, label, query, column=column, weight=weight) + self.inputs = { + 'Score': score, + 'Label': label, + 'QueryID': query, + 'AccumulatePositivePair': acc_pos, + 'AccumulateNegativePair': acc_neg, + 'AccumulateNeutralPair': acc_neu, + 'Weight': weight + } + self.attrs = {'column': column} + self.outputs = { + 'PositivePair': pos + acc_pos, + 'NegativePair': neg + acc_neg, + 'NeutralPair': neu + acc_neu + } + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main()