diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index 6f77dbd4e07c87144619adb9fa1fd8546ca64dad..1e30467c4f722779f03d22a2a37a467190dbffb3 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -155,6 +155,11 @@ class Parallelizer: time.time() - time0, self._mode ) ) + micro_bsz = ( + 1 + if not self._strategy.pipeline.enable + else self._strategy.pipeline.micro_batch_size + ) time0 = time.time() resharder = Resharder( dist_main_prog, @@ -162,7 +167,7 @@ class Parallelizer: rank, self._dist_context, [], - 1, + micro_bsz, ) resharder.reshard() self._logger.debug(