diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 3c51c65b073904aedaa8b4c6777aaecc7bc223c2..d9a4f0da13c77e9376b7d1497218936cb0e11698 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -123,32 +123,30 @@ class CAllReduceOpCPUKernel : public framework::OpKernel { #if defined(PADDLE_WITH_ASCEND_CL) // return true if found_inf_or_nan or return false; template -bool CheckNumerics(const framework::ExecutionContext& exe_ctx, - aclrtStream stream, const paddle::framework::Tensor* in) { +bool ContainsNan(const framework::ExecutionContext& exe_ctx, aclrtStream stream, + const paddle::framework::Tensor* in) { auto& dev_ctx = exe_ctx.template device_context(); using Tensor = paddle::framework::Tensor; Tensor out(in->type()); - out.Resize(in->dims()); - out.mutable_data(dev_ctx.GetPlace()); - - bool found_inf_data = false; - - try { - const auto& runner = - NpuOpRunner("CheckNumerics", {*in}, {out}, - {{"message", std::string("check_numberics")}}); - runner.Run(stream); - dev_ctx.Wait(); - } catch (platform::EnforceNotMet& exception) { - LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!"; - found_inf_data = true; - } catch (...) { - LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!"; - found_inf_data = true; + + Tensor mean(in->type()); + mean.Resize({1}); + mean.mutable_data(dev_ctx.GetPlace()); + std::vector axes; + for (int i = 0; i < in->dims().size(); ++i) { + axes.push_back(i); } + const auto& runner_mean = NpuOpRunner("ReduceMeanD", {*in}, {mean}, + {{"axes", axes}, {"keep_dims", false}}); + + std::vector vec; + TensorToVector(mean, exe_ctx.device_context(), &vec); - return found_inf_data; + if (std::isnan(static_cast(vec[0]))) { + return true; + } + return false; } #endif @@ -216,22 +214,22 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel { framework::Tensor tmp; tmp.mutable_data({8}, ctx.GetPlace()); - bool check_numerics = false; + bool has_nan = false; auto d_type = in->type(); switch (d_type) { case framework::proto::VarType::FP16: case framework::proto::VarType::FP32: { - VLOG(4) << "prepare to FoundNanInf"; - check_numerics = CheckNumerics(ctx, dev_ctx->stream(), in); - VLOG(4) << "check_numerics:" << check_numerics; + VLOG(4) << "prepare to check nan"; + has_nan = ContainsNan(ctx, dev_ctx->stream(), in); + VLOG(4) << "ContainsNan:" << has_nan; break; } default: break; } - if (check_numerics) { + if (has_nan) { T inf = static_cast(std::numeric_limits::infinity()); VLOG(4) << "fill input data constant inf"; auto dims = in->dims();