未验证 提交 7df043ec 编写于 作者: zhenhailiu's avatar zhenhailiu 提交者: GitHub

pipeline model 移除 self.data (#54387)

* polish

* polish

* polish

* polish

* polish

* polish
上级 161dad50
...@@ -30,6 +30,85 @@ from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size ...@@ -30,6 +30,85 @@ from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
__all__ = [] __all__ = []
# assume only the first stage and last stage need data, and data consumption are ordred;
# to be replaced by real micro dataset from reader
class FakeMicroDataset:
def __init__(
self, data, is_first_stage, is_last_stage, acc_steps, micro_batch_size
):
self._data = data
self._index = 0
self._acc_steps = acc_steps
self._is_first_stage = is_first_stage
self._is_last_stage = is_last_stage
self._micro_batch_size = micro_batch_size
def __iter__(self):
return self
def __next__(self):
assert self._index < self._acc_steps
assert self._is_first_stage or self._is_last_stage
micro_batch_data = self._load_micro_batch(self._index)
self._index += 1
return micro_batch_data
def _load_micro_batch(self, micro_step):
inputs = self._data
if self._is_first_stage or self._is_last_stage:
assert len(inputs) == 2, "length of input should be 2"
data = self._load_micro_batch_impl(inputs[0], micro_step)
label = self._load_micro_batch_impl(inputs[1], micro_step)
return (data, label)
else:
return (None, None)
def _load_micro_batch_impl(self, inputs, micro_step):
begin = micro_step * self._micro_batch_size
end = begin + self._micro_batch_size
if isinstance(inputs, tuple):
output = []
for data in inputs:
if isinstance(data, list):
assert (
len(data) == self._acc_steps
), "length of data should be %d, but it is %d" % (
self._acc_steps,
len(data),
)
output.append(data[micro_step].detach())
elif data is not None:
self._check_data_vaild(data)
output.append(data[begin:end, :].detach())
else:
output.append(None)
return tuple(output)
elif isinstance(inputs, list):
assert (
len(inputs) == self._acc_steps
), "length of data should be %d, but it is %d" % (
self.accumulate_steps,
len(inputs),
)
return inputs[micro_step].detach()
elif inputs is not None:
self._check_data_vaild(inputs)
return inputs[begin:end, :].detach()
else:
return None
def _check_data_vaild(self, data):
batch_size = data.shape[0]
assert self._micro_batch_size * self._acc_steps == batch_size, (
"batch_size needs to be divisible by micro_batch_size. Currently, "
"batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
% (batch_size, self._micro_batch_size, self._acc_steps)
)
class PipelineParallel(MetaParallelBase): class PipelineParallel(MetaParallelBase):
def __init__(self, layers, hcg, strategy): def __init__(self, layers, hcg, strategy):
if not isinstance(layers, PipelineLayer): if not isinstance(layers, PipelineLayer):
...@@ -237,9 +316,6 @@ class PipelineParallel(MetaParallelBase): ...@@ -237,9 +316,6 @@ class PipelineParallel(MetaParallelBase):
self.scaler = scaler self.scaler = scaler
# store data for train
self.data = data
# store total loss of entire batch # store total loss of entire batch
self.total_loss = None self.total_loss = None
...@@ -253,10 +329,12 @@ class PipelineParallel(MetaParallelBase): ...@@ -253,10 +329,12 @@ class PipelineParallel(MetaParallelBase):
input_buffers = [] input_buffers = []
output_buffers = [] output_buffers = []
micro_dataset = self._wrap_data(data)
for step_id in range(startup_steps): for step_id in range(startup_steps):
input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
output_tensor = self._forward_step(input_tensor) output_tensor = self._forward_step(input_tensor, micro_dataset)
p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
input_buffers.append(input_tensor) input_buffers.append(input_tensor)
...@@ -271,7 +349,7 @@ class PipelineParallel(MetaParallelBase): ...@@ -271,7 +349,7 @@ class PipelineParallel(MetaParallelBase):
for i in range(steady_steps): for i in range(steady_steps):
last_iter = i == (steady_steps - 1) last_iter = i == (steady_steps - 1)
output_tensor = self._forward_step(input_tensor) output_tensor = self._forward_step(input_tensor, micro_dataset)
output_tensor_grad = p2p.send_forward_recv_backward( output_tensor_grad = p2p.send_forward_recv_backward(
output_tensor, self.is_pipeline_last_stage() output_tensor, self.is_pipeline_last_stage()
...@@ -365,6 +443,22 @@ class PipelineParallel(MetaParallelBase): ...@@ -365,6 +443,22 @@ class PipelineParallel(MetaParallelBase):
return data return data
def _wrap_data(self, data):
"""
for backward compatibilty, wrap data to Fake FakeMicroDataset if it is of type list or tuple
"""
if (not isinstance(data, tuple)) and (not isinstance(data, list)):
return data
micro_dataset = FakeMicroDataset(
data,
self.is_pipeline_first_stage(ignore_virtual=True),
self.is_pipeline_last_stage(ignore_virtual=True),
self.accumulate_steps,
self.micro_batch_size,
)
return micro_dataset
def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
data = self._prepare_training(data, optimizer, lr_scheduler) data = self._prepare_training(data, optimizer, lr_scheduler)
# 1f1b scheduler for pipeline parallel # 1f1b scheduler for pipeline parallel
...@@ -383,8 +477,6 @@ class PipelineParallel(MetaParallelBase): ...@@ -383,8 +477,6 @@ class PipelineParallel(MetaParallelBase):
self._layers.eval() self._layers.eval()
self._compute_loss = compute_loss self._compute_loss = compute_loss
# save data for eval
self.data = data
# store data id for micro_batch # store data id for micro_batch
self.micro_batch_id = 0 self.micro_batch_id = 0
...@@ -398,10 +490,12 @@ class PipelineParallel(MetaParallelBase): ...@@ -398,10 +490,12 @@ class PipelineParallel(MetaParallelBase):
input_buffers = [] input_buffers = []
output_buffers = [] output_buffers = []
micro_dataset = self._wrap_data(data)
for step_id in range(startup_steps): for step_id in range(startup_steps):
input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
output_tensor = self._forward_step(input_tensor) output_tensor = self._forward_step(input_tensor, micro_dataset)
p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
input_buffers.append(input_tensor) input_buffers.append(input_tensor)
...@@ -413,7 +507,7 @@ class PipelineParallel(MetaParallelBase): ...@@ -413,7 +507,7 @@ class PipelineParallel(MetaParallelBase):
for i in range(steady_steps): for i in range(steady_steps):
last_iter = i == (steady_steps - 1) last_iter = i == (steady_steps - 1)
output_tensor = self._forward_step(input_tensor) output_tensor = self._forward_step(input_tensor, micro_dataset)
p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
input_buffers.append(input_tensor) input_buffers.append(input_tensor)
...@@ -429,11 +523,12 @@ class PipelineParallel(MetaParallelBase): ...@@ -429,11 +523,12 @@ class PipelineParallel(MetaParallelBase):
return self.train_loss return self.train_loss
def _forward_step(self, input_tensor, chunk_id=None): def _forward_step(self, input_tensor, micro_dataset, chunk_id=None):
if self._enable_timer: if self._enable_timer:
self.timers("forward_step").start() self.timers("forward_step").start()
if self.is_pipeline_first_stage(): if self.is_pipeline_first_stage():
input_tensor = self._load_micro_batch(self.micro_batch_id) input_tensor = next(micro_dataset)[0]
self._check_micro_batch_data_valid(input_tensor)
assert chunk_id is None or isinstance(chunk_id, int) assert chunk_id is None or isinstance(chunk_id, int)
...@@ -445,7 +540,8 @@ class PipelineParallel(MetaParallelBase): ...@@ -445,7 +540,8 @@ class PipelineParallel(MetaParallelBase):
assert ( assert (
self._layers._loss_fn is not None self._layers._loss_fn is not None
), "loss function should exist to compute loss" ), "loss function should exist to compute loss"
labels = self._load_micro_batch(self.micro_batch_id) labels = next(micro_dataset)[1]
self._check_micro_batch_data_valid(labels)
output_tensor = self._layers._loss_fn(output_tensor, labels) output_tensor = self._layers._loss_fn(output_tensor, labels)
assert isinstance( assert isinstance(
output_tensor, (paddle.Tensor, framework.core.eager.Tensor) output_tensor, (paddle.Tensor, framework.core.eager.Tensor)
...@@ -467,6 +563,16 @@ class PipelineParallel(MetaParallelBase): ...@@ -467,6 +563,16 @@ class PipelineParallel(MetaParallelBase):
self.timers("forward_step").stop() self.timers("forward_step").stop()
return output_tensor return output_tensor
def _check_micro_batch_data_valid(self, micro_batch_data):
if isinstance(micro_batch_data, (tuple, list)):
for data in micro_batch_data:
self._check_micro_batch_data_valid(data)
elif micro_batch_data is not None:
micro_batch_size = micro_batch_data.shape[0]
assert (
micro_batch_size == self.micro_batch_size
), f"expected micro_batch_size {self.micro_batch_size} but get {micro_batch_size}"
def _backward_step(self, input_tensor, output_tensor, output_tensor_grad): def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
if self._enable_timer: if self._enable_timer:
self.timers("backward_step").start() self.timers("backward_step").start()
...@@ -503,57 +609,6 @@ class PipelineParallel(MetaParallelBase): ...@@ -503,57 +609,6 @@ class PipelineParallel(MetaParallelBase):
self.timers("backward_step").stop() self.timers("backward_step").stop()
return input_tensor_grad return input_tensor_grad
def _check_data_vaild(self, data):
batch_size = data.shape[0]
assert self.micro_batch_size * self.accumulate_steps == batch_size, (
"batch_size needs to be divisible by micro_batch_size. Currently, "
"batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
% (batch_size, self.micro_batch_size, self.accumulate_steps)
)
def _load_micro_batch_impl(self, inputs, cache_id):
begin = cache_id * self.micro_batch_size
end = begin + self.micro_batch_size
if isinstance(inputs, tuple):
output = []
for data in inputs:
if isinstance(data, list):
assert (
len(data) == self.accumulate_steps
), "length of data should be %d, but it is %d" % (
self.accumulate_steps,
len(data),
)
output.append(data[cache_id].detach())
else:
self._check_data_vaild(data)
output.append(data[begin:end, :].detach())
return tuple(output)
elif isinstance(inputs, list):
assert (
len(inputs) == self.accumulate_steps
), "length of data should be %d, but it is %d" % (
self.accumulate_steps,
len(inputs),
)
return inputs[cache_id].detach()
else:
self._check_data_vaild(inputs)
return inputs[begin:end, :].detach()
def _load_micro_batch(self, cache_id):
inputs = self.data
if self.is_pipeline_first_stage():
assert len(inputs) == 2, "length of input should be 2"
return self._load_micro_batch_impl(inputs[0], cache_id)
elif self.is_pipeline_last_stage():
assert len(inputs) == 2, "length of input should be 2"
return self._load_micro_batch_impl(inputs[1], cache_id)
else:
inputs = None
def _broadcast_final_loss(self): def _broadcast_final_loss(self):
# Since the last backward run in interleave will set the virtual rank to 0, # Since the last backward run in interleave will set the virtual rank to 0,
# here we need to check last stage ignoring virtual stage. # here we need to check last stage ignoring virtual stage.
...@@ -658,7 +713,7 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -658,7 +713,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1 virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1
return virtual_pp_stage return virtual_pp_stage
def _forward_step_helper(self, micro_step): def _forward_step_helper(self, micro_dataset, micro_step):
virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True) virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
self.set_virtual_pipeline_rank(virtual_pp_rank) self.set_virtual_pipeline_rank(virtual_pp_rank)
...@@ -674,7 +729,9 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -674,7 +729,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
): ):
self.input_tensors[virtual_pp_rank].append(None) self.input_tensors[virtual_pp_rank].append(None)
input_tensor = self.input_tensors[virtual_pp_rank][-1] input_tensor = self.input_tensors[virtual_pp_rank][-1]
output_tensor = self._forward_step(input_tensor, virtual_pp_rank) output_tensor = self._forward_step(
input_tensor, micro_dataset, virtual_pp_rank
)
self.output_tensors[virtual_pp_rank].append(output_tensor) self.output_tensors[virtual_pp_rank].append(output_tensor)
if self._forward_only: if self._forward_only:
...@@ -719,7 +776,6 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -719,7 +776,6 @@ class PipelineParallelWithInterleave(PipelineParallel):
# init some attributes for this batch run # init some attributes for this batch run
self.scaler = scaler self.scaler = scaler
self.data = data
self.total_loss = None self.total_loss = None
self.micro_batch_id = 0 self.micro_batch_id = 0
self._forward_only = forward_only self._forward_only = forward_only
...@@ -729,6 +785,8 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -729,6 +785,8 @@ class PipelineParallelWithInterleave(PipelineParallel):
self.output_tensors = [[] for _ in range(self.num_model_chunks)] self.output_tensors = [[] for _ in range(self.num_model_chunks)]
self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)] self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]
micro_dataset = self._wrap_data(data)
num_steps = self.accumulate_steps * self.num_model_chunks num_steps = self.accumulate_steps * self.num_model_chunks
all_startup_steps = False all_startup_steps = False
if forward_only: if forward_only:
...@@ -752,7 +810,7 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -752,7 +810,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
# run startup steps # run startup steps
for micro_step in range(startup_steps): for micro_step in range(startup_steps):
output_tensor = self._forward_step_helper(micro_step) output_tensor = self._forward_step_helper(micro_dataset, micro_step)
# determine whether recv forward tensor or not # determine whether recv forward tensor or not
next_virtual_pp_rank = self._get_virtual_pp_rank( next_virtual_pp_rank = self._get_virtual_pp_rank(
...@@ -806,7 +864,9 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -806,7 +864,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
for micro_step in range(steady_steps): for micro_step in range(steady_steps):
# forward # forward
forward_micro_step_id = micro_step + startup_steps forward_micro_step_id = micro_step + startup_steps
output_tensor = self._forward_step_helper(forward_micro_step_id) output_tensor = self._forward_step_helper(
micro_dataset, forward_micro_step_id
)
# backward # backward
backward_micro_step_id = micro_step backward_micro_step_id = micro_step
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册