未验证 提交 45d7a3ea 编写于 作者: D danleifeng 提交者: GitHub

[GPUPS]fix gpups pscore (#42967)

上级 b6859054
...@@ -219,6 +219,10 @@ void HogwildWorker::TrainFiles() { ...@@ -219,6 +219,10 @@ void HogwildWorker::TrainFiles() {
device_reader_->Start(); device_reader_->Start();
int cur_batch; int cur_batch;
int batch_cnt = 0; int batch_cnt = 0;
#if defined(PADDLE_WITH_HETERPS) && defined(PADDLE_WITH_CUDA)
platform::SetDeviceId(thread_id_);
#endif
while ((cur_batch = device_reader_->Next()) > 0) { while ((cur_batch = device_reader_->Next()) > 0) {
for (auto &op : ops_) { for (auto &op : ops_) {
bool need_skip = false; bool need_skip = false;
...@@ -244,9 +248,12 @@ void HogwildWorker::TrainFiles() { ...@@ -244,9 +248,12 @@ void HogwildWorker::TrainFiles() {
++batch_cnt; ++batch_cnt;
PrintFetchVars(); PrintFetchVars();
thread_scope_->DropKids(); thread_scope_->DropKids();
#ifdef PADDLE_WITH_HETERPS
dev_ctx_->Wait();
#endif
} }
timeline.Pause(); timeline.Pause();
VLOG(3) << "worker " << thread_id_ << " train cost " << timeline.ElapsedSec() VLOG(1) << "worker " << thread_id_ << " train cost " << timeline.ElapsedSec()
<< " seconds, ins_num: " << total_ins_num; << " seconds, ins_num: " << total_ins_num;
if (need_dump_field_ || need_dump_param_) { if (need_dump_field_ || need_dump_param_) {
......
...@@ -148,6 +148,17 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program, ...@@ -148,6 +148,17 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
} }
} }
#endif #endif
for (auto& var : main_program.Block(0).AllVars()) {
if (var->Persistable()) {
auto it = std::find(need_merge_var_names_.begin(),
need_merge_var_names_.end(), var->Name());
if (it == need_merge_var_names_.end() &&
var->GetType() != proto::VarType::SELECTED_ROWS) {
VLOG(2) << "train param: " << var->Name();
trainable_param_.push_back(var->Name());
}
}
}
} }
void MultiTrainer::InitOtherEnv(const ProgramDesc& main_program) { void MultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
...@@ -192,18 +203,30 @@ void MultiTrainer::Run() { ...@@ -192,18 +203,30 @@ void MultiTrainer::Run() {
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
void MultiTrainer::MergeDenseParam() { void MultiTrainer::MergeDenseParam() {
#ifdef PADDLE_WTIH_PSCORE #ifdef PADDLE_WITH_PSCORE
auto communicator = paddle::distributed::Communicator::GetInstance(); auto communicator = paddle::distributed::Communicator::GetInstance();
auto& recv_ctx = communicator->GetRecvCtxMap(); auto thread_scope = workers_[0]->GetThreadScope();
Scope* thread_scope = workers_[0]->GetThreadScope(); if (communicator == nullptr) {
for (auto& iter : recv_ctx) { for (auto& name : trainable_param_) {
auto& varnames = iter.second; VLOG(2) << "merge var " << name << " to root scope";
for (auto& name : varnames) {
Variable* root_var = root_scope_->FindVar(name); Variable* root_var = root_scope_->FindVar(name);
LoDTensor* root_tensor = root_var->GetMutable<LoDTensor>(); LoDTensor* root_tensor = root_var->GetMutable<LoDTensor>();
Variable* var = thread_scope->FindVar(name); Variable* var = thread_scope->FindVar(name);
LoDTensor* tensor = var->GetMutable<LoDTensor>(); LoDTensor* tensor = var->GetMutable<LoDTensor>();
TensorCopy((*tensor), root_tensor->place(), root_tensor); TensorCopySync((*tensor), root_tensor->place(), root_tensor);
}
} else {
auto& recv_ctx = communicator->GetRecvCtxMap();
for (auto& iter : recv_ctx) {
auto& varnames = iter.second;
for (auto& name : varnames) {
VLOG(2) << "merge var " << name << " to root scope";
Variable* root_var = root_scope_->FindVar(name);
LoDTensor* root_tensor = root_var->GetMutable<LoDTensor>();
Variable* var = thread_scope->FindVar(name);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
TensorCopySync((*tensor), root_tensor->place(), root_tensor);
}
} }
} }
#endif #endif
...@@ -236,11 +259,7 @@ void MultiTrainer::Finalize() { ...@@ -236,11 +259,7 @@ void MultiTrainer::Finalize() {
} }
LoDTensor* root_tensor = root_var->GetMutable<LoDTensor>(); LoDTensor* root_tensor = root_var->GetMutable<LoDTensor>();
#ifdef PADDLE_WITH_HETERPS
for (size_t j = 0; j < places_.size(); j++) {
#else
for (int j = 1; j < thread_num_; j++) { for (int j = 1; j < thread_num_; j++) {
#endif
Scope* cur_thread_scope = workers_[j]->GetThreadScope(); Scope* cur_thread_scope = workers_[j]->GetThreadScope();
Variable* thread_var = Variable* thread_var =
cur_thread_scope->FindVar(need_merge_var_names_[i]); cur_thread_scope->FindVar(need_merge_var_names_[i]);
......
...@@ -129,6 +129,7 @@ class MultiTrainer : public TrainerBase { ...@@ -129,6 +129,7 @@ class MultiTrainer : public TrainerBase {
std::vector<DataFeed*> readers_; std::vector<DataFeed*> readers_;
std::vector<std::shared_ptr<DeviceWorker>> workers_; std::vector<std::shared_ptr<DeviceWorker>> workers_;
std::vector<std::string> need_merge_var_names_; std::vector<std::string> need_merge_var_names_;
std::vector<std::string> trainable_param_;
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
#endif #endif
......
...@@ -614,15 +614,24 @@ class PsGpuPass(PassBase): ...@@ -614,15 +614,24 @@ class PsGpuPass(PassBase):
return True return True
def _add_push_box_sparse_op(self, program): def _add_push_box_sparse_op(self, program):
insert_index = -1
for idx, op in list(enumerate(program.global_block().ops)):
if op.type == "lookup_table_grad":
insert_index = idx
for op in program.global_block().ops: for op in program.global_block().ops:
if op.type != "pull_box_sparse": if op.type != "pull_box_sparse" and op.type != "pull_gpups_sparse":
continue continue
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(set()), []) op.desc, cpt.to_text(set()), [])
for op_desc in grad_op_desc: for op_desc in grad_op_desc:
new_op_desc = program.global_block().desc.append_op() new_op_desc = program.global_block().desc._insert_op(
insert_index + 1)
new_op_desc.copy_from(op_desc) new_op_desc.copy_from(op_desc)
new_op_desc._set_attr(op_role_attr_name, backward) new_op_desc._set_attr(op_role_attr_name, backward)
new_op = paddle.fluid.framework.Operator(program.global_block(),
new_op_desc)
program.global_block().ops.insert(insert_index + 1, new_op)
program.global_block()._sync_with_cpp()
def _remove_optimizer_var(self, program): def _remove_optimizer_var(self, program):
embedding_w = {} embedding_w = {}
......
...@@ -1013,12 +1013,13 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1013,12 +1013,13 @@ class TheOnePSRuntime(RuntimeBase):
if self.context['ps_mode'] == DistributedMode.GEO: if self.context['ps_mode'] == DistributedMode.GEO:
self._communicator.init_params(init_params) self._communicator.init_params(init_params)
else: else:
if role_id == 0: if not self.context['use_ps_gpu']:
self._init_all_params(scopes, send_ctx, dense_map) if role_id == 0:
self._init_all_params(scopes, send_ctx, dense_map)
fleet.util.barrier() fleet.util.barrier()
if not self.context['use_ps_gpu']:
self._pull_all_dense(scopes, send_ctx, dense_map) self._pull_all_dense(scopes, send_ctx, dense_map)
fleet.util.barrier() fleet.util.barrier()
if self.context['ps_mode'] == DistributedMode.GEO: if self.context['ps_mode'] == DistributedMode.GEO:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册