提交 72869543 编写于 作者: Y Yu Yang

Add more comments

上级 a1d910b1
...@@ -16,6 +16,7 @@ import core ...@@ -16,6 +16,7 @@ import core
import multiprocessing import multiprocessing
import framework import framework
import executor import executor
import sys
__all__ = ['ParallelExecutor'] __all__ = ['ParallelExecutor']
...@@ -125,6 +126,30 @@ class ParallelExecutor(object): ...@@ -125,6 +126,30 @@ class ParallelExecutor(object):
def run(self, fetch_list, feed=None, feed_dict=None): def run(self, fetch_list, feed=None, feed_dict=None):
""" """
Run a parallel executor with fetch_list.
The feed parameter can be a dict or a list. If feed is a dict, the
feed data will be split into multiple devices. If feed is a list, we
assume the data has been splitted into multiple devices, the each
element in the list will be copied to each device directly.
For example, if the feed is a dict:
>>> exe = ParallelExecutor()
>>> # the image will be splitted into devices. If there is two devices
>>> # each device will process an image with shape (24, 1, 28, 28)
>>> exe.run(feed={'image': numpy.random.random(size=(48, 1, 28, 28))})
For example, if the feed is a list:
>>> exe = ParallelExecutor()
>>> # each device will process each element in the list.
>>> # the 1st device will process an image with shape (48, 1, 28, 28)
>>> # the 2nd device will process an image with shape (32, 1, 28, 28)
>>> #
>>> # you can use exe.device_count to get the device number.
>>> exe.run(feed=[{"image": numpy.random.random(size=(48, 1, 28, 28))},
>>> {"image": numpy.random.random(size=(32, 1, 28, 28))},
>>> ])
Args: Args:
fetch_list(list): The fetched variable names fetch_list(list): The fetched variable names
...@@ -133,12 +158,14 @@ class ParallelExecutor(object): ...@@ -133,12 +158,14 @@ class ParallelExecutor(object):
the feed is a list, each element of the list will be copied the feed is a list, each element of the list will be copied
to each device. to each device.
feed_dict: Alias for feed parameter, for backward compatibility. feed_dict: Alias for feed parameter, for backward compatibility.
This parameter is deprecated.
Returns: fetched result list. Returns: fetched result list.
""" """
if feed is None: if feed is None:
feed = feed_dict feed = feed_dict
print >> sys.stderr, "`feed_dict` is deprecated. Please use `feed=`"
if isinstance(feed, dict): if isinstance(feed, dict):
feed_tensor_dict = dict() feed_tensor_dict = dict()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册