Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
76710e5f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
76710e5f
编写于
7月 29, 2021
作者:
K
Kaipeng Deng
提交者:
GitHub
7月 29, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add persistent_workers (#34017)
* add persistent_workers. test=develop
上级
b451ff26
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
208 addition
and
79 deletion
+208
-79
python/paddle/fluid/dataloader/dataloader_iter.py
python/paddle/fluid/dataloader/dataloader_iter.py
+104
-24
python/paddle/fluid/dataloader/worker.py
python/paddle/fluid/dataloader/worker.py
+11
-0
python/paddle/fluid/reader.py
python/paddle/fluid/reader.py
+11
-1
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dynamic.py
...d/tests/unittests/test_multiprocess_dataloader_dynamic.py
+21
-14
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py
.../test_multiprocess_dataloader_iterable_dataset_dynamic.py
+19
-12
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py
...s/test_multiprocess_dataloader_iterable_dataset_static.py
+20
-13
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py
...id/tests/unittests/test_multiprocess_dataloader_static.py
+22
-15
未找到文件。
python/paddle/fluid/dataloader/dataloader_iter.py
浏览文件 @
76710e5f
...
...
@@ -37,7 +37,8 @@ from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher
from
.batch_sampler
import
_InfiniteIterableSampler
from
.collate
import
default_collate_fn
,
default_convert_fn
from
.worker
import
ParentWatchDog
,
get_worker_info
,
_worker_loop
,
\
_DatasetKind
,
_IterableDatasetStopIteration
,
_WorkerException
_DatasetKind
,
_IterableDatasetStopIteration
,
_WorkerException
,
\
_ResumeIteration
from
.flat
import
_flatten_batch
,
_restore_batch
__all__
=
[
'get_worker_info'
]
...
...
@@ -67,15 +68,10 @@ class _DataLoaderIterBase(object):
self
.
_dataset_kind
=
loader
.
dataset_kind
self
.
_pin_memory
=
loader
.
pin_memory
self
.
_sampler_iter
=
iter
(
self
.
_index_sampler
)
if
self
.
_auto_collate_batch
:
self
.
_sampler_iter
=
iter
(
loader
.
batch_sampler
)
self
.
_collate_fn
=
loader
.
collate_fn
or
default_collate_fn
else
:
if
self
.
_dataset_kind
==
_DatasetKind
.
MAP
:
self
.
_sampler_iter
=
iter
(
list
(
range
(
len
(
self
.
_dataset
))))
else
:
self
.
_sampler_iter
=
iter
(
_InfiniteIterableSampler
(
self
.
_dataset
,
1
))
self
.
_collate_fn
=
loader
.
collate_fn
or
default_convert_fn
# LoDTensorBlockingQueue instance for create_py_reader and a thread
...
...
@@ -87,6 +83,16 @@ class _DataLoaderIterBase(object):
self
.
_thread
=
None
self
.
_thread_done_event
=
threading
.
Event
()
@
property
def
_index_sampler
(
self
):
if
self
.
_auto_collate_batch
:
return
self
.
_batch_sampler
else
:
if
self
.
_dataset_kind
==
_DatasetKind
.
MAP
:
return
list
(
range
(
len
(
self
.
_dataset
)))
else
:
return
_InfiniteIterableSampler
(
self
.
_dataset
,
1
)
def
__iter__
(
self
):
return
self
...
...
@@ -242,6 +248,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
def
__init__
(
self
,
loader
):
super
(
_DataLoaderIterMultiProcess
,
self
).
__init__
(
loader
)
self
.
_persistent_workers
=
loader
.
_persistent_workers
self
.
_resume_worker_cnt
=
0
assert
self
.
_num_workers
>
0
,
"Multi-process DataLoader "
\
"invalid num_workers({})"
.
format
(
self
.
_num_workers
)
...
...
@@ -336,13 +345,65 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self
.
_pin_memory
)
self
.
_thread_done_event
=
threading
.
Event
()
# thread event is only need in multi-processing mode
self
.
_thread
=
threading
.
Thread
(
target
=
self
.
_thread_loop
,
args
=
(
_current_expected_place
(),
))
self
.
_thread
.
daemon
=
True
self
.
_thread
.
start
()
def
_shutdown_worker
(
self
,
worker_id
):
if
self
.
_worker_status
[
worker_id
]:
def
_reset
(
self
):
# resume iteration in following steps
# 1. Resume workers, clear worker caches
# put _ResumeIteration to all worker as resume iteration flag
with
self
.
_thread_lock
:
self
.
_resume_worker_cnt
=
self
.
_num_workers
for
worker_id
in
range
(
self
.
_num_workers
):
self
.
_indices_queues
[
worker_id
].
put
(
_ResumeIteration
())
self
.
_batches_outstanding
+=
1
# all flag will be check in _thread_loop, simply wait here
while
self
.
_resume_worker_cnt
>
0
:
time
.
sleep
(
0.5
)
# 2. clear blocking_queue caches
# in order not to restart the thread, we just clear
# the blocking_queue cachees instead of recreating one
while
self
.
_blocking_queue
.
size
()
>=
len
(
self
.
_places
):
if
in_dygraph_mode
():
self
.
_reader
.
read_next_var_list
()
elif
self
.
_return_list
:
self
.
_reader
.
read_next_list
()
else
:
data
=
self
.
_reader
.
read_next
()
# 3. reset all states
self
.
_send_idx
=
0
self
.
_rcvd_idx
=
0
self
.
_batches_outstanding
=
0
self
.
_task_infos
=
{}
self
.
_structure_infos
=
[]
# set all worker status available
self
.
_worker_status
=
[
True
]
*
self
.
_num_workers
# 4. reset _sampler_iter and put prefetch indices to start next epoch
# init workers and indices queues and put 2 indices in each indices queue
self
.
_sampler_iter
=
iter
(
self
.
_index_sampler
)
for
_
in
range
(
self
.
_outstanding_capacity
):
self
.
_try_put_indices
()
def
_clear_and_remove_data_queue
(
self
):
if
self
.
_data_queue
is
not
None
:
while
True
:
try
:
self
.
_data_queue
.
get_nowait
()
except
:
self
.
_data_queue
.
cancel_join_thread
()
self
.
_data_queue
.
close
()
break
def
_shutdown_worker
(
self
,
worker_id
,
shutdown
=
False
):
if
self
.
_worker_status
[
worker_id
]
or
(
self
.
_persistent_workers
and
shutdown
):
self
.
_indices_queues
[
worker_id
].
put
(
None
)
self
.
_worker_status
[
worker_id
]
=
False
...
...
@@ -357,7 +418,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# indices_queue
self
.
_workers_done_event
.
set
()
for
i
in
range
(
self
.
_num_workers
):
self
.
_shutdown_worker
(
i
)
self
.
_shutdown_worker
(
i
,
shutdown
=
True
)
if
not
self
.
_shutdown
:
for
w
in
self
.
_workers
:
...
...
@@ -392,6 +453,10 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
if
batch
is
None
:
self
.
_exit_thread_expectedly
()
else
:
if
isinstance
(
batch
,
_ResumeIteration
):
assert
self
.
_resume_worker_cnt
>
0
self
.
_resume_worker_cnt
-=
1
continue
try
:
# pack as LoDTensorArray
array
=
core
.
LoDTensorArray
()
...
...
@@ -412,7 +477,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
if
not
self
.
_blocking_queue
.
push
(
array
):
self
.
_blocking_queue
.
close
()
except
:
except
Exception
as
e
:
self
.
_exit_thread_unexpectedly
()
six
.
reraise
(
*
sys
.
exc_info
())
finally
:
...
...
@@ -428,7 +493,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# batch indices and increase _rcvd_idx
if
self
.
_dataset_kind
==
_DatasetKind
.
ITER
:
while
self
.
_rcvd_idx
<
self
.
_send_idx
:
sys
.
stdout
.
flush
()
info
=
self
.
_task_infos
[
self
.
_rcvd_idx
]
if
len
(
info
)
==
3
or
self
.
_worker_status
[
info
[
0
]]:
break
...
...
@@ -436,6 +500,10 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self
.
_rcvd_idx
+=
1
self
.
_batches_outstanding
-=
1
else
:
# NOTE: in persistent workers mode, do not check data
# drained here, simply let it go to _data_queue
# reading to get _ResumeIteration
if
not
self
.
_persistent_workers
:
# NOTE: _rcvd_idx and _send_idx only record batches among
# workers, if batches among workers drained, there
# may also be data in blocking queue
...
...
@@ -493,12 +561,20 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# is discard, outstanding batch number should be decrease
# and another indices should be put for other workers
# may still working.
if
self
.
_persistent_workers
:
self
.
_worker_status
[
data
.
worker_id
]
=
False
else
:
self
.
_shutdown_worker
(
data
.
worker_id
)
self
.
_batches_outstanding
-=
1
self
.
_try_put_indices
()
continue
idx
,
batch
,
structure
=
data
if
isinstance
(
idx
,
_ResumeIteration
)
and
batch
is
None
\
and
structure
is
None
:
return
idx
if
isinstance
(
batch
,
_WorkerException
):
self
.
_exit_thread_unexpectedly
()
batch
.
reraise
()
...
...
@@ -557,6 +633,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# set _thread_done_event here, py_reader will raise StopIteration,
# end workers and indices_queues in StopIteration handling
if
self
.
_batches_outstanding
<
len
(
self
.
_places
):
if
self
.
_persistent_workers
:
raise
StopIteration
else
:
self
.
_thread_done_event
.
set
()
self
.
_blocking_queue
.
close
()
...
...
@@ -583,6 +662,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self
.
_on_output_batch
()
return
data
except
StopIteration
:
if
not
self
.
_persistent_workers
:
self
.
_reader
.
shutdown
()
self
.
_try_shutdown_all
()
six
.
reraise
(
*
sys
.
exc_info
())
...
...
python/paddle/fluid/dataloader/worker.py
浏览文件 @
76710e5f
...
...
@@ -36,6 +36,10 @@ class _IterableDatasetStopIteration(object):
self
.
worker_id
=
worker_id
class
_ResumeIteration
(
object
):
pass
class
_DatasetKind
(
object
):
MAP
=
0
ITER
=
1
...
...
@@ -292,6 +296,13 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
except
queue
.
Empty
:
continue
if
isinstance
(
data
,
_ResumeIteration
):
out_queue
.
put
((
data
,
None
,
None
))
iterator_drained
=
False
fetcher
=
_DatasetKind
.
create_fetcher
(
dataset_kind
,
dataset
,
auto_collate_batch
,
collate_fn
,
True
)
continue
# None as poison piil, so worker event should be set
if
data
is
None
:
assert
done_event
.
is_set
()
or
iterator_drained
,
\
...
...
python/paddle/fluid/reader.py
浏览文件 @
76710e5f
...
...
@@ -325,7 +325,8 @@ class DataLoader(object):
use_buffer_reader
=
True
,
use_shared_memory
=
True
,
timeout
=
0
,
worker_init_fn
=
None
):
worker_init_fn
=
None
,
persistent_workers
=
False
):
self
.
return_list
=
return_list
self
.
collate_fn
=
collate_fn
self
.
use_buffer_reader
=
use_buffer_reader
...
...
@@ -407,6 +408,9 @@ class DataLoader(object):
self
.
pin_memory
=
True
if
use_pinned_memory
(
)
is
None
else
use_pinned_memory
()
self
.
_persistent_workers
=
persistent_workers
self
.
_iterator
=
None
def
__len__
(
self
):
if
self
.
dataset_kind
==
_DatasetKind
.
ITER
:
raise
ValueError
(
"length of IterableDataset not supported"
)
...
...
@@ -419,6 +423,12 @@ class DataLoader(object):
def
__iter__
(
self
):
if
self
.
num_workers
==
0
:
return
_DataLoaderIterSingleProcess
(
self
)
elif
self
.
_persistent_workers
:
if
self
.
_iterator
is
None
:
self
.
_iterator
=
_DataLoaderIterMultiProcess
(
self
)
else
:
self
.
_iterator
.
_reset
()
return
self
.
_iterator
else
:
return
_DataLoaderIterMultiProcess
(
self
)
...
...
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dynamic.py
浏览文件 @
76710e5f
...
...
@@ -66,7 +66,7 @@ class SimpleFCNet(fluid.dygraph.Layer):
class
TestDygraphDataLoader
(
unittest
.
TestCase
):
def
run_main
(
self
,
num_workers
,
places
):
def
run_main
(
self
,
num_workers
,
places
,
persistent_workers
):
fluid
.
default_startup_program
().
random_seed
=
1
fluid
.
default_main_program
().
random_seed
=
1
with
fluid
.
dygraph
.
guard
(
places
[
0
]):
...
...
@@ -78,7 +78,8 @@ class TestDygraphDataLoader(unittest.TestCase):
dataset
,
num_workers
=
num_workers
,
batch_size
=
BATCH_SIZE
,
drop_last
=
True
)
drop_last
=
True
,
persistent_workers
=
persistent_workers
)
assert
len
(
dataloader
)
==
int
(
SAMPLE_NUM
/
BATCH_SIZE
)
step_list
=
[]
...
...
@@ -110,11 +111,16 @@ class TestDygraphDataLoader(unittest.TestCase):
def
test_main
(
self
):
# dynamic graph do not run with_data_parallel
for
p
in
prepare_places
(
False
):
for
persistent_workers
in
[
False
,
True
]:
results
=
[]
for
num_workers
in
[
0
,
2
]:
print
(
self
.
__class__
.
__name__
,
p
,
num_workers
)
print
(
self
.
__class__
.
__name__
,
p
,
num_workers
,
persistent_workers
)
sys
.
stdout
.
flush
()
ret
=
self
.
run_main
(
num_workers
=
num_workers
,
places
=
p
)
ret
=
self
.
run_main
(
num_workers
=
num_workers
,
places
=
p
,
persistent_workers
=
persistent_workers
)
results
.
append
(
ret
)
diff
=
np
.
max
(
np
.
abs
(
results
[
0
][
'loss'
]
-
results
[
1
][
'loss'
])
/
...
...
@@ -123,7 +129,7 @@ class TestDygraphDataLoader(unittest.TestCase):
class
TestDygraphDataLoaderWithBatchedDataset
(
TestDygraphDataLoader
):
def
run_main
(
self
,
num_workers
,
places
):
def
run_main
(
self
,
num_workers
,
places
,
persistent_workers
):
fluid
.
default_startup_program
().
random_seed
=
1
fluid
.
default_main_program
().
random_seed
=
1
with
fluid
.
dygraph
.
guard
(
places
[
0
]):
...
...
@@ -135,7 +141,8 @@ class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader):
dataset
,
num_workers
=
num_workers
,
batch_size
=
None
,
drop_last
=
True
)
drop_last
=
True
,
persistent_workers
=
persistent_workers
)
assert
len
(
dataloader
)
==
int
(
SAMPLE_NUM
/
BATCH_SIZE
)
step_list
=
[]
...
...
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py
浏览文件 @
76710e5f
...
...
@@ -66,7 +66,7 @@ class SimpleFCNet(fluid.dygraph.Layer):
class
TestDygraphDataLoader
(
unittest
.
TestCase
):
def
run_main
(
self
,
num_workers
,
places
):
def
run_main
(
self
,
num_workers
,
places
,
persistent_workers
):
fluid
.
default_startup_program
().
random_seed
=
1
fluid
.
default_main_program
().
random_seed
=
1
with
fluid
.
dygraph
.
guard
(
places
[
0
]):
...
...
@@ -78,7 +78,8 @@ class TestDygraphDataLoader(unittest.TestCase):
dataset
,
num_workers
=
num_workers
,
batch_size
=
BATCH_SIZE
,
drop_last
=
True
)
drop_last
=
True
,
persistent_workers
=
persistent_workers
)
step_list
=
[]
loss_list
=
[]
...
...
@@ -109,18 +110,23 @@ class TestDygraphDataLoader(unittest.TestCase):
def
test_main
(
self
):
# dynamic graph do not run with_data_parallel
for
p
in
prepare_places
(
False
):
for
persistent_workers
in
[
False
,
True
]:
results
=
[]
for
num_workers
in
[
0
,
2
]:
print
(
self
.
__class__
.
__name__
,
p
,
num_workers
)
print
(
self
.
__class__
.
__name__
,
p
,
num_workers
,
persistent_workers
)
sys
.
stdout
.
flush
()
ret
=
self
.
run_main
(
num_workers
=
num_workers
,
places
=
p
)
ret
=
self
.
run_main
(
num_workers
=
num_workers
,
places
=
p
,
persistent_workers
=
persistent_workers
)
results
.
append
(
ret
)
assert
results
[
0
][
'loss'
].
shape
[
0
]
*
2
==
results
[
1
][
'loss'
].
shape
[
0
]
assert
results
[
0
][
'loss'
].
shape
[
0
]
*
2
==
results
[
1
]
[
'loss'
].
shape
[
0
]
class
TestDygraphDataLoaderWithBatchedDataset
(
TestDygraphDataLoader
):
def
run_main
(
self
,
num_workers
,
places
):
def
run_main
(
self
,
num_workers
,
places
,
persistent_workers
):
fluid
.
default_startup_program
().
random_seed
=
1
fluid
.
default_main_program
().
random_seed
=
1
with
fluid
.
dygraph
.
guard
(
places
[
0
]):
...
...
@@ -132,7 +138,8 @@ class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader):
dataset
,
num_workers
=
num_workers
,
batch_size
=
None
,
drop_last
=
True
)
drop_last
=
True
,
persistent_workers
=
persistent_workers
)
step_list
=
[]
loss_list
=
[]
...
...
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py
浏览文件 @
76710e5f
...
...
@@ -93,14 +93,14 @@ def prepare_places(with_data_parallel, with_cpu=False, with_gpu=True):
if
with_gpu
and
fluid
.
core
.
is_compiled_with_cuda
():
tmp
=
fluid
.
cuda_places
()[:
2
]
assert
len
(
tmp
)
>
0
,
"no gpu detected"
if
with_data_parallel
:
if
with_data_parallel
and
len
(
tmp
)
>
1
:
places
.
append
(
tmp
)
places
.
append
([
tmp
[
0
]])
return
places
class
TestStaticDataLoader
(
unittest
.
TestCase
):
def
run_main
(
self
,
num_workers
,
places
):
def
run_main
(
self
,
num_workers
,
places
,
persistent_workers
):
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
startup_prog
,
main_prog
,
image
,
label
,
loss
=
simple_fc_net_static
()
...
...
@@ -113,7 +113,8 @@ class TestStaticDataLoader(unittest.TestCase):
num_workers
=
num_workers
,
batch_size
=
BATCH_SIZE
,
return_list
=
False
,
drop_last
=
True
)
drop_last
=
True
,
persistent_workers
=
persistent_workers
)
# assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
exe
=
fluid
.
Executor
(
place
=
places
[
0
])
...
...
@@ -158,14 +159,19 @@ class TestStaticDataLoader(unittest.TestCase):
def
test_main
(
self
):
for
p
in
prepare_places
(
True
):
for
persistent_workers
in
[
False
,
True
]:
results
=
[]
for
num_workers
in
[
0
,
2
]:
print
(
self
.
__class__
.
__name__
,
p
,
num_workers
)
print
(
self
.
__class__
.
__name__
,
p
,
num_workers
,
persistent_workers
)
sys
.
stdout
.
flush
()
ret
=
self
.
run_main
(
num_workers
=
num_workers
,
places
=
p
)
ret
=
self
.
run_main
(
num_workers
=
num_workers
,
places
=
p
,
persistent_workers
=
persistent_workers
)
results
.
append
(
ret
)
assert
results
[
0
][
'loss'
].
shape
[
0
]
*
2
==
results
[
1
][
'loss'
].
shape
[
0
]
assert
results
[
0
][
'loss'
].
shape
[
0
]
*
2
==
results
[
1
]
[
'loss'
].
shape
[
0
]
class
RandomBatchedDataset
(
IterableDataset
):
...
...
@@ -188,7 +194,7 @@ class RandomBatchedDataset(IterableDataset):
class
TestStaticDataLoaderWithBatchedDataset
(
TestStaticDataLoader
):
def
run_main
(
self
,
num_workers
,
places
):
def
run_main
(
self
,
num_workers
,
places
,
persistent_workers
):
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
startup_prog
,
main_prog
,
image
,
label
,
loss
=
simple_fc_net_static
()
...
...
@@ -201,7 +207,8 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
num_workers
=
num_workers
,
batch_size
=
None
,
return_list
=
False
,
drop_last
=
True
)
drop_last
=
True
,
persistent_workers
=
persistent_workers
)
exe
=
fluid
.
Executor
(
place
=
places
[
0
])
exe
.
run
(
startup_prog
)
...
...
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py
浏览文件 @
76710e5f
...
...
@@ -94,14 +94,14 @@ def prepare_places(with_data_parallel, with_cpu=False, with_gpu=True):
if
with_gpu
and
fluid
.
core
.
is_compiled_with_cuda
():
tmp
=
fluid
.
cuda_places
()[:
2
]
assert
len
(
tmp
)
>
0
,
"no gpu detected"
if
with_data_parallel
:
if
with_data_parallel
and
len
(
tmp
)
>
1
:
places
.
append
(
tmp
)
places
.
append
([
tmp
[
0
]])
return
places
class
TestStaticDataLoader
(
unittest
.
TestCase
):
def
run_main
(
self
,
num_workers
,
places
,
use_pe
=
True
):
def
run_main
(
self
,
num_workers
,
places
,
persistent_workers
,
use_pe
=
True
):
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
startup_prog
,
main_prog
,
image
,
label
,
loss
=
simple_fc_net_static
()
...
...
@@ -114,7 +114,8 @@ class TestStaticDataLoader(unittest.TestCase):
num_workers
=
num_workers
,
batch_size
=
BATCH_SIZE
,
return_list
=
False
,
drop_last
=
True
)
drop_last
=
True
,
persistent_workers
=
persistent_workers
)
assert
len
(
dataloader
)
==
int
(
SAMPLE_NUM
/
BATCH_SIZE
)
exe
=
fluid
.
Executor
(
place
=
places
[
0
])
...
...
@@ -162,11 +163,16 @@ class TestStaticDataLoader(unittest.TestCase):
def
test_main
(
self
):
for
p
in
prepare_places
(
True
):
for
persistent_workers
in
[
True
,
False
]:
results
=
[]
for
num_workers
in
[
0
,
2
]:
print
(
self
.
__class__
.
__name__
,
p
,
num_workers
)
print
(
self
.
__class__
.
__name__
,
p
,
num_workers
,
persistent_workers
)
sys
.
stdout
.
flush
()
ret
=
self
.
run_main
(
num_workers
=
num_workers
,
places
=
p
)
ret
=
self
.
run_main
(
num_workers
=
num_workers
,
places
=
p
,
persistent_workers
=
persistent_workers
)
results
.
append
(
ret
)
diff
=
np
.
max
(
np
.
abs
(
results
[
0
][
'loss'
]
-
results
[
1
][
'loss'
])
/
...
...
@@ -241,7 +247,7 @@ class RandomBatchedDataset(Dataset):
class
TestStaticDataLoaderWithBatchedDataset
(
TestStaticDataLoader
):
def
run_main
(
self
,
num_workers
,
places
):
def
run_main
(
self
,
num_workers
,
places
,
persistent_workers
):
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
startup_prog
,
main_prog
,
image
,
label
,
loss
=
simple_fc_net_static
()
...
...
@@ -254,7 +260,8 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
num_workers
=
num_workers
,
batch_size
=
None
,
return_list
=
False
,
drop_last
=
True
)
drop_last
=
True
,
persistent_workers
=
persistent_workers
)
assert
len
(
dataloader
)
==
int
(
SAMPLE_NUM
/
BATCH_SIZE
)
exe
=
fluid
.
Executor
(
place
=
places
[
0
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录