提交 c9a65382 编写于 作者: F frankwhzhang

fix label_pos ,add test_layers.py, test=develop

上级 a672b291
......@@ -23,18 +23,17 @@ class BprLossOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("LabelPos"),
"Input(LabelPos) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
auto x_dims = ctx->GetInputDim("X");
auto label_pos_dims = ctx->GetInputDim("LabelPos");
auto label_dims = ctx->GetInputDim("Label");
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, label_pos_dims.size(),
"Input(X) and Input(LabelPos) shall have the same rank.");
PADDLE_ENFORCE_EQ(rank, label_dims.size(),
"Input(X) and Input(Label) shall have the same rank.");
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_pos_dims, 0, rank - 1),
"Input(X) and Input(LabelPos) shall have the same shape "
framework::slice_ddim(label_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension.");
auto y_dims = x_dims;
......@@ -60,25 +59,23 @@ class BprLossGradientOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("LabelPos"),
"Input(LabelPos) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) shoudl be not null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@GRAD) should be not null.");
auto x_dims = ctx->GetInputDim("X");
auto label_pos_dims = ctx->GetInputDim("LabelPos");
auto label_dims = ctx->GetInputDim("Label");
auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(dy_dims.size(), rank,
"Input(Y@Grad) and Input(X) should have the same rank.");
PADDLE_ENFORCE_EQ(
label_pos_dims.size(), rank,
"Input(LabelPos) and Input(X) should have the same rank.");
PADDLE_ENFORCE_EQ(label_dims.size(), rank,
"Input(Label) and Input(X) should have the same rank.");
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_pos_dims, 0, rank - 1),
"The Input(X) and Input(LabelPos) should have the same "
framework::slice_ddim(label_dims, 0, rank - 1),
"The Input(X) and Input(Label) should have the same "
"shape except the last dimension.");
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(dy_dims, 0, rank - 1),
......@@ -86,8 +83,8 @@ class BprLossGradientOp : public framework::OperatorWithKernel {
"shape except the last dimension.");
PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1,
"The last dimension of Input(Y@Grad) should be 1.");
PADDLE_ENFORCE_EQ(label_pos_dims[rank - 1], 1,
" the last dimension of Input(LabelPos) should be 1.");
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1,
" the last dimension of Input(Label) should be 1.");
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD("X", framework::GradVarName("X"));
}
......@@ -111,7 +108,7 @@ class BprLossOpMaker : public framework::OpProtoAndCheckerMaker {
"size is equal to the number of classes. This input is a "
"real number.");
AddInput(
"LabelPos",
"Label",
"(Tensor), the tensor which represents the ground truth. It has the "
"same shape with 'X' except the last dimension. the last dimension "
"size is 1.");
......@@ -122,7 +119,7 @@ class BprLossOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
Bayesian Personalized Ranking Loss Operator.
This operator belongs to pairwise ranking loss. LabelPos is the desired item.
This operator belongs to pairwise ranking loss. Label is the desired item.
The loss at a given point in one session is defined as:
$Y[i] = -\frac{1}{N_{i}} * \sum_{j=0}^{N_{i}}\log(\sigma(X[i, Label[i]]-X[i, j]))$
......
......@@ -41,17 +41,17 @@ class BprLossOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* label_pos = ctx.Input<Tensor>("LabelPos");
auto* label = ctx.Input<Tensor>("Label");
auto* y = ctx.Output<Tensor>("Y");
y->mutable_data<T>(ctx.GetPlace());
int rank = x->dims().size();
Tensor x_2d = framework::ReshapeToMatrix(*x, rank - 1);
Tensor labels_Pos_2d = framework::ReshapeToMatrix(*label_pos, rank - 1);
Tensor labels_2d = framework::ReshapeToMatrix(*label, rank - 1);
Tensor y_2d = framework::ReshapeToMatrix(*y, rank - 1);
const framework::Tensor* logits = &x_2d;
const framework::Tensor* labels_pos = &labels_Pos_2d;
const framework::Tensor* labels = &labels_2d;
framework::Tensor* out = &y_2d;
const int step_size = logits->dims()[0];
......@@ -59,9 +59,9 @@ class BprLossOpKernel : public framework::OpKernel<T> {
const T* logits_data = logits->data<T>();
T* loss_data = out->data<T>();
const int64_t* label_pos_data = labels_pos->data<int64_t>();
const int64_t* label_data = labels->data<int64_t>();
for (int i = 0; i < step_size; ++i) {
int lbl_pos = label_pos_data[i];
int lbl_pos = label_data[i];
PADDLE_ENFORCE_GE(lbl_pos, 0);
PADDLE_ENFORCE_LT(lbl_pos, class_num);
int index_pos = i * class_num + lbl_pos;
......@@ -84,7 +84,7 @@ class BprLossGradientOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* label_pos = ctx.Input<Tensor>("LabelPos");
auto* label = ctx.Input<Tensor>("Label");
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
const int step_size = x->dims()[0];
......@@ -92,16 +92,16 @@ class BprLossGradientOpKernel : public framework::OpKernel<T> {
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
const T* dy_data = dy->data<T>();
const T* x_data = x->data<T>();
const int64_t* label_pos_data = label_pos->data<int64_t>();
const int64_t* label_data = label->data<int64_t>();
for (size_t sample_id = 0; sample_id < step_size; sample_id++) {
for (size_t x_offset = sample_id * num_classes;
x_offset < (sample_id + 1) * num_classes; x_offset++) {
dx_data[x_offset] = static_cast<T>(0);
}
auto p_index = sample_id * num_classes + label_pos_data[sample_id];
auto p_index = sample_id * num_classes + label_data[sample_id];
for (size_t ni = 0; ni < num_classes; ni++) {
if (label_pos_data[sample_id] == ni) continue;
if (label_data[sample_id] == ni) continue;
auto n_index = sample_id * num_classes + ni;
auto grad_ = -dy_data[sample_id] /
((num_classes - 1) *
......
......@@ -1349,21 +1349,30 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex):
return out
def bpr_loss(input, label_pos):
def bpr_loss(input, label):
"""
Bayesian Personalized Ranking Loss Operator.
This operator belongs to pairwise ranking loss. LabelPos is the desired item.
This operator belongs to pairwise ranking loss. Label is the desired item.
The loss at a given point in one session is defined as:
$Y[i] = -\frac{1}{N_{i}-1} * \sum_{0\le j<N_{i},~ j\neq Label[i]}\log(\sigma(X[i, Label[i]]-X[i, j]))$
Learn more details by reading paper <session-based recommendations with recurrent
neural networks>(https://arxiv.org/abs/1511.06939)
Args:
input (Variable|list): a 2-D tensor with shape [N x D], where N is the
batch size and D is the number of classes.
This input is not probability but logits.
label (Variable|list): the ground truth which is a 2-D tensor. `label`
is a tensor<int64> with shape [N x 1].
Returns:
A 2-D tensor with shape [N x 1], the bpr loss.
Examples:
.. code-block:: python
cost = fluid.layers.bpr_loss(input=predict, label_pos=label)
cost = fluid.layers.bpr_loss(input=predict, label=label)
"""
helper = LayerHelper('bpr_loss', **locals())
......@@ -1371,7 +1380,7 @@ def bpr_loss(input, label_pos):
helper.append_op(
type='bpr_loss',
inputs={'X': [input],
'LabelPos': [label_pos]},
'Label': [label]},
outputs={'Y': [out]})
return out
......
......@@ -28,18 +28,17 @@ class TestBprLossOp1(OpTest):
batch_size = 40
class_num = 5
X = randomize_probability(batch_size, class_num, dtype='float64')
label_pos = np.random.randint(
0, class_num, (batch_size, 1), dtype="int64")
label = np.random.randint(0, class_num, (batch_size, 1), dtype="int64")
bpr_loss_result = []
for i in range(batch_size):
sum = 0.0
for j in range(class_num):
if j == label_pos[i][0]:
if j == label[i][0]:
continue
sum += (-np.log(1.0 + np.exp(X[i][j] - X[i][label_pos[i][0]])))
sum += (-np.log(1.0 + np.exp(X[i][j] - X[i][label[i][0]])))
bpr_loss_result.append(-sum / (class_num - 1))
bpr_loss = np.asmatrix([[x] for x in bpr_loss_result], dtype="float64")
self.inputs = {"X": X, "LabelPos": label_pos}
self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": bpr_loss}
def test_check_output(self):
......
......@@ -846,6 +846,15 @@ class TestBook(unittest.TestCase):
out = layers.cross_entropy(x, label, False, 4)
self.assertIsNotNone(out)
def test_bpr_loss(self):
program = Program()
with program_guard(program):
x = layers.data(name="x", shape=[30, 10], dtype="float32")
label = layers.data(name="label", shape=[30, 1], dtype="int32")
out = layers.bpr_loss(x, label)
self.assertIsNotNone(out)
print(str(program))
def test_expand(self):
program = Program()
with program_guard(program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册