From a0f4ac54ee03e8b1197b6c44b43abd5db49c0c78 Mon Sep 17 00:00:00 2001 From: lilong12 Date: Mon, 3 May 2021 21:34:54 +0800 Subject: [PATCH] Fix the bug in pipeline for dygraph mode (#32716) * update, test=develop --- .../parallel_layers/pp_layers.py | 1 - .../fleet/meta_parallel/pipeline_parallel.py | 342 ++++++++++-------- .../fleet/meta_parallel/pp_utils/utils.py | 43 ++- 3 files changed, 231 insertions(+), 155 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 669ed032a34..a9704e38f3f 100644 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -108,7 +108,6 @@ class PipelineLayer(Layer): # construct layer self.run_function = [] self._build_layer() - self.to(paddle.CUDAPlace(self.device_id)) def _segment_network(self, seg_method): logger.info("start segment network..") diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 98a82f2b798..11180054afb 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -22,15 +22,11 @@ from numpy import prod import paddle import paddle.fluid as fluid from .meta_parallel_base import MetaParallelBase -from .pp_utils.utils import get_tensor_bytes +from .pp_utils.utils import get_tensor_bytes, is_float_tensor from .pp_utils import utils from .parallel_layers.pp_layers import PipelineLayer - -FLOAT_TYPES = [ - paddle.float16, - paddle.float32, - paddle.float64, -] +from ..utils.hybrid_parallel_util import * +from ..utils.log_util import logger class PipelineParallel(MetaParallelBase): @@ -46,20 +42,18 @@ class PipelineParallel(MetaParallelBase): 'inputs': [], 'labels': [], 'outputs': [], - 'backward_tensors': [], } + self.recv_cache = None self.grad_tensors = None - self.meta_buffer = None - self.send_meta = True - self.first_gradient_send = True self.current_loss = paddle.to_tensor(0.0) self.total_loss = None - def _prepare_for_model(self): + self.use_amp = self._strategy.amp + self.init_loss_scaling = self._strategy.amp_configs['init_loss_scaling'] self.micro_batch_size = self._strategy.pipeline_configs[ 'micro_batch_size'] self.accumulate_steps = self._strategy.pipeline_configs[ @@ -69,9 +63,17 @@ class PipelineParallel(MetaParallelBase): self.stage_id = self._hcg.get_stage_id() self.prev_stage_id = self.stage_id - 1 self.next_stage_id = self.stage_id + 1 - self._layers = PipelineLayer( - layers=self._layers, num_stages=self.num_stages) - #TODO: init process group + self.pp_group = self._hcg.get_pipe_parallel_group() + logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format( + self.num_stages, self.stage_id)) + + if self.use_model_parallel: + logger.info("start broadcast mp parameters") + broadcast_mp_parameters(self._layers, self._hcg) + + if self.use_data_parallel: + logger.info("start broadcast mp parameters") + broadcast_dp_parameters(self._layers, self._hcg) def _allocate_caches(self, num_caches): if self.num_caches >= num_caches: @@ -82,19 +84,19 @@ class PipelineParallel(MetaParallelBase): for key in self.caches: self.caches[key].extend([None] * num) - def train_batch(self, data_iter, optimizer): + def train_batch(self, data, optimizer): self.optimizer = optimizer assert fluid.framework._dygraph_tracer()._has_grad, ( 'Please enable the generation of gradients.') if self.stage_id == 0 or self.stage_id == self.num_stages - 1: - assert data_iter, ( + assert data, ( "For the first and the last stage, the data_iter must be set.") else: - assert data_iter is None, ( + assert data is None, ( "For pipe stages other than the first and the last one, " "the data_iter must be None.") - self.data_iter = data_iter + self.data = data self._layers.train() self.total_loss = None @@ -104,39 +106,24 @@ class PipelineParallel(MetaParallelBase): return self.total_loss def _train(self, minibatch_cmds): - self._allocate_caches(self.num_stages) - for microbatch_cmds in minibatch_cmds: - for cmd in microbatch_cmds: - if type(cmd) not in self._COMMAND_MAP: - #FIXME: - continue - + self._allocate_caches(self.accumulate_steps) + for micro_cmds in minibatch_cmds: + for cmd in micro_cmds: + assert type(cmd) in self._COMMAND_MAP, "unknow cmd: {}".format( + type(cmd)) self._apply_cmd = MethodType(self._COMMAND_MAP[type(cmd)], self) self._apply_cmd(**cmd.kwargs) def _allreduce_grads(self): - self._modifying_grad = True - assert self.use_data_parallel <= 1, ("Do not support data parallel " - "with pipeline parallel now.") - self._modifying_grad = False - - def _get_data(self): - if self.use_model_parallel: - mp_rank = self._hcg.get_model_parallel_rank() - else: - mp_rank = 0 - - data = None - - # mp rank 0 loads the data and broadcat it to others. - if mp_rank == 0: - data = next(self.data_iter) - if self.use_model_parallel: - data = paddle.distributed.broadcast( - data, group=self._hcg.get_model_parallel_group()) - return data + if not self.use_data_parallel: return + fused_allreduce_gradients(list(self._layers.parameters()), self._hcg) def _forward(self, cache_id): + # load data + self._load_micro_batch(cache_id) + if self.stage_id != 0: + self._recv_activations(cache_id) + if isinstance(self.caches['inputs'][cache_id], tuple): inputs = tuple(t.clone() for t in self.caches['inputs'][cache_id]) else: @@ -144,9 +131,13 @@ class PipelineParallel(MetaParallelBase): self._clear_grads(inputs) outputs = self._layers.forward(inputs) - self.caches['outputs'][cache_id] = outputs + if self.stage_id == self.num_stages - 1: + if self._layers._loss_fn is not None: + labels = self.caches['labels'][cache_id] + outputs = self._layers._loss_fn(outputs, labels) + if self.stage_id == self.num_stages - 1: self.current_loss = outputs if isinstance(self.current_loss, paddle.Tensor): @@ -160,18 +151,28 @@ class PipelineParallel(MetaParallelBase): ] for idx, v in enumerate(self.current_loss): self.total_loss[idx] += v.detach() + if self.use_data_parallel: + self.current_loss = self.current_loss / self._hcg.get_data_parallel_world_size( + ) + if self.accumulate_steps > 1: + self.current_loss = self.current_loss / self.accumulate_steps + self.caches['outputs'][cache_id] = self.current_loss.clone() + else: + self._send_activations(cache_id) def _backward(self, cache_id): assert self.optimizer is not None if self.stage_id == self.num_stages - 1: - paddle.autograd.backward(self.current_loss) + paddle.autograd.backward(self.caches['outputs'][cache_id]) + self._send_gradients(cache_id) return + self._recv_gradients(cache_id) outputs = self.caches['outputs'][cache_id] grad_tensors = self.grad_tensors if isinstance(outputs, tuple): - out_tensors = [t for t in outputs if t.dtype in FLOAT_TYPES] + out_tensors = [t for t in outputs if is_float_tensor(t)] assert len(out_tensors) == len(grad_tensors) paddle.autograd.backward( tensors=out_tensors, grad_tensors=grad_tensors) @@ -179,41 +180,76 @@ class PipelineParallel(MetaParallelBase): paddle.autograd.backward( tensors=[outputs], grad_tensors=[grad_tensors]) - self.caches['outputs'][cache_id] = None grad_tensors = None + if self.stage_id != 0: self._send_gradients(cache_id) + self.caches['outputs'][cache_id] = None + #self.caches['backward_tensors'][cache_id] = None + + def _get_data(self): + if self.use_model_parallel: + mp_rank = self._hcg.get_model_parallel_rank() + else: + mp_rank = 0 + + # mp rank 0 loads the data and broadcat it to others. + data = self.data + if self.use_model_parallel and (self.stage_id == 0 or + self.stage_id == self.num_stages - 1): + assert isinstance(data, (tuple, paddle.Tensor)) + if isinstance(data, paddle.Tensor): + paddle.distributed.broadcast( + data, + src=self._hcg.get_model_parallel_group_src_rank(), + group=self._hcg.get_model_parallel_group()) + else: + data = [] + for d in self.data: + assert isinstance(d, paddle.Tensor) + paddle.distributed.broadcast( + d, + src=self._hcg.get_model_parallel_group_src_rank(), + group=self._hcg.get_model_parallel_group()) + data.append(d) + data = tuple(data) + return data def _load_micro_batch(self, cache_id): inputs = self._get_data() if self.stage_id == 0: data = None - if isinstance(inputs[0], paddle.Tensor): + #if isinstance(inputs[0], paddle.Tensor): + if len(inputs) == 1: + assert isinstance(inputs[0], paddle.Tensor) data = inputs[0].clone().detach() - data.stop_gradient = data.dtype == paddle.float32 + #data.stop_gradient = not is_float_tensor(data) + data.stop_gradient = True else: - assert isinstance(inputs[0], tuple) - # Assume list or tuple + assert isinstance(inputs, tuple) data = [] - for d in inputs[0]: + for d in inputs: assert isinstance(d, paddle.Tensor) - d = d.clone().detach() - d.stop_gradient = d.dtype == paddle.float32 - loaded.append(d) + i = d.clone().detach() + #i.stop_gradient = not is_float_tensor(i) + i.stop_gradient = True + data.append(i) data = tuple(data) self.caches['inputs'][cache_id] = data if self.stage_id == self.num_stages - 1: - label = None - if isinstance(inputs[1], paddle.Tensor): - label = inputs[1] - elif isinstance(data[1], tuple): - label = [] - for l in inputs[1]: - assert isinstance(l, paddle.Tensor) - l = l.detach() - label.append(l) - label = tuple(label) - self.caches['labels'][cache_id] = label + labels = None + #if isinstance(inputs[1], paddle.Tensor): + if len(inputs) == 1: + assert isinstance(inputs[0], paddle.Tensor) + labels = inputs[0] + elif isinstance(inputs, tuple): + labels = [] + for label in inputs: + assert isinstance(label, paddle.Tensor) + label = label.detach() + labels.append(label) + labels = tuple(labels) + self.caches['labels'][cache_id] = labels def _send_meta(self, data, peer): """ @@ -225,54 +261,67 @@ class PipelineParallel(MetaParallelBase): """ if isinstance(data, paddle.Tensor): tensor_type = paddle.to_tensor([0]) - paddle.distributed.send(tensor_type, peer) + paddle.distributed.send( + tensor_type, peer, use_calc_stream=True, group=self.pp_group) dims = paddle.to_tensor(len(data.shape)) - paddle.distributed.send(dims, peer) + paddle.distributed.send( + dims, peer, use_calc_stream=True, group=self.pp_group) shape = paddle.to_tensor(data.shape) - paddle.distributed.send(shape, peer) + paddle.distributed.send( + shape, peer, use_calc_stream=True, group=self.pp_group) elif isinstance(data, tuple): tensor_type = paddle.to_tensor([1]) - paddle.distributed.send(tensor_type, peer) + paddle.distributed.send( + tensor_type, peer, use_calc_stream=True, group=self.pp_group) nums = paddle.to_tensor(len(data)) - paddle.distributed.send(nums, peer) + paddle.distributed.send( + nums, peer, use_calc_stream=True, group=self.pp_group) for idx, d in enumerate(data): assert isinstance(d, paddle.Tensor) dims = paddle.to_tensor(len(d.shape)) - paddle.distributed.send(dims, peer) + paddle.distributed.send( + dims, peer, use_calc_stream=True, group=self.pp_group) shape = paddle.to_tensor(d.shape) - paddle.distributed.send(shape, peer) + paddle.distributed.send( + shape, peer, use_calc_stream=True, group=self.pp_group) def _recv_meta(self, peer): tensor_type = paddle.to_tensor([0]) - paddle.distributed.recv(tensor_type, peer) + paddle.distributed.recv( + tensor_type, peer, use_calc_stream=True, group=self.pp_group) tensor_type = tensor_type.numpy()[0] if tensor_type == 0: dims = paddle.to_tensor([0]) - paddle.distributed.recv(dims, peer) + paddle.distributed.recv( + dims, peer, use_calc_stream=True, group=self.pp_group) dims = dims.numpy()[0] shape = paddle.to_tensor([0] * dims) - paddle.distributed.recv(shape, peer) + paddle.distributed.recv( + shape, peer, use_calc_stream=True, group=self.pp_group) shape = shape.numpy().tolist() return self._allocate_buffer( shape, dtype="float32", num_caches=1)[0] elif tensor_type == 1: num = paddle.to_tensor([0]) - paddle.distributed.recv(num, peer) + paddle.distributed.recv( + num, peer, use_calc_stream=True, group=self.pp_group) num = num.numpy()[0] shapes = [] for i in range(num): dims = paddle.to_tensor([0]) - paddle.distributed.recv(dims, peer) + paddle.distributed.recv( + dims, peer, use_calc_stream=True, group=self.pp_group) dims = dims.numpy()[0] shape = paddle.to_tensor([0] * dims) - paddle.distributed.recv(shape, peer) + paddle.distributed.recv( + shape, peer, use_calc_stream=True, group=self.pp_group) shapes.append(shape.numpy().tolist()) dtypes = ["float32"] * len(shapes) - caches = self._allocate_buffers(shapes, dtypes, num_buffers=1)[0] - buffers = tuple(buffers) - return buffers + caches = self._allocate_buffers(shapes, dtypes, num_caches=1)[0] + caches = tuple(caches) + return caches def _send_activations(self, cache_id): outputs = self.caches['outputs'][cache_id] @@ -282,10 +331,18 @@ class PipelineParallel(MetaParallelBase): self._send_meta(outputs, self.next_stage_id) if isinstance(outputs, paddle.Tensor): - paddle.distributed.send(outputs, self.next_stage_id) + paddle.distributed.send( + outputs, + self.next_stage_id, + use_calc_stream=True, + group=self.pp_group) elif isinstance(outputs, tuple): for output in outputs: - paddle.distributed.send(output, self.next_stage_id) + paddle.distributed.send( + output, + self.next_stage_id, + use_calc_stream=True, + group=self.pp_group) def _send_gradients(self, cache_id): inputs = self.caches['inputs'][cache_id] @@ -293,15 +350,22 @@ class PipelineParallel(MetaParallelBase): if isinstance(inputs, paddle.Tensor): assert inputs.grad is not None paddle.distributed.send( - paddle.to_tensor(inputs.grad), self.prev_stage_id) + paddle.to_tensor(inputs.grad), + self.prev_stage_id, + use_calc_stream=True, + group=self.pp_group) else: for idx, d in enumerate(inputs): # Skip tensors that will not produce a grad - if not d.dtype in FLOAT_TYPES: + if not is_float_tensor(d): assert d.grad is None continue assert d.grad is not None - paddle.distributed.send(d.grad, self.prev_stage_id) + paddle.distributed.send( + d.grad, + self.prev_stage_id, + use_calc_stream=True, + group=self.pp_group) self.caches['inputs'][cache_id] = None def _recv_activations(self, cache_id): @@ -312,22 +376,30 @@ class PipelineParallel(MetaParallelBase): self.recv_cache = self._recv_meta(self.prev_stage_id) if isinstance(self.recv_cache, paddle.Tensor): - paddle.distributed.recv(self.recv_cache, self.prev_stage_id) + paddle.distributed.recv( + self.recv_cache, + self.prev_stage_id, + use_calc_stream=True, + group=self.pp_group) inputs = self.recv_cache.clone().detach() - inputs.stop_gradient = inputs.dtype not in FLOAT_TYPES + inputs.stop_gradient = not is_float_tensor(inputs) else: assert isinstance(self.recv_cache, tuple) inputs = [None] * len(self.recv_cache) for idx, d in enumerate(self.recv_cache): assert isinstance(d, paddle.Tensor) - paddle.distributed.recv(d, self.prev_stage_id) + paddle.distributed.recv( + d, + self.prev_stage_id, + use_calc_stream=True, + group=self.pp_group) inputs[idx] = d.clone().detach() inputs = tuple(inputs) for d in inputs: - d.stop_gradient = d.dtype not in FLOAT_TYPES + d.stop_gradient = not is_float_tensor(d) self.caches['inputs'][cache_id] = inputs @@ -336,29 +408,35 @@ class PipelineParallel(MetaParallelBase): if self.grad_tensors is None: if isinstance(outputs, paddle.Tensor): s = list(outputs.shape) - dtype = 'float32' + dtype = 'float16' if self.use_amp else "float32" self.grad_tensors = self._allocate_buffer( s, dtype, num_buffers=1)[0] else: - sizes = [ - list(d.shape) for d in outputs if d.dtype in FLOAT_TYPES - ] - dtypes = ['float32'] * len(sizes) + sizes = [list(d.shape) for d in outputs if is_float_tensor(d)] + dtypes = ['float16'] * len( + sizes) if self.use_amp else ['float32'] * len(sizes) self.grad_tensors = self._allocate_buffers( - sizes, dtypes, num_buffers=1)[0] + sizes, dtypes, num_caches=1)[0] if isinstance(self.grad_tensors, paddle.Tensor): - paddle.distributed.recv(self.grad_tensors, self.next_stage_id) + paddle.distributed.recv( + self.grad_tensors, + self.next_stage_id, + use_calc_stream=True, + group=self.pp_group) else: assert isinstance(outputs, tuple) for d in self.grad_tensors: - paddle.distributed.recv(d, self.next_stage_id) - - def _step(self, lr_kwargs=None): - self._modifying_grad = True + paddle.distributed.recv( + d, + self.next_stage_id, + use_calc_stream=True, + group=self.pp_group) + + def _step(self): + self._allreduce_grads() self.optimizer.step() self.optimizer.clear_gradients() - self._modifying_grad = False def _clear_grads(self, inputs): if isinstance(inputs, paddle.Tensor): @@ -372,26 +450,24 @@ class PipelineParallel(MetaParallelBase): def _allocate_zeros(self, shape, dtype): return paddle.zeros(shape, dtype) - def _allocate_buffer(self, shape, dtype, num_buffers=-1, **kwargs): - buffers = [] - if num_buffers == -1: - num_buffers = self.num_caches - for count in range(num_buffers): - buffers.append(self._allocate_zeros(shape, dtype)) - return buffers - - def _allocate_buffers(self, shapes, dtypes, num_buffers=-1): - buffers = [] - if num_buffers == -1: - num_buffers = self.num_caches - for count in range(num_buffers): - buffer = [] + def _allocate_buffer(self, shape, dtype, num_caches=-1): + caches = [] + if num_caches == -1: + num_caches = self.num_caches + for count in range(num_caches): + caches.append(self._allocate_zeros(shape, dtype)) + return caches + + def _allocate_buffers(self, shapes, dtypes, num_caches=-1): + caches = [] + if num_caches == -1: + num_caches = self.num_caches + for count in range(num_caches): + cache = [] for shape, dtype in zip(shapes, dtypes): - buffer.append( - self._allocate_zeros( - shape, dtype, requires_grad=requires_grad)) - buffers.append(buffer) - return buffers + cache.append(self._allocate_zeros(shape, dtype)) + caches.append(cache) + return caches def save_state_dict(self, model_path): state_dict = self._layers.state_dict() @@ -403,25 +479,9 @@ class PipelineParallel(MetaParallelBase): _COMMAND_MAP = { utils.Optimize: _step, - #utils.ReduceGrads: _allreduce_grads, utils.Forward: _forward, utils.Backward: _backward, } - def _pre_forward(self, *inputs, **kwargs): - pass - def forward(self, *inputs, **kwargs): raise RuntimeError("Call train_batch for pipeline instead of forward.") - - def _post_forward(self, output): - pass - - def _pre_backward(self, loss): - pass - - def backward_impl(self, loss, parameters): - pass - - def _post_backward(self, loss): - pass diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index 56eef8d7d21..7b426e2c3f7 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -16,7 +16,21 @@ import abc import paddle from ...utils import hybrid_parallel_util as hp_util -__all__ = ['get_tensor_bytes', ] +__all__ = [ + 'get_tensor_bytes', + 'is_float_tensor', +] + +FLOAT_TYPES = [ + paddle.float16, + paddle.float32, + paddle.float64, +] + + +def is_float_tensor(tensor): + """Is a float tensor""" + return tensor.dtype in FLOAT_TYPES def get_tensor_bytes(tensor): @@ -48,10 +62,6 @@ class Generator(): self.stage_id = stage_id self.prev_stage = self.stage_id - 1 self.next_stage = self.stage_id + 1 - assert self.micro_batches >= self.stages, ( - "micro_batches {} " - "must be greater than or equal to {}".format(self.micro_batches, - self.stages)) @abc.abstractmethod def generate(self): @@ -73,18 +83,25 @@ class TrainGenerator(Generator): cmds = [] forward_steps = 0 backward_steps = 0 - while (forward_steps < startup_steps): - cmds.append(Forward) - forward_steps += 1 + #while (forward_steps < startup_steps): + # cmds.append(Forward(cache_id=forward_steps)) + # forward_steps += 1 + #while (forward_steps < self.micro_batches): + # cmds.append(Forward(cache_id=forward_steps)) + # forward_steps += 1 + # cmds.append(Backward(cache_id=backward_steps)) + # backward_steps += 1 + #while (backward_steps < self.micro_batches): + # cmds.append(Backward(cache_id=backward_steps)) + # backward_steps += 1 + #cmds.append(Optimize()) while (forward_steps < self.micro_batches): - cmds.append(Forward) + cmds.append(Forward(cache_id=forward_steps)) forward_steps += 1 - cmds.append(Backward) - backward_steps += 1 while (backward_steps < self.micro_batches): - cmds.append(Backward) + cmds.append(Backward(cache_id=backward_steps)) backward_steps += 1 - cmds.append(Optimize) + cmds.append(Optimize()) yield cmds -- GitLab