diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index e56895c63a426b782f7b46091bc86c367d49899d..da39c2cb550c1faebbfec0d9214d2ae607d71f9e 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -94,8 +94,8 @@ set(DEPS_OPS op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor net_op) op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) -op_library(cross_entropy_op DEPS cross_entropy_function) -op_library(softmax_with_cross_entropy_op DEPS cross_entropy_function softmax_function) +op_library(cross_entropy_op DEPS cross_entropy) +op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 91ae3d49f1df51d9524547f7765285bff9dbb5c5..b60e945aa86463ce7d69afeb76de15edd83a44d6 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,15 +1,14 @@ if(WITH_GPU) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator) - nv_library(softmax_function SRCS softmax.cc softmax.cu - DEPS operator) - nv_library(cross_entropy_function SRCS cross_entropy.cc cross_entropy.cu + nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) + nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) - cc_library(softmax_function SRCS softmax.cc DEPS operator) - cc_library(cross_entropy_function SRCS cross_entropy.cc DEPS operator) + cc_library(softmax SRCS softmax.cc DEPS operator) + cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) endif() nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/softmax.cc b/paddle/operators/math/softmax.cc index ac9f3c4bf61bf8e13faa17387f1112756db9a100..0ba8197ab8b64649c8adcf67771ba01eca7f1d10 100644 --- a/paddle/operators/math/softmax.cc +++ b/paddle/operators/math/softmax.cc @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include "paddle/operators/math/softmax.h" @@ -19,6 +19,7 @@ namespace operators { namespace math { template class SoftmaxFunctor; +template class SoftmaxGradFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/softmax.cu b/paddle/operators/math/softmax.cu index 4c3df0550e7ca6f4310db1d35cc34d5c73a2dd16..99f988d51e4b16c3f3bfd9c76b411bb53619603e 100644 --- a/paddle/operators/math/softmax.cu +++ b/paddle/operators/math/softmax.cu @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #define EIGEN_USE_GPU @@ -21,6 +21,7 @@ namespace operators { namespace math { template class SoftmaxFunctor; +template class SoftmaxGradFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/softmax.h b/paddle/operators/math/softmax.h index 3d2f0d0aecffcd0fe51166c3d863aa8b91bba196..3c05a86bce96231511e9d45068659659090738f8 100644 --- a/paddle/operators/math/softmax.h +++ b/paddle/operators/math/softmax.h @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" @@ -68,6 +68,37 @@ class SoftmaxFunctor { .broadcast(one_by_class)); } }; + +template +class SoftmaxGradFunctor { + public: + void operator()(const framework::ExecutionContext& context, + const framework::Tensor* y, const framework::Tensor* y_grad, + framework::Tensor* x_grad) { + auto softmax = EigenMatrix::From(*y); + auto softmax_grad = EigenMatrix::From(*y_grad); + auto logits_grad = EigenMatrix::From(*x_grad); + + const int kBatchDim = 0; + const int kClassDim = 1; + + const int batch_size = softmax.dimension(kBatchDim); + const int num_classes = softmax.dimension(kClassDim); + + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + + auto dot = (softmax * softmax_grad) + .sum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class); + logits_grad.device(context.GetEigenDevice()) = + (softmax_grad - dot) * softmax; + } +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 7220f486be055e1b841a06b15f519717c54f575c..3d35507a9ac6963939e125e6d3ab6f5eb67300a0 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -29,8 +29,8 @@ template class SoftmaxKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto X = context.Input("X"); - auto Y = context.Output("Y"); + auto* X = context.Input("X"); + auto* Y = context.Output("Y"); // allocate memory on device. Y->mutable_data(context.GetPlace()); @@ -43,29 +43,14 @@ template class SoftmaxGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto Y = context.Input("Y"); - auto dY = context.Input(framework::GradVarName("Y")); - auto dX = context.Output(framework::GradVarName("X")); - dX->mutable_data(context.GetPlace()); - - const int batch_size = Y->dims()[0]; - const int class_num = Y->dims()[1]; - - Eigen::DSizes along_class(1); - Eigen::DSizes batch_by_one(batch_size, 1); - Eigen::DSizes one_by_class(1, class_num); + auto* Y = context.Input("Y"); + auto* dY = context.Input(framework::GradVarName("Y")); + auto* dX = context.Output(framework::GradVarName("X")); - auto Y_eigen = EigenMatrix::From(*Y); - auto dY_eigen = EigenMatrix::From(*dY); - auto dX_eigen = EigenMatrix::From(*dX); - auto place = context.GetEigenDevice(); + // allocate memory on device. + dX->mutable_data(context.GetPlace()); - auto dot = (Y_eigen * dY_eigen) - .sum(along_class) - .eval() - .reshape(batch_by_one) - .broadcast(one_by_class); - dX_eigen.device(place) = (dY_eigen - dot) * Y_eigen; + math::SoftmaxGradFunctor()(context, Y, dY, dX); } };