diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index f9ea25ab045a02be5ab9ed81ef9c679126d3a188..94e00ac3824943a0f834dd13cedb43d9387d885e 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -55,10 +55,12 @@ set(DEPS_OPS minus_op mul_op recurrent_op - scale_op) + scale_op + softmax_op) op_library(identity_op DEPS scale_op) op_library(minus_op DEPS scale_op) op_library(mul_op DEPS math_function) +op_library(softmax_op DEPS math_function) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor operator net_op) op_library(scale_op DEPS net_op) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index f8333f34f7b4c7b0f9a0f14a7a33f9d98e1d331c..8ce39db6214b9550a512963e85efebb262121336 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,9 +1,9 @@ - if(WITH_GPU) - nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc - im2col.cu DEPS cblas device_context) + nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc + im2col.cu softmax_function.cc DEPS cblas device_context operator) else() - cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context) + cc_library(math_function SRCS math_function.cc im2col.cc + softmax_function.cc DEPS cblas device_context operator) endif() nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/softmax_function.cc b/paddle/operators/math/softmax_function.cc new file mode 100644 index 0000000000000000000000000000000000000000..7edb632d3195e217abd3bc951a6e206c3ddc71d9 --- /dev/null +++ b/paddle/operators/math/softmax_function.cc @@ -0,0 +1,63 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef PADDLE_ONLY_CPU +#define EIGEN_USE_GPU +#endif + +#include "paddle/operators/math/softmax_function.h" + +namespace paddle { +namespace operators { +namespace math { + +template +using EigenMatrix = framework::EigenMatrix; + +template +void softmax(const framework::Tensor* X, framework::Tensor* Y, + const framework::ExecutionContext& context) { + auto logits = EigenMatrix::From(*X); + auto softmax = EigenMatrix::From(*Y); + + const int kBatchDim = 0; + const int kClassDim = 1; + + const int batch_size = logits.dimension(kBatchDim); + const int num_classes = logits.dimension(kClassDim); + + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + + auto shifted_logits = (logits - + logits.maximum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); + + softmax.device(context.GetEigenDevice()) = shifted_logits.exp(); + softmax.device(context.GetEigenDevice()) = + (softmax * + softmax.sum(along_class) + .inverse() + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/softmax_function.h b/paddle/operators/math/softmax_function.h new file mode 100644 index 0000000000000000000000000000000000000000..2e1b2a7ad00a616eac02772c9f278f24f8155111 --- /dev/null +++ b/paddle/operators/math/softmax_function.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace operators { +namespace math { + +template +void softmax(const framework::Tensor* X, framework::Tensor* Y, + const framework::ExecutionContext& context); +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 8a3a5ab927c0e2937936fcc973f000d4d95c3dbc..ff054a59aed1862a8aff9eb58bc902d7c98b42cc 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/softmax_function.h" namespace paddle { namespace operators { @@ -30,36 +31,11 @@ class SoftmaxKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto X = context.Input("X"); auto Y = context.Output("Y"); - Y->mutable_data(context.GetPlace()); - - auto logits = EigenMatrix::From(*X); - auto softmax = EigenMatrix::From(*Y); - - const int kBatchDim = 0; - const int kClassDim = 1; - const int batch_size = logits.dimension(kBatchDim); - const int num_classes = logits.dimension(kClassDim); + // allocate memory on device. + Y->mutable_data(context.GetPlace()); - Eigen::DSizes along_class(kClassDim); - Eigen::DSizes batch_by_one(batch_size, 1); - Eigen::DSizes one_by_class(1, num_classes); - - auto shifted_logits = (logits - - logits.maximum(along_class) - .eval() - .reshape(batch_by_one) - .broadcast(one_by_class)); - - softmax.device(context.GetEigenDevice()) = shifted_logits.exp(); - - softmax.device(context.GetEigenDevice()) = - (softmax * - softmax.sum(along_class) - .inverse() - .eval() - .reshape(batch_by_one) - .broadcast(one_by_class)); + math::softmax(X, Y, context); } };