未验证 提交 59bcdc8a 编写于 作者: T Thunderbrook 提交者: GitHub

support dump param of model into afs (#20302)

* support dump param to afs
test=develop

* code style
test=develop

* code style
test=develop

* dump param
test=develop

* dump param
test=develop

* dump param
test=develop

* dump param
test=develop
上级 768551b2
...@@ -194,8 +194,11 @@ class DownpourWorker : public HogwildWorker { ...@@ -194,8 +194,11 @@ class DownpourWorker : public HogwildWorker {
void PushGradients(); void PushGradients();
void CollectLabelInfo(size_t table_id); void CollectLabelInfo(size_t table_id);
void AdjustInsWeight(); void AdjustInsWeight();
void DumpParam();
private: private:
bool need_dump_param_;
std::vector<std::string> dump_param_;
bool need_to_push_dense_; bool need_to_push_dense_;
bool need_dump_field_; bool need_dump_field_;
bool dump_slot_; bool dump_slot_;
......
...@@ -82,6 +82,14 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { ...@@ -82,6 +82,14 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
dump_fields_[i] = desc.dump_fields(i); dump_fields_[i] = desc.dump_fields(i);
} }
adjust_ins_weight_config_ = desc.adjust_ins_weight_config(); adjust_ins_weight_config_ = desc.adjust_ins_weight_config();
need_dump_param_ = false;
dump_param_.resize(desc.dump_param_size());
for (int i = 0; i < desc.dump_param_size(); ++i) {
dump_param_[i] = desc.dump_param(i);
}
if (desc.dump_param_size() != 0) {
need_dump_param_ = true;
}
for (int i = 0; i < desc.check_nan_var_names_size(); ++i) { for (int i = 0; i < desc.check_nan_var_names_size(); ++i) {
check_nan_var_names_.push_back(desc.check_nan_var_names(i)); check_nan_var_names_.push_back(desc.check_nan_var_names(i));
} }
...@@ -163,6 +171,22 @@ bool CheckValidOutput(LoDTensor* tensor, int batch_size) { ...@@ -163,6 +171,22 @@ bool CheckValidOutput(LoDTensor* tensor, int batch_size) {
return true; return true;
} }
void DownpourWorker::DumpParam() {
std::string os;
for (auto& param : dump_param_) {
os.clear();
os = param;
Variable* var = thread_scope_->FindVar(param);
if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t len = tensor->numel();
os += PrintLodTensor(tensor, 0, len);
writer_ << os;
}
}
void DownpourWorker::CollectLabelInfo(size_t table_idx) { void DownpourWorker::CollectLabelInfo(size_t table_idx) {
uint64_t table_id = static_cast<uint64_t>( uint64_t table_id = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(table_idx)); param_.program_config(0).pull_sparse_table_id(table_idx));
...@@ -814,6 +838,9 @@ void DownpourWorker::TrainFiles() { ...@@ -814,6 +838,9 @@ void DownpourWorker::TrainFiles() {
} }
writer_ << ars[i]; writer_ << ars[i];
} }
if (need_dump_param_ && thread_id_ == 0) {
DumpParam();
}
} }
PrintFetchVars(); PrintFetchVars();
......
...@@ -105,7 +105,6 @@ class DistMultiTrainer : public MultiTrainer { ...@@ -105,7 +105,6 @@ class DistMultiTrainer : public MultiTrainer {
bool need_dump_field_; bool need_dump_field_;
std::string dump_fields_path_; std::string dump_fields_path_;
std::string dump_converter_; std::string dump_converter_;
std::vector<std::string> dump_fields_;
int mpi_rank_; int mpi_rank_;
int mpi_size_; int mpi_size_;
int dump_file_num_; int dump_file_num_;
......
...@@ -39,6 +39,7 @@ message TrainerDesc { ...@@ -39,6 +39,7 @@ message TrainerDesc {
optional string dump_fields_path = 12; optional string dump_fields_path = 12;
repeated string dump_fields = 13; repeated string dump_fields = 13;
optional string dump_converter = 14; optional string dump_converter = 14;
repeated string dump_param = 15;
optional int32 mpi_size = 16 [ default = -1 ]; optional int32 mpi_size = 16 [ default = -1 ];
optional int32 dump_file_num = 17 [ default = 16 ]; optional int32 dump_file_num = 17 [ default = 16 ];
......
...@@ -358,6 +358,7 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -358,6 +358,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["dump_fields"] = strategy.get("dump_fields", []) opt_info["dump_fields"] = strategy.get("dump_fields", [])
opt_info["dump_file_num"] = strategy.get("dump_file_num", 16) opt_info["dump_file_num"] = strategy.get("dump_file_num", 16)
opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "") opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "")
opt_info["dump_param"] = strategy.get("dump_param", [])
if server._server.downpour_server_param.downpour_table_param[ if server._server.downpour_server_param.downpour_table_param[
0].accessor.accessor_class == "DownpourCtrAccessor": 0].accessor.accessor_class == "DownpourCtrAccessor":
opt_info["dump_slot"] = True opt_info["dump_slot"] = True
......
...@@ -100,6 +100,10 @@ class TrainerDesc(object): ...@@ -100,6 +100,10 @@ class TrainerDesc(object):
def _set_dump_converter(self, converter): def _set_dump_converter(self, converter):
self.proto_desc.dump_converter = converter self.proto_desc.dump_converter = converter
def _set_dump_param(self, dump_param):
for param in dump_param:
self.proto_desc.dump_param.append(param)
def _set_check_nan_var_names(self, check_nan_var_names): def _set_check_nan_var_names(self, check_nan_var_names):
for var in check_nan_var_names: for var in check_nan_var_names:
self.proto_desc.check_nan_var_names.append(var) self.proto_desc.check_nan_var_names.append(var)
......
...@@ -53,6 +53,7 @@ class TrainerFactory(object): ...@@ -53,6 +53,7 @@ class TrainerFactory(object):
trainer._set_dump_file_num(opt_info["dump_file_num"]) trainer._set_dump_file_num(opt_info["dump_file_num"])
trainer._set_dump_converter(opt_info["dump_converter"]) trainer._set_dump_converter(opt_info["dump_converter"])
trainer._set_adjust_ins_weight(opt_info["adjust_ins_weight"]) trainer._set_adjust_ins_weight(opt_info["adjust_ins_weight"])
trainer._set_dump_param(opt_info["dump_param"])
trainer._set_check_nan_var_names(opt_info[ trainer._set_check_nan_var_names(opt_info[
"check_nan_var_names"]) "check_nan_var_names"])
trainer._set_device_worker(device_worker) trainer._set_device_worker(device_worker)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册