From 802613cca75aa4af9a7b5f8fed9a913d456d24f0 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Mon, 3 Jul 2023 15:19:29 +0800 Subject: [PATCH] relax_micro_batch_check (#54788) --- .../distributed/fleet/meta_parallel/pipeline_parallel.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 8fcc3d855e1..808019651a5 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -715,10 +715,7 @@ class PipelineParallel(MetaParallelBase): for data in micro_batch_data: self._check_micro_batch_data_valid(data) elif micro_batch_data is not None: - micro_batch_size = micro_batch_data.shape[0] - assert ( - micro_batch_size == self.micro_batch_size - ), f"expected micro_batch_size {self.micro_batch_size} but get {micro_batch_size}" + assert isinstance(micro_batch_data, paddle.Tensor) def _broadcast_final_loss(self): # Since the last backward run in interleave will set the virtual rank to 0, -- GitLab