diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index d9a4f0da13c77e9376b7d1497218936cb0e11698..3c51c65b073904aedaa8b4c6777aaecc7bc223c2 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -123,30 +123,32 @@ class CAllReduceOpCPUKernel : public framework::OpKernel { #if defined(PADDLE_WITH_ASCEND_CL) // return true if found_inf_or_nan or return false; template -bool ContainsNan(const framework::ExecutionContext& exe_ctx, aclrtStream stream, - const paddle::framework::Tensor* in) { +bool CheckNumerics(const framework::ExecutionContext& exe_ctx, + aclrtStream stream, const paddle::framework::Tensor* in) { auto& dev_ctx = exe_ctx.template device_context(); using Tensor = paddle::framework::Tensor; Tensor out(in->type()); - - Tensor mean(in->type()); - mean.Resize({1}); - mean.mutable_data(dev_ctx.GetPlace()); - std::vector axes; - for (int i = 0; i < in->dims().size(); ++i) { - axes.push_back(i); + out.Resize(in->dims()); + out.mutable_data(dev_ctx.GetPlace()); + + bool found_inf_data = false; + + try { + const auto& runner = + NpuOpRunner("CheckNumerics", {*in}, {out}, + {{"message", std::string("check_numberics")}}); + runner.Run(stream); + dev_ctx.Wait(); + } catch (platform::EnforceNotMet& exception) { + LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!"; + found_inf_data = true; + } catch (...) { + LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!"; + found_inf_data = true; } - const auto& runner_mean = NpuOpRunner("ReduceMeanD", {*in}, {mean}, - {{"axes", axes}, {"keep_dims", false}}); - - std::vector vec; - TensorToVector(mean, exe_ctx.device_context(), &vec); - if (std::isnan(static_cast(vec[0]))) { - return true; - } - return false; + return found_inf_data; } #endif @@ -214,22 +216,22 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel { framework::Tensor tmp; tmp.mutable_data({8}, ctx.GetPlace()); - bool has_nan = false; + bool check_numerics = false; auto d_type = in->type(); switch (d_type) { case framework::proto::VarType::FP16: case framework::proto::VarType::FP32: { - VLOG(4) << "prepare to check nan"; - has_nan = ContainsNan(ctx, dev_ctx->stream(), in); - VLOG(4) << "ContainsNan:" << has_nan; + VLOG(4) << "prepare to FoundNanInf"; + check_numerics = CheckNumerics(ctx, dev_ctx->stream(), in); + VLOG(4) << "check_numerics:" << check_numerics; break; } default: break; } - if (has_nan) { + if (check_numerics) { T inf = static_cast(std::numeric_limits::infinity()); VLOG(4) << "fill input data constant inf"; auto dims = in->dims();