未验证 提交 b5fe3f63 编写于 作者: zhenhailiu's avatar zhenhailiu 提交者: GitHub

Pipeline model, 清理掉self.data (#54374)

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
上级 2230eda9
...@@ -30,6 +30,86 @@ from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size ...@@ -30,6 +30,86 @@ from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
__all__ = [] __all__ = []
# assume only the first stage and last stage need data, and data consumption is ordred
# to be replaced by real micro dataset from reader
class FakeMicroDataset:
def __init__(
self, data, is_first_stage, is_last_stage, acc_steps, micro_batch_size
):
self._data = data
self._index = 0
self._acc_steps = acc_steps
self._is_first_stage = is_first_stage
self._is_last_stage = is_last_stage
self._micro_batch_size = micro_batch_size
def __iter__(self):
return self
def __next__(self):
if self._index >= self._acc_steps:
raise StopIteration
assert self._is_first_stage or self._is_last_stage
micro_batch_data = self._load_micro_batch(self._index)
self._index += 1
return micro_batch_data
def _load_micro_batch(self, micro_step):
inputs = self._data
if self._is_first_stage or self._is_last_stage:
assert len(inputs) == 2, "length of input should be 2"
data = self._load_micro_batch_impl(inputs[0], micro_step)
label = self._load_micro_batch_impl(inputs[1], micro_step)
return (data, label)
else:
return (None, None)
def _load_micro_batch_impl(self, inputs, micro_step):
begin = micro_step * self._micro_batch_size
end = begin + self._micro_batch_size
if isinstance(inputs, tuple):
output = []
for data in inputs:
if isinstance(data, list):
assert (
len(data) == self._acc_steps
), "length of data should be %d, but it is %d" % (
self._acc_steps,
len(data),
)
output.append(data[micro_step].detach())
elif data is not None:
self._check_data_vaild(data)
output.append(data[begin:end, :].detach())
else:
output.append(None)
return tuple(output)
elif isinstance(inputs, list):
assert (
len(inputs) == self._acc_steps
), "length of data should be %d, but it is %d" % (
self.accumulate_steps,
len(inputs),
)
return inputs[micro_step].detach()
elif inputs is not None:
self._check_data_vaild(inputs)
return inputs[begin:end, :].detach()
else:
return None
def _check_data_vaild(self, data):
batch_size = data.shape[0]
assert self._micro_batch_size * self._acc_steps == batch_size, (
"batch_size needs to be divisible by micro_batch_size. Currently, "
"batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
% (batch_size, self._micro_batch_size, self._acc_steps)
)
class PipelineParallel(MetaParallelBase): class PipelineParallel(MetaParallelBase):
def __init__(self, layers, hcg, strategy): def __init__(self, layers, hcg, strategy):
if not isinstance(layers, PipelineLayer): if not isinstance(layers, PipelineLayer):
...@@ -233,9 +313,6 @@ class PipelineParallel(MetaParallelBase): ...@@ -233,9 +313,6 @@ class PipelineParallel(MetaParallelBase):
self.scaler = scaler self.scaler = scaler
# store data for train
self.data = data
# store total loss of entire batch # store total loss of entire batch
self.total_loss = None self.total_loss = None
...@@ -249,10 +326,12 @@ class PipelineParallel(MetaParallelBase): ...@@ -249,10 +326,12 @@ class PipelineParallel(MetaParallelBase):
input_buffers = [] input_buffers = []
output_buffers = [] output_buffers = []
micro_dataset = self._wrap_data(data)
for step_id in range(startup_steps): for step_id in range(startup_steps):
input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
output_tensor = self._forward_step(input_tensor) output_tensor = self._forward_step(input_tensor, micro_dataset)
p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
input_buffers.append(input_tensor) input_buffers.append(input_tensor)
...@@ -267,7 +346,7 @@ class PipelineParallel(MetaParallelBase): ...@@ -267,7 +346,7 @@ class PipelineParallel(MetaParallelBase):
for i in range(steady_steps): for i in range(steady_steps):
last_iter = i == (steady_steps - 1) last_iter = i == (steady_steps - 1)
output_tensor = self._forward_step(input_tensor) output_tensor = self._forward_step(input_tensor, micro_dataset)
output_tensor_grad = p2p.send_forward_recv_backward( output_tensor_grad = p2p.send_forward_recv_backward(
output_tensor, self.is_pipeline_last_stage() output_tensor, self.is_pipeline_last_stage()
...@@ -361,6 +440,22 @@ class PipelineParallel(MetaParallelBase): ...@@ -361,6 +440,22 @@ class PipelineParallel(MetaParallelBase):
return data return data
def _wrap_data(self, data):
"""
for backward compatibilty, wrap data to Fake FakeMicroDataset if it is of type list or tuple
"""
if (not isinstance(data, tuple)) and (not isinstance(data, list)):
return data
micro_dataset = FakeMicroDataset(
data,
self.is_pipeline_first_stage(ignore_virtual=True),
self.is_pipeline_last_stage(ignore_virtual=True),
self.accumulate_steps,
self.micro_batch_size,
)
return micro_dataset
def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
data = self._prepare_training(data, optimizer, lr_scheduler) data = self._prepare_training(data, optimizer, lr_scheduler)
# 1f1b scheduler for pipeline parallel # 1f1b scheduler for pipeline parallel
...@@ -379,8 +474,6 @@ class PipelineParallel(MetaParallelBase): ...@@ -379,8 +474,6 @@ class PipelineParallel(MetaParallelBase):
self._layers.eval() self._layers.eval()
self._compute_loss = compute_loss self._compute_loss = compute_loss
# save data for eval
self.data = data
# store data id for micro_batch # store data id for micro_batch
self.micro_batch_id = 0 self.micro_batch_id = 0
...@@ -394,10 +487,13 @@ class PipelineParallel(MetaParallelBase): ...@@ -394,10 +487,13 @@ class PipelineParallel(MetaParallelBase):
input_buffers = [] input_buffers = []
output_buffers = [] output_buffers = []
# convert to micro dataset
micro_dataset = self._wrap_data(data)
for step_id in range(startup_steps): for step_id in range(startup_steps):
input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
output_tensor = self._forward_step(input_tensor) output_tensor = self._forward_step(input_tensor, micro_dataset)
p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
input_buffers.append(input_tensor) input_buffers.append(input_tensor)
...@@ -409,7 +505,7 @@ class PipelineParallel(MetaParallelBase): ...@@ -409,7 +505,7 @@ class PipelineParallel(MetaParallelBase):
for i in range(steady_steps): for i in range(steady_steps):
last_iter = i == (steady_steps - 1) last_iter = i == (steady_steps - 1)
output_tensor = self._forward_step(input_tensor) output_tensor = self._forward_step(input_tensor, micro_dataset)
p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
input_buffers.append(input_tensor) input_buffers.append(input_tensor)
...@@ -425,11 +521,12 @@ class PipelineParallel(MetaParallelBase): ...@@ -425,11 +521,12 @@ class PipelineParallel(MetaParallelBase):
return self.train_loss return self.train_loss
def _forward_step(self, input_tensor, chunk_id=None): def _forward_step(self, input_tensor, micro_dataset, chunk_id=None):
if self._enable_timer: if self._enable_timer:
self.timers("forward_step").start() self.timers("forward_step").start()
if self.is_pipeline_first_stage(): if self.is_pipeline_first_stage():
input_tensor = self._load_micro_batch(self.micro_batch_id) input_tensor = next(micro_dataset)[0]
self._check_micro_batch_data_valid(input_tensor)
assert chunk_id is None or isinstance(chunk_id, int) assert chunk_id is None or isinstance(chunk_id, int)
...@@ -441,7 +538,8 @@ class PipelineParallel(MetaParallelBase): ...@@ -441,7 +538,8 @@ class PipelineParallel(MetaParallelBase):
assert ( assert (
self._layers._loss_fn is not None self._layers._loss_fn is not None
), "loss function should exist to compute loss" ), "loss function should exist to compute loss"
labels = self._load_micro_batch(self.micro_batch_id) labels = next(micro_dataset)[1]
self._check_micro_batch_data_valid(labels)
output_tensor = self._layers._loss_fn(output_tensor, labels) output_tensor = self._layers._loss_fn(output_tensor, labels)
assert isinstance( assert isinstance(
output_tensor, (paddle.Tensor, framework.core.eager.Tensor) output_tensor, (paddle.Tensor, framework.core.eager.Tensor)
...@@ -499,56 +597,15 @@ class PipelineParallel(MetaParallelBase): ...@@ -499,56 +597,15 @@ class PipelineParallel(MetaParallelBase):
self.timers("backward_step").stop() self.timers("backward_step").stop()
return input_tensor_grad return input_tensor_grad
def _check_data_vaild(self, data): def _check_micro_batch_data_valid(self, micro_batch_data):
batch_size = data.shape[0] if isinstance(micro_batch_data, (tuple, list)):
assert self.micro_batch_size * self.accumulate_steps == batch_size, ( for data in micro_batch_data:
"batch_size needs to be divisible by micro_batch_size. Currently, " self._check_micro_batch_data_valid(data)
"batch_size = %d, micro_batch_size = %d, accumulate_steps = %d." elif micro_batch_data is not None:
% (batch_size, self.micro_batch_size, self.accumulate_steps) micro_batch_size = micro_batch_data.shape[0]
)
def _load_micro_batch_impl(self, inputs, cache_id):
begin = cache_id * self.micro_batch_size
end = begin + self.micro_batch_size
if isinstance(inputs, tuple):
output = []
for data in inputs:
if isinstance(data, list):
assert (
len(data) == self.accumulate_steps
), "length of data should be %d, but it is %d" % (
self.accumulate_steps,
len(data),
)
output.append(data[cache_id].detach())
else:
self._check_data_vaild(data)
output.append(data[begin:end, :].detach())
return tuple(output)
elif isinstance(inputs, list):
assert ( assert (
len(inputs) == self.accumulate_steps micro_batch_size == self.micro_batch_size
), "length of data should be %d, but it is %d" % ( ), f"expected micro_batch_size {self.micro_batch_size} but get {micro_batch_size}"
self.accumulate_steps,
len(inputs),
)
return inputs[cache_id].detach()
else:
self._check_data_vaild(inputs)
return inputs[begin:end, :].detach()
def _load_micro_batch(self, cache_id):
inputs = self.data
if self.is_pipeline_first_stage():
assert len(inputs) == 2, "length of input should be 2"
return self._load_micro_batch_impl(inputs[0], cache_id)
elif self.is_pipeline_last_stage():
assert len(inputs) == 2, "length of input should be 2"
return self._load_micro_batch_impl(inputs[1], cache_id)
else:
inputs = None
def _broadcast_final_loss(self): def _broadcast_final_loss(self):
# Since the last backward run in interleave will set the virtual rank to 0, # Since the last backward run in interleave will set the virtual rank to 0,
...@@ -654,7 +711,7 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -654,7 +711,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1 virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1
return virtual_pp_stage return virtual_pp_stage
def _forward_step_helper(self, micro_step): def _forward_step_helper(self, micro_dataset, micro_step):
virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True) virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
self.set_virtual_pipeline_rank(virtual_pp_rank) self.set_virtual_pipeline_rank(virtual_pp_rank)
...@@ -667,7 +724,9 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -667,7 +724,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
len(self.output_tensors[virtual_pp_rank]) + 1 len(self.output_tensors[virtual_pp_rank]) + 1
) )
input_tensor = self.input_tensors[virtual_pp_rank][-1] input_tensor = self.input_tensors[virtual_pp_rank][-1]
output_tensor = self._forward_step(input_tensor, virtual_pp_rank) output_tensor = self._forward_step(
input_tensor, micro_dataset, virtual_pp_rank
)
self.output_tensors[virtual_pp_rank].append(output_tensor) self.output_tensors[virtual_pp_rank].append(output_tensor)
if self._forward_only: if self._forward_only:
...@@ -715,7 +774,6 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -715,7 +774,6 @@ class PipelineParallelWithInterleave(PipelineParallel):
# init some attributes for this batch run # init some attributes for this batch run
self.scaler = scaler self.scaler = scaler
self.data = data
self.total_loss = None self.total_loss = None
self.micro_batch_id = 0 self.micro_batch_id = 0
self._forward_only = forward_only self._forward_only = forward_only
...@@ -725,6 +783,8 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -725,6 +783,8 @@ class PipelineParallelWithInterleave(PipelineParallel):
self.output_tensors = [[] for _ in range(self.num_model_chunks)] self.output_tensors = [[] for _ in range(self.num_model_chunks)]
self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)] self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]
micro_dataset = self._wrap_data(data)
num_steps = self.accumulate_steps * self.num_model_chunks num_steps = self.accumulate_steps * self.num_model_chunks
if forward_only: if forward_only:
# If only forward, since there is no backward during running, all steps are startup steps # If only forward, since there is no backward during running, all steps are startup steps
...@@ -747,7 +807,7 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -747,7 +807,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
# run startup steps # run startup steps
for micro_step in range(startup_steps): for micro_step in range(startup_steps):
output_tensor = self._forward_step_helper(micro_step) output_tensor = self._forward_step_helper(micro_dataset, micro_step)
# determine whether recv forward tensor or not # determine whether recv forward tensor or not
next_virtual_pp_rank = self._get_virtual_pp_rank( next_virtual_pp_rank = self._get_virtual_pp_rank(
...@@ -800,7 +860,9 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -800,7 +860,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
for micro_step in range(steady_steps): for micro_step in range(steady_steps):
# forward # forward
forward_micro_step_id = micro_step + startup_steps forward_micro_step_id = micro_step + startup_steps
output_tensor = self._forward_step_helper(forward_micro_step_id) output_tensor = self._forward_step_helper(
micro_dataset, forward_micro_step_id
)
# backward # backward
backward_micro_step_id = micro_step backward_micro_step_id = micro_step
......
...@@ -17,6 +17,8 @@ import unittest ...@@ -17,6 +17,8 @@ import unittest
from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus
import paddle
class TestHybridPipeParallel(TestMultipleGpus): class TestHybridPipeParallel(TestMultipleGpus):
def test_hybrid_parallel_pp_layer(self): def test_hybrid_parallel_pp_layer(self):
...@@ -55,5 +57,52 @@ class TestHybridPipeParallel(TestMultipleGpus): ...@@ -55,5 +57,52 @@ class TestHybridPipeParallel(TestMultipleGpus):
self.run_mnist_2gpu('hybrid_parallel_pp_transformer_unbalanced_data.py') self.run_mnist_2gpu('hybrid_parallel_pp_transformer_unbalanced_data.py')
class TestFakeMicroDataSet(unittest.TestCase):
def test_fake_micro_data_set(self):
import numpy as np
from paddle.distributed.fleet.meta_parallel.pipeline_parallel import (
FakeMicroDataset,
)
batch_size = 4
micro_batch_size = 2
acc_step = 2
length = 4
x_data = np.random.randint(0, batch_size, size=[batch_size, length])
data1 = paddle.to_tensor(x_data)
data1.stop_gradient = True
data2 = [
data1[
(i * micro_batch_size) : ((i + 1) * micro_batch_size), :
].detach()
for i in range(acc_step)
]
data3 = None
batch = [(data1, data2, data3), None]
for micro_batch in FakeMicroDataset(
batch, True, False, acc_step, micro_batch_size
):
x, y = micro_batch
self.assertEqual(len(x), 3)
for e in [x[0], x[1]]:
self.assertEqual(e.shape[0], micro_batch_size)
self.assertEqual(e.shape[1], length)
self.assertTrue(x[2] is None)
self.assertTrue(y is None)
# not first stage or last stage
micro_batches = FakeMicroDataset(
batch, False, False, acc_step, micro_batch_size
)
x, y = micro_batches._load_micro_batch(0)
self.assertTrue(x is None)
self.assertTrue(y is None)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册