Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
VisualDL
提交
e31c96ac
V
VisualDL
项目概览
PaddlePaddle
/
VisualDL
大约 1 年 前同步成功
通知
88
Star
4655
Fork
642
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
2
Wiki
5
Wiki
分析
仓库
DevOps
项目成员
Pages
V
VisualDL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
2
合并请求
2
Pages
分析
分析
仓库分析
DevOps
Wiki
5
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e31c96ac
编写于
11月 02, 2021
作者:
C
chenjian
提交者:
GitHub
11月 02, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix the bug that the child thread crashes causing the main thread to deadlock (#1013)
上级
4943b793
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
118 addition
and
58 deletion
+118
-58
visualdl/io/bfile.py
visualdl/io/bfile.py
+90
-47
visualdl/writer/record_writer.py
visualdl/writer/record_writer.py
+28
-11
未找到文件。
visualdl/io/bfile.py
浏览文件 @
e31c96ac
...
...
@@ -59,7 +59,8 @@ class FileFactory(object):
if
not
HDFS_ENABLED
:
raise
RuntimeError
(
'Please install module named "hdfs".'
)
try
:
default_file_factory
.
register_filesystem
(
"hdfs"
,
HDFileSystem
())
default_file_factory
.
register_filesystem
(
"hdfs"
,
HDFileSystem
())
except
hdfs
.
util
.
HdfsError
:
raise
RuntimeError
(
"Please initialize `~/.hdfscli.cfg` for HDFS."
)
...
...
@@ -182,8 +183,9 @@ class HDFileSystem(object):
encoding
=
None
if
binary_mode
else
"utf-8"
try
:
with
self
.
cli
.
read
(
hdfs_path
=
filename
[
7
:],
offset
=
offset
,
encoding
=
encoding
)
as
reader
:
with
self
.
cli
.
read
(
hdfs_path
=
filename
[
7
:],
offset
=
offset
,
encoding
=
encoding
)
as
reader
:
data
=
reader
.
read
()
continue_from_token
=
{
"last_offset"
:
offset
+
len
(
data
)}
return
data
,
continue_from_token
...
...
@@ -214,7 +216,8 @@ class BosConfigClient(object):
def
__init__
(
self
,
bos_ak
,
bos_sk
,
bos_sts
,
bos_host
=
"bj.bcebos.com"
):
self
.
config
=
BceClientConfiguration
(
credentials
=
BceCredentials
(
bos_ak
,
bos_sk
),
endpoint
=
bos_host
,
security_token
=
bos_sts
)
endpoint
=
bos_host
,
security_token
=
bos_sts
)
self
.
bos_client
=
BosClient
(
self
.
config
)
def
exists
(
self
,
path
):
...
...
@@ -234,11 +237,12 @@ class BosConfigClient(object):
if
not
object_key
.
endswith
(
'/'
):
object_key
+=
'/'
init_data
=
b
''
self
.
bos_client
.
append_object
(
bucket_name
=
bucket_name
,
key
=
object_key
,
data
=
init_data
,
content_md5
=
content_md5
(
init_data
),
content_length
=
len
(
init_data
))
self
.
bos_client
.
append_object
(
bucket_name
=
bucket_name
,
key
=
object_key
,
data
=
init_data
,
content_md5
=
content_md5
(
init_data
),
content_length
=
len
(
init_data
))
@
staticmethod
def
join
(
path
,
*
paths
):
...
...
@@ -255,9 +259,8 @@ class BosConfigClient(object):
# if not object_key.endswith('/'):
# object_key += '/'
print
(
'Uploading file `%s`'
%
filename
)
self
.
bos_client
.
put_object_from_file
(
bucket
=
bucket_name
,
key
=
object_key
,
file_name
=
filename
)
self
.
bos_client
.
put_object_from_file
(
bucket
=
bucket_name
,
key
=
object_key
,
file_name
=
filename
)
class
BosFileSystem
(
object
):
...
...
@@ -288,14 +291,36 @@ class BosFileSystem(object):
bos_sts
=
os
.
getenv
(
"BOS_STS"
)
self
.
config
=
BceClientConfiguration
(
credentials
=
BceCredentials
(
access_key_id
,
secret_access_key
),
endpoint
=
bos_host
,
security_token
=
bos_sts
)
endpoint
=
bos_host
,
security_token
=
bos_sts
)
def
set_bos_config
(
self
,
bos_ak
,
bos_sk
,
bos_sts
,
bos_host
=
"bj.bcebos.com"
):
def
set_bos_config
(
self
,
bos_ak
,
bos_sk
,
bos_sts
,
bos_host
=
"bj.bcebos.com"
):
self
.
config
=
BceClientConfiguration
(
credentials
=
BceCredentials
(
bos_ak
,
bos_sk
),
endpoint
=
bos_host
,
security_token
=
bos_sts
)
endpoint
=
bos_host
,
security_token
=
bos_sts
)
self
.
bos_client
=
BosClient
(
self
.
config
)
def
renew_bos_client_from_server
(
self
):
import
requests
import
json
from
visualdl.utils.dir
import
CONFIG_PATH
with
open
(
CONFIG_PATH
,
'r'
)
as
fp
:
server_url
=
json
.
load
(
fp
)[
'server_url'
]
url
=
server_url
+
'/sts/'
res
=
requests
.
post
(
url
=
url
).
json
()
err_code
=
res
.
get
(
'code'
)
msg
=
res
.
get
(
'msg'
)
if
'000000'
==
err_code
:
sts_ak
=
msg
.
get
(
'sts_ak'
)
sts_sk
=
msg
.
get
(
'sts_sk'
)
sts_token
=
msg
.
get
(
'token'
)
self
.
set_bos_config
(
sts_ak
,
sts_sk
,
sts_token
)
else
:
print
(
'Renew bos client error. Error msg: {}'
.
format
(
msg
))
return
def
isfile
(
self
,
filename
):
return
exists
(
filename
)
...
...
@@ -324,11 +349,12 @@ class BosFileSystem(object):
if
not
object_key
.
endswith
(
'/'
):
object_key
+=
'/'
init_data
=
b
''
self
.
bos_client
.
append_object
(
bucket_name
=
bucket_name
,
key
=
object_key
,
data
=
init_data
,
content_md5
=
content_md5
(
init_data
),
content_length
=
len
(
init_data
))
self
.
bos_client
.
append_object
(
bucket_name
=
bucket_name
,
key
=
object_key
,
data
=
init_data
,
content_md5
=
content_md5
(
init_data
),
content_length
=
len
(
init_data
))
@
staticmethod
def
join
(
path
,
*
paths
):
...
...
@@ -344,10 +370,10 @@ class BosFileSystem(object):
length
=
int
(
self
.
get_meta
(
bucket_name
,
object_key
).
metadata
.
content_length
)
if
offset
<
length
:
data
=
self
.
bos_client
.
get_object_as_string
(
bucket_name
=
bucket_name
,
key
=
object_key
,
range
=
[
offset
,
length
-
1
])
data
=
self
.
bos_client
.
get_object_as_string
(
bucket_name
=
bucket_name
,
key
=
object_key
,
range
=
[
offset
,
length
-
1
])
else
:
data
=
b
''
...
...
@@ -371,29 +397,45 @@ class BosFileSystem(object):
bucket_name
,
object_key
=
get_object_info
(
filename
)
if
not
self
.
exists
(
filename
):
init_data
=
b
''
self
.
bos_client
.
append_object
(
bucket_name
=
bucket_name
,
key
=
object_key
,
data
=
init_data
,
content_md5
=
content_md5
(
init_data
),
content_length
=
len
(
init_data
))
try
:
self
.
bos_client
.
append_object
(
bucket_name
=
bucket_name
,
key
=
object_key
,
data
=
init_data
,
content_md5
=
content_md5
(
init_data
),
content_length
=
len
(
init_data
))
except
(
exception
.
BceServerError
,
exception
.
BceHttpClientError
):
self
.
renew_bos_client_from_server
()
self
.
bos_client
.
append_object
(
bucket_name
=
bucket_name
,
key
=
object_key
,
data
=
init_data
,
content_md5
=
content_md5
(
init_data
),
content_length
=
len
(
init_data
))
return
content_length
=
len
(
file_content
)
try
:
offset
=
self
.
get_meta
(
bucket_name
,
object_key
).
metadata
.
content_length
self
.
bos_client
.
append_object
(
bucket_name
=
bucket_name
,
key
=
object_key
,
data
=
file_content
,
content_md5
=
content_md5
(
file_content
),
content_length
=
content_length
,
offset
=
offset
)
self
.
bos_client
.
append_object
(
bucket_name
=
bucket_name
,
key
=
object_key
,
data
=
file_content
,
content_md5
=
content_md5
(
file_content
),
content_length
=
content_length
,
offset
=
offset
)
except
(
exception
.
BceServerError
,
exception
.
BceHttpClientError
):
init_data
=
b
''
self
.
bos_client
.
append_object
(
bucket_name
=
bucket_name
,
key
=
object_key
,
data
=
init_data
,
content_md5
=
content_md5
(
init_data
),
content_length
=
len
(
init_data
))
self
.
renew_bos_client_from_server
()
offset
=
self
.
get_meta
(
bucket_name
,
object_key
).
metadata
.
content_length
self
.
bos_client
.
append_object
(
bucket_name
=
bucket_name
,
key
=
object_key
,
data
=
file_content
,
content_md5
=
content_md5
(
file_content
),
content_length
=
content_length
,
offset
=
offset
)
self
.
_file_contents_to_add
=
b
''
self
.
_file_contents_count
=
0
...
...
@@ -435,9 +477,10 @@ class BosFileSystem(object):
contents_map
[
key
]
=
[
value
]
temp_walk
=
[]
for
key
,
value
in
contents_map
.
items
():
temp_walk
.
append
(
[
BosFileSystem
.
join
(
'bos://'
+
self
.
bucket
,
key
),
[],
value
])
temp_walk
.
append
([
BosFileSystem
.
join
(
'bos://'
+
self
.
bucket
,
key
),
[],
value
])
self
.
length
=
len
(
temp_walk
)
self
.
contents
=
temp_walk
...
...
@@ -458,8 +501,7 @@ class BosFileSystem(object):
else
:
prefix
=
object_key
if
object_key
.
endswith
(
'/'
)
else
object_key
+
'/'
response
=
self
.
bos_client
.
list_objects
(
bucket_name
,
prefix
=
prefix
)
response
=
self
.
bos_client
.
list_objects
(
bucket_name
,
prefix
=
prefix
)
contents
=
[
content
.
key
for
content
in
response
.
contents
]
return
WalkGenerator
(
bucket_name
,
contents
)
...
...
@@ -633,7 +675,8 @@ class BFile(object):
def
close
(
self
):
if
isinstance
(
self
.
fs
,
BosFileSystem
):
try
:
self
.
fs
.
append
(
self
.
_filename
,
b
''
,
self
.
binary_mode
,
force
=
True
)
self
.
fs
.
append
(
self
.
_filename
,
b
''
,
self
.
binary_mode
,
force
=
True
)
except
Exception
:
pass
self
.
flush
()
...
...
visualdl/writer/record_writer.py
浏览文件 @
e31c96ac
...
...
@@ -30,6 +30,7 @@ if isinstance(QUEUE_TIMEOUT, str):
class
RecordWriter
(
object
):
"""Package data with crc32 or not.
"""
def
__init__
(
self
,
writer
):
self
.
_writer
=
writer
...
...
@@ -77,8 +78,13 @@ class RecordFileWriter(object):
directory and asynchronously writes `Record` protocol buffers to this
file.
"""
def
__init__
(
self
,
logdir
,
max_queue_size
=
10
,
flush_secs
=
120
,
filename_suffix
=
''
,
filename
=
''
):
def
__init__
(
self
,
logdir
,
max_queue_size
=
10
,
flush_secs
=
120
,
filename_suffix
=
''
,
filename
=
''
):
self
.
_logdir
=
logdir
if
not
bfile
.
exists
(
logdir
):
bfile
.
makedirs
(
logdir
)
...
...
@@ -93,16 +99,19 @@ class RecordFileWriter(object):
else
:
fn
=
"vdlrecords.%010d.log%s"
%
(
time
.
time
(),
filename_suffix
)
self
.
_file_name
=
bfile
.
join
(
logdir
,
fn
)
print
(
'Since the log filename should contain `vdlrecords`, the filename is invalid and `{}` will replace `{}`'
.
format
(
# noqa: E501
fn
,
filename
))
print
(
'Since the log filename should contain `vdlrecords`, '
'the filename is invalid and `{}` will replace `{}`'
.
format
(
# noqa: E501
fn
,
filename
))
else
:
self
.
_file_name
=
bfile
.
join
(
logdir
,
"vdlrecords.%010d.log%s"
%
(
time
.
time
(),
filename_suffix
))
self
.
_file_name
=
bfile
.
join
(
logdir
,
"vdlrecords.%010d.log%s"
%
(
time
.
time
(),
filename_suffix
))
self
.
_general_file_writer
=
bfile
.
BFile
(
self
.
_file_name
,
"wb"
)
self
.
_async_writer
=
_AsyncWriter
(
RecordWriter
(
self
.
_general_file_writer
),
max_queue_size
,
flush_secs
)
self
.
_async_writer
=
_AsyncWriter
(
RecordWriter
(
self
.
_general_file_writer
),
max_queue_size
,
flush_secs
)
# TODO(shenyuhan) Maybe file_version in future.
# _record = record_pb2.Record()
# self.add_record(_record)
...
...
@@ -140,8 +149,7 @@ class _AsyncWriter(object):
self
.
_closed
=
False
self
.
_bytes_queue
=
queue
.
Queue
(
max_queue_size
)
self
.
_worker
=
_AsyncWriterThread
(
self
.
_bytes_queue
,
self
.
_record_writer
,
flush_secs
)
self
.
_record_writer
,
flush_secs
)
self
.
_lock
=
threading
.
Lock
()
self
.
_worker
.
start
()
...
...
@@ -188,6 +196,7 @@ class _AsyncWriterThread(threading.Thread):
self
.
join
()
def
run
(
self
):
has_unresolved_bug
=
False
while
True
:
now
=
time
.
time
()
queue_wait_duration
=
self
.
_next_flush_time
-
now
...
...
@@ -205,6 +214,14 @@ class _AsyncWriterThread(threading.Thread):
self
.
_has_pending_data
=
True
except
queue
.
Empty
:
pass
except
Exception
as
e
:
# prevent the main thread from deadlock due to writing error.
if
not
has_unresolved_bug
:
print
(
'Warning: Writing data Error, Due to unresolved Exception {}'
.
format
(
e
))
print
(
'Warning: Writing data to FileSystem failed since {}.'
.
format
(
time
.
strftime
(
"%a, %d %b %Y %H:%M:%S +0000"
,
time
.
gmtime
())))
has_unresolved_bug
=
True
pass
finally
:
if
data
:
self
.
_queue
.
task_done
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录