提交 7e7611d0 编写于 作者: C chengduoZH

when the number of samples of current batch is less than the count of devices, let it crash.

上级 b1a5a3ca
...@@ -174,6 +174,11 @@ void ParallelExecutor::SplitTensorToPlaces( ...@@ -174,6 +174,11 @@ void ParallelExecutor::SplitTensorToPlaces(
const std::unordered_map<std::string, LoDTensor> &feed_tensors) { const std::unordered_map<std::string, LoDTensor> &feed_tensors) {
for (auto it : feed_tensors) { for (auto it : feed_tensors) {
auto lod_tensors = it.second.SplitLoDTensor(member_->places_); auto lod_tensors = it.second.SplitLoDTensor(member_->places_);
PADDLE_ENFORCE_EQ(
member_->places_.size(), lod_tensors.size(),
"The number of samples of current batch is less than the count of "
"devices, currently, it is not allowed. (%d vs %d)",
member_->places_.size(), lod_tensors.size());
for (size_t j = 0; j < member_->places_.size(); ++j) { for (size_t j = 0; j < member_->places_.size(); ++j) {
// TODO(panxy0718): Do I need to delete this var? // TODO(panxy0718): Do I need to delete this var?
member_->local_scopes_[j] member_->local_scopes_[j]
......
...@@ -87,7 +87,8 @@ class ParallelExecutor(object): ...@@ -87,7 +87,8 @@ class ParallelExecutor(object):
# performance. Worth tunning for other models in the future. # performance. Worth tunning for other models in the future.
num_threads = len(self._places) num_threads = len(self._places)
else: else:
min(len(self._places) * 2, multiprocessing.cpu_count()) num_threads = min(
len(self._places) * 2, multiprocessing.cpu_count())
main = main_program main = main_program
main = main if main else framework.default_main_program() main = main if main else framework.default_main_program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册