提交 8c2c8dc6 编写于 作者: W WangXi 提交者: gongweibao

distribute.launch use poll to query subprocess (#19853)

distribute.launch use poll to query subprocess
上级 8e927327
...@@ -36,16 +36,25 @@ launch a process on each of the given gpu card. ...@@ -36,16 +36,25 @@ launch a process on each of the given gpu card.
""" """
from __future__ import print_function from __future__ import print_function
import logging
import sys import sys
from sys import version from sys import version
import subprocess import subprocess
import os import os
import warnings import time
import six import six
import copy import copy
from argparse import ArgumentParser, REMAINDER from argparse import ArgumentParser, REMAINDER
import paddle.fluid as fluid import paddle.fluid as fluid
logger = logging.getLogger()
logger.setLevel(logging.INFO)
log_handler = logging.StreamHandler()
log_format = logging.Formatter(
'%(asctime)s - %(filename)s:%(lineno)d - %(levelname)s: %(message)s')
log_handler.setFormatter(log_format)
logger.addHandler(log_handler)
def _print_arguments(args): def _print_arguments(args):
print("----------- Configuration Arguments -----------") print("----------- Configuration Arguments -----------")
...@@ -129,6 +138,12 @@ POD_IP (current node ip address, not needed for local training) ...@@ -129,6 +138,12 @@ POD_IP (current node ip address, not needed for local training)
return parser.parse_args() return parser.parse_args()
def terminate_procs(procs):
for p in procs:
if p.poll() is None:
p.terminate()
def start_procs(args): def start_procs(args):
""" """
""" """
...@@ -154,14 +169,14 @@ def start_procs(args): ...@@ -154,14 +169,14 @@ def start_procs(args):
node_id = int(node_id) node_id = int(node_id)
if args.node_ip != "127.0.0.1" and current_node_ip != args.node_ip: if args.node_ip != "127.0.0.1" and current_node_ip != args.node_ip:
warnings.warn( logger.warning(
"Please NOTE: When using paddlecloud, current_node_ip is \ "Please NOTE: When using paddlecloud, current_node_ip is \
automatically got from POD_IP. Your input node_ip: {} doesn't equals to \ automatically got from POD_IP. Your input node_ip: {} doesn't equals to \
current_node_ip: {} from paddlecloud environment." current_node_ip: {} from paddlecloud environment."
.format(args.node_ip, current_node_ip)) .format(args.node_ip, current_node_ip))
if args.cluster_node_ips != "127.0.0.1" and args.cluster_node_ips != ",".join( if args.cluster_node_ips != "127.0.0.1" and args.cluster_node_ips != ",".join(
node_ips): node_ips):
warnings.warn( logger.warning(
"Please NOTE: When using paddlecloud, cluster_node_ips is \ "Please NOTE: When using paddlecloud, cluster_node_ips is \
automatically got from PADDLE_TRAINERS(multi nodes) or POD_IP(single node).\ automatically got from PADDLE_TRAINERS(multi nodes) or POD_IP(single node).\
Your input cluster_node_ips: {} doesn't equals to IPs: {} from \ Your input cluster_node_ips: {} doesn't equals to IPs: {} from \
...@@ -228,16 +243,39 @@ paddlecloud environment.".format(args.cluster_node_ips, node_ips)) ...@@ -228,16 +243,39 @@ paddlecloud environment.".format(args.cluster_node_ips, node_ips))
procs.append(proc) procs.append(proc)
for i in range(0, len(procs)): try:
proc = procs[i] alive = True
error = False
proc.wait() # wait all process finish or one error
if len(log_fns) > 0: while alive and not error:
log_fns[i].close() alive = False
for p in procs:
if proc.returncode != 0: ret = p.poll()
raise subprocess.CalledProcessError( if ret is None:
returncode=procs[i].returncode, cmd=cmds[i]) alive = True
elif ret != 0:
error = True
time.sleep(1)
if error:
terminate_procs(procs)
exit(1)
except KeyboardInterrupt:
logger.warning("KeyboardInterrupt, exit")
terminate_procs(procs)
raise
except SystemExit:
logger.error("One trainer process abort, exit")
terminate_procs(procs)
raise
except:
logger.error("Trainer process abort, exit")
terminate_procs(procs)
raise
finally:
for fn in log_fns:
fn.close()
def launch(): def launch():
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
import os import os
import sys
import time
def train(): def train():
...@@ -31,5 +33,39 @@ def train(): ...@@ -31,5 +33,39 @@ def train():
f.write(name) f.write(name)
def train_abort():
selected_gpus = os.getenv("FLAGS_selected_gpus")
trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
worker_endpoints_env = os.getenv("PADDLE_TRAINER_ENDPOINTS")
current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
worker_endpoints = worker_endpoints_env
trainers_num = len(worker_endpoints.split(','))
if trainer_id == 0:
try:
# train abort
exit(1)
except SystemExit:
name = "abort>>> selected_gpus:{} worker_endpoints:{} trainers_num:{} current_endpoint:{} trainer_id:{}"\
.format(selected_gpus, worker_endpoints, trainers_num, current_endpoint,trainer_id)
print(name)
with open("multi_process.check_{}.log".format(trainer_id),
"w") as f:
f.write(name)
raise
else:
# sleep 30s to make sure paddle.distributed.launch will terminate this process
time.sleep(30)
name = "selected_gpus:{} worker_endpoints:{} trainers_num:{} current_endpoint:{} trainer_id:{}"\
.format(selected_gpus, worker_endpoints, trainers_num, current_endpoint,trainer_id)
print(name)
with open("multi_process.check_{}.log".format(trainer_id), "w") as f:
f.write(name)
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) == 2 and sys.argv[1] == "abort":
train_abort()
else:
train() train()
#!/bin/bash #!/bin/bash
set -ex set -e
# use default values # use default values
python -m paddle.distributed.launch multi_process.py python -m paddle.distributed.launch multi_process.py
...@@ -33,3 +33,33 @@ else ...@@ -33,3 +33,33 @@ else
echo "not find trainer 1" echo "not find trainer 1"
exit -1 exit -1
fi fi
# test async poll process
if [ -f $file_0 ]; then
rm $file_0
fi
if [ -f $file_1 ]; then
rm $file_1
fi
echo ""
echo "paddle.distributed.launch async poll process test"
if ! python -m paddle.distributed.launch ${distributed_args} multi_process.py abort; then
echo "train abort as planned"
fi
abort_str1="abort>>> selected_gpus:0 worker_endpoints:127.0.0.1:6170,127.0.0.1:6171,127.0.0.2:6170,127.0.0.2:6171 trainers_num:4 current_endpoint:127.0.0.1:6170 trainer_id:0"
if grep -q "$abort_str1" "$file_0"; then
echo "trainer 0 abort as planned"
else
echo "trainer 0 not abort as planned"
exit -1
fi
if [ ! -f $file_1 ]; then
echo "trainer 1 terminate as planned"
else
echo "trainer 1 not terminate as planned"
exit -1
fi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册