未验证 提交 802613cc 编写于 作者: zhenhailiu's avatar zhenhailiu 提交者: GitHub

relax_micro_batch_check (#54788)

上级 772d9481
......@@ -715,10 +715,7 @@ class PipelineParallel(MetaParallelBase):
for data in micro_batch_data:
self._check_micro_batch_data_valid(data)
elif micro_batch_data is not None:
micro_batch_size = micro_batch_data.shape[0]
assert (
micro_batch_size == self.micro_batch_size
), f"expected micro_batch_size {self.micro_batch_size} but get {micro_batch_size}"
assert isinstance(micro_batch_data, paddle.Tensor)
def _broadcast_final_loss(self):
# Since the last backward run in interleave will set the virtual rank to 0,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册