positive_negative_pair_op.cc 8.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/positive_negative_pair_op.h"

namespace paddle {
namespace operators {

class PositiveNegativePairOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
    PADDLE_ENFORCE(
        ctx->HasInput("Score"),
        "Input(Score) of PositiveNegativePairOp should not be null.");
    PADDLE_ENFORCE(
        ctx->HasInput("Label"),
        "Input(Label) of PositiveNegativePairOp should not be null.");
    PADDLE_ENFORCE(
29 30
        ctx->HasInput("QueryID"),
        "Input(QueryID) of PositiveNegativePairOp should not be null.");
31 32 33 34 35 36 37 38 39
    PADDLE_ENFORCE(
        ctx->HasOutput("PositivePair"),
        "Output(PositivePair) of PositiveNegativePairOp should not be null.");
    PADDLE_ENFORCE(
        ctx->HasOutput("NegativePair"),
        "Output(NegativePair) of PositiveNegativePairOp should not be null.");
    PADDLE_ENFORCE(
        ctx->HasOutput("NeutralPair"),
        "Output(NeutralPair) of PositiveNegativePairOp should not be null.");
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    auto scalar_dim = framework::make_ddim({1});
    if (ctx->HasInput("AccumulatePositivePair") ||
        ctx->HasInput("AccumulateNegativePair") ||
        ctx->HasInput("AccumulateNeutralPair")) {
      PADDLE_ENFORCE(ctx->HasInput("AccumulatePositivePair") &&
                         ctx->HasInput("AccumulateNegativePair") &&
                         ctx->HasInput("AccumulateNeutralPair"),
                     "All optional inputs(AccumulatePositivePair, "
                     "AccumulateNegativePair, AccumulateNeutralPair) of "
                     "PositiveNegativePairOp are required if one of them is "
                     "specified.");
      PADDLE_ENFORCE_EQ(ctx->GetInputDim("AccumulatePositivePair"), scalar_dim,
                        "Shape of AccumulatePositivePair should be {1}.");
      PADDLE_ENFORCE_EQ(ctx->GetInputDim("AccumulateNegativePair"), scalar_dim,
                        "Shape of AccumulateNegativePair should be {1}.");
      PADDLE_ENFORCE_EQ(ctx->GetInputDim("AccumulateNeutralPair"), scalar_dim,
                        "Shape of AccumulateNeutralPair should be {1}.");
    }
58 59 60

    auto score_dim = ctx->GetInputDim("Score");
    auto label_dim = ctx->GetInputDim("Label");
61 62 63 64 65 66 67 68 69
    auto query_dim = ctx->GetInputDim("QueryID");
    PADDLE_ENFORCE_EQ(score_dim.size(), 2, "Score should be a 2-D tensor.");
    PADDLE_ENFORCE_EQ(label_dim.size(), 2, "Label should be a 2-D tensor.");
    PADDLE_ENFORCE_EQ(
        label_dim[0], score_dim[0],
        "Tensor Score and Label should have the same height (batch size).");
    PADDLE_ENFORCE_EQ(label_dim[1], 1,
                      "The width of Label should be 1, i.e. each item should "
                      "have a scalar label.");
70
    PADDLE_ENFORCE(query_dim == label_dim,
71 72 73 74 75 76 77 78 79 80
                   "QueryID should have the same shape as Label.");
    if (ctx->HasInput("Weight")) {
      PADDLE_ENFORCE(ctx->GetInputDim("Weight") == label_dim,
                     "Weight should have the same shape as Label.");
    }
    int column = ctx->Attrs().Get<int>("column");
    auto depth = score_dim[1];
    PADDLE_ENFORCE(column < depth && column >= -depth,
                   "Attribute column should be in the range of [-%l, %l)",
                   depth, depth);
81

82 83 84
    ctx->SetOutputDim("PositivePair", scalar_dim);
    ctx->SetOutputDim("NegativePair", scalar_dim);
    ctx->SetOutputDim("NeutralPair", scalar_dim);
85 86 87
  }

 protected:
Y
Yu Yang 已提交
88
  framework::OpKernelType GetKernelType(
89
      const framework::ExecutionContext &ctx) const override {
Y
Yu Yang 已提交
90 91 92
    return framework::OpKernelType(
        framework::ToDataType(ctx.Input<Tensor>("Score")->type()),
        ctx.device_context());
93 94 95 96 97 98 99 100 101
  }
};

class PositiveNegativePairOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  PositiveNegativePairOpMaker(framework::OpProto *proto,
                              framework::OpAttrChecker *op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("Score",
102 103 104 105
             "(Tensor, float) Model Score on an item (with "
             "respect to QueryID). It's a 2-D tensor with shape [batch_size, "
             "depth], where the column specified by the attribute \"column\" "
             "is used as item score.");
106
    AddInput("Label",
107 108 109
             "(Tensor, float) Label of an item (with repsect to "
             "QueryId). It's a 2-D tensor with shape [batch_size, 1].");
    AddInput("QueryID",
Z
zhouxiao-coder 已提交
110
             "(Tensor, int64) Query ID that indicates the context. Its shape "
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
             "should be the same as Label.");
    AddInput(
        "AccumulatePositivePair",
        "(float) Optional. The accumulated number of positive pairs over a "
        "stream of data. If provided, the output PositivePair will be "
        "initialized with this number rather than 0. it won't be modified "
        "in place.")
        .AsDispensable();
    AddInput(
        "AccumulateNegativePair",
        "(float) Optional. The accumulated number of negative pairs over a "
        "stream of data. If provided, the output NegativePair will be "
        "initialized with this number rather than 0. it won't be modified "
        "in place.")
        .AsDispensable();
    AddInput("AccumulateNeutralPair",
             "(float) Optional. The accumulated number of neutral pairs over a "
             "stream of data. If provided, the output NeutralPair will be "
             "initialized with this number rather than 0. it won't be modified "
             "in place.")
        .AsDispensable();
    AddInput("Weight",
             "(float) Optional. Weight of current item. If specified, its "
134 135 136 137
             "shape should be the same as Label, and the meaning of the output "
             "changes from numbers of pairs to the total sum of pairs' "
             "weights. Weight of a pair of items is the average of their "
             "weights.")
138
        .AsDispensable();
139
    AddOutput("PositivePair",
140 141
              "(float) Number of positive pairs, i.e. the pairs of "
              "items that are ranked correctly.");
142
    AddOutput("NegativePair",
143 144
              "(float) Number of negative pairs, i.e. the pairs of "
              "items that are ranked incorrectly.");
145
    AddOutput("NeutralPair",
146 147 148 149 150 151 152 153 154 155 156
              "(float) Number of neutral pairs, i.e. the pairs of items "
              "that have the same score.")
        .AsDispensable();
    AddAttr<int>(
        "column",
        "(int, default -1) The column position of Score used to rank items in "
        "descending order. It must be in the range of [-rank(Score), "
        "rank(Score)). "
        "If `dim < 0`, the dim to reduce is `rank + dim`. "
        "Noting that reducing on the first dim will make the LoD info lost.")
        .SetDefault(0);
157
    AddComment(R"DOC(
158 159 160 161 162 163 164
        PositiveNegativePairOp can be used to evaluate Learning To Rank(LTR) 
        model performance. 
        Within some context, e.g. the "query", a LTR model generates scores
        for a list of items, which gives a partial order of the items.
        PositiveNegativePairOp takes a list of reference rank order 
        (Input("Label")) and the model generated scores (Input(Score)) as 
        inputs and counts the pairs that ranked correctly and incorrectly.
165 166 167 168 169 170 171 172 173 174 175 176 177
)DOC");
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(positive_negative_pair,
                             ops::PositiveNegativePairOp,
                             ops::PositiveNegativePairOpMaker);
REGISTER_OP_CPU_KERNEL(
    positive_negative_pair,
178 179
    ops::PositiveNegativePairKernel<paddle::platform::CPUPlace, float>,
    ops::PositiveNegativePairKernel<paddle::platform::CPUPlace, double>);