未验证 提交 545df287 编写于 作者: Y yaoxuefeng 提交者: GitHub

add user_define_dump (#28596)

上级 216e0856
...@@ -661,7 +661,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe( ...@@ -661,7 +661,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe(
"characters.\nplease check this error line: %s, \n Specifically, " "characters.\nplease check this error line: %s, \n Specifically, "
"something wrong happened(the length of this slot's feasign is 0)" "something wrong happened(the length of this slot's feasign is 0)"
"when we parse the %d th slots." "when we parse the %d th slots."
"Maybe something wrong around this slot", "Maybe something wrong around this slot"
"\nWe detect the feasign number of this slot is %d, " "\nWe detect the feasign number of this slot is %d, "
"which is illegal.", "which is illegal.",
str, i, num)); str, i, num));
...@@ -717,7 +717,7 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) { ...@@ -717,7 +717,7 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
"characters.\nplease check this error line: %s, \n Specifically, " "characters.\nplease check this error line: %s, \n Specifically, "
"something wrong happened(the length of this slot's feasign is 0)" "something wrong happened(the length of this slot's feasign is 0)"
"when we parse the %d th slots." "when we parse the %d th slots."
"Maybe something wrong around this slot", "Maybe something wrong around this slot"
"\nWe detect the feasign number of this slot is %d, " "\nWe detect the feasign number of this slot is %d, "
"which is illegal.", "which is illegal.",
str, i, num)); str, i, num));
...@@ -955,7 +955,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { ...@@ -955,7 +955,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
"characters.\nplease check this error line: %s, \n Specifically, " "characters.\nplease check this error line: %s, \n Specifically, "
"something wrong happened(the length of this slot's feasign is 0)" "something wrong happened(the length of this slot's feasign is 0)"
"when we parse the %d th slots." "when we parse the %d th slots."
"Maybe something wrong around this slot", "Maybe something wrong around this slot"
"\nWe detect the feasign number of this slot is %d, " "\nWe detect the feasign number of this slot is %d, "
"which is illegal.", "which is illegal.",
str, i, num)); str, i, num));
...@@ -1026,7 +1026,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) { ...@@ -1026,7 +1026,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) {
"characters.\nplease check this error line: %s, \n Specifically, " "characters.\nplease check this error line: %s, \n Specifically, "
"something wrong happened(the length of this slot's feasign is 0)" "something wrong happened(the length of this slot's feasign is 0)"
"when we parse the %d th slots." "when we parse the %d th slots."
"Maybe something wrong around this slot", "Maybe something wrong around this slot"
"\nWe detect the feasign number of this slot is %d, " "\nWe detect the feasign number of this slot is %d, "
"which is illegal.", "which is illegal.",
str, i, num)); str, i, num));
......
...@@ -33,6 +33,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc, ...@@ -33,6 +33,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc,
mpi_rank_ = trainer_desc.mpi_rank(); mpi_rank_ = trainer_desc.mpi_rank();
mpi_size_ = trainer_desc.mpi_size(); mpi_size_ = trainer_desc.mpi_size();
dump_file_num_ = trainer_desc.dump_file_num(); dump_file_num_ = trainer_desc.dump_file_num();
user_define_dump_filename_ = trainer_desc.user_define_dump_filename();
const std::vector<paddle::framework::DataFeed *> readers = const std::vector<paddle::framework::DataFeed *> readers =
dataset->GetReaders(); dataset->GetReaders();
RegisterHeterCallback(); RegisterHeterCallback();
......
...@@ -71,6 +71,10 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -71,6 +71,10 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
} }
std::string MultiTrainer::GetDumpPath(int tid) { std::string MultiTrainer::GetDumpPath(int tid) {
if (user_define_dump_filename_ != "") {
return string::format_string("%s/part-%s-%05d", dump_fields_path_.c_str(),
user_define_dump_filename_.c_str(), tid);
}
return string::format_string("%s/part-%03d-%05d", dump_fields_path_.c_str(), return string::format_string("%s/part-%03d-%05d", dump_fields_path_.c_str(),
mpi_rank_, tid); mpi_rank_, tid);
} }
......
...@@ -79,6 +79,7 @@ class TrainerBase { ...@@ -79,6 +79,7 @@ class TrainerBase {
// For dump param or field // For dump param or field
bool need_dump_field_ = false; bool need_dump_field_ = false;
std::string user_define_dump_filename_;
bool need_dump_param_ = false; bool need_dump_param_ = false;
std::string dump_fields_path_; std::string dump_fields_path_;
std::string dump_converter_; std::string dump_converter_;
......
...@@ -60,6 +60,7 @@ message TrainerDesc { ...@@ -60,6 +60,7 @@ message TrainerDesc {
optional int32 xpu_end_idx = 31; optional int32 xpu_end_idx = 31;
optional bool use_ps_gpu = 32 [ default = false ]; optional bool use_ps_gpu = 32 [ default = false ];
optional string user_define_dump_filename = 33;
// device worker parameters // device worker parameters
optional HogwildWorkerParameter hogwild_param = 101; optional HogwildWorkerParameter hogwild_param = 101;
......
...@@ -760,6 +760,8 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -760,6 +760,8 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["dump_converter"] = "" opt_info["dump_converter"] = ""
opt_info["dump_fields"] = strategy.get("dump_fields", []) opt_info["dump_fields"] = strategy.get("dump_fields", [])
opt_info["dump_file_num"] = strategy.get("dump_file_num", 16) opt_info["dump_file_num"] = strategy.get("dump_file_num", 16)
opt_info["user_define_dump_filename"] = strategy.get(
"user_define_dump_filename", "")
opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "") opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "")
opt_info["dump_param"] = strategy.get("dump_param", []) opt_info["dump_param"] = strategy.get("dump_param", [])
opt_info["worker_places"] = strategy.get("worker_places", []) opt_info["worker_places"] = strategy.get("worker_places", [])
......
...@@ -146,6 +146,9 @@ class TrainerDesc(object): ...@@ -146,6 +146,9 @@ class TrainerDesc(object):
def _set_dump_file_num(self, dump_file_num): def _set_dump_file_num(self, dump_file_num):
self.proto_desc.dump_file_num = dump_file_num self.proto_desc.dump_file_num = dump_file_num
def _set_user_define_dump_filename(self, user_define_dump_filename):
self.proto_desc.user_define_dump_filename = user_define_dump_filename
def _set_dump_converter(self, converter): def _set_dump_converter(self, converter):
self.proto_desc.dump_converter = converter self.proto_desc.dump_converter = converter
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册