未验证 提交 e31c96ac 编写于 作者: C chenjian 提交者: GitHub

Fix the bug that the child thread crashes causing the main thread to deadlock (#1013)

上级 4943b793
......@@ -59,7 +59,8 @@ class FileFactory(object):
if not HDFS_ENABLED:
raise RuntimeError('Please install module named "hdfs".')
try:
default_file_factory.register_filesystem("hdfs", HDFileSystem())
default_file_factory.register_filesystem(
"hdfs", HDFileSystem())
except hdfs.util.HdfsError:
raise RuntimeError(
"Please initialize `~/.hdfscli.cfg` for HDFS.")
......@@ -182,8 +183,9 @@ class HDFileSystem(object):
encoding = None if binary_mode else "utf-8"
try:
with self.cli.read(hdfs_path=filename[7:], offset=offset,
encoding=encoding) as reader:
with self.cli.read(
hdfs_path=filename[7:], offset=offset,
encoding=encoding) as reader:
data = reader.read()
continue_from_token = {"last_offset": offset + len(data)}
return data, continue_from_token
......@@ -214,7 +216,8 @@ class BosConfigClient(object):
def __init__(self, bos_ak, bos_sk, bos_sts, bos_host="bj.bcebos.com"):
self.config = BceClientConfiguration(
credentials=BceCredentials(bos_ak, bos_sk),
endpoint=bos_host, security_token=bos_sts)
endpoint=bos_host,
security_token=bos_sts)
self.bos_client = BosClient(self.config)
def exists(self, path):
......@@ -234,11 +237,12 @@ class BosConfigClient(object):
if not object_key.endswith('/'):
object_key += '/'
init_data = b''
self.bos_client.append_object(bucket_name=bucket_name,
key=object_key,
data=init_data,
content_md5=content_md5(init_data),
content_length=len(init_data))
self.bos_client.append_object(
bucket_name=bucket_name,
key=object_key,
data=init_data,
content_md5=content_md5(init_data),
content_length=len(init_data))
@staticmethod
def join(path, *paths):
......@@ -255,9 +259,8 @@ class BosConfigClient(object):
# if not object_key.endswith('/'):
# object_key += '/'
print('Uploading file `%s`' % filename)
self.bos_client.put_object_from_file(bucket=bucket_name,
key=object_key,
file_name=filename)
self.bos_client.put_object_from_file(
bucket=bucket_name, key=object_key, file_name=filename)
class BosFileSystem(object):
......@@ -288,14 +291,36 @@ class BosFileSystem(object):
bos_sts = os.getenv("BOS_STS")
self.config = BceClientConfiguration(
credentials=BceCredentials(access_key_id, secret_access_key),
endpoint=bos_host, security_token=bos_sts)
endpoint=bos_host,
security_token=bos_sts)
def set_bos_config(self, bos_ak, bos_sk, bos_sts, bos_host="bj.bcebos.com"):
def set_bos_config(self, bos_ak, bos_sk, bos_sts,
bos_host="bj.bcebos.com"):
self.config = BceClientConfiguration(
credentials=BceCredentials(bos_ak, bos_sk),
endpoint=bos_host, security_token=bos_sts)
endpoint=bos_host,
security_token=bos_sts)
self.bos_client = BosClient(self.config)
def renew_bos_client_from_server(self):
import requests
import json
from visualdl.utils.dir import CONFIG_PATH
with open(CONFIG_PATH, 'r') as fp:
server_url = json.load(fp)['server_url']
url = server_url + '/sts/'
res = requests.post(url=url).json()
err_code = res.get('code')
msg = res.get('msg')
if '000000' == err_code:
sts_ak = msg.get('sts_ak')
sts_sk = msg.get('sts_sk')
sts_token = msg.get('token')
self.set_bos_config(sts_ak, sts_sk, sts_token)
else:
print('Renew bos client error. Error msg: {}'.format(msg))
return
def isfile(self, filename):
return exists(filename)
......@@ -324,11 +349,12 @@ class BosFileSystem(object):
if not object_key.endswith('/'):
object_key += '/'
init_data = b''
self.bos_client.append_object(bucket_name=bucket_name,
key=object_key,
data=init_data,
content_md5=content_md5(init_data),
content_length=len(init_data))
self.bos_client.append_object(
bucket_name=bucket_name,
key=object_key,
data=init_data,
content_md5=content_md5(init_data),
content_length=len(init_data))
@staticmethod
def join(path, *paths):
......@@ -344,10 +370,10 @@ class BosFileSystem(object):
length = int(
self.get_meta(bucket_name, object_key).metadata.content_length)
if offset < length:
data = self.bos_client.get_object_as_string(bucket_name=bucket_name,
key=object_key,
range=[offset,
length - 1])
data = self.bos_client.get_object_as_string(
bucket_name=bucket_name,
key=object_key,
range=[offset, length - 1])
else:
data = b''
......@@ -371,29 +397,45 @@ class BosFileSystem(object):
bucket_name, object_key = get_object_info(filename)
if not self.exists(filename):
init_data = b''
self.bos_client.append_object(bucket_name=bucket_name,
key=object_key,
data=init_data,
content_md5=content_md5(init_data),
content_length=len(init_data))
try:
self.bos_client.append_object(
bucket_name=bucket_name,
key=object_key,
data=init_data,
content_md5=content_md5(init_data),
content_length=len(init_data))
except (exception.BceServerError, exception.BceHttpClientError):
self.renew_bos_client_from_server()
self.bos_client.append_object(
bucket_name=bucket_name,
key=object_key,
data=init_data,
content_md5=content_md5(init_data),
content_length=len(init_data))
return
content_length = len(file_content)
try:
offset = self.get_meta(bucket_name,
object_key).metadata.content_length
self.bos_client.append_object(bucket_name=bucket_name,
key=object_key,
data=file_content,
content_md5=content_md5(file_content),
content_length=content_length,
offset=offset)
self.bos_client.append_object(
bucket_name=bucket_name,
key=object_key,
data=file_content,
content_md5=content_md5(file_content),
content_length=content_length,
offset=offset)
except (exception.BceServerError, exception.BceHttpClientError):
init_data = b''
self.bos_client.append_object(bucket_name=bucket_name,
key=object_key,
data=init_data,
content_md5=content_md5(init_data),
content_length=len(init_data))
self.renew_bos_client_from_server()
offset = self.get_meta(bucket_name,
object_key).metadata.content_length
self.bos_client.append_object(
bucket_name=bucket_name,
key=object_key,
data=file_content,
content_md5=content_md5(file_content),
content_length=content_length,
offset=offset)
self._file_contents_to_add = b''
self._file_contents_count = 0
......@@ -435,9 +477,10 @@ class BosFileSystem(object):
contents_map[key] = [value]
temp_walk = []
for key, value in contents_map.items():
temp_walk.append(
[BosFileSystem.join('bos://' + self.bucket, key), [],
value])
temp_walk.append([
BosFileSystem.join('bos://' + self.bucket, key), [],
value
])
self.length = len(temp_walk)
self.contents = temp_walk
......@@ -458,8 +501,7 @@ class BosFileSystem(object):
else:
prefix = object_key if object_key.endswith(
'/') else object_key + '/'
response = self.bos_client.list_objects(bucket_name,
prefix=prefix)
response = self.bos_client.list_objects(bucket_name, prefix=prefix)
contents = [content.key for content in response.contents]
return WalkGenerator(bucket_name, contents)
......@@ -633,7 +675,8 @@ class BFile(object):
def close(self):
if isinstance(self.fs, BosFileSystem):
try:
self.fs.append(self._filename, b'', self.binary_mode, force=True)
self.fs.append(
self._filename, b'', self.binary_mode, force=True)
except Exception:
pass
self.flush()
......
......@@ -30,6 +30,7 @@ if isinstance(QUEUE_TIMEOUT, str):
class RecordWriter(object):
"""Package data with crc32 or not.
"""
def __init__(self, writer):
self._writer = writer
......@@ -77,8 +78,13 @@ class RecordFileWriter(object):
directory and asynchronously writes `Record` protocol buffers to this
file.
"""
def __init__(self, logdir, max_queue_size=10, flush_secs=120,
filename_suffix='', filename=''):
def __init__(self,
logdir,
max_queue_size=10,
flush_secs=120,
filename_suffix='',
filename=''):
self._logdir = logdir
if not bfile.exists(logdir):
bfile.makedirs(logdir)
......@@ -93,16 +99,19 @@ class RecordFileWriter(object):
else:
fn = "vdlrecords.%010d.log%s" % (time.time(), filename_suffix)
self._file_name = bfile.join(logdir, fn)
print(
'Since the log filename should contain `vdlrecords`, the filename is invalid and `{}` will replace `{}`'.format( # noqa: E501
fn, filename))
print('Since the log filename should contain `vdlrecords`, '
'the filename is invalid and `{}` will replace `{}`'.
format( # noqa: E501
fn, filename))
else:
self._file_name = bfile.join(logdir, "vdlrecords.%010d.log%s" % (
time.time(), filename_suffix))
self._file_name = bfile.join(
logdir,
"vdlrecords.%010d.log%s" % (time.time(), filename_suffix))
self._general_file_writer = bfile.BFile(self._file_name, "wb")
self._async_writer = _AsyncWriter(RecordWriter(
self._general_file_writer), max_queue_size, flush_secs)
self._async_writer = _AsyncWriter(
RecordWriter(self._general_file_writer), max_queue_size,
flush_secs)
# TODO(shenyuhan) Maybe file_version in future.
# _record = record_pb2.Record()
# self.add_record(_record)
......@@ -140,8 +149,7 @@ class _AsyncWriter(object):
self._closed = False
self._bytes_queue = queue.Queue(max_queue_size)
self._worker = _AsyncWriterThread(self._bytes_queue,
self._record_writer,
flush_secs)
self._record_writer, flush_secs)
self._lock = threading.Lock()
self._worker.start()
......@@ -188,6 +196,7 @@ class _AsyncWriterThread(threading.Thread):
self.join()
def run(self):
has_unresolved_bug = False
while True:
now = time.time()
queue_wait_duration = self._next_flush_time - now
......@@ -205,6 +214,14 @@ class _AsyncWriterThread(threading.Thread):
self._has_pending_data = True
except queue.Empty:
pass
except Exception as e:
# prevent the main thread from deadlock due to writing error.
if not has_unresolved_bug:
print('Warning: Writing data Error, Due to unresolved Exception {}'.format(e))
print('Warning: Writing data to FileSystem failed since {}.'.format(
time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.gmtime())))
has_unresolved_bug = True
pass
finally:
if data:
self._queue.task_done()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册