未验证 提交 be55bac3 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] enable check_nan_inf (#36802)

* enable check_nan_inf and fix variable scope

* add ut

* fix bug

* update ut

* revert doc change

* fix npu compile
上级 82fb63eb
...@@ -117,7 +117,7 @@ cc_test(reader_test SRCS reader_test.cc DEPS reader) ...@@ -117,7 +117,7 @@ cc_test(reader_test SRCS reader_test.cc DEPS reader)
cc_library(threadpool SRCS threadpool.cc DEPS enforce) cc_library(threadpool SRCS threadpool.cc DEPS enforce)
cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool) cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
cc_library(var_type_traits SRCS var_type_traits.cc DEPS lod_tensor selected_rows framework_proto) cc_library(var_type_traits SRCS var_type_traits.cc DEPS lod_tensor selected_rows framework_proto scope)
if (WITH_GPU) if (WITH_GPU)
target_link_libraries(var_type_traits dynload_cuda) target_link_libraries(var_type_traits dynload_cuda)
endif() endif()
......
...@@ -27,7 +27,7 @@ namespace framework { ...@@ -27,7 +27,7 @@ namespace framework {
namespace details { namespace details {
// assert false when meets NAN or inf // assert false when meets NAN or inf
void CheckVarHasNanOrInf(const std::string& op_type, void CheckVarHasNanOrInf(const std::string& op_type,
const framework::Scope& scope, const framework::ScopeBase& scope,
const std::string& var_name, const std::string& var_name,
const platform::Place& place); const platform::Place& place);
...@@ -37,7 +37,7 @@ void CheckVarHasNanOrInf(const std::string& op_type, ...@@ -37,7 +37,7 @@ void CheckVarHasNanOrInf(const std::string& op_type,
const platform::Place& place); const platform::Place& place);
void CheckOpHasNanOrInf(const framework::OperatorBase& op, void CheckOpHasNanOrInf(const framework::OperatorBase& op,
const framework::Scope& scope, const framework::ScopeBase& scope,
const platform::Place& place); const platform::Place& place);
template <typename VarType> template <typename VarType>
...@@ -55,7 +55,7 @@ void CheckOpHasNanOrInfInDygraph(const std::string& op_type, ...@@ -55,7 +55,7 @@ void CheckOpHasNanOrInfInDygraph(const std::string& op_type,
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
void NPUAllocAndClearFloatStatus(const framework::OperatorBase& op, void NPUAllocAndClearFloatStatus(const framework::OperatorBase& op,
const framework::Scope& scope, const framework::ScopeBase& scope,
const platform::Place& place); const platform::Place& place);
#endif #endif
......
...@@ -407,7 +407,7 @@ void CheckVarHasNanOrInf(const std::string& op_type, ...@@ -407,7 +407,7 @@ void CheckVarHasNanOrInf(const std::string& op_type,
} }
void CheckVarHasNanOrInf(const std::string& op_type, void CheckVarHasNanOrInf(const std::string& op_type,
const framework::Scope& scope, const framework::ScopeBase& scope,
const std::string& var_name, const std::string& var_name,
const platform::Place& place) { const platform::Place& place) {
auto* var = scope.FindVar(var_name); auto* var = scope.FindVar(var_name);
...@@ -440,7 +440,7 @@ static framework::Tensor& npu_float_status() { ...@@ -440,7 +440,7 @@ static framework::Tensor& npu_float_status() {
} }
void NPUAllocAndClearFloatStatus(const framework::OperatorBase& op, void NPUAllocAndClearFloatStatus(const framework::OperatorBase& op,
const framework::Scope& scope, const framework::ScopeBase& scope,
const platform::Place& place) { const platform::Place& place) {
if (!platform::is_npu_place(place)) return; if (!platform::is_npu_place(place)) return;
...@@ -505,7 +505,7 @@ void PrintNpuVarInfo(const std::string& op_type, const std::string& var_name, ...@@ -505,7 +505,7 @@ void PrintNpuVarInfo(const std::string& op_type, const std::string& var_name,
} }
void PrintNPUOpValueInfo(const framework::OperatorBase& op, void PrintNPUOpValueInfo(const framework::OperatorBase& op,
const framework::Scope& scope, const framework::ScopeBase& scope,
const platform::Place& place) { const platform::Place& place) {
LOG(WARNING) << "There are `nan` or `inf` in operator (" << op.Type() LOG(WARNING) << "There are `nan` or `inf` in operator (" << op.Type()
<< "), here we print some tensor value info of this op."; << "), here we print some tensor value info of this op.";
...@@ -523,7 +523,7 @@ void PrintNPUOpValueInfo(const framework::OperatorBase& op, ...@@ -523,7 +523,7 @@ void PrintNPUOpValueInfo(const framework::OperatorBase& op,
} }
static void NPUCheckOpHasNanOrInf(const framework::OperatorBase& op, static void NPUCheckOpHasNanOrInf(const framework::OperatorBase& op,
const framework::Scope& scope, const framework::ScopeBase& scope,
const platform::Place& place) { const platform::Place& place) {
if (!platform::is_npu_place(place)) return; if (!platform::is_npu_place(place)) return;
...@@ -551,14 +551,13 @@ static void NPUCheckOpHasNanOrInf(const framework::OperatorBase& op, ...@@ -551,14 +551,13 @@ static void NPUCheckOpHasNanOrInf(const framework::OperatorBase& op,
if (sum >= 1.0) PrintNPUOpValueInfo(op, scope, place); if (sum >= 1.0) PrintNPUOpValueInfo(op, scope, place);
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(sum, 1.0, platform::errors::PreconditionNotMet(
sum, 1.0, platform::errors::PreconditionNotMet( "Operator %s contains Nan/Inf.", op.Type()));
"Operator %s contains Nan/Inf.", op.DebugStringEx(&scope)));
} }
#endif #endif
void CheckOpHasNanOrInf(const framework::OperatorBase& op, void CheckOpHasNanOrInf(const framework::OperatorBase& op,
const framework::Scope& exec_scope, const framework::ScopeBase& exec_scope,
const platform::Place& place) { const platform::Place& place) {
std::call_once(white_list_init_flag, InitWhiteListFormEnv); std::call_once(white_list_init_flag, InitWhiteListFormEnv);
......
set(INTERPRETERCORE_DEPS op_registry device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog set(INTERPRETERCORE_DEPS op_registry device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor) graph_to_program_pass variable_helper timer monitor nan_inf_utils)
cc_library(workqueue SRCS workqueue.cc workqueue_utils.cc DEPS enforce) cc_library(workqueue SRCS workqueue.cc workqueue_utils.cc DEPS enforce)
cc_library(interpretercore_garbage_collector SRCS interpretercore_garbage_collector.cc DEPS workqueue ${DEVICE_EVENT_LIBS}) cc_library(interpretercore_garbage_collector SRCS interpretercore_garbage_collector.cc DEPS workqueue ${DEVICE_EVENT_LIBS})
......
...@@ -17,12 +17,15 @@ ...@@ -17,12 +17,15 @@
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_functor.h" #include "paddle/fluid/framework/details/share_tensor_buffer_functor.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace, true, PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace, true,
"Use inplace in new executor"); "Use inplace in new executor");
DECLARE_bool(check_nan_inf);
constexpr const char* kExceptionCaught = "ExceptionCaught"; constexpr const char* kExceptionCaught = "ExceptionCaught";
namespace paddle { namespace paddle {
...@@ -80,7 +83,6 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -80,7 +83,6 @@ paddle::framework::FetchList InterpreterCore::Run(
auto FeedInput = [&] { auto FeedInput = [&] {
for (size_t i = 0; i < feed_names_.size(); ++i) { for (size_t i = 0; i < feed_names_.size(); ++i) {
auto* feed_var = global_scope_->Var(feed_names_[i]); auto* feed_var = global_scope_->Var(feed_names_[i]);
auto feed_tensor = feed_var->GetMutable<framework::LoDTensor>(); auto feed_tensor = feed_var->GetMutable<framework::LoDTensor>();
feed_tensor->ShareDataWith(feed_tensors[i]); feed_tensor->ShareDataWith(feed_tensors[i]);
} }
...@@ -246,10 +248,10 @@ void InterpreterCore::BuildInplace() { ...@@ -246,10 +248,10 @@ void InterpreterCore::BuildInplace() {
auto outvar = global_scope_->Var(iterout->second[0]); auto outvar = global_scope_->Var(iterout->second[0]);
if (invar && outvar) { if (invar && outvar) {
instr.AddInplace(invar, outvar); instr.AddInplace(invar, outvar);
VLOG(3) << "inplace " << op_base->Type() << " " VLOG(3) << "inplace " << vec_instruction_[i].OpBase()->Type()
<< global_scope_->VarDesc(iter->second[0])->Name() << " " << global_scope_->GetNameById(iter->second[0])
<< " -> " << " -> "
<< global_scope_->VarDesc(iterout->second[0])->Name() << global_scope_->GetNameById(iterout->second[0])
<< std::endl; << std::endl;
} }
} }
...@@ -330,6 +332,14 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -330,6 +332,14 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
platform::RecordEvent compute_event("Compute"); platform::RecordEvent compute_event("Compute");
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get()); instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
} }
// for debug nan/inf
if (FLAGS_check_nan_inf) {
VLOG(4) << "Check nan/inf";
framework::details::CheckOpHasNanOrInf(
*instr_node.OpBase(), *global_scope_,
instr_node.DeviceContext().GetPlace());
}
} }
void InterpreterCore::ExecuteInstructionList( void InterpreterCore::ExecuteInstructionList(
......
...@@ -471,44 +471,73 @@ struct VariableMetaInfo { ...@@ -471,44 +471,73 @@ struct VariableMetaInfo {
paddle::framework::VarDesc* vardesc_; paddle::framework::VarDesc* vardesc_;
}; };
// TODO(Aurelius84): Consider inherit ScopeBase to unify interface. // TODO(zhiqiu): Maybe we need to add rwlock for VariableScope?
class VariableScope { class VariableScope : public ScopeBase {
public: public:
Variable* FindVar(const std::string& name) const { Variable* FindVar(const std::string& name) const {
if (!HasVar(name)) { auto it = name2id_.find(name);
return nullptr; if (it != name2id_.end()) {
PADDLE_ENFORCE_LT(it->second, var_list_.size(),
platform::errors::NotFound(
"The id(%d) of variable(%s) should not be larger "
"than the size of variable list(%d).",
it->second, name, var_list_.size()));
return var_list_[it->second];
} }
auto var_id = VarId(name); return nullptr;
CheckExist(var_id); }
return var_list[var_id];
// Get variable id by name, return -1 if not found
int GetIdByName(const std::string& name) const {
auto it = name2id_.find(name);
if (it != name2id_.end()) {
return it->second;
}
return -1;
}
// Get variable name by id, return "" if not found
std::string GetNameById(int id) const {
// NOTE(zhiqiu): do not use vec_meta_info_[id].vardesc_->Name() since
// vec_meta_info_[id] may be nullptr,
// typically when the target variable is not existed in the original program
// desc, but created by interpretercore.
// For example, created and used by d2h_copy or h2d_copy operator.
auto it =
std::find_if(name2id_.begin(), name2id_.end(),
[id](const auto& pair) { return pair.second == id; });
if (it != name2id_.end()) {
return it->first;
}
return "";
} }
bool HasVar(const std::string& name) const { bool HasVar(const std::string& name) const {
return name2id.find(name) != name2id.end(); return name2id_.find(name) != name2id_.end();
} }
int VarId(const std::string& name) const { int VarId(const std::string& name) const {
CheckExist(name); CheckExist(name);
return name2id.at(name); return name2id_.at(name);
} }
Variable* Var(int id) const { return var_list.at(id); } Variable* Var(int id) const { return var_list_.at(id); }
Variable* Var(const std::string& name) const { Variable* Var(const std::string& name) const {
return var_list.at(VarId(name)); return var_list_.at(VarId(name));
} }
size_t VarSize() const { return var_list.size(); } size_t VarSize() const { return var_list_.size(); }
void AddVar(const std::string& name, VarDesc* var_desc) { // NOLINT void AddVar(const std::string& name, VarDesc* var_desc) { // NOLINT
name2id[name] = VarSize(); name2id_[name] = VarSize();
auto v = new Variable(); auto v = new Variable();
if (nullptr == var_desc) { if (nullptr == var_desc) {
v->GetMutable<LoDTensor>(); v->GetMutable<LoDTensor>();
} else { } else {
InitializeVariable(v, var_desc->GetType()); InitializeVariable(v, var_desc->GetType());
} }
var_list.push_back(v); var_list_.push_back(v);
VariableMetaInfo info; VariableMetaInfo info;
info.var_ref_count_ = 0; info.var_ref_count_ = 0;
...@@ -517,8 +546,8 @@ class VariableScope { ...@@ -517,8 +546,8 @@ class VariableScope {
} }
void AddVar(const std::string& name, Variable& var) { // NOLINT void AddVar(const std::string& name, Variable& var) { // NOLINT
name2id[name] = VarSize(); name2id_[name] = VarSize();
var_list.push_back(&var); var_list_.push_back(&var);
VariableMetaInfo info; VariableMetaInfo info;
info.var_ref_count_ = 0; info.var_ref_count_ = 0;
...@@ -540,10 +569,10 @@ class VariableScope { ...@@ -540,10 +569,10 @@ class VariableScope {
} }
void CheckExist(int id) const { void CheckExist(int id) const {
PADDLE_ENFORCE_LT(id, var_list.size(), PADDLE_ENFORCE_LT(id, var_list_.size(),
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Required var_id < %d, but received var_id = %d.", "Required var_id < %d, but received var_id = %d.",
var_list.size(), id)); var_list_.size(), id));
} }
void CheckExist(const std::string& name) const { void CheckExist(const std::string& name) const {
...@@ -553,8 +582,8 @@ class VariableScope { ...@@ -553,8 +582,8 @@ class VariableScope {
} }
private: private:
std::vector<Variable*> var_list; std::vector<Variable*> var_list_;
std::map<std::string, int> name2id; std::map<std::string, int> name2id_;
std::vector<VariableMetaInfo> vec_meta_info_; std::vector<VariableMetaInfo> vec_meta_info_;
}; };
......
...@@ -39,6 +39,16 @@ class Variable; ...@@ -39,6 +39,16 @@ class Variable;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// TODO(zhiqiu): add more function in base class
class ScopeBase {
public:
/// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find.
/// Caller doesn't own the returned Variable.
virtual Variable* FindVar(const std::string& name) const = 0;
virtual ~ScopeBase() {}
};
class Scope; class Scope;
/** /**
...@@ -49,7 +59,7 @@ class Scope; ...@@ -49,7 +59,7 @@ class Scope;
* One net can run in different scopes and update different variable in the * One net can run in different scopes and update different variable in the
* scope. * scope.
*/ */
class Scope { class Scope : public ScopeBase {
public: public:
Scope() {} Scope() {}
~Scope(); ~Scope();
......
...@@ -69,6 +69,7 @@ class BKCLCommunicator; ...@@ -69,6 +69,7 @@ class BKCLCommunicator;
namespace framework { namespace framework {
class LoDRankTable; class LoDRankTable;
class ScopeBase;
class LoDTensor; class LoDTensor;
class ReaderHolder; class ReaderHolder;
class Scope; class Scope;
......
...@@ -256,10 +256,12 @@ class TestException(unittest.TestCase): ...@@ -256,10 +256,12 @@ class TestException(unittest.TestCase):
main_program = paddle.static.Program() main_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
w = paddle.rand([10, 20]) w = paddle.rand([10, 3])
ids = paddle.static.data(name="id", shape=[5], dtype='int64') ids = paddle.static.data(name="id", shape=[5], dtype='int64')
data = paddle.static.data(name="data", shape=[3], dtype='float32')
emb = paddle.nn.functional.embedding( emb = paddle.nn.functional.embedding(
x=ids, weight=w, sparse=False, name="embedding") x=ids, weight=w, sparse=False, name="embedding")
emb = emb + data
return main_program, startup_program, emb return main_program, startup_program, emb
...@@ -273,7 +275,7 @@ class TestException(unittest.TestCase): ...@@ -273,7 +275,7 @@ class TestException(unittest.TestCase):
for feed in feeds: for feed in feeds:
out = exe.run(main_program, feed=feed, fetch_list=fetch_vars) out = exe.run(main_program, feed=feed, fetch_list=fetch_vars)
print(out)
return out return out
def run_new_executor(self, feed): def run_new_executor(self, feed):
...@@ -284,12 +286,27 @@ class TestException(unittest.TestCase): ...@@ -284,12 +286,27 @@ class TestException(unittest.TestCase):
def test_exception(self): def test_exception(self):
feed = [{ feed = [{
'id': np.array([1, 2, 3, 4, 5]).astype(np.int64) 'id': np.array([1, 2, 3, 4, 5]).astype(np.int64),
'data': np.array([1, 2, 3, 4]).astype(np.float32),
}, { }, {
'id': np.array([1, 2, 3, 4, 11]).astype(np.int64) 'id': np.array([1, 2, 3, 4, 11]).astype(np.int64),
'data': np.array([1, 2, 3, 4]).astype(np.float32),
}] }]
self.assertRaises(ValueError, self.run_new_executor, feed) self.assertRaises(ValueError, self.run_new_executor, feed)
def test_nan(self):
flags = {'FLAGS_check_nan_inf': True}
paddle.fluid.set_flags(flags)
feed = [{
'id': np.array([1, 2, 3, 4, 5]).astype(np.int64),
'data': np.array([1, 2, 3]).astype(np.float32),
}, {
'id': np.array([1, 2, 3, 4, 5]).astype(np.int64),
'data': np.array([1, 2, 3]).astype(np.float32),
}]
feed[1]['data'][0] = np.nan
self.assertRaises(RuntimeError, self.run_new_executor, feed)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册