diff --git a/paddle/fluid/operators/math/softmax_impl.h b/paddle/fluid/operators/math/softmax_impl.h index d568b186a067a997b043d82bd02b938e7c609eb5..fae5160cc82cd9a93c8a80b74293da82175c9a43 100644 --- a/paddle/fluid/operators/math/softmax_impl.h +++ b/paddle/fluid/operators/math/softmax_impl.h @@ -52,26 +52,45 @@ void SoftmaxEigen(const DeviceContext& context, const int axis_dim, Eigen::DSizes along_axis(kAxisDim); Eigen::DSizes batch_classes(batch_size, num_classes); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); Eigen::DSizes batch_one_remain(batch_size, 1, num_remain); Eigen::DSizes one_axis_one(1, axis_dim, 1); + Eigen::DSizes one_axis(1, axis_dim); Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); - auto logits_reshape = logits.reshape(batch_axis_remain); - auto shifted_logits = (logits_reshape - - logits_reshape.maximum(along_axis) - .eval() - .reshape(batch_one_remain) - .broadcast(one_axis_one)) - .unaryExpr(ValueClip()); - - auto exp = shifted_logits.exp(); - softmax.device(*context.eigen_device()) = (exp * - exp.sum(along_axis) + // For numerical stability, logits should be shifted by maximum number along + // axis, calculate shifted_logits into softmax tensor for memory reuse. + if (num_remain == 1) { + // axis == -1, axis and class in same dimension, calculate along + // class dimension directly for higher performance + softmax.device(*context.eigen_device()) = (logits - + logits.maximum(along_axis) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)) + .unaryExpr(ValueClip()); + } else { + // axis != -1, class dimension split into (axis, remain), max and sum + // should be calculated along axis dimension + softmax.device(*context.eigen_device()) = + (logits.reshape(batch_axis_remain) - + logits.reshape(batch_axis_remain) + .maximum(along_axis) + .eval() + .reshape(batch_one_remain) + .broadcast(one_axis_one) + .reshape(batch_classes)) + .unaryExpr(ValueClip()); + } + + softmax.device(*context.eigen_device()) = softmax.exp(); + softmax.device(*context.eigen_device()) = (softmax * + softmax.reshape(batch_axis_remain) + .sum(along_axis) .inverse() .eval() - .reshape(batch_one_remain) - .broadcast(one_axis_one)) - .reshape(batch_classes); + .broadcast(one_axis)); } template