From a64844ad00e47dda549aba0e1846efcc185609d6 Mon Sep 17 00:00:00 2001 From: chengduo Date: Tue, 26 Jun 2018 11:46:26 +0800 Subject: [PATCH] enable PE return numpy (#11704) --- python/paddle/fluid/executor.py | 2 ++ python/paddle/fluid/parallel_executor.py | 7 ++++++- .../tests/unittests/test_parallel_executor_fetch_feed.py | 4 +++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index dc27567461..145f1423e4 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -78,6 +78,8 @@ def as_numpy(tensor): Returns: numpy.ndarray """ + if isinstance(tensor, core.LoDTensorArray): + return [as_numpy(t) for t in tensor] if isinstance(tensor, list): return [as_numpy(t) for t in tensor] assert isinstance(tensor, core.LoDTensor) diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 25cc1355d5..bb7b7d82f0 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -160,7 +160,7 @@ class ParallelExecutor(object): build_strategy, num_trainers, trainer_id) self.scope = scope - def run(self, fetch_list, feed=None, feed_dict=None): + def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=False): """ Run a parallel executor with fetch_list. @@ -196,6 +196,8 @@ class ParallelExecutor(object): to each device. Default None. feed_dict: Alias for feed parameter, for backward compatibility. This parameter has been deprecated. Default None. + return_numpy(bool): Whether converts the fetched tensor to numpy. + Default: False. Returns: List: The fetched result list. @@ -270,6 +272,9 @@ class ParallelExecutor(object): if self.is_dist: self.bcast_params() + if return_numpy: + return executor.as_numpy(arr) + return [arr[i] for i in range(len(arr))] def bcast_params(self): diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py index 79702475cc..3b18072c7b 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py @@ -75,7 +75,9 @@ class TestFetchOp(unittest.TestCase): fetch_list.append(k) for data in train_inputs: - ret = pe.run(fetch_list, feed=feeder.feed(data)) + ret = pe.run(fetch_list, + feed=feeder.feed(data), + return_numpy=True) for i in range(len(fetch_list)): assert not math.isnan(np.sum(ret[i])) and \ not math.isinf(np.sum(ret[i])) -- GitLab