未验证 提交 6d5a04c1 编写于 作者: T tangwei12 提交者: GitHub

add op type in check nan/inf (#15986)

* add op name in check nan/inf, test=develop
上级 187cffd0
...@@ -882,7 +882,8 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -882,7 +882,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
const RuntimeContext& ctx_; const RuntimeContext& ctx_;
}; };
static void CheckTensorNANOrInf(const std::string& name, static void CheckTensorNANOrInf(const std::string& op_type,
const std::string& name,
const framework::Tensor& tensor) { const framework::Tensor& tensor) {
if (tensor.memory_size() == 0) { if (tensor.memory_size() == 0) {
return; return;
...@@ -892,9 +893,9 @@ static void CheckTensorNANOrInf(const std::string& name, ...@@ -892,9 +893,9 @@ static void CheckTensorNANOrInf(const std::string& name,
return; return;
} }
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor), PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
"Tensor %s contains Inf", name); "Operator %s output Tensor %s contains Inf", op_type, name);
PADDLE_ENFORCE(!framework::TensorContainsNAN(tensor), PADDLE_ENFORCE(!framework::TensorContainsNAN(tensor),
"Tensor %s contains NAN", name); "Operator %s output Tensor %s contains NAN", op_type, name);
} }
void OperatorWithKernel::RuntimeInferShape(const Scope& scope, void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
...@@ -988,9 +989,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -988,9 +989,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
auto* var = exec_scope.FindVar(vname); auto* var = exec_scope.FindVar(vname);
if (var == nullptr) continue; if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>()); CheckTensorNANOrInf(type_, vname, var->Get<framework::LoDTensor>());
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<framework::SelectedRows>()) {
CheckTensorNANOrInf(vname, var->Get<framework::SelectedRows>().value()); CheckTensorNANOrInf(type_, vname,
var->Get<framework::SelectedRows>().value());
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册