From b94edba7fd08873926801fdef22502355f6e5c5f Mon Sep 17 00:00:00 2001 From: yaoxuefeng Date: Tue, 1 Dec 2020 21:05:10 +0800 Subject: [PATCH] add user_define_dump (#28596) (#29162) --- paddle/fluid/framework/data_feed.cc | 8 ++++---- paddle/fluid/framework/dist_multi_trainer.cc | 1 + paddle/fluid/framework/multi_trainer.cc | 4 ++++ paddle/fluid/framework/trainer.h | 13 +++++++++++++ paddle/fluid/framework/trainer_desc.proto | 2 ++ .../parameter_server/pslib/optimizer_factory.py | 2 ++ python/paddle/fluid/trainer_desc.py | 3 +++ 7 files changed, 29 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index e77053389fc..41c59c5893f 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -644,7 +644,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe( "characters.\nplease check this error line: %s, \n Specifically, " "something wrong happened(the length of this slot's feasign is 0)" "when we parse the %d th slots." - "Maybe something wrong around this slot", + "Maybe something wrong around this slot" "\nWe detect the feasign number of this slot is %d, " "which is illegal.", str, i, num)); @@ -700,7 +700,7 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector* instance) { "characters.\nplease check this error line: %s, \n Specifically, " "something wrong happened(the length of this slot's feasign is 0)" "when we parse the %d th slots." - "Maybe something wrong around this slot", + "Maybe something wrong around this slot" "\nWe detect the feasign number of this slot is %d, " "which is illegal.", str, i, num)); @@ -921,7 +921,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { "characters.\nplease check this error line: %s, \n Specifically, " "something wrong happened(the length of this slot's feasign is 0)" "when we parse the %d th slots." - "Maybe something wrong around this slot", + "Maybe something wrong around this slot" "\nWe detect the feasign number of this slot is %d, " "which is illegal.", str, i, num)); @@ -991,7 +991,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) { "characters.\nplease check this error line: %s, \n Specifically, " "something wrong happened(the length of this slot's feasign is 0)" "when we parse the %d th slots." - "Maybe something wrong around this slot", + "Maybe something wrong around this slot" "\nWe detect the feasign number of this slot is %d, " "which is illegal.", str, i, num)); diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index 9fe28bddd1f..d16699babae 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -44,6 +44,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc, mpi_rank_ = trainer_desc.mpi_rank(); mpi_size_ = trainer_desc.mpi_size(); dump_file_num_ = trainer_desc.dump_file_num(); + user_define_dump_filename_ = trainer_desc.user_define_dump_filename(); const std::vector readers = dataset->GetReaders(); diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 0faf9619540..9bb7c4c8740 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -83,6 +83,10 @@ void MultiTrainer::DumpWork(int tid) { int err_no = 0; std::string path = string::format_string( "%s/part-%03d-%05d", dump_fields_path_.c_str(), mpi_rank_, tid); + if (user_define_dump_filename_ != "") { + path = string::format_string("%s/part-%s-%05d", dump_fields_path_.c_str(), + user_define_dump_filename_.c_str(), tid); + } std::shared_ptr fp = fs_open_write(path, &err_no, dump_converter_); while (1) { diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index e22d659a367..5af372f7a6c 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -56,6 +56,19 @@ class TrainerBase { Scope* root_scope_; bool debug_; Dataset* dataset_ptr_; + TrainerDesc trainer_desc_; + + // For dump param or field + bool need_dump_field_ = false; + std::string user_define_dump_filename_; + bool need_dump_param_ = false; + std::string dump_fields_path_; + std::string dump_converter_; + std::vector dump_param_; + std::vector dump_fields_; + int dump_thread_num_; + std::vector dump_thread_; + std::shared_ptr> queue_; }; // general trainer for async execution diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index f442063313f..b38572681dc 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -50,6 +50,8 @@ message TrainerDesc { optional bool thread_barrier = 22; repeated string loss_names = 23; + optional string user_define_dump_filename = 24; + // device worker parameters optional HogwildWorkerParameter hogwild_param = 101; optional DownpourWorkerParameter downpour_param = 103; diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py index c0be2ca66ca..8447d0f9608 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py @@ -527,6 +527,8 @@ class DistributedAdam(DistributedOptimizerImplBase): opt_info["dump_converter"] = "" opt_info["dump_fields"] = strategy.get("dump_fields", []) opt_info["dump_file_num"] = strategy.get("dump_file_num", 16) + opt_info["user_define_dump_filename"] = strategy.get( + "user_define_dump_filename", "") opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "") opt_info["dump_param"] = strategy.get("dump_param", []) if server._server.downpour_server_param.downpour_table_param[ diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index c70d4f7b731..2dda17072bf 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -106,6 +106,9 @@ class TrainerDesc(object): def _set_dump_file_num(self, dump_file_num): self.proto_desc.dump_file_num = dump_file_num + def _set_user_define_dump_filename(self, user_define_dump_filename): + self.proto_desc.user_define_dump_filename = user_define_dump_filename + def _set_dump_converter(self, converter): self.proto_desc.dump_converter = converter -- GitLab