diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 190efe6c1bcc7916461b1069eb0e59b4d108a13d..07cc1e29341bd497e88097a9ee5653631b79d734 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -16,6 +16,7 @@ import core import multiprocessing import framework import executor +import sys __all__ = ['ParallelExecutor'] @@ -125,6 +126,30 @@ class ParallelExecutor(object): def run(self, fetch_list, feed=None, feed_dict=None): """ + Run a parallel executor with fetch_list. + + The feed parameter can be a dict or a list. If feed is a dict, the + feed data will be split into multiple devices. If feed is a list, we + assume the data has been splitted into multiple devices, the each + element in the list will be copied to each device directly. + + For example, if the feed is a dict: + >>> exe = ParallelExecutor() + >>> # the image will be splitted into devices. If there is two devices + >>> # each device will process an image with shape (24, 1, 28, 28) + >>> exe.run(feed={'image': numpy.random.random(size=(48, 1, 28, 28))}) + + For example, if the feed is a list: + >>> exe = ParallelExecutor() + >>> # each device will process each element in the list. + >>> # the 1st device will process an image with shape (48, 1, 28, 28) + >>> # the 2nd device will process an image with shape (32, 1, 28, 28) + >>> # + >>> # you can use exe.device_count to get the device number. + >>> exe.run(feed=[{"image": numpy.random.random(size=(48, 1, 28, 28))}, + >>> {"image": numpy.random.random(size=(32, 1, 28, 28))}, + >>> ]) + Args: fetch_list(list): The fetched variable names @@ -133,12 +158,14 @@ class ParallelExecutor(object): the feed is a list, each element of the list will be copied to each device. feed_dict: Alias for feed parameter, for backward compatibility. + This parameter is deprecated. Returns: fetched result list. """ if feed is None: feed = feed_dict + print >> sys.stderr, "`feed_dict` is deprecated. Please use `feed=`" if isinstance(feed, dict): feed_tensor_dict = dict()