From 8ae9094e0759db04bfd80cbda0ead703c053ebdf Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Fri, 4 Jan 2019 11:32:34 +0800 Subject: [PATCH] polish and resolve conflicts test=develop --- paddle/fluid/framework/parallel_executor.cc | 2 +- python/paddle/fluid/executor.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 5c8776b62fe..f61c9e3a911 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -200,7 +200,7 @@ ParallelExecutor::ParallelExecutor( member_->build_strategy_ = build_strategy; member_->use_all_reduce_ = build_strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce; - member_->nranks_ = num_trainers * places.size(); + member_->nranks_ = build_strategy.num_trainers_ * places.size(); if (!member_->use_all_reduce_) { PADDLE_ENFORCE(places.size() > 1, diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 7c417cd8285..4003e988f22 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -375,7 +375,6 @@ class Executor(object): self._closed = True def _run_parallel(self, - exe, scope, feed=None, fetch_list=None, @@ -391,7 +390,8 @@ class Executor(object): feed_tensor.set(feed[feed_name], core.CPUPlace()) feed_tensor_dict[feed_name] = feed_tensor - exe.feed_and_split_tensor_into_local_scopes(feed_tensor_dict) + self.executor.feed_and_split_tensor_into_local_scopes( + feed_tensor_dict) elif isinstance(feed, list) or isinstance(feed, tuple): if len(feed) != len(self._places): raise ValueError( @@ -412,10 +412,10 @@ class Executor(object): tensor = tmp res_dict[feed_name] = tensor res.append(res_dict) - exe.feed_tensors_into_local_scopes(res) + self.executor.feed_tensors_into_local_scopes(res) fetch_var_name = '@FETCHED_VAR_NAME@' - exe.run(fetch_list, fetch_var_name) + self.executor.run(fetch_list, fetch_var_name) arr = scope.find_var(fetch_var_name).get_lod_tensor_array() if return_numpy: @@ -502,12 +502,13 @@ class Executor(object): self.executor = program._executor if program._is_data_parallel: return self._run_parallel( - exe=program._executor, scope=scope, feed=feed, fetch_list=fetch_list, return_numpy=return_numpy) else: + # TODO(panyx0718): Can compile program to optimize executor + # performance. return self._run( program._program, feed=feed, -- GitLab