From 3ae781eb2bc139a946b7f195183e31304af49822 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Wed, 27 Dec 2017 15:45:13 +0800 Subject: [PATCH] Executor check nan --- paddle/framework/executor.cc | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index 997773c1689..9ee2ddb7c3e 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -14,18 +14,17 @@ limitations under the License. */ #include "paddle/framework/executor.h" -#include -#include -#include #include -#include +#include "gflags/gflags.h" #include "paddle/framework/feed_fetch_type.h" #include "paddle/framework/lod_rank_table.h" -#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor_array.h" #include "paddle/framework/op_registry.h" -#include "paddle/framework/scope.h" + +DEFINE_bool(check_nan_inf, false, + "Checking whether operator produce NAN/INF or not. It will be " + "extremely slow so please use this flag wisely."); namespace paddle { namespace framework { @@ -58,6 +57,19 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { } } +static void CheckTensorNANOrInf(const std::string& name, + const framework::Tensor& tensor) { + if (tensor.type().hash_code() != typeid(float).hash_code() && + tensor.type().hash_code() != typeid(double).hash_code()) { + return; + } + if (tensor.memory_size() == 0) { + return; + } + PADDLE_ENFORCE(!framework::HasInf(tensor), "Tensor %s has Inf", name); + PADDLE_ENFORCE(!framework::HasNAN(tensor), "Tensor %s has NAN", name); +} + void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, bool create_local_scope, bool create_vars) { // TODO(tonyyang-svail): @@ -101,6 +113,15 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); VLOG(3) << op->DebugString(); op->Run(*local_scope, place_); + if (FLAGS_check_nan_inf) { + for (auto& vname : op->OutputVars(true)) { + auto* var = local_scope->FindVar(vname); + if (var == nullptr) continue; + if (var->IsType()) { + CheckTensorNANOrInf(vname, var->Get()); + } + } + } } if (create_local_scope) { scope->DeleteScope(local_scope); -- GitLab