未验证 提交 b94edba7 编写于 作者: Y yaoxuefeng 提交者: GitHub

add user_define_dump (#28596) (#29162)

上级 e9961bc3
...@@ -644,7 +644,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe( ...@@ -644,7 +644,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe(
"characters.\nplease check this error line: %s, \n Specifically, " "characters.\nplease check this error line: %s, \n Specifically, "
"something wrong happened(the length of this slot's feasign is 0)" "something wrong happened(the length of this slot's feasign is 0)"
"when we parse the %d th slots." "when we parse the %d th slots."
"Maybe something wrong around this slot", "Maybe something wrong around this slot"
"\nWe detect the feasign number of this slot is %d, " "\nWe detect the feasign number of this slot is %d, "
"which is illegal.", "which is illegal.",
str, i, num)); str, i, num));
...@@ -700,7 +700,7 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) { ...@@ -700,7 +700,7 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
"characters.\nplease check this error line: %s, \n Specifically, " "characters.\nplease check this error line: %s, \n Specifically, "
"something wrong happened(the length of this slot's feasign is 0)" "something wrong happened(the length of this slot's feasign is 0)"
"when we parse the %d th slots." "when we parse the %d th slots."
"Maybe something wrong around this slot", "Maybe something wrong around this slot"
"\nWe detect the feasign number of this slot is %d, " "\nWe detect the feasign number of this slot is %d, "
"which is illegal.", "which is illegal.",
str, i, num)); str, i, num));
...@@ -921,7 +921,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { ...@@ -921,7 +921,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
"characters.\nplease check this error line: %s, \n Specifically, " "characters.\nplease check this error line: %s, \n Specifically, "
"something wrong happened(the length of this slot's feasign is 0)" "something wrong happened(the length of this slot's feasign is 0)"
"when we parse the %d th slots." "when we parse the %d th slots."
"Maybe something wrong around this slot", "Maybe something wrong around this slot"
"\nWe detect the feasign number of this slot is %d, " "\nWe detect the feasign number of this slot is %d, "
"which is illegal.", "which is illegal.",
str, i, num)); str, i, num));
...@@ -991,7 +991,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) { ...@@ -991,7 +991,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) {
"characters.\nplease check this error line: %s, \n Specifically, " "characters.\nplease check this error line: %s, \n Specifically, "
"something wrong happened(the length of this slot's feasign is 0)" "something wrong happened(the length of this slot's feasign is 0)"
"when we parse the %d th slots." "when we parse the %d th slots."
"Maybe something wrong around this slot", "Maybe something wrong around this slot"
"\nWe detect the feasign number of this slot is %d, " "\nWe detect the feasign number of this slot is %d, "
"which is illegal.", "which is illegal.",
str, i, num)); str, i, num));
......
...@@ -44,6 +44,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc, ...@@ -44,6 +44,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc,
mpi_rank_ = trainer_desc.mpi_rank(); mpi_rank_ = trainer_desc.mpi_rank();
mpi_size_ = trainer_desc.mpi_size(); mpi_size_ = trainer_desc.mpi_size();
dump_file_num_ = trainer_desc.dump_file_num(); dump_file_num_ = trainer_desc.dump_file_num();
user_define_dump_filename_ = trainer_desc.user_define_dump_filename();
const std::vector<paddle::framework::DataFeed *> readers = const std::vector<paddle::framework::DataFeed *> readers =
dataset->GetReaders(); dataset->GetReaders();
......
...@@ -83,6 +83,10 @@ void MultiTrainer::DumpWork(int tid) { ...@@ -83,6 +83,10 @@ void MultiTrainer::DumpWork(int tid) {
int err_no = 0; int err_no = 0;
std::string path = string::format_string( std::string path = string::format_string(
"%s/part-%03d-%05d", dump_fields_path_.c_str(), mpi_rank_, tid); "%s/part-%03d-%05d", dump_fields_path_.c_str(), mpi_rank_, tid);
if (user_define_dump_filename_ != "") {
path = string::format_string("%s/part-%s-%05d", dump_fields_path_.c_str(),
user_define_dump_filename_.c_str(), tid);
}
std::shared_ptr<FILE> fp = fs_open_write(path, &err_no, dump_converter_); std::shared_ptr<FILE> fp = fs_open_write(path, &err_no, dump_converter_);
while (1) { while (1) {
......
...@@ -56,6 +56,19 @@ class TrainerBase { ...@@ -56,6 +56,19 @@ class TrainerBase {
Scope* root_scope_; Scope* root_scope_;
bool debug_; bool debug_;
Dataset* dataset_ptr_; Dataset* dataset_ptr_;
TrainerDesc trainer_desc_;
// For dump param or field
bool need_dump_field_ = false;
std::string user_define_dump_filename_;
bool need_dump_param_ = false;
std::string dump_fields_path_;
std::string dump_converter_;
std::vector<std::string> dump_param_;
std::vector<std::string> dump_fields_;
int dump_thread_num_;
std::vector<std::thread> dump_thread_;
std::shared_ptr<paddle::framework::ChannelObject<std::string>> queue_;
}; };
// general trainer for async execution // general trainer for async execution
......
...@@ -50,6 +50,8 @@ message TrainerDesc { ...@@ -50,6 +50,8 @@ message TrainerDesc {
optional bool thread_barrier = 22; optional bool thread_barrier = 22;
repeated string loss_names = 23; repeated string loss_names = 23;
optional string user_define_dump_filename = 24;
// device worker parameters // device worker parameters
optional HogwildWorkerParameter hogwild_param = 101; optional HogwildWorkerParameter hogwild_param = 101;
optional DownpourWorkerParameter downpour_param = 103; optional DownpourWorkerParameter downpour_param = 103;
......
...@@ -527,6 +527,8 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -527,6 +527,8 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["dump_converter"] = "" opt_info["dump_converter"] = ""
opt_info["dump_fields"] = strategy.get("dump_fields", []) opt_info["dump_fields"] = strategy.get("dump_fields", [])
opt_info["dump_file_num"] = strategy.get("dump_file_num", 16) opt_info["dump_file_num"] = strategy.get("dump_file_num", 16)
opt_info["user_define_dump_filename"] = strategy.get(
"user_define_dump_filename", "")
opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "") opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "")
opt_info["dump_param"] = strategy.get("dump_param", []) opt_info["dump_param"] = strategy.get("dump_param", [])
if server._server.downpour_server_param.downpour_table_param[ if server._server.downpour_server_param.downpour_table_param[
......
...@@ -106,6 +106,9 @@ class TrainerDesc(object): ...@@ -106,6 +106,9 @@ class TrainerDesc(object):
def _set_dump_file_num(self, dump_file_num): def _set_dump_file_num(self, dump_file_num):
self.proto_desc.dump_file_num = dump_file_num self.proto_desc.dump_file_num = dump_file_num
def _set_user_define_dump_filename(self, user_define_dump_filename):
self.proto_desc.user_define_dump_filename = user_define_dump_filename
def _set_dump_converter(self, converter): def _set_dump_converter(self, converter):
self.proto_desc.dump_converter = converter self.proto_desc.dump_converter = converter
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册