未验证 提交 8bb1038c 编写于 作者: Z zmx 提交者: GitHub

[cherry-pick 2.2 heterps]bug fix for launch_utils.py (#37521)

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* [heterps]bug fix for _run_from_dataset

* fix heter_server.cc

* fix launch_utils.py

* fix heter_section_worker.cc

* fix. test=develop

* fix. test=develop
上级 b9c464c3
......@@ -43,11 +43,11 @@ void HeterServer::StartHeterService() {
{
std::lock_guard<std::mutex> lock(this->mutex_ready_);
stoped_ = false;
ready_ = 1;
}
condition_ready_.notify_all();
std::unique_lock<std::mutex> running_lock(mutex_);
stoped_ = false;
cv_.wait(running_lock, [&] {
VLOG(1) << "Heter Server is Stop? " << stoped_;
return stoped_;
......
......@@ -350,7 +350,7 @@ void HeterSectionWorker::BatchPostProcess() {
DumpParam(*((*microbatch_scopes_)[0]), batch_num_);
}
// print each op time
if (thread_id_ == 0) {
if (debug_ && thread_id_ == 0) {
size_t total_ops_size = forward_ops_.size() + backward_ops_.size();
if (batch_num_ > 0 && batch_num_ % 100 == 0) {
for (size_t i = 0; i < total_ops_size; ++i) {
......
......@@ -1088,8 +1088,6 @@ class ParameterServerLauncher(object):
heter_worker_endpoints_list = args.heter_workers.split(";")
self.heter_worker_endpoints = ""
for i in range(len(heter_worker_endpoints_list)):
if self.heter_worker_endpoints != "":
self.heter_worker_endpoints += ","
heter_worker_endpoints = heter_worker_endpoints_list[
i].split(",")
self.stage_heter_trainer_num.append(
......@@ -1182,15 +1180,18 @@ class ParameterServerLauncher(object):
_, self.current_node_ip = get_host_name_ip()
else:
self.current_node_ip = pod_ip
assert self.current_node_ip in self.node_ips, "Can't find your local ip {%s} in args.servers and args.workers ips: {%s}" \
% (self.current_node_ip, self.node_ips)
self.node_rank = self.node_ips.index(self.current_node_ip)
logger.debug(
"parsed from args: node_ips:{} current_node_ip:{} node_rank:{}".
format(self.node_ips, self.current_node_ip, self.node_rank))
if not self.distribute_mode == DistributeMode.PS_HETER:
assert self.current_node_ip in self.node_ips, "Can't find your local ip {%s} in args.servers and args.workers ips: {%s}" \
% (self.current_node_ip, self.node_ips)
if self.current_node_ip in self.node_ips:
self.node_rank = self.node_ips.index(self.current_node_ip)
logger.debug(
"parsed from args: node_ips:{} current_node_ip:{} node_rank:{}".
format(self.node_ips, self.current_node_ip, self.node_rank))
def start_ps(self):
if not self.current_node_ip in self.node_ips:
return
cluster = Cluster(hdfs=None)
server_rank = 0
worker_rank = 0
......
......@@ -1737,12 +1737,8 @@ class Executor(object):
for var in program.global_block().vars.values():
if var.is_data:
data_vars.append(var)
if core.is_compiled_with_npu():
dataset = paddle.fluid.DatasetFactory().create_dataset(
'InMemoryDataset')
else:
dataset = paddle.fluid.DatasetFactory().create_dataset(
'FileInstantDataset')
dataset = paddle.fluid.DatasetFactory().create_dataset(
'InMemoryDataset')
dataset.set_batch_size(1)
dataset.set_thread(1)
dataset.set_filelist(['None'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册