提交 f1d5fb3b 编写于 作者: C caoying03

support soft labels.

上级 a2a0d6f8
if(WITH_GPU) if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc
im2col.cu DEPS cblas device_context operator) im2col.cu DEPS cblas device_context operator)
nv_library(softmax_function SRCS softmax_function.cc softmax_function.cu nv_library(softmax_function SRCS softmax.cc softmax.cu
DEPS operator) DEPS operator)
else() else()
cc_library(math_function SRCS math_function.cc im2col.cc cc_library(math_function SRCS math_function.cc im2col.cc
DEPS cblas device_context operator) DEPS cblas device_context operator)
cc_library(softmax_function SRCS softmax_function.cc DEPS operator) cc_library(softmax_function SRCS softmax.cc DEPS operator)
endif() endif()
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/math/softmax_function.h" #include "paddle/operators/math/softmax.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/math/softmax_function.h" #include "paddle/operators/math/softmax.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/platform/assert.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
T HOSTDEVICE tolerable_value(const T x) {
PADDLE_ASSERT(std::is_floating_point<T>::value);
const T kApproInf = 1e20;
if (x == INFINITY) {
return kApproInf;
}
if (x == -INFINITY) {
return -kApproInf;
}
return x;
}
} // namespace math
} // namespace operators
} // namespace paddle
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/softmax_function.h" #include "paddle/operators/math/softmax.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -23,16 +23,32 @@ class SoftmaxWithCrossEntropyOpMaker ...@@ -23,16 +23,32 @@ class SoftmaxWithCrossEntropyOpMaker
SoftmaxWithCrossEntropyOpMaker(framework::OpProto* proto, SoftmaxWithCrossEntropyOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
//(TODO caoying) replace int with boolean
AddAttr<int>("soft_label",
"(int, default 0), A flag to indicate whether to interpretate "
"the given labels as soft labels.")
.SetDefault(0);
AddInput("Logits", AddInput("Logits",
"The unscaled log probabilities which is a 2-D tensor<float> with" "(Tensor, default Tensor<float>), The unscaled log probabilities "
"shape [N x K]. N is the batch_size, and K is the class number.") "which is a 2-D tensor with shape [N x K]. N is the batch_size, "
"and K is the class number.")
.NotInGradient(); .NotInGradient();
AddInput("Label", "The ground truth. A 1-D tensor<int> with shape N."); AddInput(
AddOutput("Softmax", "Label",
"Store the outputs of softmax function, " "(Tensor, default Tensor<int>), The ground truth which is "
"a 1-D or 2-D tensor. "
"If soft_label is set to 0, Label is a Tensor<int> with shape [N x 1]. "
"If soft_label is set to 1, Label is a Tensor<float/double> "
"with shape [N x K].");
AddOutput(
"Softmax",
"(Tensor, default Tensor<float>), A 2-D tensor with shape [N x K]. "
"The outputs value of softmax activation by given the input batch, "
"which will be used in backward calculation.") "which will be used in backward calculation.")
.AsIntermediate(); .AsIntermediate();
AddOutput("Out", "A 1-D tensor<float> with shape N."); AddOutput("Loss",
"(Tensor, default Tensor<float>), A 1-D tensor. The cross "
"entropy loss with shape [N x 1].");
AddComment(R"DOC( AddComment(R"DOC(
Cross entropy loss with softmax are used as the output layer extensively. This Cross entropy loss with softmax are used as the output layer extensively. This
operator computes the softmax normalized values for each row of the input operator computes the softmax normalized values for each row of the input
...@@ -46,25 +62,18 @@ which will produce incorrect results. ...@@ -46,25 +62,18 @@ which will produce incorrect results.
This operators expects mutually exclusive hard labels, each sample in a batch This operators expects mutually exclusive hard labels, each sample in a batch
is in exactly one class with probabilities 1. Each sample in the batch with one is in exactly one class with probabilities 1. Each sample in the batch with one
and only one label. and only one label.
)DOC");
}
};
class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { Equation:
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: 1) hard label (one-hot label)
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@Grad) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"),
"Input(Softmax) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Lable) should be not null.");
ctx.Output<framework::LoDTensor>(framework::GradVarName("Logits")) Loss_j = -\text{Logit}_{Label_j} + \log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right), j = 1, ..., K
->Resize(ctx.Input<Tensor>("Softmax")->dims());
2) soft label (a distribution over all classes)
Loss_j = -\sum_{i=0}^{K}\text{Label}_i\left(\text{Logit}_i-\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right), j = 1,...,K
)DOC");
} }
}; };
...@@ -82,7 +91,25 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { ...@@ -82,7 +91,25 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
"The label should be a 1-d tensor."); "The label should be a 1-d tensor.");
ctx.Output<framework::LoDTensor>("Softmax")->Resize(logits->dims()); ctx.Output<framework::LoDTensor>("Softmax")->Resize(logits->dims());
ctx.Output<framework::LoDTensor>("Out")->Resize({logits->dims()[0], 1}); ctx.Output<framework::LoDTensor>("Loss")->Resize({logits->dims()[0], 1});
}
};
class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Loss")),
"Input(Loss@Grad) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"),
"Input(Softmax) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Lable) should be not null.");
ctx.Output<framework::LoDTensor>(framework::GradVarName("Logits"))
->Resize(ctx.Input<Tensor>("Softmax")->dims());
} }
}; };
......
...@@ -13,9 +13,10 @@ ...@@ -13,9 +13,10 @@
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/softmax_function.h" #include "paddle/operators/cross_entropy_op.h"
#include "paddle/operators/math/utils.h" #include "paddle/operators/math/softmax.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -27,9 +28,10 @@ __global__ void CrossEntropyKernel(T* out, const T* softmax_out, ...@@ -27,9 +28,10 @@ __global__ void CrossEntropyKernel(T* out, const T* softmax_out,
const int* label, const int batch_size, const int* label, const int batch_size,
const int class_num) { const int class_num) {
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= batch_size) return; if (i < batch_size) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num); PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num);
out[i] = -math::tolerable_value(log(softmax_out[i * class_num + label[i]])); out[i] = -tolerable_value(std::log(softmax_out[i * class_num + label[i]]));
}
} }
template <typename T> template <typename T>
...@@ -38,10 +40,10 @@ __global__ void CrossEntropyWithSoftmaxGradKernel(T* softmax_out, ...@@ -38,10 +40,10 @@ __global__ void CrossEntropyWithSoftmaxGradKernel(T* softmax_out,
const int batch_size, const int batch_size,
const int class_num) { const int class_num) {
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= batch_size) return; if (i < batch_size) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num); PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num);
softmax_out[i * class_num + label[i]] -= 1.; softmax_out[i * class_num + label[i]] -= 1.;
}
} }
template <typename T> template <typename T>
...@@ -60,7 +62,7 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { ...@@ -60,7 +62,7 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel {
// Calculate the cross entropy loss based on hard labels. // Calculate the cross entropy loss based on hard labels.
const int* label_data = context.Input<Tensor>("Label")->data<int>(); const int* label_data = context.Input<Tensor>("Label")->data<int>();
Tensor* loss = context.Output<Tensor>("Out"); Tensor* loss = context.Output<Tensor>("Loss");
loss->mutable_data<T>(context.GetPlace()); loss->mutable_data<T>(context.GetPlace());
T* loss_data = loss->data<T>(); T* loss_data = loss->data<T>();
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
#pragma once #pragma once
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/softmax_function.h" #include "paddle/operators/cross_entropy_op.h"
#include "paddle/operators/math/utils.h" #include "paddle/operators/math/softmax.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -44,7 +44,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { ...@@ -44,7 +44,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel {
T* softmax_out = softmax->data<T>(); T* softmax_out = softmax->data<T>();
const int* label_data = context.Input<Tensor>("Label")->data<int>(); const int* label_data = context.Input<Tensor>("Label")->data<int>();
Tensor* loss = context.Output<Tensor>("Out"); Tensor* loss = context.Output<Tensor>("Loss");
loss->mutable_data<T>(context.GetPlace()); loss->mutable_data<T>(context.GetPlace());
T* loss_data = loss->data<T>(); T* loss_data = loss->data<T>();
...@@ -53,7 +53,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { ...@@ -53,7 +53,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel {
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i]; int index = i * class_num + label_data[i];
loss_data[i] = -math::tolerable_value(std::log(softmax_out[index])); loss_data[i] = -tolerable_value(std::log(softmax_out[index]));
} }
} }
}; };
......
...@@ -25,13 +25,13 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): ...@@ -25,13 +25,13 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
dtype="float32") dtype="float32")
self.inputs = {"Logits": logits, "Label": labels} self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {"Softmax": softmax, "Out": cross_entropy} self.outputs = {"Softmax": softmax, "Loss": cross_entropy}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["Logits"], "Out", max_relative_error=0.05) self.check_grad(["Logits"], "Loss", max_relative_error=0.05)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册