diff --git a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu index 6a2dab9005a63122c4b15c92c9c428cfd2d95570..114c52f608f3416d275478d2d81c42c4deaee211 100644 --- a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/platform/collective_helper.h" @@ -337,9 +338,7 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor { eigen_logits_max.device(*dev_ctx.eigen_device()) = eigen_logits.maximum(along_axis); - std::vector in_out; - in_out.push_back(logits_max); - pg->AllReduce(in_out, in_out, opts)->Synchronize(); + pg->AllReduce(&logits_max, logits_max, opts, true, true); // step 2, obtain logit - logit_max Eigen::DSizes batch_by_one(N, 1); @@ -390,10 +389,8 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor { nranks); } - in_out.clear(); - in_out.push_back(predicted_logits); opts.reduce_op = distributed::ReduceOp::SUM; - pg->AllReduce(in_out, in_out, opts)->Synchronize(); + pg->AllReduce(&predicted_logits, predicted_logits, opts, true, true); // step 4, obtain exp(logit) eigen_softmax.device(*dev_ctx.eigen_device()) = eigen_softmax.exp(); @@ -403,15 +400,11 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor { sum_exp_logits = ctx.AllocateTmpTensor({N, 1}, dev_ctx); void* sum_exp_logits_buff = sum_exp_logits.mutable_data(place); - auto eigen_sum_exp_logits = - phi::funcs::EigenMatrix::From(sum_exp_logits); - eigen_sum_exp_logits.device(*dev_ctx.eigen_device()) = - eigen_softmax.sum(along_axis); + phi::SumKernel( + dev_ctx, softmax_2d, {-1}, softmax_2d.dtype(), true, &sum_exp_logits); - in_out.clear(); - in_out.push_back(sum_exp_logits); opts.reduce_op = distributed::ReduceOp::SUM; - pg->AllReduce(in_out, in_out, opts)->Synchronize(); + pg->AllReduce(&sum_exp_logits, sum_exp_logits, opts, true, true); if (label_type == framework::proto::VarType::INT32) { CaculateLoss @@ -431,6 +424,8 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor { N); } + auto eigen_sum_exp_logits = + phi::funcs::EigenMatrix::From(sum_exp_logits); eigen_softmax.device(*dev_ctx.eigen_device()) = (eigen_softmax * eigen_sum_exp_logits.inverse().broadcast(one_by_class));