未验证 提交 7d1332f6 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #11006 from chengduoZH/fix_add_check_nan_inf_in_operator

Move check_nan_inf to operator
...@@ -24,9 +24,6 @@ limitations under the License. */ ...@@ -24,9 +24,6 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
DECLARE_bool(benchmark); DECLARE_bool(benchmark);
DEFINE_bool(check_nan_inf, false,
"Checking whether operator produce NAN/INF or not. It will be "
"extremely slow so please use this flag wisely.");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -78,21 +75,6 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) { ...@@ -78,21 +75,6 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
} }
} }
static void CheckTensorNANOrInf(const std::string& name,
const framework::Tensor& tensor) {
if (tensor.memory_size() == 0) {
return;
}
if (tensor.type().hash_code() != typeid(float).hash_code() && // NOLINT
tensor.type().hash_code() != typeid(double).hash_code()) { // NOLINT
return;
}
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
"Tensor %s contains Inf", name);
PADDLE_ENFORCE(!framework::TensorContainsNAN(tensor),
"Tensor %s contains NAN", name);
}
void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
int block_id) { int block_id) {
auto& global_block = pdesc.Block(block_id); auto& global_block = pdesc.Block(block_id);
...@@ -340,15 +322,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -340,15 +322,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
VLOG(2) << "Memory used after operator " + op->Type() + " running: " VLOG(2) << "Memory used after operator " + op->Type() + " running: "
<< memory::memory_usage(place_); << memory::memory_usage(place_);
} }
if (FLAGS_check_nan_inf) {
for (auto& vname : op->OutputVars(true)) {
auto* var = local_scope->FindVar(vname);
if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) {
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
}
}
}
} }
platform::DeviceContextPool::Instance().Get(place_)->Wait(); platform::DeviceContextPool::Instance().Get(place_)->Wait();
if (create_vars && create_local_scope) { if (create_vars && create_local_scope) {
......
...@@ -24,6 +24,9 @@ limitations under the License. */ ...@@ -24,6 +24,9 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
DECLARE_bool(benchmark); DECLARE_bool(benchmark);
DEFINE_bool(check_nan_inf, false,
"Checking whether operator produce NAN/INF or not. It will be "
"extremely slow so please use this flag wisely.");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -513,6 +516,21 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -513,6 +516,21 @@ class RuntimeInferShapeContext : public InferShapeContext {
const Scope& scope_; const Scope& scope_;
}; };
static void CheckTensorNANOrInf(const std::string& name,
const framework::Tensor& tensor) {
if (tensor.memory_size() == 0) {
return;
}
if (tensor.type().hash_code() != typeid(float).hash_code() && // NOLINT
tensor.type().hash_code() != typeid(double).hash_code()) { // NOLINT
return;
}
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
"Tensor %s contains Inf", name);
PADDLE_ENFORCE(!framework::TensorContainsNAN(tensor),
"Tensor %s contains NAN", name);
}
void OperatorWithKernel::RunImpl(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope); RuntimeInferShapeContext infer_shape_ctx(*this, scope);
...@@ -597,6 +615,16 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -597,6 +615,16 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
new_dev_ctx->Wait(); new_dev_ctx->Wait();
} }
if (FLAGS_check_nan_inf) {
for (auto& vname : OutputVars(true)) {
auto* var = new_scope.FindVar(vname);
if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) {
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
}
}
}
} }
proto::VarType::Type OperatorWithKernel::IndicateDataType( proto::VarType::Type OperatorWithKernel::IndicateDataType(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册