未验证 提交 ddf80c4b 编写于 作者: T Thunderbrook 提交者: GitHub

dump fix dov vec file num (#20539) (#20605)

* support dump multi file
test=develop

* dump fix num file
test=develop
上级 30dda5a7
...@@ -41,6 +41,8 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc, ...@@ -41,6 +41,8 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc,
} }
} }
mpi_rank_ = trainer_desc.mpi_rank() / 2; mpi_rank_ = trainer_desc.mpi_rank() / 2;
mpi_size_ = trainer_desc.mpi_size() / 2;
dump_file_num_ = trainer_desc.dump_file_num();
const std::vector<paddle::framework::DataFeed *> readers = const std::vector<paddle::framework::DataFeed *> readers =
dataset->GetReaders(); dataset->GetReaders();
...@@ -68,20 +70,25 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc, ...@@ -68,20 +70,25 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc,
SetDebug(trainer_desc.debug()); SetDebug(trainer_desc.debug());
} }
void DistMultiTrainer::DumpWork() { void DistMultiTrainer::DumpWork(int tid) {
#ifdef _LINUX #ifdef _LINUX
int err_no = 0;
std::string path = string::format_string(
"%s/part-%03d-%05d", dump_fields_path_.c_str(), mpi_rank_, tid);
std::shared_ptr<FILE> fp = fs_open_write(path, &err_no, dump_converter_);
while (1) { while (1) {
std::string out_str; std::string out_str;
if (!queue_->Get(out_str)) { if (!queue_->Get(out_str)) {
break; break;
} }
size_t write_count = size_t write_count =
fwrite_unlocked(out_str.data(), 1, out_str.length(), fp_.get()); fwrite_unlocked(out_str.data(), 1, out_str.length(), fp.get());
if (write_count != out_str.length()) { if (write_count != out_str.length()) {
VLOG(3) << "dump text failed"; VLOG(3) << "dump text failed";
continue; continue;
} }
write_count = fwrite_unlocked("\n", 1, 1, fp_.get()); write_count = fwrite_unlocked("\n", 1, 1, fp.get());
if (write_count != 1) { if (write_count != 1) {
VLOG(3) << "dump text failed"; VLOG(3) << "dump text failed";
continue; continue;
...@@ -92,20 +99,27 @@ void DistMultiTrainer::DumpWork() { ...@@ -92,20 +99,27 @@ void DistMultiTrainer::DumpWork() {
void DistMultiTrainer::InitDumpEnv() { void DistMultiTrainer::InitDumpEnv() {
queue_ = paddle::framework::MakeChannel<std::string>(); queue_ = paddle::framework::MakeChannel<std::string>();
int err_no = 0;
std::string path = string::format_string(
"%s/part-%03d", dump_fields_path_.c_str(), mpi_rank_);
fp_ = fs_open_write(path, &err_no, dump_converter_);
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
workers_[i]->SetChannelWriter(queue_.get()); workers_[i]->SetChannelWriter(queue_.get());
} }
dump_thread_ = std::thread(&DistMultiTrainer::DumpWork, this); dump_thread_num_ = 1;
if (dump_file_num_ > mpi_size_) {
dump_thread_num_ = dump_file_num_ / mpi_size_;
if (dump_file_num_ % mpi_size_ > mpi_rank_) {
dump_thread_num_ += 1;
}
}
for (int i = 0; i < dump_thread_num_; i++) {
dump_thread_.push_back(
std::thread(std::bind(&DistMultiTrainer::DumpWork, this, i)));
}
} }
void DistMultiTrainer::FinalizeDumpEnv() { void DistMultiTrainer::FinalizeDumpEnv() {
queue_->Close(); queue_->Close();
dump_thread_.join(); for (auto &th : dump_thread_) {
th.join();
}
queue_.reset(); queue_.reset();
} }
......
...@@ -93,13 +93,13 @@ class DistMultiTrainer : public MultiTrainer { ...@@ -93,13 +93,13 @@ class DistMultiTrainer : public MultiTrainer {
void MergeToRootScope(LoDTensor* root_tensor, LoDTensor* thread_tensor); void MergeToRootScope(LoDTensor* root_tensor, LoDTensor* thread_tensor);
virtual void FinalizeDumpEnv(); virtual void FinalizeDumpEnv();
virtual void InitDumpEnv(); virtual void InitDumpEnv();
virtual void DumpWork(); virtual void DumpWork(int tid);
virtual Scope* GetWorkerScope(int thread_id) { return root_scope_; } virtual Scope* GetWorkerScope(int thread_id) { return root_scope_; }
protected: protected:
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_; std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
std::thread dump_thread_; std::vector<std::thread> dump_thread_;
std::shared_ptr<FILE> fp_; int dump_thread_num_;
std::shared_ptr<paddle::framework::ChannelObject<std::string>> queue_; std::shared_ptr<paddle::framework::ChannelObject<std::string>> queue_;
bool need_dump_field_; bool need_dump_field_;
...@@ -107,6 +107,8 @@ class DistMultiTrainer : public MultiTrainer { ...@@ -107,6 +107,8 @@ class DistMultiTrainer : public MultiTrainer {
std::string dump_converter_; std::string dump_converter_;
std::vector<std::string> dump_fields_; std::vector<std::string> dump_fields_;
int mpi_rank_; int mpi_rank_;
int mpi_size_;
int dump_file_num_;
}; };
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
......
...@@ -40,6 +40,9 @@ message TrainerDesc { ...@@ -40,6 +40,9 @@ message TrainerDesc {
repeated string dump_fields = 13; repeated string dump_fields = 13;
optional string dump_converter = 14; optional string dump_converter = 14;
optional int32 mpi_size = 16 [ default = -1 ];
optional int32 dump_file_num = 17 [ default = 16 ];
// device worker parameters // device worker parameters
optional HogwildWorkerParameter hogwild_param = 101; optional HogwildWorkerParameter hogwild_param = 101;
optional DownpourWorkerParameter downpour_param = 103; optional DownpourWorkerParameter downpour_param = 103;
......
...@@ -588,6 +588,7 @@ class DownpourOptimizer(DistributedOptimizer): ...@@ -588,6 +588,7 @@ class DownpourOptimizer(DistributedOptimizer):
no_grad_set, no_grad_set,
self._strategy) self._strategy)
opt_info["mpi_rank"] = fleet._role_maker._get_rank() opt_info["mpi_rank"] = fleet._role_maker._get_rank()
opt_info["mpi_size"] = fleet._role_maker._get_size()
fleet._set_opt_info(opt_info) fleet._set_opt_info(opt_info)
programs = [loss.block.program for loss in losses] programs = [loss.block.program for loss in losses]
......
...@@ -251,6 +251,7 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -251,6 +251,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["dump_slot"] = False opt_info["dump_slot"] = False
opt_info["dump_converter"] = "" opt_info["dump_converter"] = ""
opt_info["dump_fields"] = strategy.get("dump_fields", []) opt_info["dump_fields"] = strategy.get("dump_fields", [])
opt_info["dump_file_num"] = strategy.get("dump_file_num", 16)
opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "") opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "")
if server._server.downpour_server_param.downpour_table_param[ if server._server.downpour_server_param.downpour_table_param[
0].accessor.accessor_class == "DownpourCtrAccessor": 0].accessor.accessor_class == "DownpourCtrAccessor":
......
...@@ -84,6 +84,9 @@ class TrainerDesc(object): ...@@ -84,6 +84,9 @@ class TrainerDesc(object):
def _set_mpi_rank(self, mpi_rank): def _set_mpi_rank(self, mpi_rank):
self.proto_desc.mpi_rank = mpi_rank self.proto_desc.mpi_rank = mpi_rank
def _set_mpi_size(self, mpi_size):
self.proto_desc.mpi_size = mpi_size
def _set_dump_fields(self, dump_fields): def _set_dump_fields(self, dump_fields):
for field in dump_fields: for field in dump_fields:
self.proto_desc.dump_fields.append(field) self.proto_desc.dump_fields.append(field)
...@@ -91,6 +94,9 @@ class TrainerDesc(object): ...@@ -91,6 +94,9 @@ class TrainerDesc(object):
def _set_dump_fields_path(self, path): def _set_dump_fields_path(self, path):
self.proto_desc.dump_fields_path = path self.proto_desc.dump_fields_path = path
def _set_dump_file_num(self, dump_file_num):
self.proto_desc.dump_file_num = dump_file_num
def _set_dump_converter(self, converter): def _set_dump_converter(self, converter):
self.proto_desc.dump_converter = converter self.proto_desc.dump_converter = converter
......
...@@ -47,8 +47,10 @@ class TrainerFactory(object): ...@@ -47,8 +47,10 @@ class TrainerFactory(object):
trainer._set_scale_datanorm(opt_info["scale_datanorm"]) trainer._set_scale_datanorm(opt_info["scale_datanorm"])
trainer._set_dump_slot(opt_info["dump_slot"]) trainer._set_dump_slot(opt_info["dump_slot"])
trainer._set_mpi_rank(opt_info["mpi_rank"]) trainer._set_mpi_rank(opt_info["mpi_rank"])
trainer._set_mpi_size(opt_info["mpi_size"])
trainer._set_dump_fields(opt_info["dump_fields"]) trainer._set_dump_fields(opt_info["dump_fields"])
trainer._set_dump_fields_path(opt_info["dump_fields_path"]) trainer._set_dump_fields_path(opt_info["dump_fields_path"])
trainer._set_dump_file_num(opt_info["dump_file_num"])
trainer._set_dump_converter(opt_info["dump_converter"]) trainer._set_dump_converter(opt_info["dump_converter"])
trainer._set_adjust_ins_weight(opt_info["adjust_ins_weight"]) trainer._set_adjust_ins_weight(opt_info["adjust_ins_weight"])
trainer._set_device_worker(device_worker) trainer._set_device_worker(device_worker)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册