From 545df287fc9ca8c5ee9106b9e4bb585981d518b4 Mon Sep 17 00:00:00 2001 From: yaoxuefeng Date: Fri, 27 Nov 2020 15:00:54 +0800 Subject: [PATCH] add user_define_dump (#28596) --- paddle/fluid/framework/data_feed.cc | 8 ++++---- paddle/fluid/framework/dist_multi_trainer.cc | 1 + paddle/fluid/framework/multi_trainer.cc | 4 ++++ paddle/fluid/framework/trainer.h | 1 + paddle/fluid/framework/trainer_desc.proto | 1 + .../fleet/parameter_server/pslib/optimizer_factory.py | 2 ++ python/paddle/fluid/trainer_desc.py | 3 +++ 7 files changed, 16 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index aec27bd9d9..e006bf7c33 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -661,7 +661,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe( "characters.\nplease check this error line: %s, \n Specifically, " "something wrong happened(the length of this slot's feasign is 0)" "when we parse the %d th slots." - "Maybe something wrong around this slot", + "Maybe something wrong around this slot" "\nWe detect the feasign number of this slot is %d, " "which is illegal.", str, i, num)); @@ -717,7 +717,7 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector* instance) { "characters.\nplease check this error line: %s, \n Specifically, " "something wrong happened(the length of this slot's feasign is 0)" "when we parse the %d th slots." - "Maybe something wrong around this slot", + "Maybe something wrong around this slot" "\nWe detect the feasign number of this slot is %d, " "which is illegal.", str, i, num)); @@ -955,7 +955,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { "characters.\nplease check this error line: %s, \n Specifically, " "something wrong happened(the length of this slot's feasign is 0)" "when we parse the %d th slots." - "Maybe something wrong around this slot", + "Maybe something wrong around this slot" "\nWe detect the feasign number of this slot is %d, " "which is illegal.", str, i, num)); @@ -1026,7 +1026,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) { "characters.\nplease check this error line: %s, \n Specifically, " "something wrong happened(the length of this slot's feasign is 0)" "when we parse the %d th slots." - "Maybe something wrong around this slot", + "Maybe something wrong around this slot" "\nWe detect the feasign number of this slot is %d, " "which is illegal.", str, i, num)); diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index 4d55d2987f..e84a62a09d 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -33,6 +33,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc, mpi_rank_ = trainer_desc.mpi_rank(); mpi_size_ = trainer_desc.mpi_size(); dump_file_num_ = trainer_desc.dump_file_num(); + user_define_dump_filename_ = trainer_desc.user_define_dump_filename(); const std::vector readers = dataset->GetReaders(); RegisterHeterCallback(); diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 030e80c0b3..7c900dcfc6 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -71,6 +71,10 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, } std::string MultiTrainer::GetDumpPath(int tid) { + if (user_define_dump_filename_ != "") { + return string::format_string("%s/part-%s-%05d", dump_fields_path_.c_str(), + user_define_dump_filename_.c_str(), tid); + } return string::format_string("%s/part-%03d-%05d", dump_fields_path_.c_str(), mpi_rank_, tid); } diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index f4c8246938..be85247c7e 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -79,6 +79,7 @@ class TrainerBase { // For dump param or field bool need_dump_field_ = false; + std::string user_define_dump_filename_; bool need_dump_param_ = false; std::string dump_fields_path_; std::string dump_converter_; diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index c4e9064d05..70481cf372 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -60,6 +60,7 @@ message TrainerDesc { optional int32 xpu_end_idx = 31; optional bool use_ps_gpu = 32 [ default = false ]; + optional string user_define_dump_filename = 33; // device worker parameters optional HogwildWorkerParameter hogwild_param = 101; diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py index 61fbc7fdf6..727cc2b1b5 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py @@ -760,6 +760,8 @@ class DistributedAdam(DistributedOptimizerImplBase): opt_info["dump_converter"] = "" opt_info["dump_fields"] = strategy.get("dump_fields", []) opt_info["dump_file_num"] = strategy.get("dump_file_num", 16) + opt_info["user_define_dump_filename"] = strategy.get( + "user_define_dump_filename", "") opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "") opt_info["dump_param"] = strategy.get("dump_param", []) opt_info["worker_places"] = strategy.get("worker_places", []) diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index ac7c8c0a68..d1fb843b56 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -146,6 +146,9 @@ class TrainerDesc(object): def _set_dump_file_num(self, dump_file_num): self.proto_desc.dump_file_num = dump_file_num + def _set_user_define_dump_filename(self, user_define_dump_filename): + self.proto_desc.user_define_dump_filename = user_define_dump_filename + def _set_dump_converter(self, converter): self.proto_desc.dump_converter = converter -- GitLab