From 7c55e08c939da75c0caaabcaf309ad500c5769ca Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 6 Aug 2018 16:18:00 +0800 Subject: [PATCH] stash --- paddle/fluid/operators/cross_entropy_op.cc | 91 +++++++++++++--------- 1 file changed, 54 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index a3bec3da451..97d6a19311a 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -28,23 +28,28 @@ class CrossEntropyOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); auto label_dims = ctx->GetInputDim("Label"); - PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2."); - PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, - "Input(Label)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], - "The 1st dimension of Input(X) and Input(Label) should " - "be equal."); + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(rank, label_dims.size(), + "Input(X) and Input(Label) shall have the same rank."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "Input(X) and Input(Label) shall have the same shape " + "except the last dimension."); if (ctx->Attrs().Get("soft_label")) { - PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1], - "If Attr(soft_label) == true, the 2nd dimension of " + PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], + "If Attr(soft_label) == true, the last dimension of " "Input(X) and Input(Label) should be equal."); } else { - PADDLE_ENFORCE_EQ(label_dims[1], 1UL, - "If Attr(softLabel) == false, the 2nd dimension of " + PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL, + "If Attr(softLabel) == false, the last dimension of " "Input(Label) should be 1."); } - ctx->SetOutputDim("Y", {x_dims[0], 1}); + auto out_dim_vec = + framework::vectorize(framework::slice_ddim(x_dims, 0, rank - 1)); + out_dim_vec.push_back(1); + + ctx->SetOutputDim("Y", framework::make_ddim(out_dim_vec)); ctx->ShareLoD("X", /*->*/ "Y"); } @@ -74,24 +79,28 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); auto label_dims = ctx->GetInputDim("Label"); auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); - PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); - PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2."); - PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], - "The 1st dimension of Input(X) and Input(Label) should " - "be equal."); - PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0], - "The 1st dimension of Input(X) and Input(Y@Grad) should " - "be equal."); - PADDLE_ENFORCE_EQ(dy_dims[1], 1, - "The 2nd dimension of Input(Y@Grad) should be 1."); + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(dy_dims.size(), rank, + "Input(Y@Grad) and Input(X) should have the same rank."); + PADDLE_ENFORCE_EQ(label_dims.size(), rank, + "Input(Label) and Input(X) should have the same rank."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "The Input(X) and Input(Label) should have the same " + "shape except the last dimension."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(dy_dims, 0, rank - 1), + "The Input(X) and Input(Y@Grad) should have the same " + "shape except the last dimension."); + PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1, + "The last dimension of Input(Y@Grad) should be 1."); if (ctx->Attrs().Get("soft_label")) { - PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1], - "When Attr(soft_label) == true, the 2nd dimension of " + PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], + "When Attr(soft_label) == true, the last dimension of " "Input(X) and Input(Label) should be equal."); } else { - PADDLE_ENFORCE_EQ(label_dims[1], 1, - "When Attr(soft_label) == false, the 2nd dimension of " + PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1, + "When Attr(soft_label) == false, the last dimension of " "Input(Label) should be 1."); } ctx->SetOutputDim(framework::GradVarName("X"), x_dims); @@ -113,18 +122,20 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "(Tensor, default Tensor), a 2-D tensor with shape [N x D]," - " where N is the batch size and D is the number of classes. " - "This input is a probability computed by the previous operator, " - "which is almost always the result of a softmax operator."); - AddInput("Label", - "(Tensor), the ground truth which is a 2-D tensor. When " - "soft_label is set to false, Label is a Tensor with shape " - "[N x 1]. When soft_label is set to true, Label is a " - "Tensor with shape [N x D]."); + "(Tensor, default Tensor), a tensor whose last dimension " + "size is equal to the number of classes. This input is a " + "probability computed by the previous operator, which is almost " + "always the result of a softmax operator."); + AddInput( + "Label", + "(Tensor), the tensor which represents the ground truth. It has the " + "same shape with 'X' except the last dimension. When soft_label is set " + "to false, the last dimension size is 1; when soft_label is set to " + "true, the last dimension size is equal to the number of classes."); AddOutput("Y", - "(Tensor, default Tensor), a 2-D tensor with shape " - "[N x 1]. The cross entropy loss."); + "(Tensor, default Tensor), a tensor whose shape is same " + "with 'X' except that the last dimension size is 1. It " + "represents the cross entropy loss."); AddAttr("soft_label", "(bool, default false), a flag indicating whether to " "interpretate the given labels as soft labels.") @@ -132,6 +143,12 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( CrossEntropy Operator. +The input 'X' and 'Label' will first be logically flattened to 2-D matrixs. +The matrix's second dimension(row length) is as same as the original last +dimension, and the first dimension(column length) is the product of all other +original dimensions. Then the softmax computation will take palce on each raw +of flattened matrixs. + It supports both standard cross-entropy and soft-label cross-entropy loss computation. 1) One-hot cross-entropy: -- GitLab