未验证 提交 b94edba7 编写于 作者: Y yaoxuefeng 提交者: GitHub

add user_define_dump (#28596) (#29162)

上级 e9961bc3
......@@ -644,7 +644,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe(
"characters.\nplease check this error line: %s, \n Specifically, "
"something wrong happened(the length of this slot's feasign is 0)"
"when we parse the %d th slots."
"Maybe something wrong around this slot",
"Maybe something wrong around this slot"
"\nWe detect the feasign number of this slot is %d, "
"which is illegal.",
str, i, num));
......@@ -700,7 +700,7 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
"characters.\nplease check this error line: %s, \n Specifically, "
"something wrong happened(the length of this slot's feasign is 0)"
"when we parse the %d th slots."
"Maybe something wrong around this slot",
"Maybe something wrong around this slot"
"\nWe detect the feasign number of this slot is %d, "
"which is illegal.",
str, i, num));
......@@ -921,7 +921,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
"characters.\nplease check this error line: %s, \n Specifically, "
"something wrong happened(the length of this slot's feasign is 0)"
"when we parse the %d th slots."
"Maybe something wrong around this slot",
"Maybe something wrong around this slot"
"\nWe detect the feasign number of this slot is %d, "
"which is illegal.",
str, i, num));
......@@ -991,7 +991,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) {
"characters.\nplease check this error line: %s, \n Specifically, "
"something wrong happened(the length of this slot's feasign is 0)"
"when we parse the %d th slots."
"Maybe something wrong around this slot",
"Maybe something wrong around this slot"
"\nWe detect the feasign number of this slot is %d, "
"which is illegal.",
str, i, num));
......
......@@ -44,6 +44,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc,
mpi_rank_ = trainer_desc.mpi_rank();
mpi_size_ = trainer_desc.mpi_size();
dump_file_num_ = trainer_desc.dump_file_num();
user_define_dump_filename_ = trainer_desc.user_define_dump_filename();
const std::vector<paddle::framework::DataFeed *> readers =
dataset->GetReaders();
......
......@@ -83,6 +83,10 @@ void MultiTrainer::DumpWork(int tid) {
int err_no = 0;
std::string path = string::format_string(
"%s/part-%03d-%05d", dump_fields_path_.c_str(), mpi_rank_, tid);
if (user_define_dump_filename_ != "") {
path = string::format_string("%s/part-%s-%05d", dump_fields_path_.c_str(),
user_define_dump_filename_.c_str(), tid);
}
std::shared_ptr<FILE> fp = fs_open_write(path, &err_no, dump_converter_);
while (1) {
......
......@@ -56,6 +56,19 @@ class TrainerBase {
Scope* root_scope_;
bool debug_;
Dataset* dataset_ptr_;
TrainerDesc trainer_desc_;
// For dump param or field
bool need_dump_field_ = false;
std::string user_define_dump_filename_;
bool need_dump_param_ = false;
std::string dump_fields_path_;
std::string dump_converter_;
std::vector<std::string> dump_param_;
std::vector<std::string> dump_fields_;
int dump_thread_num_;
std::vector<std::thread> dump_thread_;
std::shared_ptr<paddle::framework::ChannelObject<std::string>> queue_;
};
// general trainer for async execution
......
......@@ -50,6 +50,8 @@ message TrainerDesc {
optional bool thread_barrier = 22;
repeated string loss_names = 23;
optional string user_define_dump_filename = 24;
// device worker parameters
optional HogwildWorkerParameter hogwild_param = 101;
optional DownpourWorkerParameter downpour_param = 103;
......
......@@ -527,6 +527,8 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["dump_converter"] = ""
opt_info["dump_fields"] = strategy.get("dump_fields", [])
opt_info["dump_file_num"] = strategy.get("dump_file_num", 16)
opt_info["user_define_dump_filename"] = strategy.get(
"user_define_dump_filename", "")
opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "")
opt_info["dump_param"] = strategy.get("dump_param", [])
if server._server.downpour_server_param.downpour_table_param[
......
......@@ -106,6 +106,9 @@ class TrainerDesc(object):
def _set_dump_file_num(self, dump_file_num):
self.proto_desc.dump_file_num = dump_file_num
def _set_user_define_dump_filename(self, user_define_dump_filename):
self.proto_desc.user_define_dump_filename = user_define_dump_filename
def _set_dump_converter(self, converter):
self.proto_desc.dump_converter = converter
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册