diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index aec27bd9d91e5afb6bf11037e60ff213162ad97f..e006bf7c33f6a3d2e0c29604a784b8379b0ef095 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -661,7 +661,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe( "characters.\nplease check this error line: %s, \n Specifically, " "something wrong happened(the length of this slot's feasign is 0)" "when we parse the %d th slots." - "Maybe something wrong around this slot", + "Maybe something wrong around this slot" "\nWe detect the feasign number of this slot is %d, " "which is illegal.", str, i, num)); @@ -717,7 +717,7 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector* instance) { "characters.\nplease check this error line: %s, \n Specifically, " "something wrong happened(the length of this slot's feasign is 0)" "when we parse the %d th slots." - "Maybe something wrong around this slot", + "Maybe something wrong around this slot" "\nWe detect the feasign number of this slot is %d, " "which is illegal.", str, i, num)); @@ -955,7 +955,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { "characters.\nplease check this error line: %s, \n Specifically, " "something wrong happened(the length of this slot's feasign is 0)" "when we parse the %d th slots." - "Maybe something wrong around this slot", + "Maybe something wrong around this slot" "\nWe detect the feasign number of this slot is %d, " "which is illegal.", str, i, num)); @@ -1026,7 +1026,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) { "characters.\nplease check this error line: %s, \n Specifically, " "something wrong happened(the length of this slot's feasign is 0)" "when we parse the %d th slots." - "Maybe something wrong around this slot", + "Maybe something wrong around this slot" "\nWe detect the feasign number of this slot is %d, " "which is illegal.", str, i, num)); diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index 4d55d2987f3f39525c1070e3213f3a2e84e18dff..e84a62a09de24ddec8d6e2333efa739cf1ce72a3 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -33,6 +33,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc, mpi_rank_ = trainer_desc.mpi_rank(); mpi_size_ = trainer_desc.mpi_size(); dump_file_num_ = trainer_desc.dump_file_num(); + user_define_dump_filename_ = trainer_desc.user_define_dump_filename(); const std::vector readers = dataset->GetReaders(); RegisterHeterCallback(); diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 030e80c0b3fa12ea2dd8f0dcc676a42ef68db3ea..7c900dcfc64631602ca799c7cf9ac32dc7146e07 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -71,6 +71,10 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, } std::string MultiTrainer::GetDumpPath(int tid) { + if (user_define_dump_filename_ != "") { + return string::format_string("%s/part-%s-%05d", dump_fields_path_.c_str(), + user_define_dump_filename_.c_str(), tid); + } return string::format_string("%s/part-%03d-%05d", dump_fields_path_.c_str(), mpi_rank_, tid); } diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index f4c8246938e9a9e2fdbe1cf27b29ade404fac64c..be85247c7ea1fcd336115ad8a8bdf98653340244 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -79,6 +79,7 @@ class TrainerBase { // For dump param or field bool need_dump_field_ = false; + std::string user_define_dump_filename_; bool need_dump_param_ = false; std::string dump_fields_path_; std::string dump_converter_; diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index c4e9064d0556c2ce17a2ec0fe6880cead11cf1d4..70481cf3727012e4cf41d235154eb277d92cc92f 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -60,6 +60,7 @@ message TrainerDesc { optional int32 xpu_end_idx = 31; optional bool use_ps_gpu = 32 [ default = false ]; + optional string user_define_dump_filename = 33; // device worker parameters optional HogwildWorkerParameter hogwild_param = 101; diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py index 61fbc7fdf6633fb9cf73141b5944a515e4155595..727cc2b1b54bc523f0fc39f30c2bb46e45445b47 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py @@ -760,6 +760,8 @@ class DistributedAdam(DistributedOptimizerImplBase): opt_info["dump_converter"] = "" opt_info["dump_fields"] = strategy.get("dump_fields", []) opt_info["dump_file_num"] = strategy.get("dump_file_num", 16) + opt_info["user_define_dump_filename"] = strategy.get( + "user_define_dump_filename", "") opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "") opt_info["dump_param"] = strategy.get("dump_param", []) opt_info["worker_places"] = strategy.get("worker_places", []) diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index ac7c8c0a687bb96d1eb0fc26b1b36d9e7ff9cf1f..d1fb843b566014d2cc0c18609cacae0df61d29a8 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -146,6 +146,9 @@ class TrainerDesc(object): def _set_dump_file_num(self, dump_file_num): self.proto_desc.dump_file_num = dump_file_num + def _set_user_define_dump_filename(self, user_define_dump_filename): + self.proto_desc.user_define_dump_filename = user_define_dump_filename + def _set_dump_converter(self, converter): self.proto_desc.dump_converter = converter