From e08d0646c6009438c05408eb5a3b2cd3fc9ba38d Mon Sep 17 00:00:00 2001 From: XiaociZhang Date: Tue, 5 Sep 2023 15:18:42 +0800 Subject: [PATCH] [XPU] remove range check for ignore index (#56869) * [XPU] remove range check for ignore index * add a log --- .../c_softmax_with_cross_entropy_op_xpu.cc | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op_xpu.cc b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op_xpu.cc index fa17b99e79f..ec4f872a7a8 100644 --- a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op_xpu.cc +++ b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op_xpu.cc @@ -34,12 +34,9 @@ class CSoftmaxWithCrossEntropyOp : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { const int64_t ignore_index = ctx.Attr("ignore_index"); - PADDLE_ENFORCE_LT(ignore_index, - 0, - platform::errors::InvalidArgument( - "When SoftmaxWithCrossEntropy run on XPU, " - "ignore_index should be <=0, however it's %ld", - ignore_index)); + if (ignore_index >= 0) { + LOG_FIRST_N(INFO, 1) << "XPU does not support ignore_index in mp."; + } const int rid = ctx.Attr("ring_id"); auto map = distributed::ProcessGroupMapFromGid::getInstance(); if (map->has(rid)) { @@ -467,12 +464,9 @@ class CSoftmaxWithCrossEntropyGrad : public framework::OpKernel { const phi::DenseTensor* softmax = context.Input("Softmax"); const int64_t ignore_index = context.Attr("ignore_index"); - PADDLE_ENFORCE_LT(ignore_index, - 0, - platform::errors::InvalidArgument( - "When SoftmaxWithCrossEntropy run on XPU, " - "ignore_index should be <=0, however it's %ld", - ignore_index)); + if (ignore_index >= 0) { + LOG_FIRST_N(INFO, 1) << "XPU does not support ignore_index in mp."; + } const int rank = context.Attr("rank"); auto& dev_ctx = context.template device_context(); -- GitLab