未验证 提交 54d2626a 编写于 作者: Z zmx 提交者: GitHub

[heterps]Refactor heterogenous worker (#37244)

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* refactor heter trainer. test=develop

* fix. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop
上级 0057c12d
......@@ -192,13 +192,24 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler {
virtual ~RequestSendAndRecvHandler() {}
// void SetMiniScopes(SharedMiniScope mini_scopes) {
// mini_scopes_ = mini_scopes;
// num_minibatch_ = mini_scopes_->size();
//}
void SetMiniScopes(SharedMiniScope mini_scopes) {
mini_scopes_ = mini_scopes;
num_minibatch_ = mini_scopes_->size();
}
void SetMicroScopes(SharedMicroScope micro_scopes) {
micro_scopes_ = micro_scopes;
num_microbatch_ = micro_scopes_->size();
for (auto& scope_pair : (*micro_scopes_)) {
// auto mini_idx = scope_pair.first;
auto& micro_scopes = scope_pair.second;
num_microbatch_ = micro_scopes->size();
break;
}
}
int GetThreadNum() {
std::unique_lock<std::mutex> lk(scope_mutex_);
return (*task_queue_).size();
}
void SetTaskQueue(SharedTaskQueue task_queue) { task_queue_ = task_queue; }
......@@ -235,25 +246,43 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler {
int minibatch_index = micro_id / 10;
int microbatch_index = micro_id % 10;
// PADDLE_ENFORCE_EQ(
// (*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end(), 1,
// platform::errors::InvalidArgument(
// "minibatch index should in current trainer"));
PADDLE_ENFORCE_EQ(
(*micro_scopes_).find(minibatch_index) != (*micro_scopes_).end(), 1,
platform::errors::InvalidArgument(
"minibatch index should in current trainer"));
// check minibatch_index is in mini_scopes_
std::unique_lock<std::mutex> lk(scope_mutex_);
if ((*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end()) {
lk.unlock();
// PADDLE_ENFORCE_EQ(
// (*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end(), 1,
// platform::errors::InvalidArgument(
// "minibatch index should in current trainer"));
PADDLE_ENFORCE_EQ(
(*micro_scopes_).find(minibatch_index) != (*micro_scopes_).end(), 1,
platform::errors::InvalidArgument(
"minibatch index should in current trainer"));
} else {
// create mini scope & micro scopes
auto* minibatch_scope = &(scope_->NewScope());
(*mini_scopes_)[minibatch_index] = minibatch_scope;
(*micro_scopes_)[minibatch_index].reset(
new std::vector<paddle::framework::Scope*>{});
for (int i = 0; i < num_microbatch_; i++) {
auto* micro_scope = &(minibatch_scope->NewScope());
(*((*micro_scopes_)[minibatch_index])).push_back(micro_scope);
}
(*task_queue_)[minibatch_index].reset(
new ::paddle::framework::BlockingQueue<
std::pair<std::string, int>>());
lk.unlock();
}
auto* micro_scope =
(*((*micro_scopes_)[minibatch_index]))[microbatch_index];
distributed::DeserializeFromMultiVarMsgAndIOBuf(
*request, &request_io_buffer, *dev_ctx_, micro_scope);
// blocking queue handles multi thread
(*task_queue_)[minibatch_index]->Push(
std::make_pair(message_name, microbatch_index));
auto response_var_nums = request->recv_var_names_size();
std::vector<std::string> response_var_names(response_var_nums),
empty_var_names{};
......@@ -269,11 +298,12 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler {
private:
// share with HeterPipelineTrainer
// SharedMiniScope mini_scopes_{nullptr};
SharedMiniScope mini_scopes_{nullptr};
SharedMicroScope micro_scopes_{nullptr};
int num_microbatch_;
int num_minibatch_;
std::mutex scope_mutex_;
bool is_first_stage_ = false;
bool is_last_stage_ = false;
......@@ -321,14 +351,16 @@ class HeterServer {
request_handler_ = request_handler;
}
// void SetMiniBatchScopes(SharedMiniScope mini_scopes) {
// request_handler_->SetMiniScopes(mini_scopes);
//}
void SetMiniBatchScopes(SharedMiniScope mini_scopes) {
request_handler_->SetMiniScopes(mini_scopes);
}
void SetMicroBatchScopes(SharedMicroScope micro_scopes) {
request_handler_->SetMicroScopes(micro_scopes);
}
int GetThreadNum() { return request_handler_->GetThreadNum(); }
void SetTaskQueue(SharedTaskQueue task_queue) {
request_handler_->SetTaskQueue(task_queue);
}
......
......@@ -631,10 +631,17 @@ class HeterSectionWorker : public DeviceWorker {
std::shared_ptr<std::vector<Scope*>> GetMicrobatchScopes() {
return microbatch_scopes_;
}
void SetMicrobatchScopes(
std::shared_ptr<std::vector<Scope*>> microbatch_scopes) {
microbatch_scopes_ = microbatch_scopes;
}
using SHARED_THREAD_QUEUE = std::shared_ptr<
::paddle::framework::BlockingQueue<std::pair<std::string, int>>>;
SHARED_THREAD_QUEUE GetThreadQueue() { return thread_queue_; }
void SetThreadQueue(SHARED_THREAD_QUEUE thread_queue) {
thread_queue_ = thread_queue;
}
void CopyParameters(int microbatch_id, const ProgramDesc& program,
const platform::Place& place);
void SetMinibatchScope(Scope* scope) { minibatch_scope_ = scope; }
......
......@@ -77,34 +77,51 @@ void HeterPipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
trainers_.push_back(trainer_num);
}
int cpu_trainer_num = trainers_[0];
int cur_stage_trainer_num = trainers_[pipeline_stage_];
int global_thread_num = cpu_trainer_num * thread_num_;
int previous_trainers = 0;
for (int i = 0; i < pipeline_stage_; i++) previous_trainers += trainers_[i];
int stage_trainer_id =
trainer_id_ - previous_trainers; // trainer id in current stage
int cnt = -1;
for (int i = stage_trainer_id; i < global_thread_num;
i += cur_stage_trainer_num) {
cnt++;
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
// int cur_stage_trainer_num = trainers_[pipeline_stage_];
// int global_thread_num = cpu_trainer_num * thread_num_;
// int previous_trainers = 0;
// for (int i = 0; i < pipeline_stage_; i++) previous_trainers +=
// trainers_[i];
// int stage_trainer_id =
// trainer_id_ - previous_trainers; // trainer id in current stage
if (pipeline_stage_ == 0) { // for cpu trainer
int cnt = -1;
int real_thread_id = trainer_id_;
for (int i = 0; i < thread_num_; i++) {
cnt++;
workers_[real_thread_id] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
auto this_worker =
std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
workers_[real_thread_id]);
this_worker->SetDebug(debug_);
this_worker->SetNeedDumpField(need_dump_field_);
this_worker->SetNeedDumpParam(need_dump_param_);
this_worker->SetDumpFieldVector(dump_fields_);
this_worker->SetDumpParamVector(dump_param_);
this_worker->InitRandomDumpConfig(trainer_desc);
this_worker->SetDeviceIndex(real_thread_id);
real_thread_id += cpu_trainer_num;
// if (pipeline_stage_ == 0) {
this_worker->SetDataFeed(readers[cnt]);
//}
this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
}
} else { // for heter_trainer
// heter trainer with thread_id == -1 is not for
// real training
workers_[-1] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
auto this_worker =
std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
workers_[i]);
this_worker->SetDebug(debug_);
this_worker->SetNeedDumpField(need_dump_field_);
this_worker->SetNeedDumpParam(need_dump_param_);
this_worker->SetDumpFieldVector(dump_fields_);
this_worker->SetDumpParamVector(dump_param_);
this_worker->InitRandomDumpConfig(trainer_desc);
this_worker->SetDeviceIndex(i);
if (pipeline_stage_ == 0) {
this_worker->SetDataFeed(readers[cnt]);
}
workers_[-1]);
this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
this_worker->SetDeviceIndex(-1);
}
}
......@@ -177,7 +194,7 @@ void HeterPipelineTrainer::Run() {
}
auto heter_server = paddle::distributed::HeterServer::GetInstance();
heter_server->WaitServerReady();
// heter_server->SetMiniBatchScopes(mini_scopes_);
heter_server->SetMiniBatchScopes(mini_scopes_);
heter_server->SetMicroBatchScopes(micro_scopes_);
heter_server->SetTaskQueue(task_queue_);
// main training logic
......@@ -193,6 +210,7 @@ void HeterPipelineTrainer::Run() {
}
}
} else { // for heter worker
// start thread_worker with thread_id = -1
for (auto& worker_pair : workers_) {
auto device_worker = worker_pair.second;
if (!debug_) {
......@@ -203,6 +221,60 @@ void HeterPipelineTrainer::Run() {
device_worker.get()));
}
}
bool epoch_finish = false;
auto heter_server = paddle::distributed::HeterServer::GetInstance();
while (!epoch_finish) {
if (heter_server->IsStop()) {
epoch_finish = true;
continue;
}
// create new thread_worker
// size_t thread_num = (*micro_scopes_).size();
// size_t thread_num = (*task_queue_).size();
size_t thread_num = heter_server->GetThreadNum();
while (thread_num > threads_.size()) {
for (auto& worker_pair : (*micro_scopes_)) {
auto worker_index = worker_pair.first;
if (workers_.find(worker_index) != workers_.end()) continue;
workers_[worker_index] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc_.device_worker_name());
auto this_worker =
std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
workers_[worker_index]);
this_worker->SetDebug(debug_);
this_worker->SetNeedDumpField(need_dump_field_);
this_worker->SetNeedDumpParam(need_dump_param_);
this_worker->SetDumpFieldVector(dump_fields_);
this_worker->SetDumpParamVector(dump_param_);
this_worker->InitRandomDumpConfig(trainer_desc_);
this_worker->SetDeviceIndex(worker_index);
this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
this_worker->SetPlace(place_);
this_worker->Initialize(trainer_desc_);
this_worker->SetRootScope(root_scope_);
// generate mini_batch scope for every worker
// auto* minibatch_scope = &root_scope_->NewScope();
auto* minibatch_scope = (*mini_scopes_)[worker_index];
// (*mini_scopes_)[worker_index] = minibatch_scope;
this_worker->SetMinibatchScope(minibatch_scope);
// after set micro num & mini batch scope
this_worker->SetMicrobatchScopes((*micro_scopes_)[worker_index]);
this_worker->CreateMicrobatchScopes();
// this_worker->SetMicrobatchScopes((*micro_scopes_)[worker_index]);
this_worker->SetThreadQueue((*task_queue_)[worker_index]);
if (!debug_) {
threads_.push_back(
std::thread(&DeviceWorker::TrainFiles, this_worker.get()));
} else {
threads_.push_back(std::thread(
&DeviceWorker::TrainFilesWithProfiler, this_worker.get()));
}
}
}
}
}
for (auto& th : threads_) {
th.join();
......@@ -228,7 +300,11 @@ void HeterPipelineTrainer::Finalize() {
}
Scope* HeterPipelineTrainer::GetWorkerScope(int thread_id) {
return workers_[thread_id]->GetThreadScope();
if (workers_.find(thread_id) != workers_.end()) {
return workers_[thread_id]->GetThreadScope();
} else {
return nullptr;
}
}
} // end namespace framework
......
......@@ -218,15 +218,21 @@ void HeterSectionWorker::CreateMicrobatchScopes() {
minibatch_scope_,
platform::errors::InvalidArgument(
"minibatch_scope_ can not be nullptr when create MicroBatch Scopes"));
microbatch_scopes_.reset(new std::vector<paddle::framework::Scope*>{});
(*microbatch_scopes_).resize(num_microbatches_);
VLOG(3) << "Create microbatch scopes...";
std::shared_ptr<framework::ProgramDesc> program;
program.reset(new ProgramDesc(
trainer_desc_.heter_section_param().section_config().program_desc()));
for (int j = 0; j < num_microbatches_; ++j) {
(*microbatch_scopes_)[j] = &minibatch_scope_->NewScope();
CopyParameters(j, *program, place_);
if (microbatch_scopes_.get() == nullptr) {
microbatch_scopes_.reset(new std::vector<paddle::framework::Scope*>{});
(*microbatch_scopes_).resize(num_microbatches_);
VLOG(3) << "Create microbatch scopes...";
for (int j = 0; j < num_microbatches_; ++j) {
(*microbatch_scopes_)[j] = &minibatch_scope_->NewScope();
}
}
if (thread_id_ >= 0) {
std::shared_ptr<framework::ProgramDesc> program;
program.reset(new ProgramDesc(
trainer_desc_.heter_section_param().section_config().program_desc()));
for (int j = 0; j < num_microbatches_; ++j) {
CopyParameters(j, *program, place_);
}
}
}
......@@ -258,6 +264,8 @@ void HeterSectionWorker::CopyParameters(int microbatch_id,
VLOG(5) << "Create persistable var: " << var->Name()
<< ", which pointer is " << ptr;
} else if (!var->Persistable()) {
if ((*microbatch_scopes_)[microbatch_id]->FindVar(var->Name()) != nullptr)
continue;
auto* ptr = (*microbatch_scopes_)[microbatch_id]->Var(var->Name());
VLOG(5) << "Create variable " << var->Name() << " for microbatch "
<< microbatch_id << ", which pointer is " << ptr;
......@@ -359,22 +367,25 @@ void HeterSectionWorker::BatchPostProcess() {
}
void HeterSectionWorker::TrainFiles() {
total_ins_num_ = 0;
batch_num_ = 0;
platform::SetNumThreads(1);
timeline_.Start();
VLOG(3) << "begin section_worker TrainFiles";
epoch_finish_ = false;
if (pipeline_stage_ == 0) {
device_reader_->Start();
}
while (!epoch_finish_) {
Run();
dev_ctx_->Wait();
if (thread_id_ >= 0) {
total_ins_num_ = 0;
batch_num_ = 0;
platform::SetNumThreads(1);
timeline_.Start();
VLOG(3) << "begin section_worker TrainFiles";
epoch_finish_ = false;
if (pipeline_stage_ == 0) {
device_reader_->Start();
}
while (!epoch_finish_) {
Run();
dev_ctx_->Wait();
}
timeline_.Pause();
VLOG(3) << "worker " << thread_id_ << " train cost "
<< timeline_.ElapsedSec()
<< " seconds, ins_num: " << total_ins_num_;
}
timeline_.Pause();
VLOG(3) << "worker " << thread_id_ << " train cost " << timeline_.ElapsedSec()
<< " seconds, ins_num: " << total_ins_num_;
}
void HeterSectionWorker::PrintFetchVars() {
......@@ -406,22 +417,24 @@ void HeterSectionWorker::PrintFetchVars() {
}
void HeterSectionWorker::TrainFilesWithProfiler() {
VLOG(3) << "begin section_worker TrainFilesWithProfiler";
batch_num_ = 0;
epoch_finish_ = false;
total_ins_num_ = 0;
op_name_.clear();
op_total_time_.clear();
if (pipeline_stage_ == 0) {
device_reader_->Start();
}
while (!epoch_finish_) {
Run();
dev_ctx_->Wait();
if (epoch_finish_) {
// dump param for debug
if (need_dump_field_ || need_dump_param_) {
writer_.Flush();
if (thread_id_ >= 0) {
VLOG(3) << "begin section_worker TrainFilesWithProfiler";
batch_num_ = 0;
epoch_finish_ = false;
total_ins_num_ = 0;
op_name_.clear();
op_total_time_.clear();
if (pipeline_stage_ == 0) {
device_reader_->Start();
}
while (!epoch_finish_) {
Run();
dev_ctx_->Wait();
if (epoch_finish_) {
// dump param for debug
if (need_dump_field_ || need_dump_param_) {
writer_.Flush();
}
}
}
}
......
......@@ -161,13 +161,20 @@ TEST(HETER_LISTEN_AND_SERV, CPU) {
b_rpc_service->WaitServerReady();
using MicroScope =
std::unordered_map<int, std::shared_ptr<std::vector<framework::Scope*>>>;
using MiniScope = std::unordered_map<int, framework::Scope*>;
std::shared_ptr<MiniScope> mini_scopes(new MiniScope{});
std::shared_ptr<MicroScope> micro_scopes(new MicroScope{});
std::shared_ptr<std::vector<framework::Scope*>> micro_scope(
new std::vector<framework::Scope*>{});
(*micro_scope).push_back(new framework::Scope());
(*micro_scope).push_back(new framework::Scope());
auto* mini_scope = new framework::Scope();
(*mini_scopes)[0] = mini_scope;
auto* micro_scope_0 = &(mini_scope->NewScope());
auto* micro_scope_1 = &(mini_scope->NewScope());
(*micro_scope).push_back(micro_scope_0);
(*micro_scope).push_back(micro_scope_1);
(*micro_scopes)[0] = micro_scope;
b_rpc_service->SetMicroBatchScopes(micro_scopes);
b_rpc_service->SetMiniBatchScopes(mini_scopes);
using TaskQueue =
std::unordered_map<int,
......
......@@ -194,13 +194,20 @@ TEST(SENDANDRECV, CPU) {
b_rpc_service->WaitServerReady();
using MicroScope =
std::unordered_map<int, std::shared_ptr<std::vector<framework::Scope*>>>;
using MiniScope = std::unordered_map<int, framework::Scope*>;
std::shared_ptr<MiniScope> mini_scopes(new MiniScope{});
std::shared_ptr<MicroScope> micro_scopes(new MicroScope{});
std::shared_ptr<std::vector<framework::Scope*>> micro_scope(
new std::vector<framework::Scope*>{});
(*micro_scope).push_back(new framework::Scope());
(*micro_scope).push_back(new framework::Scope());
auto* mini_scope = new framework::Scope();
(*mini_scopes)[0] = mini_scope;
auto* micro_scope_0 = &(mini_scope->NewScope());
auto* micro_scope_1 = &(mini_scope->NewScope());
(*micro_scope).push_back(micro_scope_0);
(*micro_scope).push_back(micro_scope_1);
(*micro_scopes)[0] = micro_scope;
b_rpc_service->SetMicroBatchScopes(micro_scopes);
b_rpc_service->SetMiniBatchScopes(mini_scopes);
using TaskQueue =
std::unordered_map<int,
......
......@@ -167,12 +167,18 @@ TEST(SENDANDRECV, CPU) {
b_rpc_service->WaitServerReady();
using MicroScope =
std::unordered_map<int, std::shared_ptr<std::vector<framework::Scope*>>>;
using MiniScope = std::unordered_map<int, framework::Scope*>;
std::shared_ptr<MiniScope> mini_scopes(new MiniScope{});
std::shared_ptr<MicroScope> micro_scopes(new MicroScope{});
auto* mini_scope = new framework::Scope();
(*mini_scopes)[0] = mini_scope;
std::shared_ptr<std::vector<framework::Scope*>> micro_scope(
new std::vector<framework::Scope*>{});
(*micro_scope).push_back(new framework::Scope());
auto* micro_scope_0 = &(mini_scope->NewScope());
(*micro_scope).push_back(micro_scope_0);
(*micro_scopes)[0] = micro_scope;
b_rpc_service->SetMicroBatchScopes(micro_scopes);
b_rpc_service->SetMiniBatchScopes(mini_scopes);
using TaskQueue =
std::unordered_map<int,
......
......@@ -187,12 +187,18 @@ TEST(SENDANDRECV, GPU) {
b_rpc_service2->WaitServerReady();
using MicroScope =
std::unordered_map<int, std::shared_ptr<std::vector<framework::Scope*>>>;
using MiniScope = std::unordered_map<int, framework::Scope*>;
std::shared_ptr<MiniScope> mini_scopes(new MiniScope{});
std::shared_ptr<MicroScope> micro_scopes(new MicroScope{});
auto* mini_scope = new framework::Scope();
(*mini_scopes)[0] = mini_scope;
std::shared_ptr<std::vector<framework::Scope*>> micro_scope(
new std::vector<framework::Scope*>{});
(*micro_scope).push_back(new framework::Scope());
auto* micro_scope_0 = &(mini_scope->NewScope());
(*micro_scope).push_back(micro_scope_0);
(*micro_scopes)[0] = micro_scope;
b_rpc_service2->SetMicroBatchScopes(micro_scopes);
b_rpc_service2->SetMiniBatchScopes(mini_scopes);
using TaskQueue =
std::unordered_map<int,
......
......@@ -898,9 +898,10 @@ class TheOnePSRuntime(RuntimeBase):
print_period=100,
fetch_handler=None):
executor = self._get_executor()
# dataset is not needed for heter worker
executor.train_from_dataset(
program=fluid.default_main_program(),
dataset=dataset,
dataset=None,
debug=debug,
fetch_list=fetch_list,
fetch_info=fetch_info,
......
......@@ -1672,6 +1672,33 @@ class Executor(object):
dataset.set_thread(1)
dataset.set_filelist(['None'])
dataset.set_use_var(data_vars)
elif program._heter_pipeline_opt is not None:
stage_id = program._heter_pipeline_opt["pipeline_stage"]
if stage_id != 0:
import paddle
if dataset is not None:
raise RuntimeError(
"dataset should be None for heter pipeline mode")
# The following fake dataset is created to call
# the _prepare_trainer api, and it is meaningless.
data_vars = []
for var in program.global_block().vars.values():
if var.is_data:
data_vars.append(var)
if core.is_compiled_with_npu():
dataset = paddle.fluid.DatasetFactory().create_dataset(
'InMemoryDataset')
else:
dataset = paddle.fluid.DatasetFactory().create_dataset(
'FileInstantDataset')
dataset.set_batch_size(1)
dataset.set_thread(1)
dataset.set_filelist(['None'])
dataset.set_use_var(data_vars)
else:
if dataset is None:
raise RuntimeError(
"dataset is need and should be initialized")
else:
if dataset is None:
raise RuntimeError("dataset is need and should be initialized")
......
......@@ -156,12 +156,7 @@ class TestHeterPipelinePsCTR2x2(FleetDistHeterRunnerBase):
thread_num = int(os.getenv("CPU_NUM", 2))
batch_size = 128
block_size = len(train_file_list) // fleet.worker_num()
worker_id = fleet.worker_index()
filelist = train_file_list[worker_id * block_size:(worker_id + 1) *
block_size]
#filelist = fleet.util.get_file_shard(train_file_list)
filelist = fleet.util.get_file_shard(train_file_list)
print("filelist: {}".format(filelist))
# config dataset
......@@ -195,31 +190,12 @@ class TestHeterPipelinePsCTR2x2(FleetDistHeterRunnerBase):
"section_program"]
print(real_program)
train_file_list = ctr_dataset_reader.prepare_fake_data()
#exe = fluid.Executor(fluid.CPUPlace())
#exe.run(fluid.default_startup_program())
#fleet.init_worker()
thread_num = int(os.getenv("CPU_NUM", 2))
batch_size = 128
#filelist = fleet.util.get_file_shard(train_file_list)
block_size = len(train_file_list) // fleet.worker_num()
filelist = train_file_list[0:block_size]
print("filelist: {}".format(filelist))
# config dataset
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_batch_size(batch_size)
dataset.set_use_var(self.feeds)
pipe_command = 'python3 ctr_dataset_reader.py'
dataset.set_pipe_command(pipe_command)
dataset.set_filelist(filelist)
dataset.set_thread(thread_num)
fleet.run_heter_worker(dataset)
pass_start = time.time()
fleet.run_heter_worker(dataset=None)
pass_time = time.time() - pass_start
print("do_dataset_heter_training done. using time {}".format(pass_time))
#for epoch_id in range(1):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册