From f1d5fb3b9a6201f3eaf92b12d84b3e3727a3a575 Mon Sep 17 00:00:00 2001 From: caoying03 Date: Thu, 21 Sep 2017 17:47:52 +0800 Subject: [PATCH] support soft labels. --- paddle/operators/math/CMakeLists.txt | 4 +- .../math/{softmax_function.cc => softmax.cc} | 2 +- .../math/{softmax_function.cu => softmax.cu} | 2 +- .../math/{softmax_function.h => softmax.h} | 0 paddle/operators/math/utils.h | 42 ----------- paddle/operators/softmax_op.h | 2 +- .../softmax_with_cross_entropy_op.cc | 75 +++++++++++++------ .../softmax_with_cross_entropy_op.cu | 22 +++--- .../operators/softmax_with_cross_entropy_op.h | 8 +- .../test_softmax_with_cross_entropy_op.py | 4 +- 10 files changed, 74 insertions(+), 87 deletions(-) rename paddle/operators/math/{softmax_function.cc => softmax.cc} (93%) rename paddle/operators/math/{softmax_function.cu => softmax.cu} (94%) rename paddle/operators/math/{softmax_function.h => softmax.h} (100%) delete mode 100644 paddle/operators/math/utils.h diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 832a954e3a..074ca47d7f 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,12 +1,12 @@ if(WITH_GPU) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator) - nv_library(softmax_function SRCS softmax_function.cc softmax_function.cu + nv_library(softmax_function SRCS softmax.cc softmax.cu DEPS operator) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) - cc_library(softmax_function SRCS softmax_function.cc DEPS operator) + cc_library(softmax_function SRCS softmax.cc DEPS operator) endif() nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/softmax_function.cc b/paddle/operators/math/softmax.cc similarity index 93% rename from paddle/operators/math/softmax_function.cc rename to paddle/operators/math/softmax.cc index cd46ed96ca..ac9f3c4bf6 100644 --- a/paddle/operators/math/softmax_function.cc +++ b/paddle/operators/math/softmax.cc @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/math/softmax_function.h" +#include "paddle/operators/math/softmax.h" namespace paddle { namespace operators { diff --git a/paddle/operators/math/softmax_function.cu b/paddle/operators/math/softmax.cu similarity index 94% rename from paddle/operators/math/softmax_function.cu rename to paddle/operators/math/softmax.cu index 486697a161..4c3df0550e 100644 --- a/paddle/operators/math/softmax_function.cu +++ b/paddle/operators/math/softmax.cu @@ -14,7 +14,7 @@ #define EIGEN_USE_GPU -#include "paddle/operators/math/softmax_function.h" +#include "paddle/operators/math/softmax.h" namespace paddle { namespace operators { diff --git a/paddle/operators/math/softmax_function.h b/paddle/operators/math/softmax.h similarity index 100% rename from paddle/operators/math/softmax_function.h rename to paddle/operators/math/softmax.h diff --git a/paddle/operators/math/utils.h b/paddle/operators/math/utils.h deleted file mode 100644 index 1e72c8e0ca..0000000000 --- a/paddle/operators/math/utils.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#pragma once -#include "paddle/platform/assert.h" -#include "paddle/platform/hostdevice.h" - -namespace paddle { -namespace operators { -namespace math { - -template -T HOSTDEVICE tolerable_value(const T x) { - PADDLE_ASSERT(std::is_floating_point::value); - - const T kApproInf = 1e20; - - if (x == INFINITY) { - return kApproInf; - } - - if (x == -INFINITY) { - return -kApproInf; - } - - return x; -} - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 68d05fc215..18494e470a 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/operators/math/softmax_function.h" +#include "paddle/operators/math/softmax.h" namespace paddle { namespace operators { diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index a0941bb624..3dd21279ad 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -23,16 +23,32 @@ class SoftmaxWithCrossEntropyOpMaker SoftmaxWithCrossEntropyOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { + //(TODO caoying) replace int with boolean + AddAttr("soft_label", + "(int, default 0), A flag to indicate whether to interpretate " + "the given labels as soft labels.") + .SetDefault(0); AddInput("Logits", - "The unscaled log probabilities which is a 2-D tensor with" - "shape [N x K]. N is the batch_size, and K is the class number.") + "(Tensor, default Tensor), The unscaled log probabilities " + "which is a 2-D tensor with shape [N x K]. N is the batch_size, " + "and K is the class number.") .NotInGradient(); - AddInput("Label", "The ground truth. A 1-D tensor with shape N."); - AddOutput("Softmax", - "Store the outputs of softmax function, " - "which will be used in backward calculation.") + AddInput( + "Label", + "(Tensor, default Tensor), The ground truth which is " + "a 1-D or 2-D tensor. " + "If soft_label is set to 0, Label is a Tensor with shape [N x 1]. " + "If soft_label is set to 1, Label is a Tensor " + "with shape [N x K]."); + AddOutput( + "Softmax", + "(Tensor, default Tensor), A 2-D tensor with shape [N x K]. " + "The outputs value of softmax activation by given the input batch, " + "which will be used in backward calculation.") .AsIntermediate(); - AddOutput("Out", "A 1-D tensor with shape N."); + AddOutput("Loss", + "(Tensor, default Tensor), A 1-D tensor. The cross " + "entropy loss with shape [N x 1]."); AddComment(R"DOC( Cross entropy loss with softmax are used as the output layer extensively. This operator computes the softmax normalized values for each row of the input @@ -46,25 +62,18 @@ which will produce incorrect results. This operators expects mutually exclusive hard labels, each sample in a batch is in exactly one class with probabilities 1. Each sample in the batch with one and only one label. -)DOC"); - } -}; -class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; +Equation: - protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@Grad) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"), - "Input(Softmax) should be not null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Lable) should be not null."); +1) hard label (one-hot label) - ctx.Output(framework::GradVarName("Logits")) - ->Resize(ctx.Input("Softmax")->dims()); +Loss_j = -\text{Logit}_{Label_j} + \log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right), j = 1, ..., K + +2) soft label (a distribution over all classes) + +Loss_j = -\sum_{i=0}^{K}\text{Label}_i\left(\text{Logit}_i-\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right), j = 1,...,K + +)DOC"); } }; @@ -82,7 +91,25 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { "The label should be a 1-d tensor."); ctx.Output("Softmax")->Resize(logits->dims()); - ctx.Output("Out")->Resize({logits->dims()[0], 1}); + ctx.Output("Loss")->Resize({logits->dims()[0], 1}); + } +}; + +class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext& ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Loss")), + "Input(Loss@Grad) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"), + "Input(Softmax) should be not null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), + "Input(Lable) should be not null."); + + ctx.Output(framework::GradVarName("Logits")) + ->Resize(ctx.Input("Softmax")->dims()); } }; diff --git a/paddle/operators/softmax_with_cross_entropy_op.cu b/paddle/operators/softmax_with_cross_entropy_op.cu index 5af6a521a8..68bb85fa8a 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/operators/softmax_with_cross_entropy_op.cu @@ -13,9 +13,10 @@ limitations under the License. */ #define EIGEN_USE_GPU + #include "paddle/framework/op_registry.h" -#include "paddle/operators/math/softmax_function.h" -#include "paddle/operators/math/utils.h" +#include "paddle/operators/cross_entropy_op.h" +#include "paddle/operators/math/softmax.h" namespace paddle { namespace operators { @@ -27,9 +28,10 @@ __global__ void CrossEntropyKernel(T* out, const T* softmax_out, const int* label, const int batch_size, const int class_num) { int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= batch_size) return; - PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num); - out[i] = -math::tolerable_value(log(softmax_out[i * class_num + label[i]])); + if (i < batch_size) { + PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num); + out[i] = -tolerable_value(std::log(softmax_out[i * class_num + label[i]])); + } } template @@ -38,10 +40,10 @@ __global__ void CrossEntropyWithSoftmaxGradKernel(T* softmax_out, const int batch_size, const int class_num) { int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= batch_size) return; - - PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num); - softmax_out[i * class_num + label[i]] -= 1.; + if (i < batch_size) { + PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num); + softmax_out[i * class_num + label[i]] -= 1.; + } } template @@ -60,7 +62,7 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { // Calculate the cross entropy loss based on hard labels. const int* label_data = context.Input("Label")->data(); - Tensor* loss = context.Output("Out"); + Tensor* loss = context.Output("Loss"); loss->mutable_data(context.GetPlace()); T* loss_data = loss->data(); diff --git a/paddle/operators/softmax_with_cross_entropy_op.h b/paddle/operators/softmax_with_cross_entropy_op.h index 38b92a0bcd..0ad48dae2c 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.h +++ b/paddle/operators/softmax_with_cross_entropy_op.h @@ -15,8 +15,8 @@ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/operators/math/softmax_function.h" -#include "paddle/operators/math/utils.h" +#include "paddle/operators/cross_entropy_op.h" +#include "paddle/operators/math/softmax.h" namespace paddle { namespace operators { @@ -44,7 +44,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { T* softmax_out = softmax->data(); const int* label_data = context.Input("Label")->data(); - Tensor* loss = context.Output("Out"); + Tensor* loss = context.Output("Loss"); loss->mutable_data(context.GetPlace()); T* loss_data = loss->data(); @@ -53,7 +53,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { for (int i = 0; i < batch_size; ++i) { int index = i * class_num + label_data[i]; - loss_data[i] = -math::tolerable_value(std::log(softmax_out[index])); + loss_data[i] = -tolerable_value(std::log(softmax_out[index])); } } }; diff --git a/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py index e965dd0482..9c9ee77b73 100644 --- a/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py @@ -25,13 +25,13 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): dtype="float32") self.inputs = {"Logits": logits, "Label": labels} - self.outputs = {"Softmax": softmax, "Out": cross_entropy} + self.outputs = {"Softmax": softmax, "Loss": cross_entropy} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(["Logits"], "Out", max_relative_error=0.05) + self.check_grad(["Logits"], "Loss", max_relative_error=0.05) if __name__ == "__main__": -- GitLab