未验证 提交 7d1332f6 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #11006 from chengduoZH/fix_add_check_nan_inf_in_operator

Move check_nan_inf to operator
......@@ -24,9 +24,6 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h"
DECLARE_bool(benchmark);
DEFINE_bool(check_nan_inf, false,
"Checking whether operator produce NAN/INF or not. It will be "
"extremely slow so please use this flag wisely.");
namespace paddle {
namespace framework {
......@@ -78,21 +75,6 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
}
}
static void CheckTensorNANOrInf(const std::string& name,
const framework::Tensor& tensor) {
if (tensor.memory_size() == 0) {
return;
}
if (tensor.type().hash_code() != typeid(float).hash_code() && // NOLINT
tensor.type().hash_code() != typeid(double).hash_code()) { // NOLINT
return;
}
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
"Tensor %s contains Inf", name);
PADDLE_ENFORCE(!framework::TensorContainsNAN(tensor),
"Tensor %s contains NAN", name);
}
void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
int block_id) {
auto& global_block = pdesc.Block(block_id);
......@@ -340,15 +322,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
<< memory::memory_usage(place_);
}
if (FLAGS_check_nan_inf) {
for (auto& vname : op->OutputVars(true)) {
auto* var = local_scope->FindVar(vname);
if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) {
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
}
}
}
}
platform::DeviceContextPool::Instance().Get(place_)->Wait();
if (create_vars && create_local_scope) {
......
......@@ -24,6 +24,9 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h"
DECLARE_bool(benchmark);
DEFINE_bool(check_nan_inf, false,
"Checking whether operator produce NAN/INF or not. It will be "
"extremely slow so please use this flag wisely.");
namespace paddle {
namespace framework {
......@@ -513,6 +516,21 @@ class RuntimeInferShapeContext : public InferShapeContext {
const Scope& scope_;
};
static void CheckTensorNANOrInf(const std::string& name,
const framework::Tensor& tensor) {
if (tensor.memory_size() == 0) {
return;
}
if (tensor.type().hash_code() != typeid(float).hash_code() && // NOLINT
tensor.type().hash_code() != typeid(double).hash_code()) { // NOLINT
return;
}
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
"Tensor %s contains Inf", name);
PADDLE_ENFORCE(!framework::TensorContainsNAN(tensor),
"Tensor %s contains NAN", name);
}
void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
......@@ -597,6 +615,16 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
if (FLAGS_benchmark) {
new_dev_ctx->Wait();
}
if (FLAGS_check_nan_inf) {
for (auto& vname : OutputVars(true)) {
auto* var = new_scope.FindVar(vname);
if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) {
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
}
}
}
}
proto::VarType::Type OperatorWithKernel::IndicateDataType(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册