Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
76710e5f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
76710e5f
编写于
7月 29, 2021
作者:
K
Kaipeng Deng
提交者:
GitHub
7月 29, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add persistent_workers (#34017)
* add persistent_workers. test=develop
上级
b451ff26
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
208 addition
and
79 deletion
+208
-79
python/paddle/fluid/dataloader/dataloader_iter.py
python/paddle/fluid/dataloader/dataloader_iter.py
+104
-24
python/paddle/fluid/dataloader/worker.py
python/paddle/fluid/dataloader/worker.py
+11
-0
python/paddle/fluid/reader.py
python/paddle/fluid/reader.py
+11
-1
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dynamic.py
...d/tests/unittests/test_multiprocess_dataloader_dynamic.py
+21
-14
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py
.../test_multiprocess_dataloader_iterable_dataset_dynamic.py
+19
-12
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py
...s/test_multiprocess_dataloader_iterable_dataset_static.py
+20
-13
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py
...id/tests/unittests/test_multiprocess_dataloader_static.py
+22
-15
未找到文件。
python/paddle/fluid/dataloader/dataloader_iter.py
浏览文件 @
76710e5f
...
...
@@ -37,7 +37,8 @@ from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher
from
.batch_sampler
import
_InfiniteIterableSampler
from
.collate
import
default_collate_fn
,
default_convert_fn
from
.worker
import
ParentWatchDog
,
get_worker_info
,
_worker_loop
,
\
_DatasetKind
,
_IterableDatasetStopIteration
,
_WorkerException
_DatasetKind
,
_IterableDatasetStopIteration
,
_WorkerException
,
\
_ResumeIteration
from
.flat
import
_flatten_batch
,
_restore_batch
__all__
=
[
'get_worker_info'
]
...
...
@@ -67,15 +68,10 @@ class _DataLoaderIterBase(object):
self
.
_dataset_kind
=
loader
.
dataset_kind
self
.
_pin_memory
=
loader
.
pin_memory
self
.
_sampler_iter
=
iter
(
self
.
_index_sampler
)
if
self
.
_auto_collate_batch
:
self
.
_sampler_iter
=
iter
(
loader
.
batch_sampler
)
self
.
_collate_fn
=
loader
.
collate_fn
or
default_collate_fn
else
:
if
self
.
_dataset_kind
==
_DatasetKind
.
MAP
:
self
.
_sampler_iter
=
iter
(
list
(
range
(
len
(
self
.
_dataset
))))
else
:
self
.
_sampler_iter
=
iter
(
_InfiniteIterableSampler
(
self
.
_dataset
,
1
))
self
.
_collate_fn
=
loader
.
collate_fn
or
default_convert_fn
# LoDTensorBlockingQueue instance for create_py_reader and a thread
...
...
@@ -87,6 +83,16 @@ class _DataLoaderIterBase(object):
self
.
_thread
=
None
self
.
_thread_done_event
=
threading
.
Event
()
@
property
def
_index_sampler
(
self
):
if
self
.
_auto_collate_batch
:
return
self
.
_batch_sampler
else
:
if
self
.
_dataset_kind
==
_DatasetKind
.
MAP
:
return
list
(
range
(
len
(
self
.
_dataset
)))
else
:
return
_InfiniteIterableSampler
(
self
.
_dataset
,
1
)
def
__iter__
(
self
):
return
self
...
...
@@ -242,6 +248,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
def
__init__
(
self
,
loader
):
super
(
_DataLoaderIterMultiProcess
,
self
).
__init__
(
loader
)
self
.
_persistent_workers
=
loader
.
_persistent_workers
self
.
_resume_worker_cnt
=
0
assert
self
.
_num_workers
>
0
,
"Multi-process DataLoader "
\
"invalid num_workers({})"
.
format
(
self
.
_num_workers
)
...
...
@@ -336,13 +345,65 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self
.
_pin_memory
)
self
.
_thread_done_event
=
threading
.
Event
()
# thread event is only need in multi-processing mode
self
.
_thread
=
threading
.
Thread
(
target
=
self
.
_thread_loop
,
args
=
(
_current_expected_place
(),
))
self
.
_thread
.
daemon
=
True
self
.
_thread
.
start
()
def
_shutdown_worker
(
self
,
worker_id
):
if
self
.
_worker_status
[
worker_id
]:
def
_reset
(
self
):
# resume iteration in following steps
# 1. Resume workers, clear worker caches
# put _ResumeIteration to all worker as resume iteration flag
with
self
.
_thread_lock
:
self
.
_resume_worker_cnt
=
self
.
_num_workers
for
worker_id
in
range
(
self
.
_num_workers
):
self
.
_indices_queues
[
worker_id
].
put
(
_ResumeIteration
())
self
.
_batches_outstanding
+=
1
# all flag will be check in _thread_loop, simply wait here
while
self
.
_resume_worker_cnt
>
0
:
time
.
sleep
(
0.5
)
# 2. clear blocking_queue caches
# in order not to restart the thread, we just clear
# the blocking_queue cachees instead of recreating one
while
self
.
_blocking_queue
.
size
()
>=
len
(
self
.
_places
):
if
in_dygraph_mode
():
self
.
_reader
.
read_next_var_list
()
elif
self
.
_return_list
:
self
.
_reader
.
read_next_list
()
else
:
data
=
self
.
_reader
.
read_next
()
# 3. reset all states
self
.
_send_idx
=
0
self
.
_rcvd_idx
=
0
self
.
_batches_outstanding
=
0
self
.
_task_infos
=
{}
self
.
_structure_infos
=
[]
# set all worker status available
self
.
_worker_status
=
[
True
]
*
self
.
_num_workers
# 4. reset _sampler_iter and put prefetch indices to start next epoch
# init workers and indices queues and put 2 indices in each indices queue
self
.
_sampler_iter
=
iter
(
self
.
_index_sampler
)
for
_
in
range
(
self
.
_outstanding_capacity
):
self
.
_try_put_indices
()
def
_clear_and_remove_data_queue
(
self
):
if
self
.
_data_queue
is
not
None
:
while
True
:
try
:
self
.
_data_queue
.
get_nowait
()
except
:
self
.
_data_queue
.
cancel_join_thread
()
self
.
_data_queue
.
close
()
break
def
_shutdown_worker
(
self
,
worker_id
,
shutdown
=
False
):
if
self
.
_worker_status
[
worker_id
]
or
(
self
.
_persistent_workers
and
shutdown
):
self
.
_indices_queues
[
worker_id
].
put
(
None
)
self
.
_worker_status
[
worker_id
]
=
False
...
...
@@ -357,7 +418,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# indices_queue
self
.
_workers_done_event
.
set
()
for
i
in
range
(
self
.
_num_workers
):
self
.
_shutdown_worker
(
i
)
self
.
_shutdown_worker
(
i
,
shutdown
=
True
)
if
not
self
.
_shutdown
:
for
w
in
self
.
_workers
:
...
...
@@ -392,6 +453,10 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
if
batch
is
None
:
self
.
_exit_thread_expectedly
()
else
:
if
isinstance
(
batch
,
_ResumeIteration
):
assert
self
.
_resume_worker_cnt
>
0
self
.
_resume_worker_cnt
-=
1
continue
try
:
# pack as LoDTensorArray
array
=
core
.
LoDTensorArray
()
...
...
@@ -412,7 +477,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
if
not
self
.
_blocking_queue
.
push
(
array
):
self
.
_blocking_queue
.
close
()
except
:
except
Exception
as
e
:
self
.
_exit_thread_unexpectedly
()
six
.
reraise
(
*
sys
.
exc_info
())
finally
:
...
...
@@ -428,7 +493,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# batch indices and increase _rcvd_idx
if
self
.
_dataset_kind
==
_DatasetKind
.
ITER
:
while
self
.
_rcvd_idx
<
self
.
_send_idx
:
sys
.
stdout
.
flush
()
info
=
self
.
_task_infos
[
self
.
_rcvd_idx
]
if
len
(
info
)
==
3
or
self
.
_worker_status
[
info
[
0
]]:
break
...
...
@@ -436,6 +500,10 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self
.
_rcvd_idx
+=
1
self
.
_batches_outstanding
-=
1
else
:
# NOTE: in persistent workers mode, do not check data
# drained here, simply let it go to _data_queue
# reading to get _ResumeIteration
if
not
self
.
_persistent_workers
:
# NOTE: _rcvd_idx and _send_idx only record batches among
# workers, if batches among workers drained, there
# may also be data in blocking queue
...
...
@@ -493,12 +561,20 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# is discard, outstanding batch number should be decrease
# and another indices should be put for other workers
# may still working.
if
self
.
_persistent_workers
:
self
.
_worker_status
[
data
.
worker_id
]
=
False
else
:
self
.
_shutdown_worker
(
data
.
worker_id
)
self
.
_batches_outstanding
-=
1
self
.
_try_put_indices
()
continue
idx
,
batch
,
structure
=
data
if
isinstance
(
idx
,
_ResumeIteration
)
and
batch
is
None
\
and
structure
is
None
:
return
idx
if
isinstance
(
batch
,
_WorkerException
):
self
.
_exit_thread_unexpectedly
()
batch
.
reraise
()
...
...
@@ -557,6 +633,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# set _thread_done_event here, py_reader will raise StopIteration,
# end workers and indices_queues in StopIteration handling
if
self
.
_batches_outstanding
<
len
(
self
.
_places
):
if
self
.
_persistent_workers
:
raise
StopIteration
else
:
self
.
_thread_done_event
.
set
()
self
.
_blocking_queue
.
close
()
...
...
@@ -583,6 +662,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self
.
_on_output_batch
()
return
data
except
StopIteration
:
if
not
self
.
_persistent_workers
:
self
.
_reader
.
shutdown
()
self
.
_try_shutdown_all
()
six
.
reraise
(
*
sys
.
exc_info
())
...
...
python/paddle/fluid/dataloader/worker.py
浏览文件 @
76710e5f
...
...
@@ -36,6 +36,10 @@ class _IterableDatasetStopIteration(object):
self
.
worker_id
=
worker_id
class
_ResumeIteration
(
object
):
pass
class
_DatasetKind
(
object
):
MAP
=
0
ITER
=
1
...
...
@@ -292,6 +296,13 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
except
queue
.
Empty
:
continue
if
isinstance
(
data
,
_ResumeIteration
):
out_queue
.
put
((
data
,
None
,
None
))
iterator_drained
=
False
fetcher
=
_DatasetKind
.
create_fetcher
(
dataset_kind
,
dataset
,
auto_collate_batch
,
collate_fn
,
True
)
continue
# None as poison piil, so worker event should be set
if
data
is
None
:
assert
done_event
.
is_set
()
or
iterator_drained
,
\
...
...
python/paddle/fluid/reader.py
浏览文件 @
76710e5f
...
...
@@ -325,7 +325,8 @@ class DataLoader(object):
use_buffer_reader
=
True
,
use_shared_memory
=
True
,
timeout
=
0
,
worker_init_fn
=
None
):
worker_init_fn
=
None
,
persistent_workers
=
False
):
self
.
return_list
=
return_list
self
.
collate_fn
=
collate_fn
self
.
use_buffer_reader
=
use_buffer_reader
...
...
@@ -407,6 +408,9 @@ class DataLoader(object):
self
.
pin_memory
=
True
if
use_pinned_memory
(
)
is
None
else
use_pinned_memory
()
self
.
_persistent_workers
=
persistent_workers
self
.
_iterator
=
None
def
__len__
(
self
):
if
self
.
dataset_kind
==
_DatasetKind
.
ITER
:
raise
ValueError
(
"length of IterableDataset not supported"
)
...
...
@@ -419,6 +423,12 @@ class DataLoader(object):
def
__iter__
(
self
):
if
self
.
num_workers
==
0
:
return
_DataLoaderIterSingleProcess
(
self
)
elif
self
.
_persistent_workers
:
if
self
.
_iterator
is
None
:
self
.
_iterator
=
_DataLoaderIterMultiProcess
(
self
)
else
:
self
.
_iterator
.
_reset
()
return
self
.
_iterator
else
:
return
_DataLoaderIterMultiProcess
(
self
)
...
...
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dynamic.py
浏览文件 @
76710e5f
...
...
@@ -66,7 +66,7 @@ class SimpleFCNet(fluid.dygraph.Layer):
class
TestDygraphDataLoader
(
unittest
.
TestCase
):
def
run_main
(
self
,
num_workers
,
places
):
def
run_main
(
self
,
num_workers
,
places
,
persistent_workers
):
fluid
.
default_startup_program
().
random_seed
=
1
fluid
.
default_main_program
().
random_seed
=
1
with
fluid
.
dygraph
.
guard
(
places
[
0
]):
...
...
@@ -78,7 +78,8 @@ class TestDygraphDataLoader(unittest.TestCase):
dataset
,
num_workers
=
num_workers
,
batch_size
=
BATCH_SIZE
,
drop_last
=
True
)
drop_last
=
True
,
persistent_workers
=
persistent_workers
)
assert
len
(
dataloader
)
==
int
(
SAMPLE_NUM
/
BATCH_SIZE
)
step_list
=
[]
...
...
@@ -110,11 +111,16 @@ class TestDygraphDataLoader(unittest.TestCase):
def
test_main
(
self
):
# dynamic graph do not run with_data_parallel
for
p
in
prepare_places
(
False
):
for
persistent_workers
in
[
False
,
True
]:
results
=
[]
for
num_workers
in
[
0
,
2
]:
print
(
self
.
__class__
.
__name__
,
p
,
num_workers
)
print
(
self
.
__class__
.
__name__
,
p
,
num_workers
,
persistent_workers
)
sys
.
stdout
.
flush
()
ret
=
self
.
run_main
(
num_workers
=
num_workers
,
places
=
p
)
ret
=
self
.
run_main
(
num_workers
=
num_workers
,
places
=
p
,
persistent_workers
=
persistent_workers
)
results
.
append
(
ret
)
diff
=
np
.
max
(
np
.
abs
(
results
[
0
][
'loss'
]
-
results
[
1
][
'loss'
])
/
...
...
@@ -123,7 +129,7 @@ class TestDygraphDataLoader(unittest.TestCase):
class
TestDygraphDataLoaderWithBatchedDataset
(
TestDygraphDataLoader
):
def
run_main
(
self
,
num_workers
,
places
):
def
run_main
(
self
,
num_workers
,
places
,
persistent_workers
):
fluid
.
default_startup_program
().
random_seed
=
1
fluid
.
default_main_program
().
random_seed
=
1
with
fluid
.
dygraph
.
guard
(
places
[
0
]):
...
...
@@ -135,7 +141,8 @@ class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader):
dataset
,
num_workers
=
num_workers
,
batch_size
=
None
,
drop_last
=
True
)
drop_last
=
True
,
persistent_workers
=
persistent_workers
)
assert
len
(
dataloader
)
==
int
(
SAMPLE_NUM
/
BATCH_SIZE
)
step_list
=
[]
...
...
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py
浏览文件 @
76710e5f
...
...
@@ -66,7 +66,7 @@ class SimpleFCNet(fluid.dygraph.Layer):
class
TestDygraphDataLoader
(
unittest
.
TestCase
):
def
run_main
(
self
,
num_workers
,
places
):
def
run_main
(
self
,
num_workers
,
places
,
persistent_workers
):
fluid
.
default_startup_program
().
random_seed
=
1
fluid
.
default_main_program
().
random_seed
=
1
with
fluid
.
dygraph
.
guard
(
places
[
0
]):
...
...
@@ -78,7 +78,8 @@ class TestDygraphDataLoader(unittest.TestCase):
dataset
,
num_workers
=
num_workers
,
batch_size
=
BATCH_SIZE
,
drop_last
=
True
)
drop_last
=
True
,
persistent_workers
=
persistent_workers
)
step_list
=
[]
loss_list
=
[]
...
...
@@ -109,18 +110,23 @@ class TestDygraphDataLoader(unittest.TestCase):
def
test_main
(
self
):
# dynamic graph do not run with_data_parallel
for
p
in
prepare_places
(
False
):
for
persistent_workers
in
[
False
,
True
]:
results
=
[]
for
num_workers
in
[
0
,
2
]:
print
(
self
.
__class__
.
__name__
,
p
,
num_workers
)
print
(
self
.
__class__
.
__name__
,
p
,
num_workers
,
persistent_workers
)
sys
.
stdout
.
flush
()
ret
=
self
.
run_main
(
num_workers
=
num_workers
,
places
=
p
)
ret
=
self
.
run_main
(
num_workers
=
num_workers
,
places
=
p
,
persistent_workers
=
persistent_workers
)
results
.
append
(
ret
)
assert
results
[
0
][
'loss'
].
shape
[
0
]
*
2
==
results
[
1
][
'loss'
].
shape
[
0
]
assert
results
[
0
][
'loss'
].
shape
[
0
]
*
2
==
results
[
1
]
[
'loss'
].
shape
[
0
]
class
TestDygraphDataLoaderWithBatchedDataset
(
TestDygraphDataLoader
):
def
run_main
(
self
,
num_workers
,
places
):
def
run_main
(
self
,
num_workers
,
places
,
persistent_workers
):
fluid
.
default_startup_program
().
random_seed
=
1
fluid
.
default_main_program
().
random_seed
=
1
with
fluid
.
dygraph
.
guard
(
places
[
0
]):
...
...
@@ -132,7 +138,8 @@ class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader):
dataset
,
num_workers
=
num_workers
,
batch_size
=
None
,
drop_last
=
True
)
drop_last
=
True
,
persistent_workers
=
persistent_workers
)
step_list
=
[]
loss_list
=
[]
...
...
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py
浏览文件 @
76710e5f
...
...
@@ -93,14 +93,14 @@ def prepare_places(with_data_parallel, with_cpu=False, with_gpu=True):
if
with_gpu
and
fluid
.
core
.
is_compiled_with_cuda
():
tmp
=
fluid
.
cuda_places
()[:
2
]
assert
len
(
tmp
)
>
0
,
"no gpu detected"
if
with_data_parallel
:
if
with_data_parallel
and
len
(
tmp
)
>
1
:
places
.
append
(
tmp
)
places
.
append
([
tmp
[
0
]])
return
places
class
TestStaticDataLoader
(
unittest
.
TestCase
):
def
run_main
(
self
,
num_workers
,
places
):
def
run_main
(
self
,
num_workers
,
places
,
persistent_workers
):
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
startup_prog
,
main_prog
,
image
,
label
,
loss
=
simple_fc_net_static
()
...
...
@@ -113,7 +113,8 @@ class TestStaticDataLoader(unittest.TestCase):
num_workers
=
num_workers
,
batch_size
=
BATCH_SIZE
,
return_list
=
False
,
drop_last
=
True
)
drop_last
=
True
,
persistent_workers
=
persistent_workers
)
# assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
exe
=
fluid
.
Executor
(
place
=
places
[
0
])
...
...
@@ -158,14 +159,19 @@ class TestStaticDataLoader(unittest.TestCase):
def
test_main
(
self
):
for
p
in
prepare_places
(
True
):
for
persistent_workers
in
[
False
,
True
]:
results
=
[]
for
num_workers
in
[
0
,
2
]:
print
(
self
.
__class__
.
__name__
,
p
,
num_workers
)
print
(
self
.
__class__
.
__name__
,
p
,
num_workers
,
persistent_workers
)
sys
.
stdout
.
flush
()
ret
=
self
.
run_main
(
num_workers
=
num_workers
,
places
=
p
)
ret
=
self
.
run_main
(
num_workers
=
num_workers
,
places
=
p
,
persistent_workers
=
persistent_workers
)
results
.
append
(
ret
)
assert
results
[
0
][
'loss'
].
shape
[
0
]
*
2
==
results
[
1
][
'loss'
].
shape
[
0
]
assert
results
[
0
][
'loss'
].
shape
[
0
]
*
2
==
results
[
1
]
[
'loss'
].
shape
[
0
]
class
RandomBatchedDataset
(
IterableDataset
):
...
...
@@ -188,7 +194,7 @@ class RandomBatchedDataset(IterableDataset):
class
TestStaticDataLoaderWithBatchedDataset
(
TestStaticDataLoader
):
def
run_main
(
self
,
num_workers
,
places
):
def
run_main
(
self
,
num_workers
,
places
,
persistent_workers
):
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
startup_prog
,
main_prog
,
image
,
label
,
loss
=
simple_fc_net_static
()
...
...
@@ -201,7 +207,8 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
num_workers
=
num_workers
,
batch_size
=
None
,
return_list
=
False
,
drop_last
=
True
)
drop_last
=
True
,
persistent_workers
=
persistent_workers
)
exe
=
fluid
.
Executor
(
place
=
places
[
0
])
exe
.
run
(
startup_prog
)
...
...
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py
浏览文件 @
76710e5f
...
...
@@ -94,14 +94,14 @@ def prepare_places(with_data_parallel, with_cpu=False, with_gpu=True):
if
with_gpu
and
fluid
.
core
.
is_compiled_with_cuda
():
tmp
=
fluid
.
cuda_places
()[:
2
]
assert
len
(
tmp
)
>
0
,
"no gpu detected"
if
with_data_parallel
:
if
with_data_parallel
and
len
(
tmp
)
>
1
:
places
.
append
(
tmp
)
places
.
append
([
tmp
[
0
]])
return
places
class
TestStaticDataLoader
(
unittest
.
TestCase
):
def
run_main
(
self
,
num_workers
,
places
,
use_pe
=
True
):
def
run_main
(
self
,
num_workers
,
places
,
persistent_workers
,
use_pe
=
True
):
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
startup_prog
,
main_prog
,
image
,
label
,
loss
=
simple_fc_net_static
()
...
...
@@ -114,7 +114,8 @@ class TestStaticDataLoader(unittest.TestCase):
num_workers
=
num_workers
,
batch_size
=
BATCH_SIZE
,
return_list
=
False
,
drop_last
=
True
)
drop_last
=
True
,
persistent_workers
=
persistent_workers
)
assert
len
(
dataloader
)
==
int
(
SAMPLE_NUM
/
BATCH_SIZE
)
exe
=
fluid
.
Executor
(
place
=
places
[
0
])
...
...
@@ -162,11 +163,16 @@ class TestStaticDataLoader(unittest.TestCase):
def
test_main
(
self
):
for
p
in
prepare_places
(
True
):
for
persistent_workers
in
[
True
,
False
]:
results
=
[]
for
num_workers
in
[
0
,
2
]:
print
(
self
.
__class__
.
__name__
,
p
,
num_workers
)
print
(
self
.
__class__
.
__name__
,
p
,
num_workers
,
persistent_workers
)
sys
.
stdout
.
flush
()
ret
=
self
.
run_main
(
num_workers
=
num_workers
,
places
=
p
)
ret
=
self
.
run_main
(
num_workers
=
num_workers
,
places
=
p
,
persistent_workers
=
persistent_workers
)
results
.
append
(
ret
)
diff
=
np
.
max
(
np
.
abs
(
results
[
0
][
'loss'
]
-
results
[
1
][
'loss'
])
/
...
...
@@ -241,7 +247,7 @@ class RandomBatchedDataset(Dataset):
class
TestStaticDataLoaderWithBatchedDataset
(
TestStaticDataLoader
):
def
run_main
(
self
,
num_workers
,
places
):
def
run_main
(
self
,
num_workers
,
places
,
persistent_workers
):
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
startup_prog
,
main_prog
,
image
,
label
,
loss
=
simple_fc_net_static
()
...
...
@@ -254,7 +260,8 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
num_workers
=
num_workers
,
batch_size
=
None
,
return_list
=
False
,
drop_last
=
True
)
drop_last
=
True
,
persistent_workers
=
persistent_workers
)
assert
len
(
dataloader
)
==
int
(
SAMPLE_NUM
/
BATCH_SIZE
)
exe
=
fluid
.
Executor
(
place
=
places
[
0
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录