Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
91a3b648
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
接近 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
91a3b648
编写于
2月 19, 2020
作者:
Y
Yibing Liu
提交者:
GitHub
2月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update pantheon in release 1.0.0 (#124)
上级
42cded62
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
32 addition
and
22 deletion
+32
-22
demo/pantheon/run_student.py
demo/pantheon/run_student.py
+2
-2
docs/zh_cn/api_cn/pantheon_api.md
docs/zh_cn/api_cn/pantheon_api.md
+3
-2
paddleslim/__init__.py
paddleslim/__init__.py
+2
-1
paddleslim/pantheon/README.md
paddleslim/pantheon/README.md
+1
-1
paddleslim/pantheon/student.py
paddleslim/pantheon/student.py
+5
-5
paddleslim/pantheon/teacher.py
paddleslim/pantheon/teacher.py
+19
-11
未找到文件。
demo/pantheon/run_student.py
浏览文件 @
91a3b648
...
...
@@ -80,8 +80,8 @@ def run(args):
student
.
start
()
if
args
.
test_send_recv
:
for
t
in
x
range
(
2
):
for
i
in
x
range
(
3
):
for
t
in
range
(
2
):
for
i
in
range
(
3
):
print
(
student
.
recv
(
t
))
student
.
send
(
"message from student!"
)
...
...
docs/zh_cn/api_cn/pantheon_api.md
浏览文件 @
91a3b648
#
多进程蒸馏
#
大规模可扩展知识蒸馏框架 Pantheon
## Teacher
...
...
@@ -100,7 +100,8 @@ pantheon.Teacher.start\_knowledge\_service(feed\_list, schema, program, reader\_
-
**times (int):**
The maximum repeated serving times, default 1. Whenever
the public method
**get\_knowledge\_generator()**
in
**Student**
object called once, the serving times will be added one,
until reaching the maximum and ending the service.
until reaching the maximum and ending the service. Only
valid in online mode, and will be ignored in offline mode.
**Return:**
None
...
...
paddleslim/__init__.py
浏览文件 @
91a3b648
...
...
@@ -19,4 +19,5 @@ from paddleslim import nas
from
paddleslim
import
analysis
from
paddleslim
import
dist
from
paddleslim
import
quant
__all__
=
[
'models'
,
'prune'
,
'nas'
,
'analysis'
,
'dist'
,
'quant'
]
from
paddleslim
import
pantheon
__all__
=
[
'models'
,
'prune'
,
'nas'
,
'analysis'
,
'dist'
,
'quant'
,
'pantheon'
]
paddleslim/pantheon/README.md
浏览文件 @
91a3b648
...
...
@@ -13,7 +13,7 @@ The illustration below shows an application of Pantheon, where the sudent model
## Prerequisites
-
Python 2.7.x or 3.x
-
PaddlePaddle >= 1.
6
.0
-
PaddlePaddle >= 1.
7
.0
## APIs
...
...
paddleslim/pantheon/student.py
浏览文件 @
91a3b648
...
...
@@ -158,7 +158,7 @@ class Student(object):
if
end_recved
:
break
with
open
(
in_path
,
'r'
)
as
fin
:
with
open
(
in_path
,
'r
b
'
)
as
fin
:
# get knowledge desc
desc
=
pickle
.
load
(
fin
)
out_queue
.
put
(
desc
)
...
...
@@ -222,7 +222,7 @@ class Student(object):
self
.
_started
=
True
def
_merge_knowledge
(
self
,
knowledge
):
for
k
,
tensors
in
knowledge
.
items
(
):
for
k
,
tensors
in
list
(
knowledge
.
items
()
):
if
len
(
tensors
)
==
0
:
del
knowledge
[
k
]
elif
len
(
tensors
)
==
1
:
...
...
@@ -308,7 +308,7 @@ class Student(object):
print
(
"Knowledge merging strategy: {}"
.
format
(
self
.
_merge_strategy
))
print
(
"Knowledge description after merging:"
)
for
schema
,
desc
in
knowledge_desc
.
items
(
):
for
schema
,
desc
in
list
(
knowledge_desc
.
items
()
):
print
(
"{}: {}"
.
format
(
schema
,
desc
))
self
.
_knowledge_desc
=
knowledge_desc
...
...
@@ -426,13 +426,13 @@ class Student(object):
end_received
=
[
0
]
*
len
(
queues
)
while
True
:
knowledge
=
OrderedDict
(
[(
k
,
[])
for
k
,
v
in
self
.
_knowledge_desc
.
items
(
)])
[(
k
,
[])
for
k
,
v
in
list
(
self
.
_knowledge_desc
.
items
()
)])
for
idx
,
receiver
in
enumerate
(
data_receivers
):
if
not
end_received
[
idx
]:
batch_samples
=
receiver
.
next
(
)
if
six
.
PY2
else
receiver
.
__next__
()
if
not
isinstance
(
batch_samples
,
EndSignal
):
for
k
,
v
in
batch_samples
.
items
(
):
for
k
,
v
in
list
(
batch_samples
.
items
()
):
knowledge
[
k
].
append
(
v
)
else
:
end_received
[
idx
]
=
1
...
...
paddleslim/pantheon/teacher.py
浏览文件 @
91a3b648
...
...
@@ -151,7 +151,7 @@ class Teacher(object):
self
.
_t2s_queue
=
None
self
.
_cmd_queue
=
None
self
.
_out_file
=
open
(
self
.
_out_path
,
"w"
)
if
self
.
_out_path
else
None
self
.
_out_file
=
open
(
self
.
_out_path
,
"w
b
"
)
if
self
.
_out_path
else
None
if
self
.
_out_file
:
return
...
...
@@ -231,7 +231,7 @@ class Teacher(object):
"The knowledge data should be a dict or OrderedDict!"
)
knowledge_desc
=
{}
for
name
,
value
in
knowledge
.
items
(
):
for
name
,
value
in
list
(
knowledge
.
items
()
):
knowledge_desc
[
name
]
=
{
"shape"
:
[
-
1
]
+
list
(
value
.
shape
[
1
:]),
"dtype"
:
str
(
value
.
dtype
),
...
...
@@ -294,7 +294,8 @@ class Teacher(object):
times (int): The maximum repeated serving times. Default 1. Whenever
the public method 'get_knowledge_generator()' in Student
object called once, the serving times will be added one,
until reaching the maximum and ending the service.
until reaching the maximum and ending the service. Only
valid in online mode, and will be ignored in offline mode.
"""
if
not
self
.
_started
:
raise
ValueError
(
"The method start() should be called first!"
)
...
...
@@ -339,9 +340,12 @@ class Teacher(object):
if
not
times
>
0
:
raise
ValueError
(
"Repeated serving times should be positive!"
)
self
.
_times
=
times
if
self
.
_times
>
1
and
self
.
_out_file
:
self
.
_times
=
1
print
(
"WARNING: args 'times' will be ignored in offline mode"
)
desc
=
{}
for
name
,
var
in
schema
.
items
(
):
for
name
,
var
in
list
(
schema
.
items
()
):
if
not
isinstance
(
var
,
fluid
.
framework
.
Variable
):
raise
ValueError
(
"The member of schema must be fluid Variable."
)
...
...
@@ -412,10 +416,14 @@ class Teacher(object):
else
:
if
self
.
_knowledge_queue
:
self
.
_knowledge_queue
.
put
(
EndSignal
())
# should close file in child thread to wait for all
# writing finished
if
self
.
_out_file
:
self
.
_out_file
.
close
()
# Asynchronous output
out_buf_queue
=
Queue
.
Queue
(
self
.
_buf_size
)
schema_keys
,
schema_vars
=
zip
(
*
self
.
_schema
.
items
(
))
schema_keys
,
schema_vars
=
zip
(
*
list
(
self
.
_schema
.
items
()
))
out_thread
=
Thread
(
target
=
writer
,
args
=
(
out_buf_queue
,
schema_keys
))
out_thread
.
daemon
=
True
out_thread
.
start
()
...
...
@@ -424,8 +432,9 @@ class Teacher(object):
self
.
_program
).
with_data_parallel
()
print
(
"Knowledge description {}"
.
format
(
self
.
_knowledge_desc
))
print
(
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
+
" Teacher begins to serve ..."
)
print
(
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
+
" Teacher begins to serve ..."
)
# For offline dump, write the knowledge description to the head of file
if
self
.
_out_file
:
self
.
_out_file
.
write
(
pickle
.
dumps
(
self
.
_knowledge_desc
))
...
...
@@ -491,11 +500,10 @@ class Teacher(object):
if
self
.
_knowledge_queue
:
self
.
_knowledge_queue
.
join
()
print
(
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
+
" Teacher ends serving."
)
print
(
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
+
" Teacher ends serving."
)
def
__del__
(
self
):
if
self
.
_manager
:
self
.
_manager
.
shutdown
()
if
self
.
_out_file
:
self
.
_out_file
.
close
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录