From 7df043ec08dcef15de024e95cb83ab728831773b Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 13 Jun 2023 15:45:45 +0800 Subject: [PATCH] =?UTF-8?q?pipeline=20model=20=E7=A7=BB=E9=99=A4=20self.da?= =?UTF-8?q?ta=20(#54387)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * polish * polish * polish * polish * polish * polish --- .../fleet/meta_parallel/pipeline_parallel.py | 196 ++++++++++++------ 1 file changed, 128 insertions(+), 68 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 7dc2c6cd99f..a3f840cf770 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -30,6 +30,85 @@ from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size __all__ = [] +# assume only the first stage and last stage need data, and data consumption are ordred; +# to be replaced by real micro dataset from reader +class FakeMicroDataset: + def __init__( + self, data, is_first_stage, is_last_stage, acc_steps, micro_batch_size + ): + self._data = data + self._index = 0 + self._acc_steps = acc_steps + self._is_first_stage = is_first_stage + self._is_last_stage = is_last_stage + self._micro_batch_size = micro_batch_size + + def __iter__(self): + return self + + def __next__(self): + assert self._index < self._acc_steps + assert self._is_first_stage or self._is_last_stage + micro_batch_data = self._load_micro_batch(self._index) + self._index += 1 + return micro_batch_data + + def _load_micro_batch(self, micro_step): + inputs = self._data + + if self._is_first_stage or self._is_last_stage: + assert len(inputs) == 2, "length of input should be 2" + data = self._load_micro_batch_impl(inputs[0], micro_step) + label = self._load_micro_batch_impl(inputs[1], micro_step) + return (data, label) + else: + return (None, None) + + def _load_micro_batch_impl(self, inputs, micro_step): + begin = micro_step * self._micro_batch_size + end = begin + self._micro_batch_size + + if isinstance(inputs, tuple): + output = [] + for data in inputs: + if isinstance(data, list): + assert ( + len(data) == self._acc_steps + ), "length of data should be %d, but it is %d" % ( + self._acc_steps, + len(data), + ) + output.append(data[micro_step].detach()) + elif data is not None: + self._check_data_vaild(data) + output.append(data[begin:end, :].detach()) + else: + output.append(None) + return tuple(output) + + elif isinstance(inputs, list): + assert ( + len(inputs) == self._acc_steps + ), "length of data should be %d, but it is %d" % ( + self.accumulate_steps, + len(inputs), + ) + return inputs[micro_step].detach() + elif inputs is not None: + self._check_data_vaild(inputs) + return inputs[begin:end, :].detach() + else: + return None + + def _check_data_vaild(self, data): + batch_size = data.shape[0] + assert self._micro_batch_size * self._acc_steps == batch_size, ( + "batch_size needs to be divisible by micro_batch_size. Currently, " + "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d." + % (batch_size, self._micro_batch_size, self._acc_steps) + ) + + class PipelineParallel(MetaParallelBase): def __init__(self, layers, hcg, strategy): if not isinstance(layers, PipelineLayer): @@ -237,9 +316,6 @@ class PipelineParallel(MetaParallelBase): self.scaler = scaler - # store data for train - self.data = data - # store total loss of entire batch self.total_loss = None @@ -253,10 +329,12 @@ class PipelineParallel(MetaParallelBase): input_buffers = [] output_buffers = [] + micro_dataset = self._wrap_data(data) + for step_id in range(startup_steps): input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) - output_tensor = self._forward_step(input_tensor) + output_tensor = self._forward_step(input_tensor, micro_dataset) p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) input_buffers.append(input_tensor) @@ -271,7 +349,7 @@ class PipelineParallel(MetaParallelBase): for i in range(steady_steps): last_iter = i == (steady_steps - 1) - output_tensor = self._forward_step(input_tensor) + output_tensor = self._forward_step(input_tensor, micro_dataset) output_tensor_grad = p2p.send_forward_recv_backward( output_tensor, self.is_pipeline_last_stage() @@ -365,6 +443,22 @@ class PipelineParallel(MetaParallelBase): return data + def _wrap_data(self, data): + """ + for backward compatibilty, wrap data to Fake FakeMicroDataset if it is of type list or tuple + """ + if (not isinstance(data, tuple)) and (not isinstance(data, list)): + return data + + micro_dataset = FakeMicroDataset( + data, + self.is_pipeline_first_stage(ignore_virtual=True), + self.is_pipeline_last_stage(ignore_virtual=True), + self.accumulate_steps, + self.micro_batch_size, + ) + return micro_dataset + def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): data = self._prepare_training(data, optimizer, lr_scheduler) # 1f1b scheduler for pipeline parallel @@ -383,8 +477,6 @@ class PipelineParallel(MetaParallelBase): self._layers.eval() self._compute_loss = compute_loss - # save data for eval - self.data = data # store data id for micro_batch self.micro_batch_id = 0 @@ -398,10 +490,12 @@ class PipelineParallel(MetaParallelBase): input_buffers = [] output_buffers = [] + micro_dataset = self._wrap_data(data) + for step_id in range(startup_steps): input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) - output_tensor = self._forward_step(input_tensor) + output_tensor = self._forward_step(input_tensor, micro_dataset) p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) input_buffers.append(input_tensor) @@ -413,7 +507,7 @@ class PipelineParallel(MetaParallelBase): for i in range(steady_steps): last_iter = i == (steady_steps - 1) - output_tensor = self._forward_step(input_tensor) + output_tensor = self._forward_step(input_tensor, micro_dataset) p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) input_buffers.append(input_tensor) @@ -429,11 +523,12 @@ class PipelineParallel(MetaParallelBase): return self.train_loss - def _forward_step(self, input_tensor, chunk_id=None): + def _forward_step(self, input_tensor, micro_dataset, chunk_id=None): if self._enable_timer: self.timers("forward_step").start() if self.is_pipeline_first_stage(): - input_tensor = self._load_micro_batch(self.micro_batch_id) + input_tensor = next(micro_dataset)[0] + self._check_micro_batch_data_valid(input_tensor) assert chunk_id is None or isinstance(chunk_id, int) @@ -445,7 +540,8 @@ class PipelineParallel(MetaParallelBase): assert ( self._layers._loss_fn is not None ), "loss function should exist to compute loss" - labels = self._load_micro_batch(self.micro_batch_id) + labels = next(micro_dataset)[1] + self._check_micro_batch_data_valid(labels) output_tensor = self._layers._loss_fn(output_tensor, labels) assert isinstance( output_tensor, (paddle.Tensor, framework.core.eager.Tensor) @@ -467,6 +563,16 @@ class PipelineParallel(MetaParallelBase): self.timers("forward_step").stop() return output_tensor + def _check_micro_batch_data_valid(self, micro_batch_data): + if isinstance(micro_batch_data, (tuple, list)): + for data in micro_batch_data: + self._check_micro_batch_data_valid(data) + elif micro_batch_data is not None: + micro_batch_size = micro_batch_data.shape[0] + assert ( + micro_batch_size == self.micro_batch_size + ), f"expected micro_batch_size {self.micro_batch_size} but get {micro_batch_size}" + def _backward_step(self, input_tensor, output_tensor, output_tensor_grad): if self._enable_timer: self.timers("backward_step").start() @@ -503,57 +609,6 @@ class PipelineParallel(MetaParallelBase): self.timers("backward_step").stop() return input_tensor_grad - def _check_data_vaild(self, data): - batch_size = data.shape[0] - assert self.micro_batch_size * self.accumulate_steps == batch_size, ( - "batch_size needs to be divisible by micro_batch_size. Currently, " - "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d." - % (batch_size, self.micro_batch_size, self.accumulate_steps) - ) - - def _load_micro_batch_impl(self, inputs, cache_id): - begin = cache_id * self.micro_batch_size - end = begin + self.micro_batch_size - - if isinstance(inputs, tuple): - output = [] - for data in inputs: - if isinstance(data, list): - assert ( - len(data) == self.accumulate_steps - ), "length of data should be %d, but it is %d" % ( - self.accumulate_steps, - len(data), - ) - output.append(data[cache_id].detach()) - else: - self._check_data_vaild(data) - output.append(data[begin:end, :].detach()) - return tuple(output) - - elif isinstance(inputs, list): - assert ( - len(inputs) == self.accumulate_steps - ), "length of data should be %d, but it is %d" % ( - self.accumulate_steps, - len(inputs), - ) - return inputs[cache_id].detach() - else: - self._check_data_vaild(inputs) - return inputs[begin:end, :].detach() - - def _load_micro_batch(self, cache_id): - inputs = self.data - if self.is_pipeline_first_stage(): - assert len(inputs) == 2, "length of input should be 2" - return self._load_micro_batch_impl(inputs[0], cache_id) - elif self.is_pipeline_last_stage(): - assert len(inputs) == 2, "length of input should be 2" - return self._load_micro_batch_impl(inputs[1], cache_id) - else: - inputs = None - def _broadcast_final_loss(self): # Since the last backward run in interleave will set the virtual rank to 0, # here we need to check last stage ignoring virtual stage. @@ -658,7 +713,7 @@ class PipelineParallelWithInterleave(PipelineParallel): virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1 return virtual_pp_stage - def _forward_step_helper(self, micro_step): + def _forward_step_helper(self, micro_dataset, micro_step): virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True) self.set_virtual_pipeline_rank(virtual_pp_rank) @@ -674,7 +729,9 @@ class PipelineParallelWithInterleave(PipelineParallel): ): self.input_tensors[virtual_pp_rank].append(None) input_tensor = self.input_tensors[virtual_pp_rank][-1] - output_tensor = self._forward_step(input_tensor, virtual_pp_rank) + output_tensor = self._forward_step( + input_tensor, micro_dataset, virtual_pp_rank + ) self.output_tensors[virtual_pp_rank].append(output_tensor) if self._forward_only: @@ -719,7 +776,6 @@ class PipelineParallelWithInterleave(PipelineParallel): # init some attributes for this batch run self.scaler = scaler - self.data = data self.total_loss = None self.micro_batch_id = 0 self._forward_only = forward_only @@ -729,6 +785,8 @@ class PipelineParallelWithInterleave(PipelineParallel): self.output_tensors = [[] for _ in range(self.num_model_chunks)] self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)] + micro_dataset = self._wrap_data(data) + num_steps = self.accumulate_steps * self.num_model_chunks all_startup_steps = False if forward_only: @@ -752,7 +810,7 @@ class PipelineParallelWithInterleave(PipelineParallel): # run startup steps for micro_step in range(startup_steps): - output_tensor = self._forward_step_helper(micro_step) + output_tensor = self._forward_step_helper(micro_dataset, micro_step) # determine whether recv forward tensor or not next_virtual_pp_rank = self._get_virtual_pp_rank( @@ -806,7 +864,9 @@ class PipelineParallelWithInterleave(PipelineParallel): for micro_step in range(steady_steps): # forward forward_micro_step_id = micro_step + startup_steps - output_tensor = self._forward_step_helper(forward_micro_step_id) + output_tensor = self._forward_step_helper( + micro_dataset, forward_micro_step_id + ) # backward backward_micro_step_id = micro_step -- GitLab