未验证 提交 54d2626a 编写于 作者: Z zmx 提交者: GitHub

[heterps]Refactor heterogenous worker (#37244)

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* refactor heter trainer. test=develop

* fix. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop
上级 0057c12d
...@@ -192,13 +192,24 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler { ...@@ -192,13 +192,24 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler {
virtual ~RequestSendAndRecvHandler() {} virtual ~RequestSendAndRecvHandler() {}
// void SetMiniScopes(SharedMiniScope mini_scopes) { void SetMiniScopes(SharedMiniScope mini_scopes) {
// mini_scopes_ = mini_scopes; mini_scopes_ = mini_scopes;
// num_minibatch_ = mini_scopes_->size(); num_minibatch_ = mini_scopes_->size();
//} }
void SetMicroScopes(SharedMicroScope micro_scopes) { void SetMicroScopes(SharedMicroScope micro_scopes) {
micro_scopes_ = micro_scopes; micro_scopes_ = micro_scopes;
num_microbatch_ = micro_scopes_->size(); for (auto& scope_pair : (*micro_scopes_)) {
// auto mini_idx = scope_pair.first;
auto& micro_scopes = scope_pair.second;
num_microbatch_ = micro_scopes->size();
break;
}
}
int GetThreadNum() {
std::unique_lock<std::mutex> lk(scope_mutex_);
return (*task_queue_).size();
} }
void SetTaskQueue(SharedTaskQueue task_queue) { task_queue_ = task_queue; } void SetTaskQueue(SharedTaskQueue task_queue) { task_queue_ = task_queue; }
...@@ -235,25 +246,43 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler { ...@@ -235,25 +246,43 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler {
int minibatch_index = micro_id / 10; int minibatch_index = micro_id / 10;
int microbatch_index = micro_id % 10; int microbatch_index = micro_id % 10;
// PADDLE_ENFORCE_EQ( // check minibatch_index is in mini_scopes_
// (*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end(), 1, std::unique_lock<std::mutex> lk(scope_mutex_);
// platform::errors::InvalidArgument( if ((*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end()) {
// "minibatch index should in current trainer")); lk.unlock();
PADDLE_ENFORCE_EQ( // PADDLE_ENFORCE_EQ(
(*micro_scopes_).find(minibatch_index) != (*micro_scopes_).end(), 1, // (*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end(), 1,
platform::errors::InvalidArgument( // platform::errors::InvalidArgument(
"minibatch index should in current trainer")); // "minibatch index should in current trainer"));
PADDLE_ENFORCE_EQ(
(*micro_scopes_).find(minibatch_index) != (*micro_scopes_).end(), 1,
platform::errors::InvalidArgument(
"minibatch index should in current trainer"));
} else {
// create mini scope & micro scopes
auto* minibatch_scope = &(scope_->NewScope());
(*mini_scopes_)[minibatch_index] = minibatch_scope;
(*micro_scopes_)[minibatch_index].reset(
new std::vector<paddle::framework::Scope*>{});
for (int i = 0; i < num_microbatch_; i++) {
auto* micro_scope = &(minibatch_scope->NewScope());
(*((*micro_scopes_)[minibatch_index])).push_back(micro_scope);
}
(*task_queue_)[minibatch_index].reset(
new ::paddle::framework::BlockingQueue<
std::pair<std::string, int>>());
lk.unlock();
}
auto* micro_scope = auto* micro_scope =
(*((*micro_scopes_)[minibatch_index]))[microbatch_index]; (*((*micro_scopes_)[minibatch_index]))[microbatch_index];
distributed::DeserializeFromMultiVarMsgAndIOBuf( distributed::DeserializeFromMultiVarMsgAndIOBuf(
*request, &request_io_buffer, *dev_ctx_, micro_scope); *request, &request_io_buffer, *dev_ctx_, micro_scope);
// blocking queue handles multi thread // blocking queue handles multi thread
(*task_queue_)[minibatch_index]->Push( (*task_queue_)[minibatch_index]->Push(
std::make_pair(message_name, microbatch_index)); std::make_pair(message_name, microbatch_index));
auto response_var_nums = request->recv_var_names_size(); auto response_var_nums = request->recv_var_names_size();
std::vector<std::string> response_var_names(response_var_nums), std::vector<std::string> response_var_names(response_var_nums),
empty_var_names{}; empty_var_names{};
...@@ -269,11 +298,12 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler { ...@@ -269,11 +298,12 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler {
private: private:
// share with HeterPipelineTrainer // share with HeterPipelineTrainer
// SharedMiniScope mini_scopes_{nullptr}; SharedMiniScope mini_scopes_{nullptr};
SharedMicroScope micro_scopes_{nullptr}; SharedMicroScope micro_scopes_{nullptr};
int num_microbatch_; int num_microbatch_;
int num_minibatch_; int num_minibatch_;
std::mutex scope_mutex_;
bool is_first_stage_ = false; bool is_first_stage_ = false;
bool is_last_stage_ = false; bool is_last_stage_ = false;
...@@ -321,14 +351,16 @@ class HeterServer { ...@@ -321,14 +351,16 @@ class HeterServer {
request_handler_ = request_handler; request_handler_ = request_handler;
} }
// void SetMiniBatchScopes(SharedMiniScope mini_scopes) { void SetMiniBatchScopes(SharedMiniScope mini_scopes) {
// request_handler_->SetMiniScopes(mini_scopes); request_handler_->SetMiniScopes(mini_scopes);
//} }
void SetMicroBatchScopes(SharedMicroScope micro_scopes) { void SetMicroBatchScopes(SharedMicroScope micro_scopes) {
request_handler_->SetMicroScopes(micro_scopes); request_handler_->SetMicroScopes(micro_scopes);
} }
int GetThreadNum() { return request_handler_->GetThreadNum(); }
void SetTaskQueue(SharedTaskQueue task_queue) { void SetTaskQueue(SharedTaskQueue task_queue) {
request_handler_->SetTaskQueue(task_queue); request_handler_->SetTaskQueue(task_queue);
} }
......
...@@ -631,10 +631,17 @@ class HeterSectionWorker : public DeviceWorker { ...@@ -631,10 +631,17 @@ class HeterSectionWorker : public DeviceWorker {
std::shared_ptr<std::vector<Scope*>> GetMicrobatchScopes() { std::shared_ptr<std::vector<Scope*>> GetMicrobatchScopes() {
return microbatch_scopes_; return microbatch_scopes_;
} }
void SetMicrobatchScopes(
std::shared_ptr<std::vector<Scope*>> microbatch_scopes) {
microbatch_scopes_ = microbatch_scopes;
}
using SHARED_THREAD_QUEUE = std::shared_ptr< using SHARED_THREAD_QUEUE = std::shared_ptr<
::paddle::framework::BlockingQueue<std::pair<std::string, int>>>; ::paddle::framework::BlockingQueue<std::pair<std::string, int>>>;
SHARED_THREAD_QUEUE GetThreadQueue() { return thread_queue_; } SHARED_THREAD_QUEUE GetThreadQueue() { return thread_queue_; }
void SetThreadQueue(SHARED_THREAD_QUEUE thread_queue) {
thread_queue_ = thread_queue;
}
void CopyParameters(int microbatch_id, const ProgramDesc& program, void CopyParameters(int microbatch_id, const ProgramDesc& program,
const platform::Place& place); const platform::Place& place);
void SetMinibatchScope(Scope* scope) { minibatch_scope_ = scope; } void SetMinibatchScope(Scope* scope) { minibatch_scope_ = scope; }
......
...@@ -77,34 +77,51 @@ void HeterPipelineTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -77,34 +77,51 @@ void HeterPipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
trainers_.push_back(trainer_num); trainers_.push_back(trainer_num);
} }
int cpu_trainer_num = trainers_[0]; int cpu_trainer_num = trainers_[0];
int cur_stage_trainer_num = trainers_[pipeline_stage_]; // int cur_stage_trainer_num = trainers_[pipeline_stage_];
int global_thread_num = cpu_trainer_num * thread_num_; // int global_thread_num = cpu_trainer_num * thread_num_;
int previous_trainers = 0; // int previous_trainers = 0;
for (int i = 0; i < pipeline_stage_; i++) previous_trainers += trainers_[i]; // for (int i = 0; i < pipeline_stage_; i++) previous_trainers +=
int stage_trainer_id = // trainers_[i];
trainer_id_ - previous_trainers; // trainer id in current stage // int stage_trainer_id =
int cnt = -1; // trainer_id_ - previous_trainers; // trainer id in current stage
for (int i = stage_trainer_id; i < global_thread_num;
i += cur_stage_trainer_num) { if (pipeline_stage_ == 0) { // for cpu trainer
cnt++; int cnt = -1;
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( int real_thread_id = trainer_id_;
for (int i = 0; i < thread_num_; i++) {
cnt++;
workers_[real_thread_id] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
auto this_worker =
std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
workers_[real_thread_id]);
this_worker->SetDebug(debug_);
this_worker->SetNeedDumpField(need_dump_field_);
this_worker->SetNeedDumpParam(need_dump_param_);
this_worker->SetDumpFieldVector(dump_fields_);
this_worker->SetDumpParamVector(dump_param_);
this_worker->InitRandomDumpConfig(trainer_desc);
this_worker->SetDeviceIndex(real_thread_id);
real_thread_id += cpu_trainer_num;
// if (pipeline_stage_ == 0) {
this_worker->SetDataFeed(readers[cnt]);
//}
this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
}
} else { // for heter_trainer
// heter trainer with thread_id == -1 is not for
// real training
workers_[-1] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name()); trainer_desc.device_worker_name());
auto this_worker = auto this_worker =
std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>( std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
workers_[i]); workers_[-1]);
this_worker->SetDebug(debug_);
this_worker->SetNeedDumpField(need_dump_field_);
this_worker->SetNeedDumpParam(need_dump_param_);
this_worker->SetDumpFieldVector(dump_fields_);
this_worker->SetDumpParamVector(dump_param_);
this_worker->InitRandomDumpConfig(trainer_desc);
this_worker->SetDeviceIndex(i);
if (pipeline_stage_ == 0) {
this_worker->SetDataFeed(readers[cnt]);
}
this_worker->SetMicrobatchNum(num_microbatches_); this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_); this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_); this_worker->SetPipelineStage(pipeline_stage_);
this_worker->SetDeviceIndex(-1);
} }
} }
...@@ -177,7 +194,7 @@ void HeterPipelineTrainer::Run() { ...@@ -177,7 +194,7 @@ void HeterPipelineTrainer::Run() {
} }
auto heter_server = paddle::distributed::HeterServer::GetInstance(); auto heter_server = paddle::distributed::HeterServer::GetInstance();
heter_server->WaitServerReady(); heter_server->WaitServerReady();
// heter_server->SetMiniBatchScopes(mini_scopes_); heter_server->SetMiniBatchScopes(mini_scopes_);
heter_server->SetMicroBatchScopes(micro_scopes_); heter_server->SetMicroBatchScopes(micro_scopes_);
heter_server->SetTaskQueue(task_queue_); heter_server->SetTaskQueue(task_queue_);
// main training logic // main training logic
...@@ -193,6 +210,7 @@ void HeterPipelineTrainer::Run() { ...@@ -193,6 +210,7 @@ void HeterPipelineTrainer::Run() {
} }
} }
} else { // for heter worker } else { // for heter worker
// start thread_worker with thread_id = -1
for (auto& worker_pair : workers_) { for (auto& worker_pair : workers_) {
auto device_worker = worker_pair.second; auto device_worker = worker_pair.second;
if (!debug_) { if (!debug_) {
...@@ -203,6 +221,60 @@ void HeterPipelineTrainer::Run() { ...@@ -203,6 +221,60 @@ void HeterPipelineTrainer::Run() {
device_worker.get())); device_worker.get()));
} }
} }
bool epoch_finish = false;
auto heter_server = paddle::distributed::HeterServer::GetInstance();
while (!epoch_finish) {
if (heter_server->IsStop()) {
epoch_finish = true;
continue;
}
// create new thread_worker
// size_t thread_num = (*micro_scopes_).size();
// size_t thread_num = (*task_queue_).size();
size_t thread_num = heter_server->GetThreadNum();
while (thread_num > threads_.size()) {
for (auto& worker_pair : (*micro_scopes_)) {
auto worker_index = worker_pair.first;
if (workers_.find(worker_index) != workers_.end()) continue;
workers_[worker_index] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc_.device_worker_name());
auto this_worker =
std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
workers_[worker_index]);
this_worker->SetDebug(debug_);
this_worker->SetNeedDumpField(need_dump_field_);
this_worker->SetNeedDumpParam(need_dump_param_);
this_worker->SetDumpFieldVector(dump_fields_);
this_worker->SetDumpParamVector(dump_param_);
this_worker->InitRandomDumpConfig(trainer_desc_);
this_worker->SetDeviceIndex(worker_index);
this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
this_worker->SetPlace(place_);
this_worker->Initialize(trainer_desc_);
this_worker->SetRootScope(root_scope_);
// generate mini_batch scope for every worker
// auto* minibatch_scope = &root_scope_->NewScope();
auto* minibatch_scope = (*mini_scopes_)[worker_index];
// (*mini_scopes_)[worker_index] = minibatch_scope;
this_worker->SetMinibatchScope(minibatch_scope);
// after set micro num & mini batch scope
this_worker->SetMicrobatchScopes((*micro_scopes_)[worker_index]);
this_worker->CreateMicrobatchScopes();
// this_worker->SetMicrobatchScopes((*micro_scopes_)[worker_index]);
this_worker->SetThreadQueue((*task_queue_)[worker_index]);
if (!debug_) {
threads_.push_back(
std::thread(&DeviceWorker::TrainFiles, this_worker.get()));
} else {
threads_.push_back(std::thread(
&DeviceWorker::TrainFilesWithProfiler, this_worker.get()));
}
}
}
}
} }
for (auto& th : threads_) { for (auto& th : threads_) {
th.join(); th.join();
...@@ -228,7 +300,11 @@ void HeterPipelineTrainer::Finalize() { ...@@ -228,7 +300,11 @@ void HeterPipelineTrainer::Finalize() {
} }
Scope* HeterPipelineTrainer::GetWorkerScope(int thread_id) { Scope* HeterPipelineTrainer::GetWorkerScope(int thread_id) {
return workers_[thread_id]->GetThreadScope(); if (workers_.find(thread_id) != workers_.end()) {
return workers_[thread_id]->GetThreadScope();
} else {
return nullptr;
}
} }
} // end namespace framework } // end namespace framework
......
...@@ -218,15 +218,21 @@ void HeterSectionWorker::CreateMicrobatchScopes() { ...@@ -218,15 +218,21 @@ void HeterSectionWorker::CreateMicrobatchScopes() {
minibatch_scope_, minibatch_scope_,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"minibatch_scope_ can not be nullptr when create MicroBatch Scopes")); "minibatch_scope_ can not be nullptr when create MicroBatch Scopes"));
microbatch_scopes_.reset(new std::vector<paddle::framework::Scope*>{}); if (microbatch_scopes_.get() == nullptr) {
(*microbatch_scopes_).resize(num_microbatches_); microbatch_scopes_.reset(new std::vector<paddle::framework::Scope*>{});
VLOG(3) << "Create microbatch scopes..."; (*microbatch_scopes_).resize(num_microbatches_);
std::shared_ptr<framework::ProgramDesc> program; VLOG(3) << "Create microbatch scopes...";
program.reset(new ProgramDesc( for (int j = 0; j < num_microbatches_; ++j) {
trainer_desc_.heter_section_param().section_config().program_desc())); (*microbatch_scopes_)[j] = &minibatch_scope_->NewScope();
for (int j = 0; j < num_microbatches_; ++j) { }
(*microbatch_scopes_)[j] = &minibatch_scope_->NewScope(); }
CopyParameters(j, *program, place_); if (thread_id_ >= 0) {
std::shared_ptr<framework::ProgramDesc> program;
program.reset(new ProgramDesc(
trainer_desc_.heter_section_param().section_config().program_desc()));
for (int j = 0; j < num_microbatches_; ++j) {
CopyParameters(j, *program, place_);
}
} }
} }
...@@ -258,6 +264,8 @@ void HeterSectionWorker::CopyParameters(int microbatch_id, ...@@ -258,6 +264,8 @@ void HeterSectionWorker::CopyParameters(int microbatch_id,
VLOG(5) << "Create persistable var: " << var->Name() VLOG(5) << "Create persistable var: " << var->Name()
<< ", which pointer is " << ptr; << ", which pointer is " << ptr;
} else if (!var->Persistable()) { } else if (!var->Persistable()) {
if ((*microbatch_scopes_)[microbatch_id]->FindVar(var->Name()) != nullptr)
continue;
auto* ptr = (*microbatch_scopes_)[microbatch_id]->Var(var->Name()); auto* ptr = (*microbatch_scopes_)[microbatch_id]->Var(var->Name());
VLOG(5) << "Create variable " << var->Name() << " for microbatch " VLOG(5) << "Create variable " << var->Name() << " for microbatch "
<< microbatch_id << ", which pointer is " << ptr; << microbatch_id << ", which pointer is " << ptr;
...@@ -359,22 +367,25 @@ void HeterSectionWorker::BatchPostProcess() { ...@@ -359,22 +367,25 @@ void HeterSectionWorker::BatchPostProcess() {
} }
void HeterSectionWorker::TrainFiles() { void HeterSectionWorker::TrainFiles() {
total_ins_num_ = 0; if (thread_id_ >= 0) {
batch_num_ = 0; total_ins_num_ = 0;
platform::SetNumThreads(1); batch_num_ = 0;
timeline_.Start(); platform::SetNumThreads(1);
VLOG(3) << "begin section_worker TrainFiles"; timeline_.Start();
epoch_finish_ = false; VLOG(3) << "begin section_worker TrainFiles";
if (pipeline_stage_ == 0) { epoch_finish_ = false;
device_reader_->Start(); if (pipeline_stage_ == 0) {
} device_reader_->Start();
while (!epoch_finish_) { }
Run(); while (!epoch_finish_) {
dev_ctx_->Wait(); Run();
dev_ctx_->Wait();
}
timeline_.Pause();
VLOG(3) << "worker " << thread_id_ << " train cost "
<< timeline_.ElapsedSec()
<< " seconds, ins_num: " << total_ins_num_;
} }
timeline_.Pause();
VLOG(3) << "worker " << thread_id_ << " train cost " << timeline_.ElapsedSec()
<< " seconds, ins_num: " << total_ins_num_;
} }
void HeterSectionWorker::PrintFetchVars() { void HeterSectionWorker::PrintFetchVars() {
...@@ -406,22 +417,24 @@ void HeterSectionWorker::PrintFetchVars() { ...@@ -406,22 +417,24 @@ void HeterSectionWorker::PrintFetchVars() {
} }
void HeterSectionWorker::TrainFilesWithProfiler() { void HeterSectionWorker::TrainFilesWithProfiler() {
VLOG(3) << "begin section_worker TrainFilesWithProfiler"; if (thread_id_ >= 0) {
batch_num_ = 0; VLOG(3) << "begin section_worker TrainFilesWithProfiler";
epoch_finish_ = false; batch_num_ = 0;
total_ins_num_ = 0; epoch_finish_ = false;
op_name_.clear(); total_ins_num_ = 0;
op_total_time_.clear(); op_name_.clear();
if (pipeline_stage_ == 0) { op_total_time_.clear();
device_reader_->Start(); if (pipeline_stage_ == 0) {
} device_reader_->Start();
while (!epoch_finish_) { }
Run(); while (!epoch_finish_) {
dev_ctx_->Wait(); Run();
if (epoch_finish_) { dev_ctx_->Wait();
// dump param for debug if (epoch_finish_) {
if (need_dump_field_ || need_dump_param_) { // dump param for debug
writer_.Flush(); if (need_dump_field_ || need_dump_param_) {
writer_.Flush();
}
} }
} }
} }
......
...@@ -161,13 +161,20 @@ TEST(HETER_LISTEN_AND_SERV, CPU) { ...@@ -161,13 +161,20 @@ TEST(HETER_LISTEN_AND_SERV, CPU) {
b_rpc_service->WaitServerReady(); b_rpc_service->WaitServerReady();
using MicroScope = using MicroScope =
std::unordered_map<int, std::shared_ptr<std::vector<framework::Scope*>>>; std::unordered_map<int, std::shared_ptr<std::vector<framework::Scope*>>>;
using MiniScope = std::unordered_map<int, framework::Scope*>;
std::shared_ptr<MiniScope> mini_scopes(new MiniScope{});
std::shared_ptr<MicroScope> micro_scopes(new MicroScope{}); std::shared_ptr<MicroScope> micro_scopes(new MicroScope{});
std::shared_ptr<std::vector<framework::Scope*>> micro_scope( std::shared_ptr<std::vector<framework::Scope*>> micro_scope(
new std::vector<framework::Scope*>{}); new std::vector<framework::Scope*>{});
(*micro_scope).push_back(new framework::Scope()); auto* mini_scope = new framework::Scope();
(*micro_scope).push_back(new framework::Scope()); (*mini_scopes)[0] = mini_scope;
auto* micro_scope_0 = &(mini_scope->NewScope());
auto* micro_scope_1 = &(mini_scope->NewScope());
(*micro_scope).push_back(micro_scope_0);
(*micro_scope).push_back(micro_scope_1);
(*micro_scopes)[0] = micro_scope; (*micro_scopes)[0] = micro_scope;
b_rpc_service->SetMicroBatchScopes(micro_scopes); b_rpc_service->SetMicroBatchScopes(micro_scopes);
b_rpc_service->SetMiniBatchScopes(mini_scopes);
using TaskQueue = using TaskQueue =
std::unordered_map<int, std::unordered_map<int,
......
...@@ -194,13 +194,20 @@ TEST(SENDANDRECV, CPU) { ...@@ -194,13 +194,20 @@ TEST(SENDANDRECV, CPU) {
b_rpc_service->WaitServerReady(); b_rpc_service->WaitServerReady();
using MicroScope = using MicroScope =
std::unordered_map<int, std::shared_ptr<std::vector<framework::Scope*>>>; std::unordered_map<int, std::shared_ptr<std::vector<framework::Scope*>>>;
using MiniScope = std::unordered_map<int, framework::Scope*>;
std::shared_ptr<MiniScope> mini_scopes(new MiniScope{});
std::shared_ptr<MicroScope> micro_scopes(new MicroScope{}); std::shared_ptr<MicroScope> micro_scopes(new MicroScope{});
std::shared_ptr<std::vector<framework::Scope*>> micro_scope( std::shared_ptr<std::vector<framework::Scope*>> micro_scope(
new std::vector<framework::Scope*>{}); new std::vector<framework::Scope*>{});
(*micro_scope).push_back(new framework::Scope()); auto* mini_scope = new framework::Scope();
(*micro_scope).push_back(new framework::Scope()); (*mini_scopes)[0] = mini_scope;
auto* micro_scope_0 = &(mini_scope->NewScope());
auto* micro_scope_1 = &(mini_scope->NewScope());
(*micro_scope).push_back(micro_scope_0);
(*micro_scope).push_back(micro_scope_1);
(*micro_scopes)[0] = micro_scope; (*micro_scopes)[0] = micro_scope;
b_rpc_service->SetMicroBatchScopes(micro_scopes); b_rpc_service->SetMicroBatchScopes(micro_scopes);
b_rpc_service->SetMiniBatchScopes(mini_scopes);
using TaskQueue = using TaskQueue =
std::unordered_map<int, std::unordered_map<int,
......
...@@ -167,12 +167,18 @@ TEST(SENDANDRECV, CPU) { ...@@ -167,12 +167,18 @@ TEST(SENDANDRECV, CPU) {
b_rpc_service->WaitServerReady(); b_rpc_service->WaitServerReady();
using MicroScope = using MicroScope =
std::unordered_map<int, std::shared_ptr<std::vector<framework::Scope*>>>; std::unordered_map<int, std::shared_ptr<std::vector<framework::Scope*>>>;
using MiniScope = std::unordered_map<int, framework::Scope*>;
std::shared_ptr<MiniScope> mini_scopes(new MiniScope{});
std::shared_ptr<MicroScope> micro_scopes(new MicroScope{}); std::shared_ptr<MicroScope> micro_scopes(new MicroScope{});
auto* mini_scope = new framework::Scope();
(*mini_scopes)[0] = mini_scope;
std::shared_ptr<std::vector<framework::Scope*>> micro_scope( std::shared_ptr<std::vector<framework::Scope*>> micro_scope(
new std::vector<framework::Scope*>{}); new std::vector<framework::Scope*>{});
(*micro_scope).push_back(new framework::Scope()); auto* micro_scope_0 = &(mini_scope->NewScope());
(*micro_scope).push_back(micro_scope_0);
(*micro_scopes)[0] = micro_scope; (*micro_scopes)[0] = micro_scope;
b_rpc_service->SetMicroBatchScopes(micro_scopes); b_rpc_service->SetMicroBatchScopes(micro_scopes);
b_rpc_service->SetMiniBatchScopes(mini_scopes);
using TaskQueue = using TaskQueue =
std::unordered_map<int, std::unordered_map<int,
......
...@@ -187,12 +187,18 @@ TEST(SENDANDRECV, GPU) { ...@@ -187,12 +187,18 @@ TEST(SENDANDRECV, GPU) {
b_rpc_service2->WaitServerReady(); b_rpc_service2->WaitServerReady();
using MicroScope = using MicroScope =
std::unordered_map<int, std::shared_ptr<std::vector<framework::Scope*>>>; std::unordered_map<int, std::shared_ptr<std::vector<framework::Scope*>>>;
using MiniScope = std::unordered_map<int, framework::Scope*>;
std::shared_ptr<MiniScope> mini_scopes(new MiniScope{});
std::shared_ptr<MicroScope> micro_scopes(new MicroScope{}); std::shared_ptr<MicroScope> micro_scopes(new MicroScope{});
auto* mini_scope = new framework::Scope();
(*mini_scopes)[0] = mini_scope;
std::shared_ptr<std::vector<framework::Scope*>> micro_scope( std::shared_ptr<std::vector<framework::Scope*>> micro_scope(
new std::vector<framework::Scope*>{}); new std::vector<framework::Scope*>{});
(*micro_scope).push_back(new framework::Scope()); auto* micro_scope_0 = &(mini_scope->NewScope());
(*micro_scope).push_back(micro_scope_0);
(*micro_scopes)[0] = micro_scope; (*micro_scopes)[0] = micro_scope;
b_rpc_service2->SetMicroBatchScopes(micro_scopes); b_rpc_service2->SetMicroBatchScopes(micro_scopes);
b_rpc_service2->SetMiniBatchScopes(mini_scopes);
using TaskQueue = using TaskQueue =
std::unordered_map<int, std::unordered_map<int,
......
...@@ -898,9 +898,10 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -898,9 +898,10 @@ class TheOnePSRuntime(RuntimeBase):
print_period=100, print_period=100,
fetch_handler=None): fetch_handler=None):
executor = self._get_executor() executor = self._get_executor()
# dataset is not needed for heter worker
executor.train_from_dataset( executor.train_from_dataset(
program=fluid.default_main_program(), program=fluid.default_main_program(),
dataset=dataset, dataset=None,
debug=debug, debug=debug,
fetch_list=fetch_list, fetch_list=fetch_list,
fetch_info=fetch_info, fetch_info=fetch_info,
......
...@@ -1672,6 +1672,33 @@ class Executor(object): ...@@ -1672,6 +1672,33 @@ class Executor(object):
dataset.set_thread(1) dataset.set_thread(1)
dataset.set_filelist(['None']) dataset.set_filelist(['None'])
dataset.set_use_var(data_vars) dataset.set_use_var(data_vars)
elif program._heter_pipeline_opt is not None:
stage_id = program._heter_pipeline_opt["pipeline_stage"]
if stage_id != 0:
import paddle
if dataset is not None:
raise RuntimeError(
"dataset should be None for heter pipeline mode")
# The following fake dataset is created to call
# the _prepare_trainer api, and it is meaningless.
data_vars = []
for var in program.global_block().vars.values():
if var.is_data:
data_vars.append(var)
if core.is_compiled_with_npu():
dataset = paddle.fluid.DatasetFactory().create_dataset(
'InMemoryDataset')
else:
dataset = paddle.fluid.DatasetFactory().create_dataset(
'FileInstantDataset')
dataset.set_batch_size(1)
dataset.set_thread(1)
dataset.set_filelist(['None'])
dataset.set_use_var(data_vars)
else:
if dataset is None:
raise RuntimeError(
"dataset is need and should be initialized")
else: else:
if dataset is None: if dataset is None:
raise RuntimeError("dataset is need and should be initialized") raise RuntimeError("dataset is need and should be initialized")
......
...@@ -156,12 +156,7 @@ class TestHeterPipelinePsCTR2x2(FleetDistHeterRunnerBase): ...@@ -156,12 +156,7 @@ class TestHeterPipelinePsCTR2x2(FleetDistHeterRunnerBase):
thread_num = int(os.getenv("CPU_NUM", 2)) thread_num = int(os.getenv("CPU_NUM", 2))
batch_size = 128 batch_size = 128
block_size = len(train_file_list) // fleet.worker_num() filelist = fleet.util.get_file_shard(train_file_list)
worker_id = fleet.worker_index()
filelist = train_file_list[worker_id * block_size:(worker_id + 1) *
block_size]
#filelist = fleet.util.get_file_shard(train_file_list)
print("filelist: {}".format(filelist)) print("filelist: {}".format(filelist))
# config dataset # config dataset
...@@ -195,31 +190,12 @@ class TestHeterPipelinePsCTR2x2(FleetDistHeterRunnerBase): ...@@ -195,31 +190,12 @@ class TestHeterPipelinePsCTR2x2(FleetDistHeterRunnerBase):
"section_program"] "section_program"]
print(real_program) print(real_program)
train_file_list = ctr_dataset_reader.prepare_fake_data()
#exe = fluid.Executor(fluid.CPUPlace())
#exe.run(fluid.default_startup_program())
#fleet.init_worker()
thread_num = int(os.getenv("CPU_NUM", 2)) thread_num = int(os.getenv("CPU_NUM", 2))
batch_size = 128 batch_size = 128
#filelist = fleet.util.get_file_shard(train_file_list) pass_start = time.time()
block_size = len(train_file_list) // fleet.worker_num() fleet.run_heter_worker(dataset=None)
filelist = train_file_list[0:block_size] pass_time = time.time() - pass_start
print("filelist: {}".format(filelist))
# config dataset
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_batch_size(batch_size)
dataset.set_use_var(self.feeds)
pipe_command = 'python3 ctr_dataset_reader.py'
dataset.set_pipe_command(pipe_command)
dataset.set_filelist(filelist)
dataset.set_thread(thread_num)
fleet.run_heter_worker(dataset)
print("do_dataset_heter_training done. using time {}".format(pass_time)) print("do_dataset_heter_training done. using time {}".format(pass_time))
#for epoch_id in range(1): #for epoch_id in range(1):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册