提交 2e67f8ae 编写于 作者: S sneaxiy

add doc

test=develop
上级 c545f1ed
...@@ -96,6 +96,27 @@ def _cpu_num(): ...@@ -96,6 +96,27 @@ def _cpu_num():
def cuda_places(device_ids=None): def cuda_places(device_ids=None):
'''
Create a list of :code:`fluid.CUDAPlace` objects.
If :code:`device_ids` is None, environment variable of
:code:`FLAGS_selected_gpus` would be checked first. If
:code:`FLAGS_selected_gpus=0,1,2`, the returned list would
be [fluid.CUDAPlace(0), fluid.CUDAPlace(1), fluid.CUDAPlace(2)].
If :code:`FLAGS_selected_gpus` is not set, all visible
gpu places would be returned.
If :code:`device_ids` is not None, it should be the device
ids of gpus. For example, if :code:`device_ids=[0,1,2]`,
the returned list would be
[fluid.CUDAPlace(0), fluid.CUDAPlace(1), fluid.CUDAPlace(2)].
Args:
device_ids (None|list(int)|tuple(int)): gpu device id list.
Returns:
out (list(fluid.CUDAPlace)): gpu place list.
'''
assert core.is_compiled_with_cuda(), \ assert core.is_compiled_with_cuda(), \
"Not compiled with CUDA" "Not compiled with CUDA"
if device_ids is None: if device_ids is None:
...@@ -110,12 +131,40 @@ def cuda_places(device_ids=None): ...@@ -110,12 +131,40 @@ def cuda_places(device_ids=None):
def cpu_places(device_count=None): def cpu_places(device_count=None):
'''
Create a list of :code:`fluid.CPUPlace` objects.
If :code:`device_count` is None, the device count would
be determined by environment variable :code:`CPU_NUM`.
If :code:`CPU_NUM` is not set, the device count would
be determined by :code:`multiprocessing.cpu_count()`.
Args:
device_count (None|int): device number.
Returns:
out (list(fluid.CPUPlace)): cpu place list.
'''
if device_count is None: if device_count is None:
device_count = _cpu_num() device_count = _cpu_num()
return [core.CPUPlace()] * device_count return [core.CPUPlace()] * device_count
def cuda_pinned_places(device_count=None): def cuda_pinned_places(device_count=None):
'''
Create a list of :code:`fluid.CUDAPinnedPlace` objects.
If :code:`device_count` is None, the device count would
be determined by environment variable :code:`CPU_NUM`.
If :code:`CPU_NUM` is not set, the device count would
be determined by :code:`multiprocessing.cpu_count()`.
Args:
device_count (None|int): device number.
Returns:
out (list(fluid.CUDAPinnedPlace)): cuda pinned place list.
'''
assert core.is_compiled_with_cuda(), \ assert core.is_compiled_with_cuda(), \
"Not compiled with CUDA" "Not compiled with CUDA"
if device_count is None: if device_count is None:
......
...@@ -531,8 +531,7 @@ def _py_reader(capacity, ...@@ -531,8 +531,7 @@ def _py_reader(capacity,
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=reader_name) startup_var = startup_blk.create_var(name=reader_name)
startup_blk.append_op( startup_blk.append_op(
type='create_py_reader' type='create_py_reader',
if not lock_free else 'create_lock_free_py_reader',
inputs={'blocking_queue': [queue_name]}, inputs={'blocking_queue': [queue_name]},
outputs={'Out': [startup_var]}, outputs={'Out': [startup_var]},
attrs={ attrs={
......
...@@ -47,6 +47,76 @@ class PyReader(object): ...@@ -47,6 +47,76 @@ class PyReader(object):
capacity, capacity,
use_double_buffer=True, use_double_buffer=True,
iterable=True): iterable=True):
"""
Create a reader object for data feeding in Python.
Data would be prefetched using Python thread and be pushed
into a queue asynchronously. Data in the queue would be extracted
automatically when `Executor.run(...)` is called.
Args:
feed_list (list(Variable)|tuple(Variable)): feed variable list.
The variables should be created by :code:`fluid.layers.data()`.
capacity (int): capacity of the queue maintained in PyReader object.
use_double_buffer (bool): whether to use double_buffer_reader to
speed up data feeding.
iterable (bool): whether the created reader object is iterable.
Returns:
reader (Reader): the created reader object.
Examples:
1. If iterable = False, the created PyReader object is almost the
same as :code:`fluid.layers.py_reader()`. Operators would be
inserted into the program. User should call :code:`start()`
before each epoch and catch :code:`fluid.core.EOFException`
thrown by :code:`Executor.run()` when epoch ends. Once the
exception is caught, user should call :code:`reset()` to reset
the reader manually.
.. code-block:: python
image = fluid.layers.data(
name='image', shape=[784], dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
reader = fluid.io.PyReader(feed_list=[image, label],
capacity=4, iterable=False)
reader.decorate_paddle_reader(user_defined_reader)
... # definition of network is omitted
executor.run(fluid.default_main_program())
for _ in range(EPOCH_NUM):
reader.start()
while True:
try:
executor.run(feed=None, ...)
except fluid.core.EOFException:
reader.reset()
break
2. If iterable=True, the created PyReader object is decoupled with
the program. No operator would be inserted into the program.
In this case, the created reader is a Python generator, which
is iterable. User should feed the data yielded from PyReader
object into :code:`Executor.run(feed=...)`.
.. code-block:: python
image = fluid.layers.data(
name='image', shape=[784], dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
reader = fluid.io.PyReader(feed_list=[image, label],
capacity=4, iterable=True)
reader.decorate_paddle_reader(user_defined_reader,
places=fluid.cuda_places())
... # definition of network is omitted
executor.run(fluid.default_main_program())
for _ in range(EPOCH_NUM):
for data in reader():
executor.run(feed=data, ...)
"""
self._tensor_reader = None self._tensor_reader = None
self._thread = None self._thread = None
self._iterable = iterable self._iterable = iterable
...@@ -161,10 +231,18 @@ class PyReader(object): ...@@ -161,10 +231,18 @@ class PyReader(object):
self._thread.join() self._thread.join()
def start(self): def start(self):
'''
Start the data feeding thread.
Can only call when the reader object is not iterable.
'''
assert not self._iterable, "start() cannot be called when PyReader is iterable" assert not self._iterable, "start() cannot be called when PyReader is iterable"
self._start() self._start()
def reset(self): def reset(self):
'''
Reset the reader object when :code:`fluid.core.EOFException` raises.
Can only call when the reader object is not iterable.
'''
assert not self._iterable, "reset() cannot be called when PyReader is iterable" assert not self._iterable, "reset() cannot be called when PyReader is iterable"
self._reset() self._reset()
...@@ -190,6 +268,18 @@ class PyReader(object): ...@@ -190,6 +268,18 @@ class PyReader(object):
self._thread.start() self._thread.start()
def decorate_paddle_reader(self, reader, places=None): def decorate_paddle_reader(self, reader, places=None):
'''
Set the data source of the PyReader object.
The provided :code:`reader` should be a Python generator,
which yields numpy-typed batched data.
:code:`places` must be set when the PyReader object is iterable.
Args:
reader (generator): Python generator that yields numpy-typed
batched data.
'''
assert self._tensor_reader is None, \ assert self._tensor_reader is None, \
"Cannot reset the data source of PyReader" "Cannot reset the data source of PyReader"
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
...@@ -204,6 +294,18 @@ class PyReader(object): ...@@ -204,6 +294,18 @@ class PyReader(object):
self.decorate_tensor_provider(__tensor_reader_impl__, places) self.decorate_tensor_provider(__tensor_reader_impl__, places)
def decorate_tensor_provider(self, reader, places=None): def decorate_tensor_provider(self, reader, places=None):
'''
Set the data source of the PyReader object.
The provided :code:`reader` should be a Python generator,
which yields LoDTensor-typed batched data.
:code:`places` must be set when the PyReader object is iterable.
Args:
reader (generator): Python generator that yields LoDTensor-typed
batched data.
'''
assert self._tensor_reader is None, \ assert self._tensor_reader is None, \
"Cannot reset the data source of PyReader" "Cannot reset the data source of PyReader"
self._tensor_reader = reader self._tensor_reader = reader
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册