未验证 提交 3789a699 编写于 作者: T Thunderbrook 提交者: GitHub

solve bug in heter mode (#31531)

* heter bug

* format

* format
上级 6148b87f
...@@ -168,6 +168,7 @@ class DeviceWorker { ...@@ -168,6 +168,7 @@ class DeviceWorker {
virtual void CacheProgram(const ProgramDesc& main_program) {} virtual void CacheProgram(const ProgramDesc& main_program) {}
virtual void ProduceTasks() {} virtual void ProduceTasks() {}
virtual void GetXpuOpIndex() {} virtual void GetXpuOpIndex() {}
virtual void Schedule(int taskid) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
virtual void SetStream(const gpuStream_t stream) {} virtual void SetStream(const gpuStream_t stream) {}
virtual void SetEvent(const gpuEvent_t event) {} virtual void SetEvent(const gpuEvent_t event) {}
......
...@@ -62,9 +62,8 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc, ...@@ -62,9 +62,8 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc,
void DistMultiTrainer::RegisterHeterCallback() { void DistMultiTrainer::RegisterHeterCallback() {
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->RegisterHeterCallback([this](int worker, int taskid) { fleet_ptr->RegisterHeterCallback(
// workers_[worker]->Schedule(taskid); [this](int worker, int taskid) { workers_[worker]->Schedule(taskid); });
});
} }
void DistMultiTrainer::InitDumpEnv() { void DistMultiTrainer::InitDumpEnv() {
......
...@@ -193,7 +193,6 @@ void FleetWrapper::HeterPullSparseVars( ...@@ -193,7 +193,6 @@ void FleetWrapper::HeterPullSparseVars(
for (auto& t : fea_values) { for (auto& t : fea_values) {
pull_result_ptr.push_back(t.data()); pull_result_ptr.push_back(t.data());
} }
/*
auto status = pslib_ptr_->_worker_ptr->heter_pull_sparse( auto status = pslib_ptr_->_worker_ptr->heter_pull_sparse(
workerid, pull_result_ptr.data(), table_id, fea_keys.data(), workerid, pull_result_ptr.data(), table_id, fea_keys.data(),
fea_keys.size(), task->taskid_); fea_keys.size(), task->taskid_);
...@@ -207,7 +206,6 @@ void FleetWrapper::HeterPullSparseVars( ...@@ -207,7 +206,6 @@ void FleetWrapper::HeterPullSparseVars(
exit(-1); exit(-1);
} }
} }
*/
} }
void FleetWrapper::HeterPushSparseVars( void FleetWrapper::HeterPushSparseVars(
......
...@@ -1039,11 +1039,17 @@ class HeterRoleMaker(GeneralRoleMaker): ...@@ -1039,11 +1039,17 @@ class HeterRoleMaker(GeneralRoleMaker):
self._node_type = 1 self._node_type = 1
self._cur_endpoint = worker_endpoints[current_id] self._cur_endpoint = worker_endpoints[current_id]
gloo = fluid.core.Gloo() gloo = fluid.core.Gloo()
gloo.init(current_id,
len(worker_endpoints), gloo.set_rank(current_id)
self._hdfs_path.rstrip("/") + "/trainer", gloo.set_size(len(worker_endpoints))
self._hdfs_name, self._hdfs_ugi, self._iface, gloo.set_prefix(self._prefix)
self._prefix) gloo.set_iface(self._iface)
gloo.set_timeout_seconds(self._init_timeout_seconds,
self._run_timeout_seconds)
gloo.set_hdfs_store(
self._hdfs_path.rstrip("/") + "/trainer", self._hdfs_name,
self._hdfs_ugi)
gloo.init()
self._node_type_comm = gloo self._node_type_comm = gloo
elif training_role == "XPU": elif training_role == "XPU":
role = Role.XPU role = Role.XPU
...@@ -1051,10 +1057,17 @@ class HeterRoleMaker(GeneralRoleMaker): ...@@ -1051,10 +1057,17 @@ class HeterRoleMaker(GeneralRoleMaker):
self._node_type = 2 self._node_type = 2
self._cur_endpoint = xpu_endpoints[current_id] self._cur_endpoint = xpu_endpoints[current_id]
gloo = fluid.core.Gloo() gloo = fluid.core.Gloo()
gloo.init(current_id,
len(xpu_endpoints), gloo.set_rank(current_id)
gloo.set_size(len(xpu_endpoints))
gloo.set_prefix(self._prefix)
gloo.set_iface(self._iface)
gloo.set_timeout_seconds(self._init_timeout_seconds,
self._run_timeout_seconds)
gloo.set_hdfs_store(
self._hdfs_path.rstrip("/") + "/xpu", self._hdfs_name, self._hdfs_path.rstrip("/") + "/xpu", self._hdfs_name,
self._hdfs_ugi, self._iface, self._prefix) self._hdfs_ugi)
gloo.init()
self._node_type_comm = gloo self._node_type_comm = gloo
elif training_role == "PSERVER": elif training_role == "PSERVER":
role = Role.SERVER role = Role.SERVER
...@@ -1070,30 +1083,47 @@ class HeterRoleMaker(GeneralRoleMaker): ...@@ -1070,30 +1083,47 @@ class HeterRoleMaker(GeneralRoleMaker):
self._node_type = 0 self._node_type = 0
self._cur_endpoint = cur_endpoint self._cur_endpoint = cur_endpoint
gloo = fluid.core.Gloo() gloo = fluid.core.Gloo()
gloo.init(current_id, gloo.set_rank(current_id)
len(eplist), gloo.set_size(len(eplist))
self._hdfs_path.rstrip("/") + "/pserver", gloo.set_prefix(self._prefix)
self._hdfs_name, self._hdfs_ugi, self._iface, gloo.set_iface(self._iface)
self._prefix) gloo.set_timeout_seconds(self._init_timeout_seconds,
self._run_timeout_seconds)
gloo.set_hdfs_store(
self._hdfs_path.rstrip("/") + "/pserver", self._hdfs_name,
self._hdfs_ugi)
gloo.init()
self._node_type_comm = gloo self._node_type_comm = gloo
if training_role == "TRAINER" or training_role == "XPU": if training_role == "TRAINER" or training_role == "XPU":
gloo = fluid.core.Gloo() gloo = fluid.core.Gloo()
heter_list = worker_endpoints + xpu_endpoints heter_list = worker_endpoints + xpu_endpoints
gloo.init(
heter_list.index(self._cur_endpoint), gloo.set_rank(heter_list.index(self._cur_endpoint))
len(heter_list), gloo.set_size(len(heter_list))
gloo.set_prefix(self._prefix)
gloo.set_iface(self._iface)
gloo.set_timeout_seconds(self._init_timeout_seconds,
self._run_timeout_seconds)
gloo.set_hdfs_store(
self._hdfs_path.rstrip("/") + "/heter", self._hdfs_name, self._hdfs_path.rstrip("/") + "/heter", self._hdfs_name,
self._hdfs_ugi, self._iface, self._prefix) self._hdfs_ugi)
gloo.init()
self._heter_comm = gloo self._heter_comm = gloo
gloo = fluid.core.Gloo() gloo = fluid.core.Gloo()
all_list = worker_endpoints + eplist + xpu_endpoints all_list = worker_endpoints + eplist + xpu_endpoints
gloo.init(
all_list.index(self._cur_endpoint), gloo.set_rank(all_list.index(self._cur_endpoint))
len(all_list), gloo.set_size(len(all_list))
gloo.set_prefix(self._prefix)
gloo.set_iface(self._iface)
gloo.set_timeout_seconds(self._init_timeout_seconds,
self._run_timeout_seconds)
gloo.set_hdfs_store(
self._hdfs_path.rstrip("/") + "/all", self._hdfs_name, self._hdfs_path.rstrip("/") + "/all", self._hdfs_name,
self._hdfs_ugi, self._iface, self._prefix) self._hdfs_ugi)
gloo.init()
self._all_comm = gloo self._all_comm = gloo
self._trainers_num = trainers_num self._trainers_num = trainers_num
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册