From c6366c819e668c21a822122086ad72008357dd66 Mon Sep 17 00:00:00 2001 From: caoying03 Date: Tue, 12 Sep 2017 16:38:11 +0800 Subject: [PATCH] softmax as functor. --- paddle/operators/CMakeLists.txt | 2 +- paddle/operators/cross_entropy_op.h | 28 ++------- paddle/operators/math/CMakeLists.txt | 7 ++- paddle/operators/math/softmax_function.cc | 58 ++++--------------- paddle/operators/math/softmax_function.cu | 27 +++++++++ paddle/operators/math/softmax_function.h | 57 ++++++++++++++---- paddle/operators/softmax_op.h | 2 +- .../softmax_with_cross_entropy_op.cc | 44 +++++++------- .../operators/softmax_with_cross_entropy_op.h | 27 ++++++++- .../framework/tests/test_cross_entropy_op.py | 13 +++-- .../tests/test_softmax_with_cost_op.py | 22 ------- .../test_softmax_with_cross_entropy_op.py | 39 +++++++++++++ 12 files changed, 192 insertions(+), 134 deletions(-) create mode 100644 paddle/operators/math/softmax_function.cu delete mode 100644 python/paddle/v2/framework/tests/test_softmax_with_cost_op.py create mode 100644 python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 94e00ac38..8863ffe8e 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -60,7 +60,7 @@ set(DEPS_OPS op_library(identity_op DEPS scale_op) op_library(minus_op DEPS scale_op) op_library(mul_op DEPS math_function) -op_library(softmax_op DEPS math_function) +op_library(softmax_op DEPS softmax_function) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor operator net_op) op_library(scale_op DEPS net_op) diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index eb4d1348d..6de23bbe0 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -14,31 +14,13 @@ limitations under the License. */ #pragma once #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/utils.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -inline T tolerable_value(const T x) { - static_assert(std::is_floating_point::value, - "tolerable_value works only on float, " - "double and double double."); - - const T kApproInf = 1e20; - - if (x == INFINITY) { - return kApproInf; - } - - if (x == -INFINITY) { - return -kApproInf; - } - - return x; -} - template class OnehotCrossEntropyOpKernel : public framework::OpKernel { public: @@ -55,12 +37,12 @@ class OnehotCrossEntropyOpKernel : public framework::OpKernel { T* Ydata = Y->data(); - int batch_size = X->dims()[0]; - int class_num = X->dims()[1]; + const int batch_size = X->dims()[0]; + const int class_num = X->dims()[1]; for (int i = 0; i < batch_size; ++i) { int index = i * class_num + label_data[i]; - Ydata[i] = -tolerable_value(std::log(Xdata[index])); + Ydata[i] = -math::tolerable_value(std::log(Xdata[index])); } } }; @@ -89,7 +71,7 @@ class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel { memset(dXdata, 0, sizeof(T) * batch_size * class_num); for (int i = 0; i < batch_size; ++i) { int index = i * class_num + label_data[i]; - dXdata[index] = -tolerable_value(dYdata[i] / Xdata[index]); + dXdata[index] = -math::tolerable_value(dYdata[i] / Xdata[index]); } } }; diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 8ce39db62..832a954e3 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,9 +1,12 @@ if(WITH_GPU) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc - im2col.cu softmax_function.cc DEPS cblas device_context operator) + im2col.cu DEPS cblas device_context operator) + nv_library(softmax_function SRCS softmax_function.cc softmax_function.cu + DEPS operator) else() cc_library(math_function SRCS math_function.cc im2col.cc - softmax_function.cc DEPS cblas device_context operator) + DEPS cblas device_context operator) + cc_library(softmax_function SRCS softmax_function.cc DEPS operator) endif() nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/softmax_function.cc b/paddle/operators/math/softmax_function.cc index 7edb632d3..cd46ed96c 100644 --- a/paddle/operators/math/softmax_function.cc +++ b/paddle/operators/math/softmax_function.cc @@ -1,20 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifndef PADDLE_ONLY_CPU -#define EIGEN_USE_GPU -#endif + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ #include "paddle/operators/math/softmax_function.h" @@ -22,41 +18,7 @@ namespace paddle { namespace operators { namespace math { -template -using EigenMatrix = framework::EigenMatrix; - -template -void softmax(const framework::Tensor* X, framework::Tensor* Y, - const framework::ExecutionContext& context) { - auto logits = EigenMatrix::From(*X); - auto softmax = EigenMatrix::From(*Y); - - const int kBatchDim = 0; - const int kClassDim = 1; - - const int batch_size = logits.dimension(kBatchDim); - const int num_classes = logits.dimension(kClassDim); - - Eigen::DSizes along_class(kClassDim); - Eigen::DSizes batch_by_one(batch_size, 1); - Eigen::DSizes one_by_class(1, num_classes); - - auto shifted_logits = (logits - - logits.maximum(along_class) - .eval() - .reshape(batch_by_one) - .broadcast(one_by_class)); - - softmax.device(context.GetEigenDevice()) = shifted_logits.exp(); - softmax.device(context.GetEigenDevice()) = - (softmax * - softmax.sum(along_class) - .inverse() - .eval() - .reshape(batch_by_one) - .broadcast(one_by_class)); -} +template class SoftmaxFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/softmax_function.cu b/paddle/operators/math/softmax_function.cu new file mode 100644 index 000000000..486697a16 --- /dev/null +++ b/paddle/operators/math/softmax_function.cu @@ -0,0 +1,27 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU + +#include "paddle/operators/math/softmax_function.h" + +namespace paddle { +namespace operators { +namespace math { + +template class SoftmaxFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/softmax_function.h b/paddle/operators/math/softmax_function.h index 2e1b2a7ad..ce29a69bc 100644 --- a/paddle/operators/math/softmax_function.h +++ b/paddle/operators/math/softmax_function.h @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" @@ -21,9 +21,44 @@ namespace paddle { namespace operators { namespace math { +template +using EigenMatrix = framework::EigenMatrix; + template -void softmax(const framework::Tensor* X, framework::Tensor* Y, - const framework::ExecutionContext& context); +class SoftmaxFunctor { + public: + void operator()(const framework::Tensor* X, framework::Tensor* Y, + const framework::ExecutionContext& context) { + auto logits = EigenMatrix::From(*X); + auto softmax = EigenMatrix::From(*Y); + + const int kBatchDim = 0; + const int kClassDim = 1; + + const int batch_size = logits.dimension(kBatchDim); + const int num_classes = logits.dimension(kClassDim); + + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + + auto shifted_logits = (logits - + logits.maximum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); + + softmax.device(context.GetEigenDevice()) = shifted_logits.exp(); + softmax.device(context.GetEigenDevice()) = + (softmax * + softmax.sum(along_class) + .inverse() + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); + } +}; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index ff054a59a..6d14542a7 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -35,7 +35,7 @@ class SoftmaxKernel : public framework::OpKernel { // allocate memory on device. Y->mutable_data(context.GetPlace()); - math::softmax(X, Y, context); + math::SoftmaxFunctor()(X, Y, context); } }; diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index 2edf00766..b4aa9aab4 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -23,13 +23,13 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto logits = ctx.Input("logits"); + auto logits = ctx.Input("Logits"); PADDLE_ENFORCE( logits->dims().size() == 2UL, "The input of softmax_with_cross_entropy should be a 2-d tensor."); - PADDLE_ENFORCE(ctx.Input("lables")->dims().size() == 1UL, + PADDLE_ENFORCE(ctx.Input("Label")->dims().size() == 1UL, "The label should be a 1-d tensor."); - ctx.Output("Y")->Resize({logits->dims()[0]}); + ctx.Output("Label")->Resize({logits->dims()[0]}); } }; @@ -39,11 +39,15 @@ class SoftmaxWithCrossEntropyOpMaker SoftmaxWithCrossEntropyOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("logits", + AddInput("Logits", "The unscaled log probabilities which is a 2-D tensor with" "shape [N x K]. N is the batch_size, and K is the class number."); - AddInput("label", "The ground truth. A 1-D tensor with shape N."); - AddOutput("Y", "A 1-D tensor with shape N."); + AddInput("Label", "The ground truth. A 1-D tensor with shape N."); + AddOutput("Softmax", + "Store the outputs of softmax function, " + "which will be used in backward calculation.") + .AsIntermediate(); + AddOutput("Loss", "A 1-D tensor with shape N."); AddComment(R"DOC( Cross entropy loss with softmax are used as the output layer extensively. This operator computes the softmax normalized values for each row of the input @@ -67,21 +71,21 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should be not null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")), - "Input(Y@GRAD) should be not null."); - PADDLE_ENFORCE_EQ(ctx.Input("Y")->dims(), - ctx.Input(framework::GradVarName("Y"))->dims(), - "Input(Y) and its gradients should have a same shape."); - - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("labels"), - "Input(lables) should be not null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("logits")), - "Input(logits@GRAD) should be not null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Loss"), + "Input(Loss) should be not null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Loss")), + "Input(Loss@GRAD) should be not null."); PADDLE_ENFORCE_EQ( - ctx.Input("logits")->dims(), - ctx.Input(framework::GradVarName("logits"))->dims(), - "Input(logits) and its gradients should have a same shape."); + ctx.Input("Logits")->dims(), + ctx.Input(framework::GradVarName("Logits"))->dims(), + "Input(Logits) and its gradients should have a same shape."); + PADDLE_ENFORCE_EQ( + ctx.Input("Logits")->dims(), + ctx.Input(framework::GradVarName("Logits"))->dims(), + "Input(Logits) and its gradients should have a same shape."); + + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), + "Input(Lable) should be not null."); } }; diff --git a/paddle/operators/softmax_with_cross_entropy_op.h b/paddle/operators/softmax_with_cross_entropy_op.h index 418fb540b..4c019a759 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.h +++ b/paddle/operators/softmax_with_cross_entropy_op.h @@ -15,6 +15,8 @@ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/softmax_function.h" +#include "paddle/operators/math/utils.h" namespace paddle { namespace operators { @@ -27,7 +29,30 @@ using EigenMatrix = framework::EigenMatrix; template class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override {} + void Compute(const framework::ExecutionContext& context) const override { + // Calculate ths softmax outputs. + const Tensor* logits = context.Input("Logits"); + Tensor* softmax = context.Output("Softmax"); + // allocate memory on device. + softmax->mutable_data(context.GetPlace()); + math::SoftmaxFunctor()(logits, softmax, context); + + // Calculate the cross entropy loss based on hard labels. + T* softmax_out = softmax->data(); + const int* label_data = context.Input("label")->data(); + + Tensor* loss = context.Output("Loss"); + loss->mutable_data(context.GetPlace()); + T* loss_data = loss->data(); + + const int batch_size = logits->dims()[0]; + const int class_num = logits->dims()[1]; + + for (int i = 0; i < batch_size; ++i) { + int index = i * class_num + label_data[i]; + loss_data[i] = -math::tolerable_value(std::log(softmax_out[index])); + } + } }; template diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index c2fc102a8..6c1dc4044 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -1,6 +1,7 @@ import unittest import numpy from op_test import OpTest +import pdb class TestCrossEntropy(OpTest): @@ -10,18 +11,20 @@ class TestCrossEntropy(OpTest): class_num = 10 X = numpy.random.uniform(0.1, 1.0, [batch_size, class_num]).astype("float32") - label = (class_num / 2) * numpy.ones(batch_size).astype("int32") - self.inputs = {'X': X, 'label': label} + + labels = numpy.random.randint(0, class_num, batch_size, dtype="int32") + + self.inputs = {"X": X, "label": labels} Y = [] for i in range(0, batch_size): - Y.append(-numpy.log(X[i][label[i]])) - self.outputs = {'Y': numpy.array(Y).astype("float32")} + Y.append(-numpy.log(X[i][labels[i]])) + self.outputs = {"Y": numpy.array(Y).astype("float32")} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y') + self.check_grad(["X"], "Y") if __name__ == "__main__": diff --git a/python/paddle/v2/framework/tests/test_softmax_with_cost_op.py b/python/paddle/v2/framework/tests/test_softmax_with_cost_op.py deleted file mode 100644 index f7b9f54a9..000000000 --- a/python/paddle/v2/framework/tests/test_softmax_with_cost_op.py +++ /dev/null @@ -1,22 +0,0 @@ -import unittest - -import numpy as np - -from gradient_checker import GradientChecker, create_op -from op_test_util import OpTestMeta - - -class TestSoftmaxWithLossOp(unittest.TestCase): - __metaclass__ = OpTestMeta - - def setUp(self): - pass - - -class SoftmaxWithLossGradOpTest(GradientChecker): - def test_softmax(self): - pass - - -if __name__ == '__main__': - unittest.main() diff --git a/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py new file mode 100644 index 000000000..611611056 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py @@ -0,0 +1,39 @@ +import unittest +import numpy as np +import pdb + +from op_test import OpTest +from test_softmax_op import stable_softmax + + +class TestSoftmaxWithCrossEntropyOp(OpTest): + def setUp(self): + self.op_type = "softmax_with_cross_entropy" + + MAX_BATCH_SIZE = 23 + MAX_CLASS_NUM = 255 + + batch_size = np.random.randint(1, MAX_BATCH_SIZE, 1)[0] + class_num = np.random.randint(2, MAX_CLASS_NUM, 1)[0] + + logits = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + softmax = np.apply_along_axis(stable_softmax, 1, logits) + labels = np.random.randint(0, class_num, batch_size, dtype="int32") + + cross_entropy = [ + -np.log(softmax[i][labels[i]]) for i in range(softmax.shape[0]) + ] + + self.inputs = {"Logits": logits, "Label": labels} + self.outputs = {"Loss": cross_entropy} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + pass + + +if __name__ == "__main__": + unittest.main() -- GitLab