未验证 提交 23c24af9 编写于 作者: S ShenLiang 提交者: GitHub

fix bug of pp (#54831)

上级 51fcceb2
......@@ -56,13 +56,17 @@ class FakeMicroDataset:
def _load_micro_batch(self, micro_step):
inputs = self._data
if self._is_first_stage or self._is_last_stage:
data = None
label = None
if self._is_first_stage:
assert len(inputs) == 2, "length of input should be 2"
data = self._load_micro_batch_impl(inputs[0], micro_step)
if self._is_last_stage:
assert len(inputs) == 2, "length of input should be 2"
label = self._load_micro_batch_impl(inputs[1], micro_step)
return (data, label)
else:
return (None, None)
return (data, label)
def _load_micro_batch_impl(self, inputs, micro_step):
begin = micro_step * self._micro_batch_size
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册