diff --git a/paddle/fluid/operators/math/softmax_impl.h b/paddle/fluid/operators/math/softmax_impl.h index 4fb03cdce0c78ed72e69e3d70e836ee8a914110a..d568b186a067a997b043d82bd02b938e7c609eb5 100644 --- a/paddle/fluid/operators/math/softmax_impl.h +++ b/paddle/fluid/operators/math/softmax_impl.h @@ -41,6 +41,7 @@ void SoftmaxEigen(const DeviceContext& context, const int axis_dim, const framework::Tensor* X, framework::Tensor* Y) { constexpr int kBatchDim = 0; constexpr int kClassDim = 1; + constexpr int kAxisDim = 1; auto logits = EigenMatrix::From(*X); auto softmax = EigenMatrix::From(*Y); @@ -49,26 +50,28 @@ void SoftmaxEigen(const DeviceContext& context, const int axis_dim, const int num_classes = logits.dimension(kClassDim); const int num_remain = num_classes / axis_dim; - Eigen::DSizes along_class(kClassDim); - Eigen::DSizes batch_by_one(batch_size, 1); - Eigen::DSizes one_by_class(1, num_classes); + Eigen::DSizes along_axis(kAxisDim); + Eigen::DSizes batch_classes(batch_size, num_classes); + Eigen::DSizes batch_one_remain(batch_size, 1, num_remain); + Eigen::DSizes one_axis_one(1, axis_dim, 1); Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); - Eigen::DSizes one_axis(1, axis_dim); - auto shifted_logits = (logits - - logits.maximum(along_class) + auto logits_reshape = logits.reshape(batch_axis_remain); + auto shifted_logits = (logits_reshape - + logits_reshape.maximum(along_axis) .eval() - .reshape(batch_by_one) - .broadcast(one_by_class)) + .reshape(batch_one_remain) + .broadcast(one_axis_one)) .unaryExpr(ValueClip()); - softmax.device(*context.eigen_device()) = shifted_logits.exp(); - softmax.device(*context.eigen_device()) = (softmax * - softmax.reshape(batch_axis_remain) - .sum(along_class) + auto exp = shifted_logits.exp(); + softmax.device(*context.eigen_device()) = (exp * + exp.sum(along_axis) .inverse() .eval() - .broadcast(one_axis)); + .reshape(batch_one_remain) + .broadcast(one_axis_one)) + .reshape(batch_classes); } template