未验证 提交 3789a699 编写于 作者: T Thunderbrook 提交者: GitHub

solve bug in heter mode (#31531)

* heter bug

* format

* format
上级 6148b87f
......@@ -168,6 +168,7 @@ class DeviceWorker {
virtual void CacheProgram(const ProgramDesc& main_program) {}
virtual void ProduceTasks() {}
virtual void GetXpuOpIndex() {}
virtual void Schedule(int taskid) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
virtual void SetStream(const gpuStream_t stream) {}
virtual void SetEvent(const gpuEvent_t event) {}
......
......@@ -62,9 +62,8 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc,
void DistMultiTrainer::RegisterHeterCallback() {
auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->RegisterHeterCallback([this](int worker, int taskid) {
// workers_[worker]->Schedule(taskid);
});
fleet_ptr->RegisterHeterCallback(
[this](int worker, int taskid) { workers_[worker]->Schedule(taskid); });
}
void DistMultiTrainer::InitDumpEnv() {
......
......@@ -193,7 +193,6 @@ void FleetWrapper::HeterPullSparseVars(
for (auto& t : fea_values) {
pull_result_ptr.push_back(t.data());
}
/*
auto status = pslib_ptr_->_worker_ptr->heter_pull_sparse(
workerid, pull_result_ptr.data(), table_id, fea_keys.data(),
fea_keys.size(), task->taskid_);
......@@ -207,7 +206,6 @@ void FleetWrapper::HeterPullSparseVars(
exit(-1);
}
}
*/
}
void FleetWrapper::HeterPushSparseVars(
......
......@@ -1039,11 +1039,17 @@ class HeterRoleMaker(GeneralRoleMaker):
self._node_type = 1
self._cur_endpoint = worker_endpoints[current_id]
gloo = fluid.core.Gloo()
gloo.init(current_id,
len(worker_endpoints),
self._hdfs_path.rstrip("/") + "/trainer",
self._hdfs_name, self._hdfs_ugi, self._iface,
self._prefix)
gloo.set_rank(current_id)
gloo.set_size(len(worker_endpoints))
gloo.set_prefix(self._prefix)
gloo.set_iface(self._iface)
gloo.set_timeout_seconds(self._init_timeout_seconds,
self._run_timeout_seconds)
gloo.set_hdfs_store(
self._hdfs_path.rstrip("/") + "/trainer", self._hdfs_name,
self._hdfs_ugi)
gloo.init()
self._node_type_comm = gloo
elif training_role == "XPU":
role = Role.XPU
......@@ -1051,10 +1057,17 @@ class HeterRoleMaker(GeneralRoleMaker):
self._node_type = 2
self._cur_endpoint = xpu_endpoints[current_id]
gloo = fluid.core.Gloo()
gloo.init(current_id,
len(xpu_endpoints),
self._hdfs_path.rstrip("/") + "/xpu", self._hdfs_name,
self._hdfs_ugi, self._iface, self._prefix)
gloo.set_rank(current_id)
gloo.set_size(len(xpu_endpoints))
gloo.set_prefix(self._prefix)
gloo.set_iface(self._iface)
gloo.set_timeout_seconds(self._init_timeout_seconds,
self._run_timeout_seconds)
gloo.set_hdfs_store(
self._hdfs_path.rstrip("/") + "/xpu", self._hdfs_name,
self._hdfs_ugi)
gloo.init()
self._node_type_comm = gloo
elif training_role == "PSERVER":
role = Role.SERVER
......@@ -1070,30 +1083,47 @@ class HeterRoleMaker(GeneralRoleMaker):
self._node_type = 0
self._cur_endpoint = cur_endpoint
gloo = fluid.core.Gloo()
gloo.init(current_id,
len(eplist),
self._hdfs_path.rstrip("/") + "/pserver",
self._hdfs_name, self._hdfs_ugi, self._iface,
self._prefix)
gloo.set_rank(current_id)
gloo.set_size(len(eplist))
gloo.set_prefix(self._prefix)
gloo.set_iface(self._iface)
gloo.set_timeout_seconds(self._init_timeout_seconds,
self._run_timeout_seconds)
gloo.set_hdfs_store(
self._hdfs_path.rstrip("/") + "/pserver", self._hdfs_name,
self._hdfs_ugi)
gloo.init()
self._node_type_comm = gloo
if training_role == "TRAINER" or training_role == "XPU":
gloo = fluid.core.Gloo()
heter_list = worker_endpoints + xpu_endpoints
gloo.init(
heter_list.index(self._cur_endpoint),
len(heter_list),
gloo.set_rank(heter_list.index(self._cur_endpoint))
gloo.set_size(len(heter_list))
gloo.set_prefix(self._prefix)
gloo.set_iface(self._iface)
gloo.set_timeout_seconds(self._init_timeout_seconds,
self._run_timeout_seconds)
gloo.set_hdfs_store(
self._hdfs_path.rstrip("/") + "/heter", self._hdfs_name,
self._hdfs_ugi, self._iface, self._prefix)
self._hdfs_ugi)
gloo.init()
self._heter_comm = gloo
gloo = fluid.core.Gloo()
all_list = worker_endpoints + eplist + xpu_endpoints
gloo.init(
all_list.index(self._cur_endpoint),
len(all_list),
gloo.set_rank(all_list.index(self._cur_endpoint))
gloo.set_size(len(all_list))
gloo.set_prefix(self._prefix)
gloo.set_iface(self._iface)
gloo.set_timeout_seconds(self._init_timeout_seconds,
self._run_timeout_seconds)
gloo.set_hdfs_store(
self._hdfs_path.rstrip("/") + "/all", self._hdfs_name,
self._hdfs_ugi, self._iface, self._prefix)
self._hdfs_ugi)
gloo.init()
self._all_comm = gloo
self._trainers_num = trainers_num
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册