未验证 提交 23c24af9 编写于 作者: S ShenLiang 提交者: GitHub

fix bug of pp (#54831)

上级 51fcceb2
...@@ -56,13 +56,17 @@ class FakeMicroDataset: ...@@ -56,13 +56,17 @@ class FakeMicroDataset:
def _load_micro_batch(self, micro_step): def _load_micro_batch(self, micro_step):
inputs = self._data inputs = self._data
if self._is_first_stage or self._is_last_stage: data = None
label = None
if self._is_first_stage:
assert len(inputs) == 2, "length of input should be 2" assert len(inputs) == 2, "length of input should be 2"
data = self._load_micro_batch_impl(inputs[0], micro_step) data = self._load_micro_batch_impl(inputs[0], micro_step)
if self._is_last_stage:
assert len(inputs) == 2, "length of input should be 2"
label = self._load_micro_batch_impl(inputs[1], micro_step) label = self._load_micro_batch_impl(inputs[1], micro_step)
return (data, label)
else: return (data, label)
return (None, None)
def _load_micro_batch_impl(self, inputs, micro_step): def _load_micro_batch_impl(self, inputs, micro_step):
begin = micro_step * self._micro_batch_size begin = micro_step * self._micro_batch_size
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册