From 7e7611d06753a6eafafda8042ad473535895e07f Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 9 Apr 2018 12:53:28 +0800 Subject: [PATCH] when the number of samples of current batch is less than the count of devices, let it crash. --- paddle/fluid/framework/parallel_executor.cc | 5 +++++ python/paddle/fluid/parallel_executor.py | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 74945fb4f..99b3065d8 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -174,6 +174,11 @@ void ParallelExecutor::SplitTensorToPlaces( const std::unordered_map &feed_tensors) { for (auto it : feed_tensors) { auto lod_tensors = it.second.SplitLoDTensor(member_->places_); + PADDLE_ENFORCE_EQ( + member_->places_.size(), lod_tensors.size(), + "The number of samples of current batch is less than the count of " + "devices, currently, it is not allowed. (%d vs %d)", + member_->places_.size(), lod_tensors.size()); for (size_t j = 0; j < member_->places_.size(); ++j) { // TODO(panxy0718): Do I need to delete this var? member_->local_scopes_[j] diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index b93f2f974..24dfa6144 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -87,7 +87,8 @@ class ParallelExecutor(object): # performance. Worth tunning for other models in the future. num_threads = len(self._places) else: - min(len(self._places) * 2, multiprocessing.cpu_count()) + num_threads = min( + len(self._places) * 2, multiprocessing.cpu_count()) main = main_program main = main if main else framework.default_main_program() -- GitLab