diff --git a/paddle/fluid/framework/tensor.cc b/paddle/fluid/framework/tensor.cc index c7286dacf01659f3af0927a71856e5a6496cb877..56bb9142dabe0d5546e321e675a5acba7bf4d306 100644 --- a/paddle/fluid/framework/tensor.cc +++ b/paddle/fluid/framework/tensor.cc @@ -112,5 +112,6 @@ Tensor& Tensor::Resize(const DDim& dims) { const DDim& Tensor::dims() const { return dims_; } int64_t Tensor::numel() const { return product(dims_); } + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h index 7f678f869aac4616c8bca440d0431f765da41dd6..b7b62eef23ec351686378c913d18fc72308fd7b2 100644 --- a/paddle/fluid/framework/tensor_impl.h +++ b/paddle/fluid/framework/tensor_impl.h @@ -59,6 +59,14 @@ inline T* Tensor::mutable_data(platform::Place place) { } inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { + int rank = src.dims().size(); + PADDLE_ENFORCE_GE( + rank, 2, + "'ReshapeToMatrix()' is only used for flatten high rank " + "tensors to matrixs. Can not be used in reshaping vectors."); + if (rank == 2) { + return src; + } Tensor res; res.ShareDataWith(src); res.Resize(flatten_to_2d(src.dims(), num_col_dims)); diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index a3bec3da45136bca5cb2763e7ffd6b67703a1813..578ab63bc380ee62d76e34b7cf3cbd590bfa2eda 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -28,23 +28,26 @@ class CrossEntropyOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); auto label_dims = ctx->GetInputDim("Label"); - PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2."); - PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, - "Input(Label)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], - "The 1st dimension of Input(X) and Input(Label) should " - "be equal."); + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(rank, label_dims.size(), + "Input(X) and Input(Label) shall have the same rank."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "Input(X) and Input(Label) shall have the same shape " + "except the last dimension."); if (ctx->Attrs().Get("soft_label")) { - PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1], - "If Attr(soft_label) == true, the 2nd dimension of " + PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], + "If Attr(soft_label) == true, the last dimension of " "Input(X) and Input(Label) should be equal."); } else { - PADDLE_ENFORCE_EQ(label_dims[1], 1UL, - "If Attr(softLabel) == false, the 2nd dimension of " + PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL, + "If Attr(softLabel) == false, the last dimension of " "Input(Label) should be 1."); } - ctx->SetOutputDim("Y", {x_dims[0], 1}); + auto y_dims = x_dims; + y_dims[rank - 1] = 1; + ctx->SetOutputDim("Y", y_dims); ctx->ShareLoD("X", /*->*/ "Y"); } @@ -74,24 +77,28 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); auto label_dims = ctx->GetInputDim("Label"); auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); - PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); - PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2."); - PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], - "The 1st dimension of Input(X) and Input(Label) should " - "be equal."); - PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0], - "The 1st dimension of Input(X) and Input(Y@Grad) should " - "be equal."); - PADDLE_ENFORCE_EQ(dy_dims[1], 1, - "The 2nd dimension of Input(Y@Grad) should be 1."); + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(dy_dims.size(), rank, + "Input(Y@Grad) and Input(X) should have the same rank."); + PADDLE_ENFORCE_EQ(label_dims.size(), rank, + "Input(Label) and Input(X) should have the same rank."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "The Input(X) and Input(Label) should have the same " + "shape except the last dimension."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(dy_dims, 0, rank - 1), + "The Input(X) and Input(Y@Grad) should have the same " + "shape except the last dimension."); + PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1, + "The last dimension of Input(Y@Grad) should be 1."); if (ctx->Attrs().Get("soft_label")) { - PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1], - "When Attr(soft_label) == true, the 2nd dimension of " + PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], + "When Attr(soft_label) == true, the last dimension of " "Input(X) and Input(Label) should be equal."); } else { - PADDLE_ENFORCE_EQ(label_dims[1], 1, - "When Attr(soft_label) == false, the 2nd dimension of " + PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1, + "When Attr(soft_label) == false, the last dimension of " "Input(Label) should be 1."); } ctx->SetOutputDim(framework::GradVarName("X"), x_dims); @@ -113,18 +120,20 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "(Tensor, default Tensor), a 2-D tensor with shape [N x D]," - " where N is the batch size and D is the number of classes. " - "This input is a probability computed by the previous operator, " - "which is almost always the result of a softmax operator."); - AddInput("Label", - "(Tensor), the ground truth which is a 2-D tensor. When " - "soft_label is set to false, Label is a Tensor with shape " - "[N x 1]. When soft_label is set to true, Label is a " - "Tensor with shape [N x D]."); + "(Tensor, default Tensor), a tensor whose last dimension " + "size is equal to the number of classes. This input is a " + "probability computed by the previous operator, which is almost " + "always the result of a softmax operator."); + AddInput( + "Label", + "(Tensor), the tensor which represents the ground truth. It has the " + "same shape with 'X' except the last dimension. When soft_label is set " + "to false, the last dimension size is 1; when soft_label is set to " + "true, the last dimension size is equal to the number of classes."); AddOutput("Y", - "(Tensor, default Tensor), a 2-D tensor with shape " - "[N x 1]. The cross entropy loss."); + "(Tensor, default Tensor), a tensor whose shape is same " + "with 'X' except that the last dimension size is 1. It " + "represents the cross entropy loss."); AddAttr("soft_label", "(bool, default false), a flag indicating whether to " "interpretate the given labels as soft labels.") @@ -132,6 +141,12 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( CrossEntropy Operator. +The input 'X' and 'Label' will first be logically flattened to 2-D matrixs. +The matrix's second dimension(row length) is as same as the original last +dimension, and the first dimension(column length) is the product of all other +original dimensions. Then the softmax computation will take palce on each raw +of flattened matrixs. + It supports both standard cross-entropy and soft-label cross-entropy loss computation. 1) One-hot cross-entropy: diff --git a/paddle/fluid/operators/cross_entropy_op.h b/paddle/fluid/operators/cross_entropy_op.h index 19a2aec92b267ece94685ce34604b7d1cfa5d209..36b58d80144d242277f6fc970a3a61a6721d4b50 100644 --- a/paddle/fluid/operators/cross_entropy_op.h +++ b/paddle/fluid/operators/cross_entropy_op.h @@ -33,8 +33,13 @@ class CrossEntropyOpKernel : public framework::OpKernel { auto* y = ctx.Output("Y"); y->mutable_data(ctx.GetPlace()); + int rank = x->dims().size(); + Tensor x_2d = framework::ReshapeToMatrix(*x, rank - 1); + Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1); + Tensor y_2d = framework::ReshapeToMatrix(*y, rank - 1); + math::CrossEntropyFunctor()( - ctx.template device_context(), y, x, labels, + ctx.template device_context(), &y_2d, &x_2d, &labels_2d, ctx.Attr("soft_label")); } }; @@ -98,9 +103,12 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { auto* dy = ctx.Input(framework::GradVarName("Y")); auto* label = ctx.Input("Label"); auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dx_data = dx->mutable_data(ctx.GetPlace()); + T* dx_data = dx->mutable_data(ctx.GetPlace()); - int64_t class_num = x->dims()[1]; + // Following computation only depends on the last dimension size. So it's + // unnecessary to convert tensors to 2-D views. + int rank = x->dims().size(); + int64_t class_num = x->dims()[rank - 1]; if (ctx.Attr("soft_label")) { XeSoftlabelGradFunctor functor(dx_data, dy->data(), x->data(), label->data(), diff --git a/paddle/fluid/operators/softmax_op.h b/paddle/fluid/operators/softmax_op.h index 1205bd0587f32caae04c27ecea581fc17988507f..cf1eeb017d666f605a431aa54637d8cbc99c7c46 100644 --- a/paddle/fluid/operators/softmax_op.h +++ b/paddle/fluid/operators/softmax_op.h @@ -31,16 +31,12 @@ class SoftmaxKernel : public framework::OpKernel { // allocate memory on device. Out->mutable_data(context.GetPlace()); - auto dims = X->dims(); - auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); - framework::LoDTensor flattened_x; - framework::LoDTensor flattened_out; - flattened_x.ShareDataWith(*X).Resize(flattened_dims); - flattened_out.ShareDataWith(*Out).Resize(flattened_dims); + int rank = X->dims().size(); + Tensor X_2d = framework::ReshapeToMatrix(*X, rank - 1); + Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); math::SoftmaxFunctor()( - context.template device_context(), &flattened_x, - &flattened_out); + context.template device_context(), &X_2d, &Out_2d); } }; @@ -55,18 +51,14 @@ class SoftmaxGradKernel : public framework::OpKernel { // allocate memory on device. dX->mutable_data(context.GetPlace()); - auto dims = Out->dims(); - auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); - framework::LoDTensor flattened_out; - framework::LoDTensor flattened_d_out; - framework::LoDTensor flattened_d_x; - flattened_out.ShareDataWith(*Out).Resize(flattened_dims); - flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims); - flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims); + int rank = Out->dims().size(); + Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); + Tensor dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1); + Tensor dX_2d = framework::ReshapeToMatrix(*dX, rank - 1); math::SoftmaxGradFunctor()( - context.template device_context(), &flattened_out, - &flattened_d_out, &flattened_d_x); + context.template device_context(), &Out_2d, &dOut_2d, + &dX_2d); } }; diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py index c5b9e92d69133e593a2ce223e83006eda590daa5..86ac159323a5f9f6149ce5ed4437402eb885c6bc 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py @@ -105,5 +105,107 @@ class TestCrossEntropyOp3(OpTest): ["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001) +class TestCrossEntropyOp4(OpTest): + """Test high rank tensor cross-entropy with discrete one-hot labels. + """ + + def setUp(self): + self.op_type = "cross_entropy" + shape = [10, 2, 4] + ins_num = np.prod(np.array(shape)) + class_num = 10 + + X_2d = randomize_probability(ins_num, class_num, dtype='float64') + + label_2d = np.random.randint(0, class_num, (ins_num, 1), dtype="int64") + cross_entropy_2d = np.asmatrix( + [[-np.log(X_2d[i][label_2d[i][0]])] for i in range(X_2d.shape[0])], + dtype="float64") + + X = X_2d.reshape(shape + [class_num]) + label = label_2d.reshape(shape + [1]) + cross_entropy = np.array(cross_entropy_2d).reshape(shape + [1]) + + self.inputs = {"X": X, "Label": label} + self.outputs = {"Y": cross_entropy} + self.attrs = {"soft_label": False} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Y", numeric_grad_delta=0.001) + + +class TestCrossEntropyOp5(OpTest): + """Test high rank tensor cross-entropy with vectorized soft labels. + """ + + def setUp(self): + self.op_type = "cross_entropy" + shape = [4, 3] + ins_num = np.prod(np.array(shape)) + class_num = 37 + + X_2d = randomize_probability(ins_num, class_num) + label_2d = np.random.uniform(0.1, 1.0, + [ins_num, class_num]).astype("float32") + label_2d /= label_2d.sum(axis=1, keepdims=True) + cross_entropy_2d = (-label_2d * np.log(X_2d)).sum( + axis=1, keepdims=True).astype("float32") + + X = X_2d.reshape(shape + [class_num]) + label = label_2d.reshape(shape + [class_num]) + cross_entropy = np.array(cross_entropy_2d).reshape(shape + [1]) + + self.inputs = {"X": X, "Label": label} + self.outputs = {"Y": cross_entropy} + self.attrs = {"soft_label": True} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001) + + +class TestCrossEntropyOp6(OpTest): + """Test high rank tensor cross-entropy with vectorized one-hot representation of labels. + """ + + def setUp(self): + self.op_type = "cross_entropy" + shape = [4, 3, 2] + ins_num = np.prod(np.array(shape)) + class_num = 17 + + X_2d = randomize_probability(ins_num, class_num) + label_index_2d = np.random.randint( + 0, class_num, (ins_num), dtype="int32") + label_2d = np.zeros(X_2d.shape) + label_2d[np.arange(ins_num), label_index_2d] = 1 + + cross_entropy_2d = np.asmatrix( + [[-np.log(X_2d[i][label_index_2d[i]])] + for i in range(X_2d.shape[0])], + dtype="float32") + + X = X_2d.reshape(shape + [class_num]) + label = label_2d.reshape(shape + [class_num]) + cross_entropy = np.array(cross_entropy_2d).reshape(shape + [1]) + + self.inputs = {"X": X, "Label": label.astype(np.float32)} + self.outputs = {"Y": cross_entropy} + self.attrs = {"soft_label": True} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001) + + if __name__ == "__main__": unittest.main()