未验证 提交 4b41b8e9 编写于 作者: Z zmx 提交者: GitHub

[cherry-pick 2.2 heterps]bug fix for launch_utils.py (#37521) (#37570)

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* [heterps]bug fix for _run_from_dataset

* fix heter_server.cc

* fix launch_utils.py

* fix heter_section_worker.cc

* fix. test=develop

* fix. test=develop
上级 ca8b8586
...@@ -43,11 +43,11 @@ void HeterServer::StartHeterService() { ...@@ -43,11 +43,11 @@ void HeterServer::StartHeterService() {
{ {
std::lock_guard<std::mutex> lock(this->mutex_ready_); std::lock_guard<std::mutex> lock(this->mutex_ready_);
stoped_ = false;
ready_ = 1; ready_ = 1;
} }
condition_ready_.notify_all(); condition_ready_.notify_all();
std::unique_lock<std::mutex> running_lock(mutex_); std::unique_lock<std::mutex> running_lock(mutex_);
stoped_ = false;
cv_.wait(running_lock, [&] { cv_.wait(running_lock, [&] {
VLOG(1) << "Heter Server is Stop? " << stoped_; VLOG(1) << "Heter Server is Stop? " << stoped_;
return stoped_; return stoped_;
......
...@@ -350,7 +350,7 @@ void HeterSectionWorker::BatchPostProcess() { ...@@ -350,7 +350,7 @@ void HeterSectionWorker::BatchPostProcess() {
DumpParam(*((*microbatch_scopes_)[0]), batch_num_); DumpParam(*((*microbatch_scopes_)[0]), batch_num_);
} }
// print each op time // print each op time
if (thread_id_ == 0) { if (debug_ && thread_id_ == 0) {
size_t total_ops_size = forward_ops_.size() + backward_ops_.size(); size_t total_ops_size = forward_ops_.size() + backward_ops_.size();
if (batch_num_ > 0 && batch_num_ % 100 == 0) { if (batch_num_ > 0 && batch_num_ % 100 == 0) {
for (size_t i = 0; i < total_ops_size; ++i) { for (size_t i = 0; i < total_ops_size; ++i) {
......
...@@ -979,8 +979,6 @@ class ParameterServerLauncher(object): ...@@ -979,8 +979,6 @@ class ParameterServerLauncher(object):
heter_worker_endpoints_list = args.heter_workers.split(";") heter_worker_endpoints_list = args.heter_workers.split(";")
self.heter_worker_endpoints = "" self.heter_worker_endpoints = ""
for i in range(len(heter_worker_endpoints_list)): for i in range(len(heter_worker_endpoints_list)):
if self.heter_worker_endpoints != "":
self.heter_worker_endpoints += ","
heter_worker_endpoints = heter_worker_endpoints_list[ heter_worker_endpoints = heter_worker_endpoints_list[
i].split(",") i].split(",")
self.stage_heter_trainer_num.append( self.stage_heter_trainer_num.append(
...@@ -1073,15 +1071,18 @@ class ParameterServerLauncher(object): ...@@ -1073,15 +1071,18 @@ class ParameterServerLauncher(object):
_, self.current_node_ip = get_host_name_ip() _, self.current_node_ip = get_host_name_ip()
else: else:
self.current_node_ip = pod_ip self.current_node_ip = pod_ip
if not self.distribute_mode == DistributeMode.PS_HETER:
assert self.current_node_ip in self.node_ips, "Can't find your local ip {%s} in args.servers and args.workers ips: {%s}" \ assert self.current_node_ip in self.node_ips, "Can't find your local ip {%s} in args.servers and args.workers ips: {%s}" \
% (self.current_node_ip, self.node_ips) % (self.current_node_ip, self.node_ips)
if self.current_node_ip in self.node_ips:
self.node_rank = self.node_ips.index(self.current_node_ip) self.node_rank = self.node_ips.index(self.current_node_ip)
logger.debug( logger.debug(
"parsed from args: node_ips:{} current_node_ip:{} node_rank:{}". "parsed from args: node_ips:{} current_node_ip:{} node_rank:{}".
format(self.node_ips, self.current_node_ip, self.node_rank)) format(self.node_ips, self.current_node_ip, self.node_rank))
def start_ps(self): def start_ps(self):
if not self.current_node_ip in self.node_ips:
return
cluster = Cluster(hdfs=None) cluster = Cluster(hdfs=None)
server_rank = 0 server_rank = 0
worker_rank = 0 worker_rank = 0
......
...@@ -1673,12 +1673,8 @@ class Executor(object): ...@@ -1673,12 +1673,8 @@ class Executor(object):
for var in program.global_block().vars.values(): for var in program.global_block().vars.values():
if var.is_data: if var.is_data:
data_vars.append(var) data_vars.append(var)
if core.is_compiled_with_npu():
dataset = paddle.fluid.DatasetFactory().create_dataset( dataset = paddle.fluid.DatasetFactory().create_dataset(
'InMemoryDataset') 'InMemoryDataset')
else:
dataset = paddle.fluid.DatasetFactory().create_dataset(
'FileInstantDataset')
dataset.set_batch_size(1) dataset.set_batch_size(1)
dataset.set_thread(1) dataset.set_thread(1)
dataset.set_filelist(['None']) dataset.set_filelist(['None'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册