diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 4a1b1ad72c05279fcf938b89c4d3c23015a50327..de6e1920bbc6111fa5d6cb2ed865c49dd4a9b6b3 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -56,13 +56,17 @@ class FakeMicroDataset: def _load_micro_batch(self, micro_step): inputs = self._data - if self._is_first_stage or self._is_last_stage: + data = None + label = None + if self._is_first_stage: assert len(inputs) == 2, "length of input should be 2" data = self._load_micro_batch_impl(inputs[0], micro_step) + + if self._is_last_stage: + assert len(inputs) == 2, "length of input should be 2" label = self._load_micro_batch_impl(inputs[1], micro_step) - return (data, label) - else: - return (None, None) + + return (data, label) def _load_micro_batch_impl(self, inputs, micro_step): begin = micro_step * self._micro_batch_size