Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
98a4359f
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 1 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
98a4359f
编写于
4月 08, 2020
作者:
Y
Yibing Liu
提交者:
GitHub
4月 08, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize knowledge transfer in pantheon (#210)
上级
77c64ef4
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
300 addition
and
68 deletion
+300
-68
demo/pantheon/run_teacher1.py
demo/pantheon/run_teacher1.py
+1
-0
paddleslim/pantheon/README.md
paddleslim/pantheon/README.md
+2
-0
paddleslim/pantheon/student.py
paddleslim/pantheon/student.py
+88
-11
paddleslim/pantheon/teacher.py
paddleslim/pantheon/teacher.py
+209
-57
未找到文件。
demo/pantheon/run_teacher1.py
浏览文件 @
98a4359f
...
@@ -72,6 +72,7 @@ def run(args):
...
@@ -72,6 +72,7 @@ def run(args):
program
=
program
,
program
=
program
,
reader_config
=
reader_config
,
reader_config
=
reader_config
,
exe
=
exe
,
exe
=
exe
,
use_fp16
=
True
,
times
=
args
.
serving_times
)
times
=
args
.
serving_times
)
...
...
paddleslim/pantheon/README.md
浏览文件 @
98a4359f
...
@@ -106,6 +106,7 @@ Usually, the public methods of these two classes work in the pairwise way. Their
...
@@ -106,6 +106,7 @@ Usually, the public methods of these two classes work in the pairwise way. Their
<br>
reader_config,
<br>
reader_config,
<br>
exe,
<br>
exe,
<br>
buf_size=10,
<br>
buf_size=10,
<br>
use_fp16=False,
<br>
times=1)
</td>
<br>
times=1)
</td>
<td><strong>
get_knowledge_desc
</strong>
()
</td>
<td><strong>
get_knowledge_desc
</strong>
()
</td>
<td><center>
✅
</center></td>
<td><center>
✅
</center></td>
...
@@ -213,6 +214,7 @@ The toy "knowledge distillation" system can be launched in three different modes
...
@@ -213,6 +214,7 @@ The toy "knowledge distillation" system can be launched in three different modes
```
shell
```
shell
export
PYTHONPATH
=
../../:
$PYTHONPATH
export
PYTHONPATH
=
../../:
$PYTHONPATH
export
CUDA_VISIBLE_DEVICES
=
0,1
export
CUDA_VISIBLE_DEVICES
=
0,1
export
NUM_POSTPROCESS_THREADS
=
10
# default 8
nohup
python
-u
run_teacher1.py
--use_cuda
true
--out_path
teacher1_offline.dat
>
teacher1_offline.log 2>&1&
nohup
python
-u
run_teacher1.py
--use_cuda
true
--out_path
teacher1_offline.dat
>
teacher1_offline.log 2>&1&
export
CUDA_VISIBLE_DEVICES
=
2
export
CUDA_VISIBLE_DEVICES
=
2
nohup
python
-u
run_teacher2.py
--use_cuda
true
--out_path
teacher2_offline.dat
>
teacher2_offline.log 2>&1&
nohup
python
-u
run_teacher2.py
--use_cuda
true
--out_path
teacher2_offline.dat
>
teacher2_offline.log 2>&1&
...
...
paddleslim/pantheon/student.py
浏览文件 @
98a4359f
...
@@ -28,7 +28,7 @@ from multiprocessing.managers import BaseManager
...
@@ -28,7 +28,7 @@ from multiprocessing.managers import BaseManager
from
threading
import
Thread
from
threading
import
Thread
from
paddleslim.pantheon.utils
import
EndSignal
,
SyncSignal
,
StartSignal
,
public_authkey
from
paddleslim.pantheon.utils
import
EndSignal
,
SyncSignal
,
StartSignal
,
public_authkey
,
convert_dtype
__all__
=
[
"Student"
]
__all__
=
[
"Student"
]
...
@@ -114,7 +114,60 @@ class Student(object):
...
@@ -114,7 +114,60 @@ class Student(object):
except
:
except
:
time
.
sleep
(
1.0
)
time
.
sleep
(
1.0
)
knowledge_queue
=
manager
.
get_knowledge_queue
()
def
merge
(
knowledge_queues
):
num
=
len
(
knowledge_queues
)
if
num
==
1
:
return
knowledge_queues
[
0
]
local_queues
=
[
Queue
.
Queue
(
100
)
for
_
in
range
(
num
)]
def
receive
(
queue
,
local_queue
):
while
True
:
data
=
queue
.
get
()
queue
.
task_done
()
local_queue
.
put
(
data
)
if
isinstance
(
data
,
EndSignal
):
break
knowledge_queue
=
Queue
.
Queue
(
100
)
def
gather
(
local_queues
,
knowledge_queue
):
num
=
len
(
local_queues
)
end_received
=
False
while
True
:
for
i
in
range
(
num
):
data
=
local_queues
[
i
].
get
()
local_queues
[
i
].
task_done
()
if
isinstance
(
data
,
SyncSignal
)
and
i
>
0
:
continue
elif
isinstance
(
data
,
EndSignal
):
end_received
=
True
knowledge_queue
.
put
(
data
)
if
end_received
:
break
# threads to receive knowledge from the online teacher
for
i
in
range
(
num
):
p
=
Thread
(
target
=
receive
,
args
=
(
knowledge_queues
[
i
],
local_queues
[
i
]))
p
.
daemon
=
True
p
.
start
()
# thread to gather data from different local queues
p
=
Thread
(
target
=
gather
,
args
=
(
local_queues
,
knowledge_queue
))
p
.
daemon
=
True
p
.
start
()
return
knowledge_queue
# get knowledge queues
knowledge_queues
,
idx
=
[],
0
while
True
:
q
=
manager
.
get_knowledge_queue
(
idx
)
if
hasattr
(
q
,
"get"
):
knowledge_queues
.
append
(
q
)
idx
+=
1
else
:
break
knowledge_queue
=
merge
(
knowledge_queues
)
self
.
_t2s_queues
.
append
(
manager
.
get_t2s_queue
())
self
.
_t2s_queues
.
append
(
manager
.
get_t2s_queue
())
self
.
_s2t_queues
.
append
(
manager
.
get_s2t_queue
())
self
.
_s2t_queues
.
append
(
manager
.
get_s2t_queue
())
self
.
_cmd_queues
.
append
(
manager
.
get_cmd_queue
())
self
.
_cmd_queues
.
append
(
manager
.
get_cmd_queue
())
...
@@ -237,6 +290,10 @@ class Student(object):
...
@@ -237,6 +290,10 @@ class Student(object):
knowledge
[
k
]
=
result
knowledge
[
k
]
=
result
elif
self
.
_merge_strategy
[
k
]
==
"mean"
:
elif
self
.
_merge_strategy
[
k
]
==
"mean"
:
knowledge
[
k
]
=
result
/
len
(
tensors
)
knowledge
[
k
]
=
result
/
len
(
tensors
)
# cast back to original data type if necessary
tgt_dtype
=
self
.
_knowledge_desc
[
k
][
"dtype"
]
if
str
(
knowledge
[
k
].
dtype
)
!=
tgt_dtype
:
knowledge
[
k
]
=
knowledge
[
k
].
astype
(
tgt_dtype
)
return
knowledge
return
knowledge
def
send
(
self
,
data
,
teacher_ids
=
None
):
def
send
(
self
,
data
,
teacher_ids
=
None
):
...
@@ -383,11 +440,23 @@ class Student(object):
...
@@ -383,11 +440,23 @@ class Student(object):
[
batches
[
i
][
key
]
for
i
in
range
(
len
(
batches
))])
[
batches
[
i
][
key
]
for
i
in
range
(
len
(
batches
))])
return
ret_batch
return
ret_batch
def
listen
(
in_queue
,
out_queue
,
batch_size
):
def
listen
(
knowledge_queue
,
out_queue
):
"""
listen on the knowledge queue for one teacher, get knowledge data
and put it into a local queue (out_queue).
"""
while
True
:
data
=
knowledge_queue
.
get
()
knowledge_queue
.
task_done
()
out_queue
.
put
(
data
)
if
isinstance
(
data
,
EndSignal
):
break
def
make_new_batch
(
in_queue
,
out_queue
,
batch_size
):
"""
"""
listen on the knowledge queue for one teacher, get knowledge
Get knowledge data from a local queue and make a new batch data in
data and make a new batch data in the batch size of student,
the batch size of student, then put it into the intermediate
then put it into the intermediate
queue (out_queue).
queue (out_queue).
"""
"""
batches
,
num_samples
=
[],
0
batches
,
num_samples
=
[],
0
while
True
:
while
True
:
...
@@ -467,17 +536,25 @@ class Student(object):
...
@@ -467,17 +536,25 @@ class Student(object):
queue
.
put
(
StartSignal
())
queue
.
put
(
StartSignal
())
queue
.
join
()
queue
.
join
()
# launch
multiple
threads to listen on all knowledge queues
# launch threads to listen on all knowledge queues
med
_queues
=
[
Queue
.
Queue
(
100
)
for
i
in
range
(
self
.
_num_teachers
)]
local
_queues
=
[
Queue
.
Queue
(
100
)
for
i
in
range
(
self
.
_num_teachers
)]
for
i
in
range
(
self
.
_num_teachers
):
for
i
in
range
(
self
.
_num_teachers
):
listen_thread
=
Thread
(
listen_thread
=
Thread
(
target
=
listen
,
target
=
listen
,
args
=
(
self
.
_teacher_knowledge_queues
[
i
],
med_queues
[
i
],
args
=
(
self
.
_teacher_knowledge_queues
[
i
],
local_queues
[
i
]))
self
.
_batch_size
))
listen_thread
.
dameon
=
True
listen_thread
.
start
()
# launch threads to make new batch for student
med_queues
=
[
Queue
.
Queue
(
100
)
for
i
in
range
(
self
.
_num_teachers
)]
for
i
in
range
(
self
.
_num_teachers
):
listen_thread
=
Thread
(
target
=
make_new_batch
,
args
=
(
local_queues
[
i
],
med_queues
[
i
],
self
.
_batch_size
))
listen_thread
.
dameon
=
True
listen_thread
.
dameon
=
True
listen_thread
.
start
()
listen_thread
.
start
()
# launch another thread to merge knowledge
# launch another thread to merge knowledge
from different teachers.
merge_thread
=
Thread
(
merge_thread
=
Thread
(
target
=
gather_and_merge
,
args
=
(
med_queues
,
self
.
_knowledge_queue
))
target
=
gather_and_merge
,
args
=
(
med_queues
,
self
.
_knowledge_queue
))
merge_thread
.
dameon
=
True
merge_thread
.
dameon
=
True
...
...
paddleslim/pantheon/teacher.py
浏览文件 @
98a4359f
...
@@ -35,7 +35,11 @@ from paddleslim.pantheon.utils import convert_dtype, EndSignal, SyncSignal, Star
...
@@ -35,7 +35,11 @@ from paddleslim.pantheon.utils import convert_dtype, EndSignal, SyncSignal, Star
__all__
=
[
"Teacher"
]
__all__
=
[
"Teacher"
]
knowledge_queue
=
Queue
.
Queue
(
100
)
# Num of threads for post-processing, including generating and transferring
# knowledge data
num_postprocess_threads
=
int
(
os
.
getenv
(
"NUM_POSTPROCESS_THREADS"
,
8
))
knowledge_queues
=
[
Queue
.
Queue
(
100
)
for
i
in
range
(
num_postprocess_threads
)]
t2s_queue
=
Queue
.
Queue
(
100
)
t2s_queue
=
Queue
.
Queue
(
100
)
s2t_queue
=
Queue
.
Queue
(
100
)
s2t_queue
=
Queue
.
Queue
(
100
)
cmd_queue
=
Queue
.
Queue
(
5
)
cmd_queue
=
Queue
.
Queue
(
5
)
...
@@ -75,6 +79,84 @@ class MixedDataReader(object):
...
@@ -75,6 +79,84 @@ class MixedDataReader(object):
self
.
_tail_data
=
[]
self
.
_tail_data
=
[]
class
WorkerParallel
(
object
):
"""
Process data from the input queue by given worker in parallel, and put the
result into output queue in order.
Args:
num_postprocess_threads (int): Number of threads for data processing.
in_queue (object): The input queue.
out_queue (object|list): The output queue(s). Its length should be equal
to arg 'num_postprocess_threads' when it is a list.
"""
def
__init__
(
self
,
num_postprocess_threads
,
in_queue
,
out_queue
):
self
.
_num_postprocess_threads
=
num_postprocess_threads
self
.
_in_queue
=
in_queue
self
.
_local_in_queues
=
[
Queue
.
Queue
(
5
)
for
i
in
range
(
num_postprocess_threads
)
]
if
isinstance
(
out_queue
,
list
):
if
len
(
out_queue
)
!=
num_postprocess_threads
:
raise
ValueError
(
"When out_queue is a list, its length must "
"equal to num_postprocess_threads!"
)
self
.
_local_out_queues
=
out_queue
self
.
_out_queue
=
None
else
:
self
.
_local_out_queues
=
[
Queue
.
Queue
(
5
)
for
i
in
range
(
num_postprocess_threads
)
]
self
.
_out_queue
=
out_queue
def
_distribute
(
self
):
def
func
():
idx
=
0
while
True
:
data
=
self
.
_in_queue
.
get
()
self
.
_in_queue
.
task_done
()
if
not
isinstance
(
data
,
EndSignal
):
self
.
_local_in_queues
[
idx
%
self
.
_num_postprocess_threads
].
put
(
data
)
idx
+=
1
else
:
for
q
in
self
.
_local_in_queues
:
q
.
put
(
EndSignal
())
t
=
Thread
(
target
=
func
)
t
.
daemon
=
True
t
.
start
()
def
_run
(
self
,
worker
,
args
):
for
i
in
range
(
self
.
_num_postprocess_threads
):
t
=
Thread
(
target
=
worker
,
args
=
(
self
.
_local_in_queues
[
i
],
self
.
_local_out_queues
[
i
])
+
args
)
t
.
daemon
=
True
t
.
start
()
def
_gather
(
self
):
def
func
():
while
True
:
for
idx
,
q
in
enumerate
(
self
.
_local_out_queues
):
data
=
q
.
get
()
q
.
task_done
()
if
isinstance
(
data
,
EndSignal
)
and
idx
>
0
:
continue
self
.
_out_queue
.
put
(
data
)
t
=
Thread
(
target
=
func
)
t
.
daemon
=
True
t
.
start
()
def
__call__
(
self
,
worker
,
args
):
self
.
_distribute
()
self
.
_run
(
worker
,
args
)
if
self
.
_out_queue
:
self
.
_gather
()
class
Teacher
(
object
):
class
Teacher
(
object
):
"""
"""
The class defined for the teacher model. Generate knowledge data and
The class defined for the teacher model. Generate knowledge data and
...
@@ -102,9 +184,12 @@ class Teacher(object):
...
@@ -102,9 +184,12 @@ class Teacher(object):
self
.
_started
=
False
self
.
_started
=
False
def
_start_manager
(
self
):
def
_start_manager
(
self
):
def
get_knowledge_queue
():
def
get_knowledge_queue
(
idx
):
global
knowledge_queue
global
knowledge_queues
return
knowledge_queue
if
idx
<
len
(
knowledge_queues
):
return
knowledge_queues
[
idx
]
else
:
return
None
def
get_s2t_queue
():
def
get_s2t_queue
():
global
s2t_queue
global
s2t_queue
...
@@ -141,12 +226,17 @@ class Teacher(object):
...
@@ -141,12 +226,17 @@ class Teacher(object):
self
.
_started
=
True
self
.
_started
=
True
self
.
_manager
=
self
.
_start_manager
()
if
self
.
_out_port
else
None
self
.
_manager
=
self
.
_start_manager
()
if
self
.
_out_port
else
None
if
self
.
_manager
:
if
self
.
_manager
:
self
.
_knowledge_queue
=
self
.
_manager
.
get_knowledge_queue
()
self
.
_knowledge_queues
=
[
self
.
_manager
.
get_knowledge_queue
(
i
)
for
i
in
range
(
num_postprocess_threads
)
]
print
(
"Num of knowledge queues: {}"
.
format
(
num_postprocess_threads
))
self
.
_s2t_queue
=
self
.
_manager
.
get_s2t_queue
()
self
.
_s2t_queue
=
self
.
_manager
.
get_s2t_queue
()
self
.
_t2s_queue
=
self
.
_manager
.
get_t2s_queue
()
self
.
_t2s_queue
=
self
.
_manager
.
get_t2s_queue
()
self
.
_cmd_queue
=
self
.
_manager
.
get_cmd_queue
()
self
.
_cmd_queue
=
self
.
_manager
.
get_cmd_queue
()
else
:
else
:
self
.
_knowledge_queue
=
None
self
.
_knowledge_queue
s
=
None
self
.
_s2t_queue
=
None
self
.
_s2t_queue
=
None
self
.
_t2s_queue
=
None
self
.
_t2s_queue
=
None
self
.
_cmd_queue
=
None
self
.
_cmd_queue
=
None
...
@@ -173,8 +263,9 @@ class Teacher(object):
...
@@ -173,8 +263,9 @@ class Teacher(object):
while
True
:
while
True
:
if
self
.
_sync_required
:
if
self
.
_sync_required
:
self
.
_knowledge_queue
.
put
(
SyncSignal
())
for
q
in
self
.
_knowledge_queues
:
self
.
_knowledge_queue
.
join
()
q
.
put
(
SyncSignal
())
q
.
join
()
self
.
_sync_required
=
False
self
.
_sync_required
=
False
break
break
...
@@ -256,6 +347,7 @@ class Teacher(object):
...
@@ -256,6 +347,7 @@ class Teacher(object):
reader_config
,
reader_config
,
exe
,
exe
,
buf_size
=
10
,
buf_size
=
10
,
use_fp16
=
False
,
times
=
1
):
times
=
1
):
"""
"""
Start the knowledge service to generate and transfer knowledge data.
Start the knowledge service to generate and transfer knowledge data.
...
@@ -291,6 +383,11 @@ class Teacher(object):
...
@@ -291,6 +383,11 @@ class Teacher(object):
exe (fluid.Executor): The executor to run the input program.
exe (fluid.Executor): The executor to run the input program.
buf_size (int): The size of buffers for data reader and knowledge
buf_size (int): The size of buffers for data reader and knowledge
writer on each device.
writer on each device.
use_fp16 (bool): Whether to transfer/store knowledge data in float16
if their data type is float32/float64. In the offline
mode, it will reduce the size of dumped knowledge file,
and in the online mode, it will speedup the online
transfer, with the sacrifice in precision . Default False.
times (int): The maximum repeated serving times. Default 1. Whenever
times (int): The maximum repeated serving times. Default 1. Whenever
the public method 'get_knowledge_generator()' in Student
the public method 'get_knowledge_generator()' in Student
object called once, the serving times will be added one,
object called once, the serving times will be added one,
...
@@ -333,6 +430,8 @@ class Teacher(object):
...
@@ -333,6 +430,8 @@ class Teacher(object):
raise
ValueError
(
"Input argument should be a fluid Executor!"
)
raise
ValueError
(
"Input argument should be a fluid Executor!"
)
self
.
_exe
=
exe
self
.
_exe
=
exe
self
.
_use_fp16
=
use_fp16
if
not
buf_size
>
0
:
if
not
buf_size
>
0
:
raise
ValueError
(
"The buffer size should be positive!"
)
raise
ValueError
(
"The buffer size should be positive!"
)
self
.
_buf_size
=
buf_size
self
.
_buf_size
=
buf_size
...
@@ -402,84 +501,136 @@ class Teacher(object):
...
@@ -402,84 +501,136 @@ class Teacher(object):
"generator type, which should be one of 'sample_generator', "
"generator type, which should be one of 'sample_generator', "
"'sample_list_generator', and 'batch_generator'."
)
"'sample_list_generator', and 'batch_generator'."
)
def
writer
(
buf_queue
,
schema_keys
):
def
cast2fp16
(
know
):
samples_sent
,
batches_sent
=
0
,
0
for
k
,
v
in
list
(
know
.
items
()):
if
not
isinstance
(
v
,
np
.
ndarray
):
break
if
v
.
dtype
==
np
.
float32
or
v
.
dtype
==
np
.
float64
:
v
=
v
.
astype
(
"float16"
)
know
[
k
]
=
v
return
know
feed_var_names
=
[
var
.
name
for
var
in
self
.
_feed_list
]
schema_in_feed
,
schema_in_fetch
=
{},
{}
for
k
,
v
in
list
(
self
.
_schema
.
items
()):
if
k
in
feed_var_names
:
schema_in_feed
[
k
]
=
v
else
:
schema_in_fetch
[
k
]
=
v
schema_in_fetch_keys
,
schema_in_fetch_vars
=
zip
(
*
list
(
schema_in_fetch
.
items
()))
def
know_maker
(
in_queue
,
out_queue
,
use_fp16
):
while
True
:
while
True
:
outputs
=
buf_queue
.
get
()
data
=
in_queue
.
get
()
buf_queue
.
task_done
()
in_queue
.
task_done
()
if
not
isinstance
(
outputs
,
EndSignal
):
if
isinstance
(
data
,
tuple
):
batch_samples
=
dict
(
zip
(
schema_keys
,
outputs
))
dev_batches
,
outputs
=
data
if
self
.
_knowledge_queue
:
know
=
{}
self
.
_knowledge_queue
.
put
(
batch_samples
)
for
k
in
schema_in_feed
.
keys
():
if
self
.
_out_file
:
batch_know
=
[
self
.
_out_file
.
write
(
pickle
.
dumps
(
batch_samples
))
np
.
array
(
batch
[
k
])
for
batch
in
dev_batches
]
know
[
k
]
=
np
.
concatenate
(
batch_know
)
know
.
update
(
dict
(
zip
(
schema_in_fetch_keys
,
outputs
)))
if
use_fp16
:
know
=
cast2fp16
(
know
)
out_queue
.
put
(
know
)
else
:
else
:
if
self
.
_knowledge_queue
:
# forward other types of data directly (maybe knowledge desc or EndSignal)
self
.
_knowledge_queue
.
put
(
EndSignal
())
out_queue
.
put
(
data
)
# should close file in child thread to wait for all
# writing finished
know_make_queue
=
Queue
.
Queue
(
self
.
_buf_size
)
if
self
.
_out_file
:
if
self
.
_out_file
:
# For offline dump, write the knowledge description to the head of file
self
.
_out_file
.
write
(
pickle
.
dumps
(
self
.
_knowledge_desc
))
print
(
"output path: %s"
%
self
.
_out_path
)
offline_write_queue
=
Queue
.
Queue
(
self
.
_buf_size
)
def
offline_write
(
queue
):
while
True
:
know
=
queue
.
get
()
queue
.
task_done
()
if
not
isinstance
(
know
,
EndSignal
):
self
.
_out_file
.
write
(
pickle
.
dumps
(
know
))
else
:
# should close file in child thread to wait for all
# writing finished
self
.
_out_file
.
close
()
self
.
_out_file
.
close
()
# Asynchronous output
t
=
Thread
(
target
=
offline_write
,
args
=
(
offline_write_queue
,
))
out_buf_queue
=
Queue
.
Queue
(
self
.
_buf_size
)
t
.
daemon
=
True
schema_keys
,
schema_vars
=
zip
(
*
list
(
self
.
_schema
.
items
()))
t
.
start
()
out_thread
=
Thread
(
target
=
writer
,
args
=
(
out_buf_queue
,
schema_keys
))
make_knowledge
=
WorkerParallel
(
out_thread
.
daemon
=
True
num_postprocess_threads
,
know_make_queue
,
offline_write_queue
)
out_thread
.
start
()
if
self
.
_knowledge_queues
:
make_knowledge
=
WorkerParallel
(
num_postprocess_threads
,
know_make_queue
,
self
.
_knowledge_queues
)
make_knowledge
(
worker
=
know_maker
,
args
=
(
self
.
_use_fp16
,
))
compiled_program
=
fluid
.
compiler
.
CompiledProgram
(
compiled_program
=
fluid
.
compiler
.
CompiledProgram
(
self
.
_program
).
with_data_parallel
()
self
.
_program
).
with_data_parallel
()
print
(
"Knowledge description {}"
.
format
(
self
.
_knowledge_desc
))
print
(
"Knowledge description {}"
.
format
(
self
.
_knowledge_desc
))
print
(
print
(
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
+
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
+
" Teacher begins to serve ..."
)
" Teacher begins to serve ..."
)
# For offline dump, write the knowledge description to the head of file
if
self
.
_out_file
:
self
.
_out_file
.
write
(
pickle
.
dumps
(
self
.
_knowledge_desc
))
print
(
"output path: %s"
%
self
.
_out_path
)
data_reader
=
MixedDataReader
(
data_loader
,
dev_count
)
data_reader
=
MixedDataReader
(
data_loader
,
dev_count
)
# For online mode, send knowledge description every time
# For online mode, send knowledge description every time
for
repeated
in
range
(
self
.
_times
):
for
repeated
in
range
(
self
.
_times
):
if
self
.
_knowledge_queue
:
if
self
.
_knowledge_queue
s
:
# wait for the accessing of knowledge desc and data
# wait for the accessing of knowledge desc and data
while
True
:
while
True
:
if
self
.
_sync_required
:
if
self
.
_sync_required
:
self
.
_knowledge_queue
.
put
(
SyncSignal
())
for
q
in
self
.
_knowledge_queues
:
self
.
_knowledge_queue
.
put
(
self
.
_knowledge_desc
)
q
.
put
(
SyncSignal
())
know_make_queue
.
put
(
self
.
_knowledge_desc
)
self
.
_sync_required
=
False
self
.
_sync_required
=
False
if
self
.
_data_required
:
if
self
.
_data_required
:
self
.
_data_required
=
False
self
.
_data_required
=
False
break
break
self
.
_knowledge_queue
.
join
()
for
q
in
self
.
_knowledge_queues
:
q
.
join
()
print
(
"No.{} time serving ... "
.
format
(
repeated
))
print
(
"No.{} time serving ... "
.
format
(
repeated
))
num_batches_sent
=
0
num_batches_sent
=
0
for
dev_batches
in
data_reader
.
multi_dev_generator
():
for
index
,
dev_batches
in
enumerate
(
data_reader
.
multi_dev_generator
()):
if
self
.
_sync_required
:
if
self
.
_sync_required
:
break
break
tic
=
time
.
time
()
outputs
=
self
.
_exe
.
run
(
compiled_program
,
outputs
=
self
.
_exe
.
run
(
compiled_program
,
feed
=
dev_batches
,
feed
=
dev_batches
,
fetch_list
=
schema_vars
)
fetch_list
=
schema_in_fetch_vars
)
out_buf_queue
.
put
(
outputs
)
toc
=
time
.
time
()
print
(
"teacher predict time = {}"
.
format
(
toc
-
tic
))
know_make_queue
.
put
((
dev_batches
,
outputs
))
#out_buf_queue.put(know)
tic
=
time
.
time
()
print
(
"teacher out time = {}"
.
format
(
tic
-
toc
))
num_batches_sent
+=
dev_count
num_batches_sent
+=
dev_count
if
num_batches_sent
%
(
100
*
dev_count
)
==
0
:
if
num_batches_sent
%
(
100
*
dev_count
)
==
0
:
log
=
"Processed {} batch samples."
.
format
(
log
=
"Processed {} batch samples."
.
format
(
num_batches_sent
)
num_batches_sent
)
if
self
.
_knowledge_queue
:
if
self
.
_knowledge_queues
:
log
+=
" Knowledge queue size {}."
.
format
(
qsize
=
0
self
.
_knowledge_queue
.
qsize
())
for
q
in
self
.
_knowledge_queues
:
qsize
+=
q
.
qsize
()
log
+=
" Knowledge queue size {}."
.
format
(
qsize
)
print
(
log
)
print
(
log
)
outputs
=
[]
dev_batches
,
outputs
=
[],
[]
for
index
,
batch
in
enumerate
(
data_reader
.
tail_generator
()):
for
index
,
batch
in
enumerate
(
data_reader
.
tail_generator
()):
if
self
.
_sync_required
:
if
self
.
_sync_required
:
break
break
dev_batches
.
append
(
batch
)
output
=
self
.
_exe
.
run
(
self
.
_program
,
output
=
self
.
_exe
.
run
(
self
.
_program
,
feed
=
batch
,
feed
=
batch
,
fetch_list
=
schema_vars
)
fetch_list
=
schema_
in_fetch_
vars
)
if
outputs
:
if
outputs
:
outputs
=
[
outputs
=
[
np
.
concatenate
(
np
.
concatenate
(
...
@@ -488,21 +639,22 @@ class Teacher(object):
...
@@ -488,21 +639,22 @@ class Teacher(object):
]
]
else
:
else
:
outputs
=
copy
.
deepcopy
(
output
)
outputs
=
copy
.
deepcopy
(
output
)
if
outputs
:
if
dev_batches
or
outputs
:
out_buf_queue
.
put
(
outputs
)
know_make_queue
.
put
((
dev_batches
,
outputs
))
#out_buf_queue.put(know)
num_batches_sent
+=
(
index
+
1
)
num_batches_sent
+=
(
index
+
1
)
print
(
"Processed {} batch samples in total."
.
format
(
print
(
"Processed {} batch samples in total."
.
format
(
num_batches_sent
))
num_batches_sent
))
out_buf
_queue
.
put
(
EndSignal
())
know_make
_queue
.
put
(
EndSignal
())
out_buf
_queue
.
join
()
know_make
_queue
.
join
()
if
self
.
_knowledge_queue
:
if
self
.
_knowledge_queue
s
:
self
.
_knowledge_queue
.
join
()
for
q
in
self
.
_knowledge_queues
:
print
(
q
.
join
()
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
+
print
(
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
+
" Teacher ends serving."
)
" Teacher ends serving."
)
def
__del__
(
self
):
def
__del__
(
self
):
if
self
.
_manager
:
if
self
.
_manager
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录