From 23c24af91851d59e4cdbd1ccc09960daaa0e0777 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Thu, 29 Jun 2023 19:26:12 +0800 Subject: [PATCH] fix bug of pp (#54831) --- .../fleet/meta_parallel/pipeline_parallel.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 4a1b1ad72c0..de6e1920bbc 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -56,13 +56,17 @@ class FakeMicroDataset: def _load_micro_batch(self, micro_step): inputs = self._data - if self._is_first_stage or self._is_last_stage: + data = None + label = None + if self._is_first_stage: assert len(inputs) == 2, "length of input should be 2" data = self._load_micro_batch_impl(inputs[0], micro_step) + + if self._is_last_stage: + assert len(inputs) == 2, "length of input should be 2" label = self._load_micro_batch_impl(inputs[1], micro_step) - return (data, label) - else: - return (None, None) + + return (data, label) def _load_micro_batch_impl(self, inputs, micro_step): begin = micro_step * self._micro_batch_size -- GitLab