diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index e77053389fc3f09ee0590696e693d615164d5155..41c59c5893f2172e7da77dbadd1588224ef6cc4e 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -644,7 +644,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe( "characters.\nplease check this error line: %s, \n Specifically, " "something wrong happened(the length of this slot's feasign is 0)" "when we parse the %d th slots." - "Maybe something wrong around this slot", + "Maybe something wrong around this slot" "\nWe detect the feasign number of this slot is %d, " "which is illegal.", str, i, num)); @@ -700,7 +700,7 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector* instance) { "characters.\nplease check this error line: %s, \n Specifically, " "something wrong happened(the length of this slot's feasign is 0)" "when we parse the %d th slots." - "Maybe something wrong around this slot", + "Maybe something wrong around this slot" "\nWe detect the feasign number of this slot is %d, " "which is illegal.", str, i, num)); @@ -921,7 +921,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { "characters.\nplease check this error line: %s, \n Specifically, " "something wrong happened(the length of this slot's feasign is 0)" "when we parse the %d th slots." - "Maybe something wrong around this slot", + "Maybe something wrong around this slot" "\nWe detect the feasign number of this slot is %d, " "which is illegal.", str, i, num)); @@ -991,7 +991,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) { "characters.\nplease check this error line: %s, \n Specifically, " "something wrong happened(the length of this slot's feasign is 0)" "when we parse the %d th slots." - "Maybe something wrong around this slot", + "Maybe something wrong around this slot" "\nWe detect the feasign number of this slot is %d, " "which is illegal.", str, i, num)); diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index 9fe28bddd1f04a80c1ede7466bf6c881c0f6c817..d16699babaea76a03fc2d2a334717d2f7e0d1345 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -44,6 +44,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc, mpi_rank_ = trainer_desc.mpi_rank(); mpi_size_ = trainer_desc.mpi_size(); dump_file_num_ = trainer_desc.dump_file_num(); + user_define_dump_filename_ = trainer_desc.user_define_dump_filename(); const std::vector readers = dataset->GetReaders(); diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 0faf96195403faeead00c56353cd5ad965269e13..9bb7c4c874081d7a4e4b6c6bc491f8204342e9f6 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -83,6 +83,10 @@ void MultiTrainer::DumpWork(int tid) { int err_no = 0; std::string path = string::format_string( "%s/part-%03d-%05d", dump_fields_path_.c_str(), mpi_rank_, tid); + if (user_define_dump_filename_ != "") { + path = string::format_string("%s/part-%s-%05d", dump_fields_path_.c_str(), + user_define_dump_filename_.c_str(), tid); + } std::shared_ptr fp = fs_open_write(path, &err_no, dump_converter_); while (1) { diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index e22d659a367df8d1c6daf24b989cf5420b5609d3..5af372f7a6cd8d612a76e6117553a378635b640f 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -56,6 +56,19 @@ class TrainerBase { Scope* root_scope_; bool debug_; Dataset* dataset_ptr_; + TrainerDesc trainer_desc_; + + // For dump param or field + bool need_dump_field_ = false; + std::string user_define_dump_filename_; + bool need_dump_param_ = false; + std::string dump_fields_path_; + std::string dump_converter_; + std::vector dump_param_; + std::vector dump_fields_; + int dump_thread_num_; + std::vector dump_thread_; + std::shared_ptr> queue_; }; // general trainer for async execution diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index f442063313f03321931112ed293ccdf8ebabeb89..b38572681dcadc40fc26f6197e5eba592e7e9633 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -50,6 +50,8 @@ message TrainerDesc { optional bool thread_barrier = 22; repeated string loss_names = 23; + optional string user_define_dump_filename = 24; + // device worker parameters optional HogwildWorkerParameter hogwild_param = 101; optional DownpourWorkerParameter downpour_param = 103; diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py index c0be2ca66caf23c75c982f6d7d964f2babd271bb..8447d0f96089dbba5b951fef91ef34feafd28887 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py @@ -527,6 +527,8 @@ class DistributedAdam(DistributedOptimizerImplBase): opt_info["dump_converter"] = "" opt_info["dump_fields"] = strategy.get("dump_fields", []) opt_info["dump_file_num"] = strategy.get("dump_file_num", 16) + opt_info["user_define_dump_filename"] = strategy.get( + "user_define_dump_filename", "") opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "") opt_info["dump_param"] = strategy.get("dump_param", []) if server._server.downpour_server_param.downpour_table_param[ diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index c70d4f7b73145c38c70b138b98a1e689ae3ee492..2dda17072bfcdcc435a3335efa35c30e393780a2 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -106,6 +106,9 @@ class TrainerDesc(object): def _set_dump_file_num(self, dump_file_num): self.proto_desc.dump_file_num = dump_file_num + def _set_user_define_dump_filename(self, user_define_dump_filename): + self.proto_desc.user_define_dump_filename = user_define_dump_filename + def _set_dump_converter(self, converter): self.proto_desc.dump_converter = converter