From b5fe3f63aadf9fc74c997979b7c596e3045454b2 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 13 Jun 2023 15:44:35 +0800 Subject: [PATCH] =?UTF-8?q?Pipeline=20model,=20=E6=B8=85=E7=90=86=E6=8E=89?= =?UTF-8?q?self.data=20=20(#54374)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * polish * polish * polish * polish * polish * polish * polish * polish * polish * polish * polish * polish --- .../fleet/meta_parallel/pipeline_parallel.py | 194 ++++++++++++------ ...test_parallel_dygraph_pipeline_parallel.py | 49 +++++ 2 files changed, 177 insertions(+), 66 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 7cd0bf19c0d..13788f7e165 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -30,6 +30,86 @@ from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size __all__ = [] +# assume only the first stage and last stage need data, and data consumption is ordred +# to be replaced by real micro dataset from reader +class FakeMicroDataset: + def __init__( + self, data, is_first_stage, is_last_stage, acc_steps, micro_batch_size + ): + self._data = data + self._index = 0 + self._acc_steps = acc_steps + self._is_first_stage = is_first_stage + self._is_last_stage = is_last_stage + self._micro_batch_size = micro_batch_size + + def __iter__(self): + return self + + def __next__(self): + if self._index >= self._acc_steps: + raise StopIteration + assert self._is_first_stage or self._is_last_stage + micro_batch_data = self._load_micro_batch(self._index) + self._index += 1 + return micro_batch_data + + def _load_micro_batch(self, micro_step): + inputs = self._data + + if self._is_first_stage or self._is_last_stage: + assert len(inputs) == 2, "length of input should be 2" + data = self._load_micro_batch_impl(inputs[0], micro_step) + label = self._load_micro_batch_impl(inputs[1], micro_step) + return (data, label) + else: + return (None, None) + + def _load_micro_batch_impl(self, inputs, micro_step): + begin = micro_step * self._micro_batch_size + end = begin + self._micro_batch_size + + if isinstance(inputs, tuple): + output = [] + for data in inputs: + if isinstance(data, list): + assert ( + len(data) == self._acc_steps + ), "length of data should be %d, but it is %d" % ( + self._acc_steps, + len(data), + ) + output.append(data[micro_step].detach()) + elif data is not None: + self._check_data_vaild(data) + output.append(data[begin:end, :].detach()) + else: + output.append(None) + return tuple(output) + + elif isinstance(inputs, list): + assert ( + len(inputs) == self._acc_steps + ), "length of data should be %d, but it is %d" % ( + self.accumulate_steps, + len(inputs), + ) + return inputs[micro_step].detach() + elif inputs is not None: + self._check_data_vaild(inputs) + return inputs[begin:end, :].detach() + else: + return None + + def _check_data_vaild(self, data): + batch_size = data.shape[0] + assert self._micro_batch_size * self._acc_steps == batch_size, ( + "batch_size needs to be divisible by micro_batch_size. Currently, " + "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d." + % (batch_size, self._micro_batch_size, self._acc_steps) + ) + + class PipelineParallel(MetaParallelBase): def __init__(self, layers, hcg, strategy): if not isinstance(layers, PipelineLayer): @@ -233,9 +313,6 @@ class PipelineParallel(MetaParallelBase): self.scaler = scaler - # store data for train - self.data = data - # store total loss of entire batch self.total_loss = None @@ -249,10 +326,12 @@ class PipelineParallel(MetaParallelBase): input_buffers = [] output_buffers = [] + micro_dataset = self._wrap_data(data) + for step_id in range(startup_steps): input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) - output_tensor = self._forward_step(input_tensor) + output_tensor = self._forward_step(input_tensor, micro_dataset) p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) input_buffers.append(input_tensor) @@ -267,7 +346,7 @@ class PipelineParallel(MetaParallelBase): for i in range(steady_steps): last_iter = i == (steady_steps - 1) - output_tensor = self._forward_step(input_tensor) + output_tensor = self._forward_step(input_tensor, micro_dataset) output_tensor_grad = p2p.send_forward_recv_backward( output_tensor, self.is_pipeline_last_stage() @@ -361,6 +440,22 @@ class PipelineParallel(MetaParallelBase): return data + def _wrap_data(self, data): + """ + for backward compatibilty, wrap data to Fake FakeMicroDataset if it is of type list or tuple + """ + if (not isinstance(data, tuple)) and (not isinstance(data, list)): + return data + + micro_dataset = FakeMicroDataset( + data, + self.is_pipeline_first_stage(ignore_virtual=True), + self.is_pipeline_last_stage(ignore_virtual=True), + self.accumulate_steps, + self.micro_batch_size, + ) + return micro_dataset + def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): data = self._prepare_training(data, optimizer, lr_scheduler) # 1f1b scheduler for pipeline parallel @@ -379,8 +474,6 @@ class PipelineParallel(MetaParallelBase): self._layers.eval() self._compute_loss = compute_loss - # save data for eval - self.data = data # store data id for micro_batch self.micro_batch_id = 0 @@ -394,10 +487,13 @@ class PipelineParallel(MetaParallelBase): input_buffers = [] output_buffers = [] + # convert to micro dataset + micro_dataset = self._wrap_data(data) + for step_id in range(startup_steps): input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) - output_tensor = self._forward_step(input_tensor) + output_tensor = self._forward_step(input_tensor, micro_dataset) p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) input_buffers.append(input_tensor) @@ -409,7 +505,7 @@ class PipelineParallel(MetaParallelBase): for i in range(steady_steps): last_iter = i == (steady_steps - 1) - output_tensor = self._forward_step(input_tensor) + output_tensor = self._forward_step(input_tensor, micro_dataset) p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) input_buffers.append(input_tensor) @@ -425,11 +521,12 @@ class PipelineParallel(MetaParallelBase): return self.train_loss - def _forward_step(self, input_tensor, chunk_id=None): + def _forward_step(self, input_tensor, micro_dataset, chunk_id=None): if self._enable_timer: self.timers("forward_step").start() if self.is_pipeline_first_stage(): - input_tensor = self._load_micro_batch(self.micro_batch_id) + input_tensor = next(micro_dataset)[0] + self._check_micro_batch_data_valid(input_tensor) assert chunk_id is None or isinstance(chunk_id, int) @@ -441,7 +538,8 @@ class PipelineParallel(MetaParallelBase): assert ( self._layers._loss_fn is not None ), "loss function should exist to compute loss" - labels = self._load_micro_batch(self.micro_batch_id) + labels = next(micro_dataset)[1] + self._check_micro_batch_data_valid(labels) output_tensor = self._layers._loss_fn(output_tensor, labels) assert isinstance( output_tensor, (paddle.Tensor, framework.core.eager.Tensor) @@ -499,56 +597,15 @@ class PipelineParallel(MetaParallelBase): self.timers("backward_step").stop() return input_tensor_grad - def _check_data_vaild(self, data): - batch_size = data.shape[0] - assert self.micro_batch_size * self.accumulate_steps == batch_size, ( - "batch_size needs to be divisible by micro_batch_size. Currently, " - "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d." - % (batch_size, self.micro_batch_size, self.accumulate_steps) - ) - - def _load_micro_batch_impl(self, inputs, cache_id): - begin = cache_id * self.micro_batch_size - end = begin + self.micro_batch_size - - if isinstance(inputs, tuple): - output = [] - for data in inputs: - if isinstance(data, list): - assert ( - len(data) == self.accumulate_steps - ), "length of data should be %d, but it is %d" % ( - self.accumulate_steps, - len(data), - ) - output.append(data[cache_id].detach()) - else: - self._check_data_vaild(data) - output.append(data[begin:end, :].detach()) - return tuple(output) - - elif isinstance(inputs, list): + def _check_micro_batch_data_valid(self, micro_batch_data): + if isinstance(micro_batch_data, (tuple, list)): + for data in micro_batch_data: + self._check_micro_batch_data_valid(data) + elif micro_batch_data is not None: + micro_batch_size = micro_batch_data.shape[0] assert ( - len(inputs) == self.accumulate_steps - ), "length of data should be %d, but it is %d" % ( - self.accumulate_steps, - len(inputs), - ) - return inputs[cache_id].detach() - else: - self._check_data_vaild(inputs) - return inputs[begin:end, :].detach() - - def _load_micro_batch(self, cache_id): - inputs = self.data - if self.is_pipeline_first_stage(): - assert len(inputs) == 2, "length of input should be 2" - return self._load_micro_batch_impl(inputs[0], cache_id) - elif self.is_pipeline_last_stage(): - assert len(inputs) == 2, "length of input should be 2" - return self._load_micro_batch_impl(inputs[1], cache_id) - else: - inputs = None + micro_batch_size == self.micro_batch_size + ), f"expected micro_batch_size {self.micro_batch_size} but get {micro_batch_size}" def _broadcast_final_loss(self): # Since the last backward run in interleave will set the virtual rank to 0, @@ -654,7 +711,7 @@ class PipelineParallelWithInterleave(PipelineParallel): virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1 return virtual_pp_stage - def _forward_step_helper(self, micro_step): + def _forward_step_helper(self, micro_dataset, micro_step): virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True) self.set_virtual_pipeline_rank(virtual_pp_rank) @@ -667,7 +724,9 @@ class PipelineParallelWithInterleave(PipelineParallel): len(self.output_tensors[virtual_pp_rank]) + 1 ) input_tensor = self.input_tensors[virtual_pp_rank][-1] - output_tensor = self._forward_step(input_tensor, virtual_pp_rank) + output_tensor = self._forward_step( + input_tensor, micro_dataset, virtual_pp_rank + ) self.output_tensors[virtual_pp_rank].append(output_tensor) if self._forward_only: @@ -715,7 +774,6 @@ class PipelineParallelWithInterleave(PipelineParallel): # init some attributes for this batch run self.scaler = scaler - self.data = data self.total_loss = None self.micro_batch_id = 0 self._forward_only = forward_only @@ -725,6 +783,8 @@ class PipelineParallelWithInterleave(PipelineParallel): self.output_tensors = [[] for _ in range(self.num_model_chunks)] self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)] + micro_dataset = self._wrap_data(data) + num_steps = self.accumulate_steps * self.num_model_chunks if forward_only: # If only forward, since there is no backward during running, all steps are startup steps @@ -747,7 +807,7 @@ class PipelineParallelWithInterleave(PipelineParallel): # run startup steps for micro_step in range(startup_steps): - output_tensor = self._forward_step_helper(micro_step) + output_tensor = self._forward_step_helper(micro_dataset, micro_step) # determine whether recv forward tensor or not next_virtual_pp_rank = self._get_virtual_pp_rank( @@ -800,7 +860,9 @@ class PipelineParallelWithInterleave(PipelineParallel): for micro_step in range(steady_steps): # forward forward_micro_step_id = micro_step + startup_steps - output_tensor = self._forward_step_helper(forward_micro_step_id) + output_tensor = self._forward_step_helper( + micro_dataset, forward_micro_step_id + ) # backward backward_micro_step_id = micro_step diff --git a/test/collective/fleet/test_parallel_dygraph_pipeline_parallel.py b/test/collective/fleet/test_parallel_dygraph_pipeline_parallel.py index 20982a2c146..4ca31781c95 100644 --- a/test/collective/fleet/test_parallel_dygraph_pipeline_parallel.py +++ b/test/collective/fleet/test_parallel_dygraph_pipeline_parallel.py @@ -17,6 +17,8 @@ import unittest from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus +import paddle + class TestHybridPipeParallel(TestMultipleGpus): def test_hybrid_parallel_pp_layer(self): @@ -55,5 +57,52 @@ class TestHybridPipeParallel(TestMultipleGpus): self.run_mnist_2gpu('hybrid_parallel_pp_transformer_unbalanced_data.py') +class TestFakeMicroDataSet(unittest.TestCase): + def test_fake_micro_data_set(self): + import numpy as np + + from paddle.distributed.fleet.meta_parallel.pipeline_parallel import ( + FakeMicroDataset, + ) + + batch_size = 4 + micro_batch_size = 2 + acc_step = 2 + length = 4 + x_data = np.random.randint(0, batch_size, size=[batch_size, length]) + data1 = paddle.to_tensor(x_data) + data1.stop_gradient = True + + data2 = [ + data1[ + (i * micro_batch_size) : ((i + 1) * micro_batch_size), : + ].detach() + for i in range(acc_step) + ] + + data3 = None + + batch = [(data1, data2, data3), None] + + for micro_batch in FakeMicroDataset( + batch, True, False, acc_step, micro_batch_size + ): + x, y = micro_batch + self.assertEqual(len(x), 3) + for e in [x[0], x[1]]: + self.assertEqual(e.shape[0], micro_batch_size) + self.assertEqual(e.shape[1], length) + self.assertTrue(x[2] is None) + self.assertTrue(y is None) + + # not first stage or last stage + micro_batches = FakeMicroDataset( + batch, False, False, acc_step, micro_batch_size + ) + x, y = micro_batches._load_micro_batch(0) + self.assertTrue(x is None) + self.assertTrue(y is None) + + if __name__ == "__main__": unittest.main() -- GitLab