Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
89d27de9
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
89d27de9
编写于
11月 16, 2020
作者:
K
Kaipeng Deng
提交者:
GitHub
11月 16, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
DataLoader support not auto collate batch (#28425)
* DataLoader support not auto collate batch. test=develop
上级
c5c273c1
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
327 addition
and
36 deletion
+327
-36
python/paddle/fluid/dataloader/dataloader_iter.py
python/paddle/fluid/dataloader/dataloader_iter.py
+24
-10
python/paddle/fluid/dataloader/fetcher.py
python/paddle/fluid/dataloader/fetcher.py
+31
-18
python/paddle/fluid/reader.py
python/paddle/fluid/reader.py
+32
-4
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dynamic.py
...d/tests/unittests/test_multiprocess_dataloader_dynamic.py
+44
-1
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py
...tests/unittests/test_multiprocess_dataloader_exception.py
+2
-2
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py
.../test_multiprocess_dataloader_iterable_dataset_dynamic.py
+42
-1
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py
...s/test_multiprocess_dataloader_iterable_dataset_static.py
+75
-0
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py
...id/tests/unittests/test_multiprocess_dataloader_static.py
+77
-0
未找到文件。
python/paddle/fluid/dataloader/dataloader_iter.py
浏览文件 @
89d27de9
...
...
@@ -36,6 +36,7 @@ from .. import core, layers
from
..framework
import
in_dygraph_mode
from
..multiprocess_utils
import
CleanupFuncRegistrar
,
_cleanup_mmap
,
_set_SIGCHLD_handler
from
.fetcher
import
_IterableDatasetFetcher
,
_MapDatasetFetcher
from
.batch_sampler
import
_InfiniteIterableSampler
__all__
=
[
'get_worker_info'
]
...
...
@@ -100,11 +101,13 @@ class _DatasetKind(object):
ITER
=
1
@
staticmethod
def
create_fetcher
(
kind
,
dataset
,
collate_fn
,
drop_last
):
def
create_fetcher
(
kind
,
dataset
,
auto_collate_batch
,
collate_fn
,
drop_last
):
if
kind
==
_DatasetKind
.
MAP
:
return
_MapDatasetFetcher
(
dataset
,
collate_fn
,
drop_last
)
return
_MapDatasetFetcher
(
dataset
,
auto_collate_batch
,
collate_fn
,
drop_last
)
elif
kind
==
_DatasetKind
.
ITER
:
return
_IterableDatasetFetcher
(
dataset
,
collate_fn
,
drop_last
)
return
_IterableDatasetFetcher
(
dataset
,
auto_collate_batch
,
collate_fn
,
drop_last
)
else
:
raise
NotImplementedError
(
"unknown Dataset kind {}"
.
format
(
kind
))
...
...
@@ -221,8 +224,7 @@ class _DataLoaderIterBase(object):
self
.
_places
=
loader
.
places
self
.
_return_list
=
loader
.
return_list
self
.
_batch_sampler
=
loader
.
batch_sampler
self
.
_sampler_iter
=
iter
(
loader
.
batch_sampler
)
self
.
_collate_fn
=
loader
.
collate_fn
or
default_collate_fn
self
.
_auto_collate_batch
=
loader
.
auto_collate_batch
self
.
_num_workers
=
loader
.
num_workers
self
.
_use_buffer_reader
=
loader
.
use_buffer_reader
self
.
_use_shared_memory
=
loader
.
use_shared_memory
...
...
@@ -231,6 +233,16 @@ class _DataLoaderIterBase(object):
self
.
_dataset_kind
=
loader
.
dataset_kind
self
.
_pin_memory
=
loader
.
pin_memory
if
self
.
_auto_collate_batch
:
self
.
_sampler_iter
=
iter
(
loader
.
batch_sampler
)
self
.
_collate_fn
=
loader
.
collate_fn
or
default_collate_fn
else
:
if
self
.
_dataset_kind
==
_DatasetKind
.
MAP
:
self
.
_sampler_iter
=
iter
(
list
(
range
(
len
(
self
.
_dataset
))))
else
:
self
.
_sampler_iter
=
iter
(
_InfiniteIterableSampler
(
self
.
_dataset
,
1
))
self
.
_collate_fn
=
loader
.
collate_fn
# LoDTensorBlockingQueue instance for create_py_reader and a thread
# to put mini-batch data to self._blocking_queue, mini-batch data
# will be get from:
...
...
@@ -257,7 +269,8 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
super
(
_DataLoaderIterSingleProcess
,
self
).
__init__
(
loader
)
self
.
_dataset_fetcher
=
_DatasetKind
.
create_fetcher
(
self
.
_dataset_kind
,
self
.
_dataset
,
self
.
_collate_fn
,
True
)
self
.
_dataset_kind
,
self
.
_dataset
,
self
.
_auto_collate_batch
,
self
.
_collate_fn
,
True
)
# NOTE: len(self._places) batch data compose as an output
# iteration, set blocking_queue can cache 2 iteration datas
...
...
@@ -367,7 +380,7 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
# NOTE(chenweihang): _worker_loop must be top level method to be pickled
def
_worker_loop
(
dataset
,
dataset_kind
,
indices_queue
,
out_queue
,
done_event
,
collate_fn
,
init_fn
,
worker_id
,
num_workers
,
auto_collate_batch
,
collate_fn
,
init_fn
,
worker_id
,
num_workers
,
use_shared_memory
):
try
:
# NOTE: [ mmap files clear ] When the child process exits unexpectedly,
...
...
@@ -388,7 +401,7 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
if
init_fn
is
not
None
:
init_fn
(
worker_id
)
fetcher
=
_DatasetKind
.
create_fetcher
(
dataset_kind
,
dataset
,
collate_fn
,
True
)
auto_collate_batch
,
collate_fn
,
True
)
except
:
init_exception
=
Exception
(
"init_fn failed in worker {}: "
\
"{}"
.
format
(
worker_id
,
sys
.
exc_info
()))
...
...
@@ -511,8 +524,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
target
=
_worker_loop
,
args
=
(
self
.
_dataset
,
self
.
_dataset_kind
,
indices_queue
,
self
.
_data_queue
,
self
.
_workers_done_event
,
self
.
_collate_fn
,
self
.
_worker_init_fn
,
i
,
self
.
_num_workers
,
self
.
_use_shared_memory
))
self
.
_auto_collate_batch
,
self
.
_collate_fn
,
self
.
_worker_init_fn
,
i
,
self
.
_num_workers
,
self
.
_use_shared_memory
))
worker
.
daemon
=
True
worker
.
start
()
self
.
_workers
.
append
(
worker
)
...
...
python/paddle/fluid/dataloader/fetcher.py
浏览文件 @
89d27de9
...
...
@@ -14,8 +14,9 @@
class
_DatasetFetcher
(
object
):
def
__init__
(
self
,
dataset
,
collate_fn
,
drop_last
):
def
__init__
(
self
,
dataset
,
auto_collate_batch
,
collate_fn
,
drop_last
):
self
.
dataset
=
dataset
self
.
auto_collate_batch
=
auto_collate_batch
self
.
collate_fn
=
collate_fn
self
.
drop_last
=
drop_last
...
...
@@ -25,29 +26,41 @@ class _DatasetFetcher(object):
class
_IterableDatasetFetcher
(
_DatasetFetcher
):
def
__init__
(
self
,
dataset
,
collate_fn
,
drop_last
):
super
(
_IterableDatasetFetcher
,
self
).
__init__
(
dataset
,
collate_fn
,
drop_last
)
def
__init__
(
self
,
dataset
,
auto_collate_batch
,
collate_fn
,
drop_last
):
super
(
_IterableDatasetFetcher
,
self
).
__init__
(
dataset
,
auto_collate_batch
,
collate_fn
,
drop_last
)
self
.
dataset_iter
=
iter
(
dataset
)
def
fetch
(
self
,
batch_indices
):
data
=
[]
for
_
in
batch_indices
:
try
:
data
.
append
(
next
(
self
.
dataset_iter
))
except
StopIteration
:
break
if
len
(
data
)
==
0
or
(
self
.
drop_last
and
len
(
data
)
<
len
(
batch_indices
)):
raise
StopIteration
return
self
.
collate_fn
(
data
)
if
self
.
auto_collate_batch
:
data
=
[]
for
_
in
batch_indices
:
try
:
data
.
append
(
next
(
self
.
dataset_iter
))
except
StopIteration
:
break
if
len
(
data
)
==
0
or
(
self
.
drop_last
and
len
(
data
)
<
len
(
batch_indices
)):
raise
StopIteration
else
:
data
=
next
(
self
.
dataset_iter
)
if
self
.
collate_fn
:
data
=
self
.
collate_fn
(
data
)
return
data
class
_MapDatasetFetcher
(
_DatasetFetcher
):
def
__init__
(
self
,
dataset
,
collate_fn
,
drop_last
):
super
(
_MapDatasetFetcher
,
self
).
__init__
(
dataset
,
collate_fn
,
drop_last
)
def
__init__
(
self
,
dataset
,
auto_collate_batch
,
collate_fn
,
drop_last
):
super
(
_MapDatasetFetcher
,
self
).
__init__
(
dataset
,
auto_collate_batch
,
collate_fn
,
drop_last
)
def
fetch
(
self
,
batch_indices
):
data
=
[
self
.
dataset
[
idx
]
for
idx
in
batch_indices
]
return
self
.
collate_fn
(
data
)
if
self
.
auto_collate_batch
:
data
=
[
self
.
dataset
[
idx
]
for
idx
in
batch_indices
]
else
:
data
=
self
.
dataset
[
batch_indices
]
if
self
.
collate_fn
:
data
=
self
.
collate_fn
(
data
)
return
data
python/paddle/fluid/reader.py
浏览文件 @
89d27de9
...
...
@@ -163,6 +163,21 @@ class DataLoader(object):
For :code:`batch_sampler` please see :code:`paddle.io.BatchSampler`
**Disable automatic batching**
In certain cases such as some NLP tasks, instead of automatic batching,
handling batching manually in dataset is needed by users. For these
cases, automatic batching is disabled if both :attr:`batch_size` and
:attr:`batch_sampler` is set as None, each data got from :attr:`dataset`
should be batched data and will be processed with function define by
:attr:`collate_fn` or :attr:`default_collate_fn`.
.. note::
When automatic batching is disabled, :attr:`default_collate_fn` will
do nothing to data from dataset.
Args:
dataset(Dataset): the dataset to load data from, should be an
instance of subclass of :code:`paddle.io.Dataset` or
...
...
@@ -185,7 +200,7 @@ class DataLoader(object):
batch_sampler(BatchSampler): an instance of `paddle.io.BatchSampler`
to generate batch indices to draw samples from :attr:`dataset`
and combine a batch. Default None.
batch_size(int): sample number in a mini-batch, a substitution
batch_size(int
|None
): sample number in a mini-batch, a substitution
parameter for :attr:`batch_sampler`, if :attr:`batch_sampler`
is not set, a default `paddle.io.BatchSampler` will be used
and initialize by :attr:`batch_size`, :attr:`shuffle` and
...
...
@@ -358,10 +373,15 @@ class DataLoader(object):
"batch_size/shuffle/drop_last should not be set when "
\
"batch_sampler is given"
self
.
batch_sampler
=
batch_sampler
self
.
batch_size
=
None
elif
batch_size
is
None
:
self
.
batch_sampler
=
None
self
.
batch_size
=
None
else
:
assert
batch_size
is
not
None
and
batch_size
>
0
,
\
"batch_size should be a positive value when "
\
assert
batch_size
>
0
,
\
"batch_size should be
None or
a positive value when "
\
"batch_sampler is not given"
self
.
batch_size
=
batch_size
if
isinstance
(
dataset
,
IterableDataset
):
self
.
batch_sampler
=
_InfiniteIterableSampler
(
dataset
,
batch_size
)
...
...
@@ -372,13 +392,21 @@ class DataLoader(object):
shuffle
=
shuffle
,
drop_last
=
drop_last
)
self
.
auto_collate_batch
=
self
.
batch_sampler
is
not
None
self
.
pin_memory
=
False
if
in_dygraph_mode
():
self
.
pin_memory
=
True
if
use_pinned_memory
(
)
is
None
else
use_pinned_memory
()
def
__len__
(
self
):
return
len
(
self
.
batch_sampler
)
if
self
.
dataset_kind
==
_DatasetKind
.
ITER
:
raise
ValueError
(
"length of IterableDataset not supported"
)
else
:
if
self
.
batch_size
is
None
:
return
len
(
self
.
dataset
)
else
:
return
len
(
self
.
batch_sampler
)
def
__iter__
(
self
):
if
self
.
num_workers
==
0
:
...
...
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dynamic.py
浏览文件 @
89d27de9
...
...
@@ -27,7 +27,7 @@ from paddle.io import Dataset, BatchSampler, DataLoader
from
paddle.fluid.dygraph.nn
import
Linear
from
paddle.fluid.dygraph.base
import
to_variable
from
test_multiprocess_dataloader_static
import
RandomDataset
,
prepare_places
from
test_multiprocess_dataloader_static
import
RandomDataset
,
RandomBatchedDataset
,
prepare_places
from
test_multiprocess_dataloader_static
import
EPOCH_NUM
,
BATCH_SIZE
,
IMAGE_SIZE
,
SAMPLE_NUM
,
CLASS_NUM
...
...
@@ -122,5 +122,48 @@ class TestDygraphDataLoader(unittest.TestCase):
self
.
assertLess
(
diff
,
1e-2
)
class
TestDygraphDataLoaderWithBatchedDataset
(
TestDygraphDataLoader
):
def
run_main
(
self
,
num_workers
,
places
):
fluid
.
default_startup_program
().
random_seed
=
1
fluid
.
default_main_program
().
random_seed
=
1
with
fluid
.
dygraph
.
guard
(
places
[
0
]):
fc_net
=
SimpleFCNet
()
optimizer
=
fluid
.
optimizer
.
Adam
(
parameter_list
=
fc_net
.
parameters
())
dataset
=
RandomBatchedDataset
(
SAMPLE_NUM
,
CLASS_NUM
)
dataloader
=
DataLoader
(
dataset
,
num_workers
=
num_workers
,
batch_size
=
None
,
drop_last
=
True
)
assert
len
(
dataloader
)
==
int
(
SAMPLE_NUM
/
BATCH_SIZE
)
step_list
=
[]
loss_list
=
[]
start_t
=
time
.
time
()
for
_
in
six
.
moves
.
range
(
EPOCH_NUM
):
step
=
0
for
image
,
label
in
dataloader
():
out
=
fc_net
(
image
)
loss
=
fluid
.
layers
.
cross_entropy
(
out
,
label
)
avg_loss
=
fluid
.
layers
.
reduce_mean
(
loss
)
avg_loss
.
backward
()
optimizer
.
minimize
(
avg_loss
)
fc_net
.
clear_gradients
()
loss_list
.
append
(
np
.
mean
(
avg_loss
.
numpy
()))
step
+=
1
step_list
.
append
(
step
)
end_t
=
time
.
time
()
ret
=
{
"time"
:
end_t
-
start_t
,
"step"
:
step_list
,
"loss"
:
np
.
array
(
loss_list
)
}
print
(
"time cost"
,
ret
[
'time'
],
'step_list'
,
ret
[
'step'
])
return
ret
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py
浏览文件 @
89d27de9
...
...
@@ -188,7 +188,7 @@ class TestDataLoaderWorkerLoop(unittest.TestCase):
indices_queue
.
put
(
None
)
_worker_loop
(
loader
.
_dataset
,
0
,
indices_queue
,
loader
.
_data_queue
,
loader
.
_workers_done_event
,
_collate_fn
,
_init_fn
,
0
,
1
,
True
,
_collate_fn
,
_init_fn
,
0
,
1
,
loader
.
_use_shared_memory
)
self
.
assertTrue
(
False
)
except
AssertionError
:
...
...
@@ -232,7 +232,7 @@ class TestDataLoaderWorkerLoop(unittest.TestCase):
loader
.
_workers_done_event
.
set
()
_worker_loop
(
loader
.
_dataset
,
0
,
indices_queue
,
loader
.
_data_queue
,
loader
.
_workers_done_event
,
_collate_fn
,
_init_fn
,
0
,
1
,
True
,
_collate_fn
,
_init_fn
,
0
,
1
,
loader
.
_use_shared_memory
)
self
.
assertTrue
(
True
)
except
AssertionError
:
...
...
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py
浏览文件 @
89d27de9
...
...
@@ -27,7 +27,7 @@ from paddle.io import Dataset, BatchSampler, DataLoader
from
paddle.fluid.dygraph.nn
import
Linear
from
paddle.fluid.dygraph.base
import
to_variable
from
test_multiprocess_dataloader_iterable_dataset_static
import
RandomDataset
,
prepare_places
from
test_multiprocess_dataloader_iterable_dataset_static
import
RandomDataset
,
RandomBatchedDataset
,
prepare_places
from
test_multiprocess_dataloader_iterable_dataset_static
import
EPOCH_NUM
,
BATCH_SIZE
,
IMAGE_SIZE
,
SAMPLE_NUM
,
CLASS_NUM
...
...
@@ -119,5 +119,46 @@ class TestDygraphDataLoader(unittest.TestCase):
0
]
class
TestDygraphDataLoaderWithBatchedDataset
(
TestDygraphDataLoader
):
def
run_main
(
self
,
num_workers
,
places
):
fluid
.
default_startup_program
().
random_seed
=
1
fluid
.
default_main_program
().
random_seed
=
1
with
fluid
.
dygraph
.
guard
(
places
[
0
]):
fc_net
=
SimpleFCNet
()
optimizer
=
fluid
.
optimizer
.
Adam
(
parameter_list
=
fc_net
.
parameters
())
dataset
=
RandomBatchedDataset
(
SAMPLE_NUM
,
CLASS_NUM
)
dataloader
=
DataLoader
(
dataset
,
num_workers
=
num_workers
,
batch_size
=
None
,
drop_last
=
True
)
step_list
=
[]
loss_list
=
[]
start_t
=
time
.
time
()
for
_
in
six
.
moves
.
range
(
EPOCH_NUM
):
step
=
0
for
image
,
label
in
dataloader
():
out
=
fc_net
(
image
)
loss
=
fluid
.
layers
.
cross_entropy
(
out
,
label
)
avg_loss
=
fluid
.
layers
.
reduce_mean
(
loss
)
avg_loss
.
backward
()
optimizer
.
minimize
(
avg_loss
)
fc_net
.
clear_gradients
()
loss_list
.
append
(
np
.
mean
(
avg_loss
.
numpy
()))
step
+=
1
step_list
.
append
(
step
)
end_t
=
time
.
time
()
ret
=
{
"time"
:
end_t
-
start_t
,
"step"
:
step_list
,
"loss"
:
np
.
array
(
loss_list
)
}
print
(
"time cost"
,
ret
[
'time'
],
'step_list'
,
ret
[
'step'
])
return
ret
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py
浏览文件 @
89d27de9
...
...
@@ -167,5 +167,80 @@ class TestStaticDataLoader(unittest.TestCase):
0
]
class
RandomBatchedDataset
(
IterableDataset
):
def
__init__
(
self
,
sample_num
,
class_num
):
self
.
sample_num
=
sample_num
//
BATCH_SIZE
self
.
class_num
=
class_num
def
__iter__
(
self
):
for
i
in
range
(
self
.
sample_num
):
np
.
random
.
seed
(
i
)
images
=
[]
labels
=
[]
for
_
in
range
(
BATCH_SIZE
):
image
=
np
.
random
.
random
([
IMAGE_SIZE
]).
astype
(
'float32'
)
label
=
np
.
random
.
randint
(
0
,
self
.
class_num
-
1
,
(
1
,
)).
astype
(
'int64'
)
images
.
append
(
image
)
labels
.
append
(
label
)
yield
np
.
stack
(
images
,
axis
=
0
),
np
.
stack
(
labels
,
axis
=
0
)
class
TestStaticDataLoaderWithBatchedDataset
(
TestStaticDataLoader
):
def
run_main
(
self
,
num_workers
,
places
):
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
startup_prog
,
main_prog
,
image
,
label
,
loss
=
simple_fc_net_static
()
dataset
=
RandomBatchedDataset
(
SAMPLE_NUM
,
CLASS_NUM
)
dataloader
=
DataLoader
(
dataset
,
feed_list
=
[
image
,
label
],
places
=
places
,
num_workers
=
num_workers
,
batch_size
=
None
,
drop_last
=
True
)
exe
=
fluid
.
Executor
(
place
=
places
[
0
])
exe
.
run
(
startup_prog
)
prog
=
fluid
.
CompiledProgram
(
main_prog
)
if
len
(
places
)
>
1
:
prog
=
prog
.
with_data_parallel
(
loss_name
=
loss
.
name
,
places
=
places
)
step_list
=
[]
loss_list
=
[]
start_t
=
time
.
time
()
for
i
in
six
.
moves
.
range
(
EPOCH_NUM
):
step
=
0
for
d
in
dataloader
:
assert
len
(
d
)
==
len
(
places
),
"{} != {}"
.
format
(
len
(
d
),
len
(
places
))
for
i
,
item
in
enumerate
(
d
):
image
=
item
[
'image'
]
label
=
item
[
'label'
]
assert
image
.
shape
()
==
[
BATCH_SIZE
,
IMAGE_SIZE
]
assert
label
.
shape
()
==
[
BATCH_SIZE
,
1
]
assert
image
.
_place
().
_equals
(
places
[
i
])
assert
label
.
_place
().
_equals
(
places
[
i
])
L
,
=
exe
.
run
(
program
=
prog
,
feed
=
d
,
fetch_list
=
[
loss
],
use_program_cache
=
True
)
loss_list
.
append
(
np
.
mean
(
L
))
step
+=
1
step_list
.
append
(
step
)
end_t
=
time
.
time
()
ret
=
{
"time"
:
end_t
-
start_t
,
"step"
:
step_list
,
"loss"
:
np
.
array
(
loss_list
)
}
print
(
"time cost"
,
ret
[
'time'
],
'step_list'
,
ret
[
'step'
])
return
ret
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py
浏览文件 @
89d27de9
...
...
@@ -215,5 +215,82 @@ class TestStaticDataLoaderReturnList(unittest.TestCase):
assert
isinstance
(
d
[
1
],
list
)
class
RandomBatchedDataset
(
Dataset
):
def
__init__
(
self
,
sample_num
,
class_num
):
self
.
sample_num
=
int
(
sample_num
/
BATCH_SIZE
)
self
.
class_num
=
class_num
def
__getitem__
(
self
,
idx
):
np
.
random
.
seed
(
idx
)
images
=
[]
labels
=
[]
for
_
in
range
(
BATCH_SIZE
):
image
=
np
.
random
.
random
([
IMAGE_SIZE
]).
astype
(
'float32'
)
label
=
np
.
random
.
randint
(
0
,
self
.
class_num
-
1
,
(
1
,
)).
astype
(
'int64'
)
images
.
append
(
image
)
labels
.
append
(
label
)
return
np
.
stack
(
images
,
axis
=
0
),
np
.
stack
(
labels
,
axis
=
0
)
def
__len__
(
self
):
return
self
.
sample_num
class
TestStaticDataLoaderWithBatchedDataset
(
TestStaticDataLoader
):
def
run_main
(
self
,
num_workers
,
places
):
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
startup_prog
,
main_prog
,
image
,
label
,
loss
=
simple_fc_net_static
()
dataset
=
RandomBatchedDataset
(
SAMPLE_NUM
,
CLASS_NUM
)
dataloader
=
DataLoader
(
dataset
,
feed_list
=
[
image
,
label
],
places
=
places
,
num_workers
=
num_workers
,
batch_size
=
None
,
drop_last
=
True
)
assert
len
(
dataloader
)
==
int
(
SAMPLE_NUM
/
BATCH_SIZE
)
exe
=
fluid
.
Executor
(
place
=
places
[
0
])
exe
.
run
(
startup_prog
)
prog
=
fluid
.
CompiledProgram
(
main_prog
)
if
len
(
places
)
>
1
:
prog
=
prog
.
with_data_parallel
(
loss_name
=
loss
.
name
,
places
=
places
)
step_list
=
[]
loss_list
=
[]
start_t
=
time
.
time
()
for
_
in
six
.
moves
.
range
(
EPOCH_NUM
):
step
=
0
for
d
in
dataloader
:
assert
len
(
d
)
==
len
(
places
),
"{} != {}"
.
format
(
len
(
d
),
len
(
places
))
for
i
,
item
in
enumerate
(
d
):
image
=
item
[
'image'
]
label
=
item
[
'label'
]
assert
image
.
shape
()
==
[
BATCH_SIZE
,
IMAGE_SIZE
]
assert
label
.
shape
()
==
[
BATCH_SIZE
,
1
]
assert
image
.
_place
().
_equals
(
places
[
i
])
assert
label
.
_place
().
_equals
(
places
[
i
])
L
,
=
exe
.
run
(
program
=
prog
,
feed
=
d
,
fetch_list
=
[
loss
],
use_program_cache
=
True
)
loss_list
.
append
(
np
.
mean
(
L
))
step
+=
1
step_list
.
append
(
step
)
end_t
=
time
.
time
()
ret
=
{
"time"
:
end_t
-
start_t
,
"step"
:
step_list
,
"loss"
:
np
.
array
(
loss_list
)
}
print
(
"time cost"
,
ret
[
'time'
],
'step_list'
,
ret
[
'step'
])
return
ret
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录