Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
eac4f3b2
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
eac4f3b2
编写于
2月 17, 2020
作者:
Y
Yibing Liu
提交者:
GitHub
2月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix offline file close in pantheon (#114)
上级
a4f4298d
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
22 addition
and
14 deletion
+22
-14
paddleslim/pantheon/student.py
paddleslim/pantheon/student.py
+4
-4
paddleslim/pantheon/teacher.py
paddleslim/pantheon/teacher.py
+18
-10
未找到文件。
paddleslim/pantheon/student.py
浏览文件 @
eac4f3b2
...
@@ -222,7 +222,7 @@ class Student(object):
...
@@ -222,7 +222,7 @@ class Student(object):
self
.
_started
=
True
self
.
_started
=
True
def
_merge_knowledge
(
self
,
knowledge
):
def
_merge_knowledge
(
self
,
knowledge
):
for
k
,
tensors
in
knowledge
.
items
(
):
for
k
,
tensors
in
list
(
knowledge
.
items
()
):
if
len
(
tensors
)
==
0
:
if
len
(
tensors
)
==
0
:
del
knowledge
[
k
]
del
knowledge
[
k
]
elif
len
(
tensors
)
==
1
:
elif
len
(
tensors
)
==
1
:
...
@@ -308,7 +308,7 @@ class Student(object):
...
@@ -308,7 +308,7 @@ class Student(object):
print
(
"Knowledge merging strategy: {}"
.
format
(
print
(
"Knowledge merging strategy: {}"
.
format
(
self
.
_merge_strategy
))
self
.
_merge_strategy
))
print
(
"Knowledge description after merging:"
)
print
(
"Knowledge description after merging:"
)
for
schema
,
desc
in
knowledge_desc
.
items
(
):
for
schema
,
desc
in
list
(
knowledge_desc
.
items
()
):
print
(
"{}: {}"
.
format
(
schema
,
desc
))
print
(
"{}: {}"
.
format
(
schema
,
desc
))
self
.
_knowledge_desc
=
knowledge_desc
self
.
_knowledge_desc
=
knowledge_desc
...
@@ -426,13 +426,13 @@ class Student(object):
...
@@ -426,13 +426,13 @@ class Student(object):
end_received
=
[
0
]
*
len
(
queues
)
end_received
=
[
0
]
*
len
(
queues
)
while
True
:
while
True
:
knowledge
=
OrderedDict
(
knowledge
=
OrderedDict
(
[(
k
,
[])
for
k
,
v
in
self
.
_knowledge_desc
.
items
(
)])
[(
k
,
[])
for
k
,
v
in
list
(
self
.
_knowledge_desc
.
items
()
)])
for
idx
,
receiver
in
enumerate
(
data_receivers
):
for
idx
,
receiver
in
enumerate
(
data_receivers
):
if
not
end_received
[
idx
]:
if
not
end_received
[
idx
]:
batch_samples
=
receiver
.
next
(
batch_samples
=
receiver
.
next
(
)
if
six
.
PY2
else
receiver
.
__next__
()
)
if
six
.
PY2
else
receiver
.
__next__
()
if
not
isinstance
(
batch_samples
,
EndSignal
):
if
not
isinstance
(
batch_samples
,
EndSignal
):
for
k
,
v
in
batch_samples
.
items
(
):
for
k
,
v
in
list
(
batch_samples
.
items
()
):
knowledge
[
k
].
append
(
v
)
knowledge
[
k
].
append
(
v
)
else
:
else
:
end_received
[
idx
]
=
1
end_received
[
idx
]
=
1
...
...
paddleslim/pantheon/teacher.py
浏览文件 @
eac4f3b2
...
@@ -231,7 +231,7 @@ class Teacher(object):
...
@@ -231,7 +231,7 @@ class Teacher(object):
"The knowledge data should be a dict or OrderedDict!"
)
"The knowledge data should be a dict or OrderedDict!"
)
knowledge_desc
=
{}
knowledge_desc
=
{}
for
name
,
value
in
knowledge
.
items
(
):
for
name
,
value
in
list
(
knowledge
.
items
()
):
knowledge_desc
[
name
]
=
{
knowledge_desc
[
name
]
=
{
"shape"
:
[
-
1
]
+
list
(
value
.
shape
[
1
:]),
"shape"
:
[
-
1
]
+
list
(
value
.
shape
[
1
:]),
"dtype"
:
str
(
value
.
dtype
),
"dtype"
:
str
(
value
.
dtype
),
...
@@ -294,7 +294,8 @@ class Teacher(object):
...
@@ -294,7 +294,8 @@ class Teacher(object):
times (int): The maximum repeated serving times. Default 1. Whenever
times (int): The maximum repeated serving times. Default 1. Whenever
the public method 'get_knowledge_generator()' in Student
the public method 'get_knowledge_generator()' in Student
object called once, the serving times will be added one,
object called once, the serving times will be added one,
until reaching the maximum and ending the service.
until reaching the maximum and ending the service. Only
valid in online mode, and will be ignored in offline mode.
"""
"""
if
not
self
.
_started
:
if
not
self
.
_started
:
raise
ValueError
(
"The method start() should be called first!"
)
raise
ValueError
(
"The method start() should be called first!"
)
...
@@ -339,9 +340,12 @@ class Teacher(object):
...
@@ -339,9 +340,12 @@ class Teacher(object):
if
not
times
>
0
:
if
not
times
>
0
:
raise
ValueError
(
"Repeated serving times should be positive!"
)
raise
ValueError
(
"Repeated serving times should be positive!"
)
self
.
_times
=
times
self
.
_times
=
times
if
self
.
_times
>
1
and
self
.
_out_file
:
self
.
_times
=
1
print
(
"WARNING: args 'times' will be ignored in offline mode"
)
desc
=
{}
desc
=
{}
for
name
,
var
in
schema
.
items
(
):
for
name
,
var
in
list
(
schema
.
items
()
):
if
not
isinstance
(
var
,
fluid
.
framework
.
Variable
):
if
not
isinstance
(
var
,
fluid
.
framework
.
Variable
):
raise
ValueError
(
raise
ValueError
(
"The member of schema must be fluid Variable."
)
"The member of schema must be fluid Variable."
)
...
@@ -412,10 +416,14 @@ class Teacher(object):
...
@@ -412,10 +416,14 @@ class Teacher(object):
else
:
else
:
if
self
.
_knowledge_queue
:
if
self
.
_knowledge_queue
:
self
.
_knowledge_queue
.
put
(
EndSignal
())
self
.
_knowledge_queue
.
put
(
EndSignal
())
# should close file in child thread to wait for all
# writing finished
if
self
.
_out_file
:
self
.
_out_file
.
close
()
# Asynchronous output
# Asynchronous output
out_buf_queue
=
Queue
.
Queue
(
self
.
_buf_size
)
out_buf_queue
=
Queue
.
Queue
(
self
.
_buf_size
)
schema_keys
,
schema_vars
=
zip
(
*
self
.
_schema
.
items
(
))
schema_keys
,
schema_vars
=
zip
(
*
list
(
self
.
_schema
.
items
()
))
out_thread
=
Thread
(
target
=
writer
,
args
=
(
out_buf_queue
,
schema_keys
))
out_thread
=
Thread
(
target
=
writer
,
args
=
(
out_buf_queue
,
schema_keys
))
out_thread
.
daemon
=
True
out_thread
.
daemon
=
True
out_thread
.
start
()
out_thread
.
start
()
...
@@ -424,7 +432,8 @@ class Teacher(object):
...
@@ -424,7 +432,8 @@ class Teacher(object):
self
.
_program
).
with_data_parallel
()
self
.
_program
).
with_data_parallel
()
print
(
"Knowledge description {}"
.
format
(
self
.
_knowledge_desc
))
print
(
"Knowledge description {}"
.
format
(
self
.
_knowledge_desc
))
print
(
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
+
print
(
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
+
" Teacher begins to serve ..."
)
" Teacher begins to serve ..."
)
# For offline dump, write the knowledge description to the head of file
# For offline dump, write the knowledge description to the head of file
if
self
.
_out_file
:
if
self
.
_out_file
:
...
@@ -491,11 +500,10 @@ class Teacher(object):
...
@@ -491,11 +500,10 @@ class Teacher(object):
if
self
.
_knowledge_queue
:
if
self
.
_knowledge_queue
:
self
.
_knowledge_queue
.
join
()
self
.
_knowledge_queue
.
join
()
print
(
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
+
print
(
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
+
" Teacher ends serving."
)
" Teacher ends serving."
)
def
__del__
(
self
):
def
__del__
(
self
):
if
self
.
_manager
:
if
self
.
_manager
:
self
.
_manager
.
shutdown
()
self
.
_manager
.
shutdown
()
if
self
.
_out_file
:
self
.
_out_file
.
close
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录