未验证 提交 5c3656bb 编写于 作者: X xujiaqi01 提交者: GitHub

add check nan / inf in downpour worker (#20694) (#20925)

* add check nan / inf in downpour worker during training
* test=develop
上级 781d2844
...@@ -230,6 +230,8 @@ class DownpourWorker : public HogwildWorker { ...@@ -230,6 +230,8 @@ class DownpourWorker : public HogwildWorker {
// adjust ins weight // adjust ins weight
AdjustInsWeightConfig adjust_ins_weight_config_; AdjustInsWeightConfig adjust_ins_weight_config_;
std::vector<float> nid_show_; std::vector<float> nid_show_;
// check nan and inf during training
std::vector<std::string> check_nan_var_names_;
}; };
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
......
...@@ -81,6 +81,9 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { ...@@ -81,6 +81,9 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
dump_fields_[i] = desc.dump_fields(i); dump_fields_[i] = desc.dump_fields(i);
} }
adjust_ins_weight_config_ = desc.adjust_ins_weight_config(); adjust_ins_weight_config_ = desc.adjust_ins_weight_config();
for (int i = 0; i < desc.check_nan_var_names_size(); ++i) {
check_nan_var_names_.push_back(desc.check_nan_var_names(i));
}
} }
void DownpourWorker::SetChannelWriter(ChannelObject<std::string>* queue) { void DownpourWorker::SetChannelWriter(ChannelObject<std::string>* queue) {
...@@ -468,6 +471,22 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -468,6 +471,22 @@ void DownpourWorker::TrainFilesWithProfiler() {
} }
} }
// check inf and nan
for (std::string& var_name : check_nan_var_names_) {
Variable* var = thread_scope_->FindVar(var_name);
if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>();
if (tensor == nullptr) {
continue;
}
PADDLE_ENFORCE_EQ(framework::TensorContainsInf(*tensor), false,
"Tensor %s contains Inf", var_name);
PADDLE_ENFORCE_EQ(framework::TensorContainsNAN(*tensor), false,
"Tensor %s contains NAN", var_name);
}
if (need_to_push_sparse_) { if (need_to_push_sparse_) {
for (int i = 0; i < param_.program_config(0).push_sparse_table_id_size(); for (int i = 0; i < param_.program_config(0).push_sparse_table_id_size();
++i) { ++i) {
...@@ -655,6 +674,22 @@ void DownpourWorker::TrainFiles() { ...@@ -655,6 +674,22 @@ void DownpourWorker::TrainFiles() {
} }
} }
// check inf and nan
for (std::string& var_name : check_nan_var_names_) {
Variable* var = thread_scope_->FindVar(var_name);
if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>();
if (tensor == nullptr) {
continue;
}
PADDLE_ENFORCE_EQ(framework::TensorContainsInf(*tensor), false,
"Tensor %s contains Inf", var_name);
PADDLE_ENFORCE_EQ(framework::TensorContainsNAN(*tensor), false,
"Tensor %s contains NAN", var_name);
}
if (need_to_push_sparse_) { if (need_to_push_sparse_) {
// push gradients here // push gradients here
for (int i = 0; i < param_.program_config(0).push_sparse_table_id_size(); for (int i = 0; i < param_.program_config(0).push_sparse_table_id_size();
......
...@@ -42,6 +42,7 @@ message TrainerDesc { ...@@ -42,6 +42,7 @@ message TrainerDesc {
optional int32 mpi_size = 16 [ default = -1 ]; optional int32 mpi_size = 16 [ default = -1 ];
optional int32 dump_file_num = 17 [ default = 16 ]; optional int32 dump_file_num = 17 [ default = 16 ];
repeated string check_nan_var_names = 18;
// device worker parameters // device worker parameters
optional HogwildWorkerParameter hogwild_param = 101; optional HogwildWorkerParameter hogwild_param = 101;
......
...@@ -248,6 +248,8 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -248,6 +248,8 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["use_cvm"] = strategy.get("use_cvm", False) opt_info["use_cvm"] = strategy.get("use_cvm", False)
opt_info["stat_var_names"] = strategy.get("stat_var_names", []) opt_info["stat_var_names"] = strategy.get("stat_var_names", [])
opt_info["scale_datanorm"] = strategy.get("scale_datanorm", -1) opt_info["scale_datanorm"] = strategy.get("scale_datanorm", -1)
opt_info["check_nan_var_names"] = strategy.get("check_nan_var_names",
[])
opt_info["dump_slot"] = False opt_info["dump_slot"] = False
opt_info["dump_converter"] = "" opt_info["dump_converter"] = ""
opt_info["dump_fields"] = strategy.get("dump_fields", []) opt_info["dump_fields"] = strategy.get("dump_fields", [])
......
...@@ -100,6 +100,10 @@ class TrainerDesc(object): ...@@ -100,6 +100,10 @@ class TrainerDesc(object):
def _set_dump_converter(self, converter): def _set_dump_converter(self, converter):
self.proto_desc.dump_converter = converter self.proto_desc.dump_converter = converter
def _set_check_nan_var_names(self, check_nan_var_names):
for var in check_nan_var_names:
self.proto_desc.check_nan_var_names.append(var)
def _set_adjust_ins_weight(self, config_dict): def _set_adjust_ins_weight(self, config_dict):
self.proto_desc.adjust_ins_weight_config.need_adjust = \ self.proto_desc.adjust_ins_weight_config.need_adjust = \
config_dict.get("need_adjust", False) config_dict.get("need_adjust", False)
......
...@@ -53,6 +53,8 @@ class TrainerFactory(object): ...@@ -53,6 +53,8 @@ class TrainerFactory(object):
trainer._set_dump_file_num(opt_info["dump_file_num"]) trainer._set_dump_file_num(opt_info["dump_file_num"])
trainer._set_dump_converter(opt_info["dump_converter"]) trainer._set_dump_converter(opt_info["dump_converter"])
trainer._set_adjust_ins_weight(opt_info["adjust_ins_weight"]) trainer._set_adjust_ins_weight(opt_info["adjust_ins_weight"])
trainer._set_check_nan_var_names(opt_info[
"check_nan_var_names"])
trainer._set_device_worker(device_worker) trainer._set_device_worker(device_worker)
return trainer return trainer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册