未验证 提交 ddf80c4b 编写于 作者: T Thunderbrook 提交者: GitHub

dump fix dov vec file num (#20539) (#20605)

* support dump multi file
test=develop

* dump fix num file
test=develop
上级 30dda5a7
......@@ -41,6 +41,8 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc,
}
}
mpi_rank_ = trainer_desc.mpi_rank() / 2;
mpi_size_ = trainer_desc.mpi_size() / 2;
dump_file_num_ = trainer_desc.dump_file_num();
const std::vector<paddle::framework::DataFeed *> readers =
dataset->GetReaders();
......@@ -68,20 +70,25 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc,
SetDebug(trainer_desc.debug());
}
void DistMultiTrainer::DumpWork() {
void DistMultiTrainer::DumpWork(int tid) {
#ifdef _LINUX
int err_no = 0;
std::string path = string::format_string(
"%s/part-%03d-%05d", dump_fields_path_.c_str(), mpi_rank_, tid);
std::shared_ptr<FILE> fp = fs_open_write(path, &err_no, dump_converter_);
while (1) {
std::string out_str;
if (!queue_->Get(out_str)) {
break;
}
size_t write_count =
fwrite_unlocked(out_str.data(), 1, out_str.length(), fp_.get());
fwrite_unlocked(out_str.data(), 1, out_str.length(), fp.get());
if (write_count != out_str.length()) {
VLOG(3) << "dump text failed";
continue;
}
write_count = fwrite_unlocked("\n", 1, 1, fp_.get());
write_count = fwrite_unlocked("\n", 1, 1, fp.get());
if (write_count != 1) {
VLOG(3) << "dump text failed";
continue;
......@@ -92,20 +99,27 @@ void DistMultiTrainer::DumpWork() {
void DistMultiTrainer::InitDumpEnv() {
queue_ = paddle::framework::MakeChannel<std::string>();
int err_no = 0;
std::string path = string::format_string(
"%s/part-%03d", dump_fields_path_.c_str(), mpi_rank_);
fp_ = fs_open_write(path, &err_no, dump_converter_);
for (int i = 0; i < thread_num_; ++i) {
workers_[i]->SetChannelWriter(queue_.get());
}
dump_thread_ = std::thread(&DistMultiTrainer::DumpWork, this);
dump_thread_num_ = 1;
if (dump_file_num_ > mpi_size_) {
dump_thread_num_ = dump_file_num_ / mpi_size_;
if (dump_file_num_ % mpi_size_ > mpi_rank_) {
dump_thread_num_ += 1;
}
}
for (int i = 0; i < dump_thread_num_; i++) {
dump_thread_.push_back(
std::thread(std::bind(&DistMultiTrainer::DumpWork, this, i)));
}
}
void DistMultiTrainer::FinalizeDumpEnv() {
queue_->Close();
dump_thread_.join();
for (auto &th : dump_thread_) {
th.join();
}
queue_.reset();
}
......
......@@ -93,13 +93,13 @@ class DistMultiTrainer : public MultiTrainer {
void MergeToRootScope(LoDTensor* root_tensor, LoDTensor* thread_tensor);
virtual void FinalizeDumpEnv();
virtual void InitDumpEnv();
virtual void DumpWork();
virtual void DumpWork(int tid);
virtual Scope* GetWorkerScope(int thread_id) { return root_scope_; }
protected:
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
std::thread dump_thread_;
std::shared_ptr<FILE> fp_;
std::vector<std::thread> dump_thread_;
int dump_thread_num_;
std::shared_ptr<paddle::framework::ChannelObject<std::string>> queue_;
bool need_dump_field_;
......@@ -107,6 +107,8 @@ class DistMultiTrainer : public MultiTrainer {
std::string dump_converter_;
std::vector<std::string> dump_fields_;
int mpi_rank_;
int mpi_size_;
int dump_file_num_;
};
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
......
......@@ -40,6 +40,9 @@ message TrainerDesc {
repeated string dump_fields = 13;
optional string dump_converter = 14;
optional int32 mpi_size = 16 [ default = -1 ];
optional int32 dump_file_num = 17 [ default = 16 ];
// device worker parameters
optional HogwildWorkerParameter hogwild_param = 101;
optional DownpourWorkerParameter downpour_param = 103;
......
......@@ -588,6 +588,7 @@ class DownpourOptimizer(DistributedOptimizer):
no_grad_set,
self._strategy)
opt_info["mpi_rank"] = fleet._role_maker._get_rank()
opt_info["mpi_size"] = fleet._role_maker._get_size()
fleet._set_opt_info(opt_info)
programs = [loss.block.program for loss in losses]
......
......@@ -251,6 +251,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["dump_slot"] = False
opt_info["dump_converter"] = ""
opt_info["dump_fields"] = strategy.get("dump_fields", [])
opt_info["dump_file_num"] = strategy.get("dump_file_num", 16)
opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "")
if server._server.downpour_server_param.downpour_table_param[
0].accessor.accessor_class == "DownpourCtrAccessor":
......
......@@ -84,6 +84,9 @@ class TrainerDesc(object):
def _set_mpi_rank(self, mpi_rank):
self.proto_desc.mpi_rank = mpi_rank
def _set_mpi_size(self, mpi_size):
self.proto_desc.mpi_size = mpi_size
def _set_dump_fields(self, dump_fields):
for field in dump_fields:
self.proto_desc.dump_fields.append(field)
......@@ -91,6 +94,9 @@ class TrainerDesc(object):
def _set_dump_fields_path(self, path):
self.proto_desc.dump_fields_path = path
def _set_dump_file_num(self, dump_file_num):
self.proto_desc.dump_file_num = dump_file_num
def _set_dump_converter(self, converter):
self.proto_desc.dump_converter = converter
......
......@@ -47,8 +47,10 @@ class TrainerFactory(object):
trainer._set_scale_datanorm(opt_info["scale_datanorm"])
trainer._set_dump_slot(opt_info["dump_slot"])
trainer._set_mpi_rank(opt_info["mpi_rank"])
trainer._set_mpi_size(opt_info["mpi_size"])
trainer._set_dump_fields(opt_info["dump_fields"])
trainer._set_dump_fields_path(opt_info["dump_fields_path"])
trainer._set_dump_file_num(opt_info["dump_file_num"])
trainer._set_dump_converter(opt_info["dump_converter"])
trainer._set_adjust_ins_weight(opt_info["adjust_ins_weight"])
trainer._set_device_worker(device_worker)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册