未验证 提交 af886995 编写于 作者: G gongweibao 提交者: GitHub

Revert "[NPU] refine nan check (#34508)" (#34530)

上级 c7cc5ac2
...@@ -123,30 +123,32 @@ class CAllReduceOpCPUKernel : public framework::OpKernel<T> { ...@@ -123,30 +123,32 @@ class CAllReduceOpCPUKernel : public framework::OpKernel<T> {
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
// return true if found_inf_or_nan or return false; // return true if found_inf_or_nan or return false;
template <typename T> template <typename T>
bool ContainsNan(const framework::ExecutionContext& exe_ctx, aclrtStream stream, bool CheckNumerics(const framework::ExecutionContext& exe_ctx,
const paddle::framework::Tensor* in) { aclrtStream stream, const paddle::framework::Tensor* in) {
auto& dev_ctx = auto& dev_ctx =
exe_ctx.template device_context<paddle::platform::NPUDeviceContext>(); exe_ctx.template device_context<paddle::platform::NPUDeviceContext>();
using Tensor = paddle::framework::Tensor; using Tensor = paddle::framework::Tensor;
Tensor out(in->type()); Tensor out(in->type());
out.Resize(in->dims());
Tensor mean(in->type()); out.mutable_data<T>(dev_ctx.GetPlace());
mean.Resize({1});
mean.mutable_data<T>(dev_ctx.GetPlace()); bool found_inf_data = false;
std::vector<int> axes;
for (int i = 0; i < in->dims().size(); ++i) { try {
axes.push_back(i); const auto& runner =
NpuOpRunner("CheckNumerics", {*in}, {out},
{{"message", std::string("check_numberics")}});
runner.Run(stream);
dev_ctx.Wait();
} catch (platform::EnforceNotMet& exception) {
LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!";
found_inf_data = true;
} catch (...) {
LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!";
found_inf_data = true;
} }
const auto& runner_mean = NpuOpRunner("ReduceMeanD", {*in}, {mean},
{{"axes", axes}, {"keep_dims", false}});
std::vector<T> vec;
TensorToVector(mean, exe_ctx.device_context(), &vec);
if (std::isnan(static_cast<float>(vec[0]))) { return found_inf_data;
return true;
}
return false;
} }
#endif #endif
...@@ -214,22 +216,22 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> { ...@@ -214,22 +216,22 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
framework::Tensor tmp; framework::Tensor tmp;
tmp.mutable_data<float>({8}, ctx.GetPlace()); tmp.mutable_data<float>({8}, ctx.GetPlace());
bool has_nan = false; bool check_numerics = false;
auto d_type = in->type(); auto d_type = in->type();
switch (d_type) { switch (d_type) {
case framework::proto::VarType::FP16: case framework::proto::VarType::FP16:
case framework::proto::VarType::FP32: { case framework::proto::VarType::FP32: {
VLOG(4) << "prepare to check nan"; VLOG(4) << "prepare to FoundNanInf";
has_nan = ContainsNan<T>(ctx, dev_ctx->stream(), in); check_numerics = CheckNumerics<T>(ctx, dev_ctx->stream(), in);
VLOG(4) << "ContainsNan:" << has_nan; VLOG(4) << "check_numerics:" << check_numerics;
break; break;
} }
default: default:
break; break;
} }
if (has_nan) { if (check_numerics) {
T inf = static_cast<T>(std::numeric_limits<float>::infinity()); T inf = static_cast<T>(std::numeric_limits<float>::infinity());
VLOG(4) << "fill input data constant inf"; VLOG(4) << "fill input data constant inf";
auto dims = in->dims(); auto dims = in->dims();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册