/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/platform/cpu_info.h" namespace paddle { namespace operators { namespace math { template using EigenMatrix = framework::EigenMatrix; template struct ValueClip { HOSTDEVICE T operator()(const T& x) const { const T kThreshold = static_cast(-64.); return x < kThreshold ? kThreshold : x; } }; template void SoftmaxEigen(const DeviceContext& context, const int axis_dim, const framework::Tensor* X, framework::Tensor* Y) { constexpr int kBatchDim = 0; constexpr int kClassDim = 1; constexpr int kAxisDim = 1; auto logits = EigenMatrix::From(*X); auto softmax = EigenMatrix::From(*Y); const int batch_size = logits.dimension(kBatchDim); const int num_classes = logits.dimension(kClassDim); const int num_remain = num_classes / axis_dim; Eigen::DSizes along_axis(kAxisDim); Eigen::DSizes batch_classes(batch_size, num_classes); Eigen::DSizes batch_by_one(batch_size, 1); Eigen::DSizes one_by_class(1, num_classes); Eigen::DSizes batch_one_remain(batch_size, 1, num_remain); Eigen::DSizes one_axis_one(1, axis_dim, 1); Eigen::DSizes one_axis(1, axis_dim); Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); // For numerical stability, logits should be shifted by maximum number along // axis, calculate shifted_logits into softmax tensor for memory reuse. if (num_remain == 1) { // axis == -1, axis and class in same dimension, calculate along // class dimension directly for higher performance softmax.device(*context.eigen_device()) = (logits - logits.maximum(along_axis) .eval() .reshape(batch_by_one) .broadcast(one_by_class)) .unaryExpr(ValueClip()); } else { // axis != -1, class dimension split into (axis, remain), max and sum // should be calculated along axis dimension softmax.device(*context.eigen_device()) = (logits.reshape(batch_axis_remain) - logits.reshape(batch_axis_remain) .maximum(along_axis) .eval() .reshape(batch_one_remain) .broadcast(one_axis_one) .reshape(batch_classes)) .unaryExpr(ValueClip()); } softmax.device(*context.eigen_device()) = softmax.exp(); softmax.device(*context.eigen_device()) = (softmax * softmax.reshape(batch_axis_remain) .sum(along_axis) .inverse() .eval() .broadcast(one_axis)); } template void SoftmaxFunctor::operator()( const DeviceContext& context, const int axis_dim, const framework::Tensor* X, framework::Tensor* Y) { SoftmaxEigen(context, axis_dim, X, Y); } template using enable_if_CPU = typename std::enable_if< std::is_same::value>::type; template class SoftmaxFunctor> { public: void operator()(const DeviceContext& context, const int axis_dim, const framework::Tensor* X, framework::Tensor* Y) { auto in_dims = X->dims(); constexpr int kBatchDim = 0; constexpr int kClassDim = 1; const int num_classes = in_dims[kClassDim]; const int batch_size = in_dims[kBatchDim]; const int num_remain = num_classes / axis_dim; if (num_remain == 1 && platform::MayIUse(platform::avx)) { const T* in_data = X->data(); T* out_data = Y->data(); for (int bs = 0; bs < batch_size; ++bs) { T max_val = *std::max_element(in_data, in_data + num_classes); max_val *= static_cast(-1); vec_add_bias(num_classes, max_val, in_data, out_data); vec_clip(num_classes, static_cast(-64), out_data, out_data); vec_exp(num_classes, out_data, out_data); T sum = 0; vec_sum(num_classes, out_data, &sum); sum = static_cast(1) / sum; vec_scal(num_classes, sum, out_data, out_data); in_data += num_classes; out_data += num_classes; } } else { SoftmaxEigen(context, axis_dim, X, Y); } } }; template class SoftmaxFunctor> { public: void operator()(const DeviceContext& context, const int axis_dim, const framework::Tensor* X, framework::Tensor* Y) { auto in_dims = X->dims(); const float* in_data = X->data(); float* out_data = Y->data(); const int kBatchDim = 0; const int kClassDim = 1; // 2D data. Batch x C auto compute_softmax = jit::KernelFuncs, platform::CPUPlace>::Cache() .At(in_dims[kClassDim]); compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim], in_dims[kClassDim] / axis_dim); } }; template void SoftmaxGradEigen(const DeviceContext& context, const int axis_dim, const framework::Tensor* y, const framework::Tensor* y_grad, framework::Tensor* x_grad) { auto softmax = EigenMatrix::From(*y); auto softmax_grad = EigenMatrix::From(*y_grad); auto logits_grad = EigenMatrix::From(*x_grad); constexpr int kBatchDim = 0; constexpr int kClassDim = 1; const int batch_size = softmax.dimension(kBatchDim); const int num_classes = softmax.dimension(kClassDim); const int num_remain = num_classes / axis_dim; Eigen::DSizes along_class(kClassDim); Eigen::DSizes batch_by_one(batch_size, 1); Eigen::DSizes one_by_class(1, num_classes); Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); Eigen::DSizes one_axis(1, axis_dim); auto dot = (softmax * softmax_grad) .reshape(batch_axis_remain) .sum(along_class) .eval() .broadcast(one_axis); logits_grad.device(*context.eigen_device()) = (softmax_grad - dot) * softmax; } template void SoftmaxGradFunctor::operator()( const DeviceContext& context, const int axis_dim, const framework::Tensor* y, const framework::Tensor* y_grad, framework::Tensor* x_grad) { SoftmaxGradEigen(context, axis_dim, y, y_grad, x_grad); } template class SoftmaxGradFunctor> { public: void operator()(const DeviceContext& context, const int axis_dim, const framework::Tensor* y, const framework::Tensor* y_grad, framework::Tensor* x_grad) { auto out_dims = y->dims(); constexpr int kBatchDim = 0; constexpr int kClassDim = 1; const int num_classes = out_dims[kClassDim]; const int batch_size = out_dims[kBatchDim]; const int num_remain = num_classes / axis_dim; if (num_remain == 1 && platform::MayIUse(platform::avx)) { const T* out_data = y->data(); const T* out_grad = y_grad->data(); T* in_grad = x_grad->data(); for (int bs = 0; bs < batch_size; ++bs) { T scalar; vec_mul_reduce(num_classes, out_grad, out_data, &scalar); scalar *= static_cast(-1); vec_add_bias(num_classes, scalar, out_grad, in_grad); vec_mul(num_classes, out_data, in_grad, in_grad); out_data += num_classes; out_grad += num_classes; in_grad += num_classes; } } else { SoftmaxGradEigen(context, axis_dim, y, y_grad, x_grad); } } }; } // namespace math } // namespace operators } // namespace paddle