diff --git a/paddle/fluid/operators/math/softmax_impl.h b/paddle/fluid/operators/math/softmax_impl.h index fae5160cc82cd9a93c8a80b74293da82175c9a43..d51d638e0c19f43f9b0a91adbac15dffcdf14588 100644 --- a/paddle/fluid/operators/math/softmax_impl.h +++ b/paddle/fluid/operators/math/softmax_impl.h @@ -37,67 +37,130 @@ struct ValueClip { }; template -void SoftmaxEigen(const DeviceContext& context, const int axis_dim, +class SoftmaxEigen { + public: + void operator()(const DeviceContext& context, const int axis_dim, const framework::Tensor* X, framework::Tensor* Y) { - constexpr int kBatchDim = 0; - constexpr int kClassDim = 1; - constexpr int kAxisDim = 1; - - auto logits = EigenMatrix::From(*X); - auto softmax = EigenMatrix::From(*Y); - - const int batch_size = logits.dimension(kBatchDim); - const int num_classes = logits.dimension(kClassDim); - const int num_remain = num_classes / axis_dim; - - Eigen::DSizes along_axis(kAxisDim); - Eigen::DSizes batch_classes(batch_size, num_classes); - Eigen::DSizes batch_by_one(batch_size, 1); - Eigen::DSizes one_by_class(1, num_classes); - Eigen::DSizes batch_one_remain(batch_size, 1, num_remain); - Eigen::DSizes one_axis_one(1, axis_dim, 1); - Eigen::DSizes one_axis(1, axis_dim); - Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); - - // For numerical stability, logits should be shifted by maximum number along - // axis, calculate shifted_logits into softmax tensor for memory reuse. - if (num_remain == 1) { - // axis == -1, axis and class in same dimension, calculate along - // class dimension directly for higher performance - softmax.device(*context.eigen_device()) = (logits - - logits.maximum(along_axis) - .eval() - .reshape(batch_by_one) - .broadcast(one_by_class)) - .unaryExpr(ValueClip()); - } else { - // axis != -1, class dimension split into (axis, remain), max and sum - // should be calculated along axis dimension + constexpr int kBatchDim = 0; + constexpr int kClassDim = 1; + constexpr int kAxisDim = 1; + + auto logits = EigenMatrix::From(*X); + auto softmax = EigenMatrix::From(*Y); + + const int batch_size = logits.dimension(kBatchDim); + const int num_classes = logits.dimension(kClassDim); + const int num_remain = num_classes / axis_dim; + + Eigen::DSizes along_axis(kAxisDim); + Eigen::DSizes batch_classes(batch_size, num_classes); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + Eigen::DSizes batch_one_remain(batch_size, 1, num_remain); + Eigen::DSizes one_axis_one(1, axis_dim, 1); + Eigen::DSizes one_axis(1, axis_dim); + Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); + + // For numerical stability, logits should be shifted by maximum number along + // axis, calculate shifted_logits into softmax tensor for memory reuse. + if (num_remain == 1) { + // axis == -1, axis and class in same dimension, calculate along + // class dimension directly for higher performance + softmax.device(*context.eigen_device()) = (logits - + logits.maximum(along_axis) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)) + .unaryExpr(ValueClip()); + } else { + // axis != -1, class dimension split into (axis, remain), max and sum + // should be calculated along axis dimension + softmax.device(*context.eigen_device()) = + (logits.reshape(batch_axis_remain) - + logits.reshape(batch_axis_remain) + .maximum(along_axis) + .eval() + .reshape(batch_one_remain) + .broadcast(one_axis_one) + .reshape(batch_classes)) + .unaryExpr(ValueClip()); + } + + softmax.device(*context.eigen_device()) = softmax.exp(); softmax.device(*context.eigen_device()) = - (logits.reshape(batch_axis_remain) - - logits.reshape(batch_axis_remain) - .maximum(along_axis) + (softmax * + softmax.reshape(batch_axis_remain) + .sum(along_axis) + .inverse() .eval() - .reshape(batch_one_remain) - .broadcast(one_axis_one) - .reshape(batch_classes)) - .unaryExpr(ValueClip()); + .broadcast(one_axis)); } +}; - softmax.device(*context.eigen_device()) = softmax.exp(); - softmax.device(*context.eigen_device()) = (softmax * - softmax.reshape(batch_axis_remain) - .sum(along_axis) - .inverse() - .eval() - .broadcast(one_axis)); -} +template +class SoftmaxEigen { + public: + void operator()(const DeviceContext& context, const int axis_dim, + const framework::Tensor* X, framework::Tensor* Y) { + constexpr int kBatchDim = 0; + constexpr int kClassDim = 1; + constexpr int kAxisDim = 1; + + auto logits = EigenMatrix::From(*X); + auto softmax = EigenMatrix::From(*Y); + + const int batch_size = logits.dimension(kBatchDim); + const int num_classes = logits.dimension(kClassDim); + const int num_remain = num_classes / axis_dim; + + Eigen::DSizes along_axis(kAxisDim); + Eigen::DSizes batch_classes(batch_size, num_classes); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + Eigen::DSizes batch_one_remain(batch_size, 1, num_remain); + Eigen::DSizes one_axis_one(1, axis_dim, 1); + Eigen::DSizes one_axis(1, axis_dim); + Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); + + // For numerical stability, logits should be shifted by maximum number along + // axis, calculate shifted_logits into softmax tensor for memory reuse. + if (num_remain == 1) { + // axis == -1, axis and class in same dimension, calculate along + // class dimension directly for higher performance + softmax.device(*context.eigen_device()) = + (logits - + logits.maximum(along_axis) + .reshape(batch_by_one) + .broadcast(one_by_class)) + .unaryExpr(ValueClip()); + } else { + // axis != -1, class dimension split into (axis, remain), max and sum + // should be calculated along axis dimension + softmax.device(*context.eigen_device()) = + (logits.reshape(batch_axis_remain) - + logits.reshape(batch_axis_remain) + .maximum(along_axis) + .reshape(batch_one_remain) + .broadcast(one_axis_one) + .reshape(batch_classes)) + .unaryExpr(ValueClip()); + } + + softmax.device(*context.eigen_device()) = softmax.exp(); + softmax.device(*context.eigen_device()) = + (softmax * + softmax.reshape(batch_axis_remain) + .sum(along_axis) + .inverse() + .broadcast(one_axis)); + } +}; template void SoftmaxFunctor::operator()( const DeviceContext& context, const int axis_dim, const framework::Tensor* X, framework::Tensor* Y) { - SoftmaxEigen(context, axis_dim, X, Y); + SoftmaxEigen()(context, axis_dim, X, Y); } template @@ -137,7 +200,7 @@ class SoftmaxFunctor> { out_data += num_classes; } } else { - SoftmaxEigen(context, axis_dim, X, Y); + SoftmaxEigen()(context, axis_dim, X, Y); } } }; @@ -162,41 +225,76 @@ class SoftmaxFunctor> { }; template -void SoftmaxGradEigen(const DeviceContext& context, const int axis_dim, - const framework::Tensor* y, - const framework::Tensor* y_grad, - framework::Tensor* x_grad) { - auto softmax = EigenMatrix::From(*y); - auto softmax_grad = EigenMatrix::From(*y_grad); - auto logits_grad = EigenMatrix::From(*x_grad); - - constexpr int kBatchDim = 0; - constexpr int kClassDim = 1; - - const int batch_size = softmax.dimension(kBatchDim); - const int num_classes = softmax.dimension(kClassDim); - const int num_remain = num_classes / axis_dim; - - Eigen::DSizes along_class(kClassDim); - Eigen::DSizes batch_by_one(batch_size, 1); - Eigen::DSizes one_by_class(1, num_classes); - Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); - Eigen::DSizes one_axis(1, axis_dim); - - auto dot = (softmax * softmax_grad) - .reshape(batch_axis_remain) - .sum(along_class) - .eval() - .broadcast(one_axis); - logits_grad.device(*context.eigen_device()) = (softmax_grad - dot) * softmax; -} +class SoftmaxGradEigen { + public: + void operator()(const DeviceContext& context, const int axis_dim, + const framework::Tensor* y, const framework::Tensor* y_grad, + framework::Tensor* x_grad) { + auto softmax = EigenMatrix::From(*y); + auto softmax_grad = EigenMatrix::From(*y_grad); + auto logits_grad = EigenMatrix::From(*x_grad); + + constexpr int kBatchDim = 0; + constexpr int kClassDim = 1; + + const int batch_size = softmax.dimension(kBatchDim); + const int num_classes = softmax.dimension(kClassDim); + const int num_remain = num_classes / axis_dim; + + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); + Eigen::DSizes one_axis(1, axis_dim); + + auto dot = (softmax * softmax_grad) + .reshape(batch_axis_remain) + .sum(along_class) + .eval() + .broadcast(one_axis); + logits_grad.device(*context.eigen_device()) = + (softmax_grad - dot) * softmax; + } +}; + +template +class SoftmaxGradEigen { + public: + void operator()(const DeviceContext& context, const int axis_dim, + const framework::Tensor* y, const framework::Tensor* y_grad, + framework::Tensor* x_grad) { + auto softmax = EigenMatrix::From(*y); + auto softmax_grad = EigenMatrix::From(*y_grad); + auto logits_grad = EigenMatrix::From(*x_grad); + + constexpr int kBatchDim = 0; + constexpr int kClassDim = 1; + + const int batch_size = softmax.dimension(kBatchDim); + const int num_classes = softmax.dimension(kClassDim); + const int num_remain = num_classes / axis_dim; + + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); + Eigen::DSizes one_axis(1, axis_dim); + + auto dot = (softmax * softmax_grad) + .reshape(batch_axis_remain) + .sum(along_class) + .broadcast(one_axis); + logits_grad.device(*context.eigen_device()) = + (softmax_grad - dot) * softmax; + } +}; template void SoftmaxGradFunctor::operator()( const DeviceContext& context, const int axis_dim, const framework::Tensor* y, const framework::Tensor* y_grad, framework::Tensor* x_grad) { - SoftmaxGradEigen(context, axis_dim, y, y_grad, x_grad); + SoftmaxGradEigen()(context, axis_dim, y, y_grad, x_grad); } template @@ -228,7 +326,8 @@ class SoftmaxGradFunctor> { in_grad += num_classes; } } else { - SoftmaxGradEigen(context, axis_dim, y, y_grad, x_grad); + SoftmaxGradEigen()(context, axis_dim, y, y_grad, + x_grad); } } };