提交 e68a217f 编写于 作者: Z zhouxiao-coder

Add optional inputs and outputs to enable updating;Add weight to match original implementation

上级 9b0f0928
......@@ -26,8 +26,8 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
ctx->HasInput("Label"),
"Input(Label) of PositiveNegativePairOp should not be null.");
PADDLE_ENFORCE(
ctx->HasInput("QueryId"),
"Input(QueryId) of PositiveNegativePairOp should not be null.");
ctx->HasInput("QueryID"),
"Input(QueryID) of PositiveNegativePairOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("PositivePair"),
"Output(PositivePair) of PositiveNegativePairOp should not be null.");
......@@ -37,21 +37,51 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(
ctx->HasOutput("NeutralPair"),
"Output(NeutralPair) of PositiveNegativePairOp should not be null.");
auto scalar_dim = framework::make_ddim({1});
if (ctx->HasInput("AccumulatePositivePair") ||
ctx->HasInput("AccumulateNegativePair") ||
ctx->HasInput("AccumulateNeutralPair")) {
PADDLE_ENFORCE(ctx->HasInput("AccumulatePositivePair") &&
ctx->HasInput("AccumulateNegativePair") &&
ctx->HasInput("AccumulateNeutralPair"),
"All optional inputs(AccumulatePositivePair, "
"AccumulateNegativePair, AccumulateNeutralPair) of "
"PositiveNegativePairOp are required if one of them is "
"specified.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("AccumulatePositivePair"), scalar_dim,
"Shape of AccumulatePositivePair should be {1}.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("AccumulateNegativePair"), scalar_dim,
"Shape of AccumulateNegativePair should be {1}.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("AccumulateNeutralPair"), scalar_dim,
"Shape of AccumulateNeutralPair should be {1}.");
}
auto score_dim = ctx->GetInputDim("Score");
auto label_dim = ctx->GetInputDim("Label");
auto query_dim = ctx->GetInputDim("QueryId");
PADDLE_ENFORCE(score_dim == label_dim,
"Shape of Score must be the same as Label's shape.");
PADDLE_ENFORCE(query_dim == label_dim,
"Shape of QueryId must be the same as Label's shape.");
auto query_dim = ctx->GetInputDim("QueryID");
PADDLE_ENFORCE_EQ(score_dim.size(), 2, "Score should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(label_dim.size(), 2, "Label should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(
label_dim[0], score_dim[0],
"Tensor Score and Label should have the same height (batch size).");
PADDLE_ENFORCE_EQ(label_dim[1], 1,
"The width of Label should be 1, i.e. each item should "
"have a scalar label.");
PADDLE_ENFORCE(query_dim == label_dim,
"Shape of QueryId must be the same as Label's shape.");
"QueryID should have the same shape as Label.");
if (ctx->HasInput("Weight")) {
PADDLE_ENFORCE(ctx->GetInputDim("Weight") == label_dim,
"Weight should have the same shape as Label.");
}
int column = ctx->Attrs().Get<int>("column");
auto depth = score_dim[1];
PADDLE_ENFORCE(column < depth && column >= -depth,
"Attribute column should be in the range of [-%l, %l)",
depth, depth);
ctx->SetOutputDim("PositivePair", {1});
ctx->SetOutputDim("NegativePair", {1});
ctx->SetOutputDim("NeutralPair", {1});
ctx->SetOutputDim("PositivePair", scalar_dim);
ctx->SetOutputDim("NegativePair", scalar_dim);
ctx->SetOutputDim("NeutralPair", scalar_dim);
}
protected:
......@@ -67,27 +97,62 @@ class PositiveNegativePairOpMaker : public framework::OpProtoAndCheckerMaker {
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Score",
"(Tensor, float) Output score of the network on <query, document> "
"pair.");
"(Tensor, float) Model Score on an item (with "
"respect to QueryID). It's a 2-D tensor with shape [batch_size, "
"depth], where the column specified by the attribute \"column\" "
"is used as item score.");
AddInput("Label",
"(Tensor, float or int) Label of current <query, document> pair.");
AddInput("QueryId",
"(Tensor, int) query id of current <query, document> pair.");
"(Tensor, float) Label of an item (with repsect to "
"QueryId). It's a 2-D tensor with shape [batch_size, 1].");
AddInput("QueryID",
"(Tensor, int) Query ID that indicates the context. Its shape "
"should be the same as Label.");
AddInput(
"AccumulatePositivePair",
"(float) Optional. The accumulated number of positive pairs over a "
"stream of data. If provided, the output PositivePair will be "
"initialized with this number rather than 0. it won't be modified "
"in place.")
.AsDispensable();
AddInput(
"AccumulateNegativePair",
"(float) Optional. The accumulated number of negative pairs over a "
"stream of data. If provided, the output NegativePair will be "
"initialized with this number rather than 0. it won't be modified "
"in place.")
.AsDispensable();
AddInput("AccumulateNeutralPair",
"(float) Optional. The accumulated number of neutral pairs over a "
"stream of data. If provided, the output NeutralPair will be "
"initialized with this number rather than 0. it won't be modified "
"in place.")
.AsDispensable();
AddInput("Weight",
"(float) Optional. Weight of current item. If specified, its "
"shape should be the same as Label.")
.AsDispensable();
AddOutput("PositivePair",
"(float) Number of positive ranking pairs, i.e. the pairs of "
"documents that are ranked correctly");
"(float) Number of positive pairs, i.e. the pairs of "
"items that are ranked correctly.");
AddOutput("NegativePair",
"(float) Number of negative ranking pairs, i.e. the pairs of "
"documents that are ranked incorrectly");
"(float) Number of negative pairs, i.e. the pairs of "
"items that are ranked incorrectly.");
AddOutput("NeutralPair",
"(float) Number of neutral ranking pairs. A pair of document "
"(doc#1, doc#2) is classified as \"neutral\" if their scores are "
"the same.");
"(float) Number of neutral pairs, i.e. the pairs of items "
"that have the same score.")
.AsDispensable();
AddAttr<int>(
"column",
"(int, default -1) The column position of Score used to rank items in "
"descending order. It must be in the range of [-rank(Score), "
"rank(Score)). "
"If `dim < 0`, the dim to reduce is `rank + dim`. "
"Noting that reducing on the first dim will make the LoD info lost.")
.SetDefault(0);
AddComment(R"DOC(
PositiveNegativePairOp can be used to evaluate Learning To Rank(LTR) model performance. Its outputs are usually
further summarized as positive-negative-ratio: PositivePair/NegativePair.
Its 3 inputs can be viewd as a series of 3 tuples: (predicition score, golden label, query id).
For each unique query id, a list of <score, label> are collected and positive/negative pairs are accumulated to its output.
PositiveNegativePairOp can be used to evaluate Learning To Rank(LTR) model performance.
Within some context, e.g. the "query", a LTR model generates scores for a list of items, which gives a partial order of the items.
PositiveNegativePairOp takes a list of reference rank order (Input("Label")) and the model generated scores (Input(Score)) as inputs and counts the pairs that ranked correctly and incorrectly.
)DOC");
}
};
......@@ -101,4 +166,5 @@ REGISTER_OP_WITHOUT_GRADIENT(positive_negative_pair,
ops::PositiveNegativePairOpMaker);
REGISTER_OP_CPU_KERNEL(
positive_negative_pair,
ops::PositiveNegativePairKernel<paddle::platform::CPUPlace, float>);
ops::PositiveNegativePairKernel<paddle::platform::CPUPlace, float>,
ops::PositiveNegativePairKernel<paddle::platform::CPUPlace, double>);
......@@ -14,6 +14,7 @@ limitations under the License. */
#include <vector>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/utils/Logging.h"
namespace paddle {
namespace operators {
......@@ -24,64 +25,86 @@ using LoDTensor = framework::LoDTensor;
template <typename Place, typename T>
class PositiveNegativePairKernel : public framework::OpKernel<T> {
public:
struct PredictionResult {
PredictionResult(T score, T label, T weight)
: score(score), label(label), weight(weight) {}
T score;
T label;
T weight;
};
void Compute(const framework::ExecutionContext& context) const override {
auto score_t = context.Input<Tensor>("Score");
auto label_t = context.Input<Tensor>("Label");
auto query_t = context.Input<Tensor>("QueryId");
auto query_t = context.Input<Tensor>("QueryID");
auto acc_positive_t = context.Input<Tensor>("AccumulatePositivePair");
auto acc_negative_t = context.Input<Tensor>("AccumulateNegativePair");
auto acc_neutral_t = context.Input<Tensor>("AccumulateNeutralPair");
auto positive_t = context.Output<Tensor>("PositivePair");
auto negative_t = context.Output<Tensor>("NegativePair");
auto neutral_t = context.Output<Tensor>("NeutralPair");
auto weight_t = context.Input<Tensor>("Weight");
auto score = score_t->data<float>();
auto label = label_t->data<float>();
auto score = score_t->data<T>();
auto label = label_t->data<T>();
auto query = query_t->data<int32_t>();
const T* weight = nullptr;
auto has_weight = weight_t != nullptr;
if (has_weight) {
weight = weight_t->data<T>();
}
T* positive = positive_t->mutable_data<T>(context.GetPlace());
T* negative = negative_t->mutable_data<T>(context.GetPlace());
T* neutral = neutral_t->mutable_data<T>(context.GetPlace());
auto score_dim = score_t->dims();
PADDLE_ENFORCE_GE(score_dim.size(), 1L,
"Rank of Score must be at least 1.");
PADDLE_ENFORCE_LE(score_dim.size(), 2L,
"Rank of Score must be less or equal to 2.");
auto batch_size = score_dim[0];
auto width = score_dim.size() > 1 ? score_dim[1] : 1;
auto width = score_dim[1];
auto column = context.Attr<int32_t>("column");
if (column < 0) {
column += width;
}
// construct document instances for each query: Query => List[<score#0,
// label#0>, ...]
std::unordered_map<int, std::vector<std::pair<float, float>>> predictions;
std::unordered_map<int32_t, std::vector<PredictionResult>> predictions;
for (auto i = 0; i < batch_size; ++i) {
if (predictions.find(query[i]) == predictions.end()) {
predictions.emplace(
std::make_pair(query[i], std::vector<std::pair<float, float>>()));
std::make_pair(query[i], std::vector<PredictionResult>()));
}
predictions[query[i]].push_back(
std::make_pair(score[i * width + width - 1], label[i]));
predictions[query[i]].push_back(PredictionResult(
score[i * width + column], label[i], has_weight ? weight[i] : 1.0));
}
// for each query, accumulate pair counts
T pos = 0, neg = 0, neu = 0;
if (acc_positive_t != nullptr && acc_negative_t != nullptr &&
acc_neutral_t != nullptr) {
pos = acc_positive_t->data<T>()[0];
neg = acc_negative_t->data<T>()[0];
neu = acc_neutral_t->data<T>()[0];
}
auto evaluate_one_list = [&pos, &neg,
&neu](std::vector<std::pair<float, float>> vec) {
&neu](std::vector<PredictionResult> vec) {
for (auto ite1 = vec.begin(); ite1 != vec.end(); ++ite1) {
for (auto ite2 = ite1 + 1; ite2 != vec.end(); ++ite2) {
if (ite1->second == ite2->second) { // labels are equal, ignore.
if (ite1->label == ite2->label) { // labels are equal, ignore.
continue;
}
if (ite1->first == ite2->first) {
++neu;
T w = (ite1->weight + ite2->weight) * 0.5;
if (ite1->score == ite2->score) {
neu += w;
}
(ite1->first - ite2->first) * (ite1->second - ite2->second) > 0.0
? pos++
: neg++;
(ite1->score - ite2->score) * (ite1->label - ite2->label) > 0.0
? pos += w
: neg += w;
}
}
};
for (auto prediction : predictions) {
evaluate_one_list(prediction.second);
}
*positive = pos;
*negative = neg;
*neutral = neu;
......
......@@ -4,30 +4,36 @@ import numpy as np
from op_test import OpTest
def py_pnpair_op(score, label, query):
def py_pnpair_op(score, label, query, column=-1, weight=None):
# group by query id
predictions = {}
for s, l, q in zip(score, label, query):
if type(s) is list:
s = s[-1]
q = q[0]
batch_size = label.shape[0]
print "batch_size=", batch_size
if weight is None:
weight = np.ones(shape=(batch_size, 1)).astype('float32')
for s, l, q, w in zip(score, label, query, weight):
# s = s[column]
# q = q[0]
# w = w[0]
s, l, q, w = s[column], l[0], q[0], w[0]
if q not in predictions:
predictions[q] = []
predictions[q].append((s, l))
predictions[q].append((s, l, w))
# accumulate statistics
pos, neg, neu = 0, 0, 0
for _, ranks in predictions.items():
for e1, e2 in itertools.combinations(ranks, 2):
s1, s2, l1, l2 = e1[0][0], e2[0][0], e1[1][0], e2[1][0]
s1, s2, l1, l2, w1, w2 = e1[0], e2[0], e1[1], e2[1], e1[2], e2[2]
w = (w1 + w2) * 0.5
if l1 == l2:
continue
if s1 == s2:
neu += 1
neu += w
elif (s1 - s2) * (l1 - l2) > 0:
pos += 1
pos += w
else:
neg += 1
neg += w
return np.array(pos).astype('float32'), np.array(neg).astype(
'float32'), np.array(neu).astype('float32')
......@@ -45,8 +51,8 @@ class TestPositiveNegativePairOp(OpTest):
query = np.reshape(query, newshape=(batch_size, 1)).astype('int32')
pos, neg, neu = py_pnpair_op(score, label, query)
self.inputs = {}
self.inputs = {'Score': score, 'Label': label, 'QueryId': query}
self.inputs = {'Score': score, 'Label': label, 'QueryID': query}
self.attrs = {'column': -1}
self.outputs = {
'PositivePair': pos,
'NegativePair': neg,
......@@ -57,5 +63,86 @@ class TestPositiveNegativePairOp(OpTest):
self.check_output()
class TestPositiveNegativePairOpAccumulate(OpTest):
def setUp(self):
self.op_type = 'positive_negative_pair'
batch_size = 20
max_query_id = 5
max_random_num = 2 << 15
score = np.random.normal(size=(batch_size, 2)).astype('float32')
label = np.random.normal(size=(batch_size, 1)).astype('float32')
query = np.array(
[np.random.randint(max_query_id) for i in range(batch_size)])
query = np.reshape(query, newshape=(batch_size, 1)).astype('int32')
acc_pos = np.reshape(
np.random.randint(max_random_num), newshape=(1)).astype('float32')
acc_neg = np.reshape(
np.random.randint(max_random_num), newshape=(1)).astype('float32')
acc_neu = np.reshape(
np.random.randint(max_random_num), newshape=(1)).astype('float32')
column = 0
pos, neg, neu = py_pnpair_op(score, label, query, column=column)
self.inputs = {
'Score': score,
'Label': label,
'QueryID': query,
'AccumulatePositivePair': acc_pos,
'AccumulateNegativePair': acc_neg,
'AccumulateNeutralPair': acc_neu,
}
self.attrs = {'column': column}
self.outputs = {
'PositivePair': pos + acc_pos,
'NegativePair': neg + acc_neg,
'NeutralPair': neu + acc_neu
}
def test_check_output(self):
self.check_output()
class TestPositiveNegativePairOpAccumulateWeight(OpTest):
def setUp(self):
self.op_type = 'positive_negative_pair'
batch_size = 20
max_query_id = 5
max_random_num = 2 << 15
score = np.random.normal(size=(batch_size, 2)).astype('float32')
label = np.random.normal(size=(batch_size, 1)).astype('float32')
weight = np.random.normal(size=(batch_size, 1)).astype('float32')
query = np.array(
[np.random.randint(max_query_id) for i in range(batch_size)])
query = np.reshape(query, newshape=(batch_size, 1)).astype('int32')
acc_pos = np.reshape(
np.random.randint(max_random_num), newshape=(1)).astype('float32')
acc_neg = np.reshape(
np.random.randint(max_random_num), newshape=(1)).astype('float32')
acc_neu = np.reshape(
np.random.randint(max_random_num), newshape=(1)).astype('float32')
column = 0
pos, neg, neu = py_pnpair_op(
score, label, query, column=column, weight=weight)
self.inputs = {
'Score': score,
'Label': label,
'QueryID': query,
'AccumulatePositivePair': acc_pos,
'AccumulateNegativePair': acc_neg,
'AccumulateNeutralPair': acc_neu,
'Weight': weight
}
self.attrs = {'column': column}
self.outputs = {
'PositivePair': pos + acc_pos,
'NegativePair': neg + acc_neg,
'NeutralPair': neu + acc_neu
}
def test_check_output(self):
self.check_output()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册