未验证 提交 59bcdc8a 编写于 作者: T Thunderbrook 提交者: GitHub

support dump param of model into afs (#20302)

* support dump param to afs
test=develop

* code style
test=develop

* code style
test=develop

* dump param
test=develop

* dump param
test=develop

* dump param
test=develop

* dump param
test=develop
上级 768551b2
......@@ -194,8 +194,11 @@ class DownpourWorker : public HogwildWorker {
void PushGradients();
void CollectLabelInfo(size_t table_id);
void AdjustInsWeight();
void DumpParam();
private:
bool need_dump_param_;
std::vector<std::string> dump_param_;
bool need_to_push_dense_;
bool need_dump_field_;
bool dump_slot_;
......
......@@ -82,6 +82,14 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
dump_fields_[i] = desc.dump_fields(i);
}
adjust_ins_weight_config_ = desc.adjust_ins_weight_config();
need_dump_param_ = false;
dump_param_.resize(desc.dump_param_size());
for (int i = 0; i < desc.dump_param_size(); ++i) {
dump_param_[i] = desc.dump_param(i);
}
if (desc.dump_param_size() != 0) {
need_dump_param_ = true;
}
for (int i = 0; i < desc.check_nan_var_names_size(); ++i) {
check_nan_var_names_.push_back(desc.check_nan_var_names(i));
}
......@@ -163,6 +171,22 @@ bool CheckValidOutput(LoDTensor* tensor, int batch_size) {
return true;
}
void DownpourWorker::DumpParam() {
std::string os;
for (auto& param : dump_param_) {
os.clear();
os = param;
Variable* var = thread_scope_->FindVar(param);
if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t len = tensor->numel();
os += PrintLodTensor(tensor, 0, len);
writer_ << os;
}
}
void DownpourWorker::CollectLabelInfo(size_t table_idx) {
uint64_t table_id = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(table_idx));
......@@ -814,6 +838,9 @@ void DownpourWorker::TrainFiles() {
}
writer_ << ars[i];
}
if (need_dump_param_ && thread_id_ == 0) {
DumpParam();
}
}
PrintFetchVars();
......
......@@ -105,7 +105,6 @@ class DistMultiTrainer : public MultiTrainer {
bool need_dump_field_;
std::string dump_fields_path_;
std::string dump_converter_;
std::vector<std::string> dump_fields_;
int mpi_rank_;
int mpi_size_;
int dump_file_num_;
......
......@@ -39,6 +39,7 @@ message TrainerDesc {
optional string dump_fields_path = 12;
repeated string dump_fields = 13;
optional string dump_converter = 14;
repeated string dump_param = 15;
optional int32 mpi_size = 16 [ default = -1 ];
optional int32 dump_file_num = 17 [ default = 16 ];
......
......@@ -358,6 +358,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["dump_fields"] = strategy.get("dump_fields", [])
opt_info["dump_file_num"] = strategy.get("dump_file_num", 16)
opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "")
opt_info["dump_param"] = strategy.get("dump_param", [])
if server._server.downpour_server_param.downpour_table_param[
0].accessor.accessor_class == "DownpourCtrAccessor":
opt_info["dump_slot"] = True
......
......@@ -100,6 +100,10 @@ class TrainerDesc(object):
def _set_dump_converter(self, converter):
self.proto_desc.dump_converter = converter
def _set_dump_param(self, dump_param):
for param in dump_param:
self.proto_desc.dump_param.append(param)
def _set_check_nan_var_names(self, check_nan_var_names):
for var in check_nan_var_names:
self.proto_desc.check_nan_var_names.append(var)
......
......@@ -53,6 +53,7 @@ class TrainerFactory(object):
trainer._set_dump_file_num(opt_info["dump_file_num"])
trainer._set_dump_converter(opt_info["dump_converter"])
trainer._set_adjust_ins_weight(opt_info["adjust_ins_weight"])
trainer._set_dump_param(opt_info["dump_param"])
trainer._set_check_nan_var_names(opt_info[
"check_nan_var_names"])
trainer._set_device_worker(device_worker)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册