提交 c6366c81 编写于 作者: C caoying03

softmax as functor.

上级 2507bcaa
...@@ -60,7 +60,7 @@ set(DEPS_OPS ...@@ -60,7 +60,7 @@ set(DEPS_OPS
op_library(identity_op DEPS scale_op) op_library(identity_op DEPS scale_op)
op_library(minus_op DEPS scale_op) op_library(minus_op DEPS scale_op)
op_library(mul_op DEPS math_function) op_library(mul_op DEPS math_function)
op_library(softmax_op DEPS math_function) op_library(softmax_op DEPS softmax_function)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor operator net_op) DEPS framework_proto tensor operator net_op)
op_library(scale_op DEPS net_op) op_library(scale_op DEPS net_op)
......
...@@ -14,31 +14,13 @@ limitations under the License. */ ...@@ -14,31 +14,13 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T>
inline T tolerable_value(const T x) {
static_assert(std::is_floating_point<T>::value,
"tolerable_value works only on float, "
"double and double double.");
const T kApproInf = 1e20;
if (x == INFINITY) {
return kApproInf;
}
if (x == -INFINITY) {
return -kApproInf;
}
return x;
}
template <typename T> template <typename T>
class OnehotCrossEntropyOpKernel : public framework::OpKernel { class OnehotCrossEntropyOpKernel : public framework::OpKernel {
public: public:
...@@ -55,12 +37,12 @@ class OnehotCrossEntropyOpKernel : public framework::OpKernel { ...@@ -55,12 +37,12 @@ class OnehotCrossEntropyOpKernel : public framework::OpKernel {
T* Ydata = Y->data<T>(); T* Ydata = Y->data<T>();
int batch_size = X->dims()[0]; const int batch_size = X->dims()[0];
int class_num = X->dims()[1]; const int class_num = X->dims()[1];
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i]; int index = i * class_num + label_data[i];
Ydata[i] = -tolerable_value(std::log(Xdata[index])); Ydata[i] = -math::tolerable_value(std::log(Xdata[index]));
} }
} }
}; };
...@@ -89,7 +71,7 @@ class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel { ...@@ -89,7 +71,7 @@ class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel {
memset(dXdata, 0, sizeof(T) * batch_size * class_num); memset(dXdata, 0, sizeof(T) * batch_size * class_num);
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i]; int index = i * class_num + label_data[i];
dXdata[index] = -tolerable_value(dYdata[i] / Xdata[index]); dXdata[index] = -math::tolerable_value(dYdata[i] / Xdata[index]);
} }
} }
}; };
......
if(WITH_GPU) if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc
im2col.cu softmax_function.cc DEPS cblas device_context operator) im2col.cu DEPS cblas device_context operator)
nv_library(softmax_function SRCS softmax_function.cc softmax_function.cu
DEPS operator)
else() else()
cc_library(math_function SRCS math_function.cc im2col.cc cc_library(math_function SRCS math_function.cc im2col.cc
softmax_function.cc DEPS cblas device_context operator) DEPS cblas device_context operator)
cc_library(softmax_function SRCS softmax_function.cc DEPS operator)
endif() endif()
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef PADDLE_ONLY_CPU
#define EIGEN_USE_GPU
#endif
#include "paddle/operators/math/softmax_function.h" #include "paddle/operators/math/softmax_function.h"
...@@ -22,41 +18,7 @@ namespace paddle { ...@@ -22,41 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template <typename T, int MajorType = Eigen::RowMajor, template class SoftmaxFunctor<platform::CPUPlace, float>;
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
void softmax(const framework::Tensor* X, framework::Tensor* Y,
const framework::ExecutionContext& context) {
auto logits = EigenMatrix<T>::From(*X);
auto softmax = EigenMatrix<T>::From(*Y);
const int kBatchDim = 0;
const int kClassDim = 1;
const int batch_size = logits.dimension(kBatchDim);
const int num_classes = logits.dimension(kClassDim);
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
auto shifted_logits = (logits -
logits.maximum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
softmax.device(context.GetEigenDevice<Place>()) = shifted_logits.exp();
softmax.device(context.GetEigenDevice<Place>()) =
(softmax *
softmax.sum(along_class)
.inverse()
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/math/softmax_function.h"
namespace paddle {
namespace operators {
namespace math {
template class SoftmaxFunctor<platform::GPUPlace, float>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
...@@ -21,9 +21,44 @@ namespace paddle { ...@@ -21,9 +21,44 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
void softmax(const framework::Tensor* X, framework::Tensor* Y, class SoftmaxFunctor {
const framework::ExecutionContext& context); public:
void operator()(const framework::Tensor* X, framework::Tensor* Y,
const framework::ExecutionContext& context) {
auto logits = EigenMatrix<T>::From(*X);
auto softmax = EigenMatrix<T>::From(*Y);
const int kBatchDim = 0;
const int kClassDim = 1;
const int batch_size = logits.dimension(kBatchDim);
const int num_classes = logits.dimension(kClassDim);
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
auto shifted_logits = (logits -
logits.maximum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
softmax.device(context.GetEigenDevice<Place>()) = shifted_logits.exp();
softmax.device(context.GetEigenDevice<Place>()) =
(softmax *
softmax.sum(along_class)
.inverse()
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
}
};
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -35,7 +35,7 @@ class SoftmaxKernel : public framework::OpKernel { ...@@ -35,7 +35,7 @@ class SoftmaxKernel : public framework::OpKernel {
// allocate memory on device. // allocate memory on device.
Y->mutable_data<T>(context.GetPlace()); Y->mutable_data<T>(context.GetPlace());
math::softmax<Place, T>(X, Y, context); math::SoftmaxFunctor<Place, T>()(X, Y, context);
} }
}; };
......
...@@ -23,13 +23,13 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { ...@@ -23,13 +23,13 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto logits = ctx.Input<Tensor>("logits"); auto logits = ctx.Input<Tensor>("Logits");
PADDLE_ENFORCE( PADDLE_ENFORCE(
logits->dims().size() == 2UL, logits->dims().size() == 2UL,
"The input of softmax_with_cross_entropy should be a 2-d tensor."); "The input of softmax_with_cross_entropy should be a 2-d tensor.");
PADDLE_ENFORCE(ctx.Input<Tensor>("lables")->dims().size() == 1UL, PADDLE_ENFORCE(ctx.Input<Tensor>("Label")->dims().size() == 1UL,
"The label should be a 1-d tensor."); "The label should be a 1-d tensor.");
ctx.Output<Tensor>("Y")->Resize({logits->dims()[0]}); ctx.Output<Tensor>("Label")->Resize({logits->dims()[0]});
} }
}; };
...@@ -39,11 +39,15 @@ class SoftmaxWithCrossEntropyOpMaker ...@@ -39,11 +39,15 @@ class SoftmaxWithCrossEntropyOpMaker
SoftmaxWithCrossEntropyOpMaker(framework::OpProto *proto, SoftmaxWithCrossEntropyOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("logits", AddInput("Logits",
"The unscaled log probabilities which is a 2-D tensor<float> with" "The unscaled log probabilities which is a 2-D tensor<float> with"
"shape [N x K]. N is the batch_size, and K is the class number."); "shape [N x K]. N is the batch_size, and K is the class number.");
AddInput("label", "The ground truth. A 1-D tensor<int> with shape N."); AddInput("Label", "The ground truth. A 1-D tensor<int> with shape N.");
AddOutput("Y", "A 1-D tensor<float> with shape N."); AddOutput("Softmax",
"Store the outputs of softmax function, "
"which will be used in backward calculation.")
.AsIntermediate();
AddOutput("Loss", "A 1-D tensor<float> with shape N.");
AddComment(R"DOC( AddComment(R"DOC(
Cross entropy loss with softmax are used as the output layer extensively. This Cross entropy loss with softmax are used as the output layer extensively. This
operator computes the softmax normalized values for each row of the input operator computes the softmax normalized values for each row of the input
...@@ -67,21 +71,21 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { ...@@ -67,21 +71,21 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should be not null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Loss"),
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")), "Input(Loss) should be not null.");
"Input(Y@GRAD) should be not null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Loss")),
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Y")->dims(), "Input(Loss@GRAD) should be not null.");
ctx.Input<Tensor>(framework::GradVarName("Y"))->dims(),
"Input(Y) and its gradients should have a same shape.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("labels"),
"Input(lables) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("logits")),
"Input(logits@GRAD) should be not null.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx.Input<Tensor>("logits")->dims(), ctx.Input<Tensor>("Logits")->dims(),
ctx.Input<Tensor>(framework::GradVarName("logits"))->dims(), ctx.Input<Tensor>(framework::GradVarName("Logits"))->dims(),
"Input(logits) and its gradients should have a same shape."); "Input(Logits) and its gradients should have a same shape.");
PADDLE_ENFORCE_EQ(
ctx.Input<Tensor>("Logits")->dims(),
ctx.Input<Tensor>(framework::GradVarName("Logits"))->dims(),
"Input(Logits) and its gradients should have a same shape.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Lable) should be not null.");
} }
}; };
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#pragma once #pragma once
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/softmax_function.h"
#include "paddle/operators/math/utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -27,7 +29,30 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; ...@@ -27,7 +29,30 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { class SoftmaxWithCrossEntropyKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override {} void Compute(const framework::ExecutionContext& context) const override {
// Calculate ths softmax outputs.
const Tensor* logits = context.Input<Tensor>("Logits");
Tensor* softmax = context.Output<Tensor>("Softmax");
// allocate memory on device.
softmax->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<Place, T>()(logits, softmax, context);
// Calculate the cross entropy loss based on hard labels.
T* softmax_out = softmax->data<T>();
const int* label_data = context.Input<Tensor>("label")->data<int>();
Tensor* loss = context.Output<Tensor>("Loss");
loss->mutable_data<T>(context.GetPlace());
T* loss_data = loss->data<T>();
const int batch_size = logits->dims()[0];
const int class_num = logits->dims()[1];
for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i];
loss_data[i] = -math::tolerable_value(std::log(softmax_out[index]));
}
}
}; };
template <typename Place, typename T> template <typename Place, typename T>
......
import unittest import unittest
import numpy import numpy
from op_test import OpTest from op_test import OpTest
import pdb
class TestCrossEntropy(OpTest): class TestCrossEntropy(OpTest):
...@@ -10,18 +11,20 @@ class TestCrossEntropy(OpTest): ...@@ -10,18 +11,20 @@ class TestCrossEntropy(OpTest):
class_num = 10 class_num = 10
X = numpy.random.uniform(0.1, 1.0, X = numpy.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32") [batch_size, class_num]).astype("float32")
label = (class_num / 2) * numpy.ones(batch_size).astype("int32")
self.inputs = {'X': X, 'label': label} labels = numpy.random.randint(0, class_num, batch_size, dtype="int32")
self.inputs = {"X": X, "label": labels}
Y = [] Y = []
for i in range(0, batch_size): for i in range(0, batch_size):
Y.append(-numpy.log(X[i][label[i]])) Y.append(-numpy.log(X[i][labels[i]]))
self.outputs = {'Y': numpy.array(Y).astype("float32")} self.outputs = {"Y": numpy.array(Y).astype("float32")}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y') self.check_grad(["X"], "Y")
if __name__ == "__main__": if __name__ == "__main__":
......
import unittest
import numpy as np
from gradient_checker import GradientChecker, create_op
from op_test_util import OpTestMeta
class TestSoftmaxWithLossOp(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
pass
class SoftmaxWithLossGradOpTest(GradientChecker):
def test_softmax(self):
pass
if __name__ == '__main__':
unittest.main()
import unittest
import numpy as np
import pdb
from op_test import OpTest
from test_softmax_op import stable_softmax
class TestSoftmaxWithCrossEntropyOp(OpTest):
def setUp(self):
self.op_type = "softmax_with_cross_entropy"
MAX_BATCH_SIZE = 23
MAX_CLASS_NUM = 255
batch_size = np.random.randint(1, MAX_BATCH_SIZE, 1)[0]
class_num = np.random.randint(2, MAX_CLASS_NUM, 1)[0]
logits = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
softmax = np.apply_along_axis(stable_softmax, 1, logits)
labels = np.random.randint(0, class_num, batch_size, dtype="int32")
cross_entropy = [
-np.log(softmax[i][labels[i]]) for i in range(softmax.shape[0])
]
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {"Loss": cross_entropy}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
pass
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册