diff --git a/python/paddle/distributed/utils.py b/python/paddle/distributed/utils.py index 53b02334cf9e02de1742f2da9cb25f4f27a2f18f..30a928cc9f86a842695bb90a7838d68ac291d9d8 100644 --- a/python/paddle/distributed/utils.py +++ b/python/paddle/distributed/utils.py @@ -21,7 +21,6 @@ import signal import copy import sys import subprocess -import threading from contextlib import closing import socket @@ -332,7 +331,9 @@ class TrainerProc(object): def __init__(self): self.proc = None self.log_fn = None + self.log_offset = None self.rank = None + self.local_rank = None self.cmd = None @@ -371,36 +372,16 @@ def start_local_trainers(cluster, if log_dir is not None: os.system("mkdir -p {}".format(log_dir)) fn = open("%s/workerlog.%d" % (log_dir, idx), "a") - if idx == 0: - proc = subprocess.Popen( - cmd, - env=current_env, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT) - - def shell_tee(proc, fn): - BUF_SIZE = 512 - while True: - buf = proc.stdout.read(BUF_SIZE) - if len(buf) == 0: - break - - sys.stdout.buffer.write(buf) - fn.buffer.write(buf) - sys.stdout.flush() - fn.flush() - - threading.Thread(target=shell_tee, args=(proc, fn)).start() - else: - proc = subprocess.Popen( - cmd, env=current_env, stdout=fn, stderr=fn) + proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn) else: proc = subprocess.Popen(cmd, env=current_env) tp = TrainerProc() tp.proc = proc tp.rank = t.rank + tp.local_rank = idx tp.log_fn = fn + tp.log_offset = 0 if fn else None tp.cmd = cmd procs.append(tp) @@ -408,6 +389,21 @@ def start_local_trainers(cluster, return procs +def pull_worker_log(tp): + if tp.log_fn: + with open(tp.log_fn.name, 'r') as fin: + fin.seek(tp.log_offset, 0) + for line in fin: + try: + sys.stdout.write(line) + except UnicodeEncodeError: + sys.stdout.write( + 'UnicodeEncodeError occurs at this line. ' + 'Please refer to the original log file "%s"\n' % + tp.log_fn.name) + tp.log_offset = fin.tell() + + def watch_local_trainers(procs, nranks): try: error = False @@ -415,6 +411,9 @@ def watch_local_trainers(procs, nranks): # wait all process finish or one error alive = False for p in procs: + if p.log_fn and p.local_rank == 0: + pull_worker_log(p) + ret = p.proc.poll() if ret is None: alive = True