未验证 提交 d4bf8b1a 编写于 作者: S ShenLiang 提交者: GitHub

support unbalanced data for pipeline (#47199) (#47569)

* add unbalanced data

* fix utest
上级 ba4fbe71
...@@ -20,7 +20,10 @@ from ..utils.hybrid_parallel_util import broadcast_mp_parameters ...@@ -20,7 +20,10 @@ from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from ..utils.hybrid_parallel_util import broadcast_dp_parameters from ..utils.hybrid_parallel_util import broadcast_dp_parameters
from ..utils.hybrid_parallel_util import broadcast_sharding_parameters from ..utils.hybrid_parallel_util import broadcast_sharding_parameters
from ..utils.log_util import logger from ..utils.log_util import logger
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer, HybridParallelGradScaler from ..meta_optimizers.dygraph_optimizer import (
HybridParallelOptimizer,
HybridParallelGradScaler,
)
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
from .pp_utils import p2p_communication as p2p from .pp_utils import p2p_communication as p2p
import paddle.fluid.core as core import paddle.fluid.core as core
...@@ -29,27 +32,31 @@ __all__ = [] ...@@ -29,27 +32,31 @@ __all__ = []
class PipelineParallel(MetaParallelBase): class PipelineParallel(MetaParallelBase):
def __init__(self, layers, hcg, strategy): def __init__(self, layers, hcg, strategy):
if not isinstance(layers, PipelineLayer): if not isinstance(layers, PipelineLayer):
raise TypeError( raise TypeError(
"The Layer should be a derived class of PipelineLayer.") "The Layer should be a derived class of PipelineLayer."
)
super(PipelineParallel, self).__init__(layers, hcg, strategy) super(PipelineParallel, self).__init__(layers, hcg, strategy)
self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1 self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1 self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
self.use_sharding_parallel = self._hcg.get_sharding_parallel_world_size( self.use_sharding_parallel = (
) > 1 self._hcg.get_sharding_parallel_world_size() > 1
)
self.total_loss = None self.total_loss = None
self.micro_batch_size = self._strategy.pipeline_configs[ self.micro_batch_size = self._strategy.pipeline_configs[
'micro_batch_size'] 'micro_batch_size'
]
self.accumulate_steps = self._strategy.pipeline_configs[ self.accumulate_steps = self._strategy.pipeline_configs[
'accumulate_steps'] 'accumulate_steps'
]
# If sent tensor are not the same from different hosts, # If sent tensor are not the same from different hosts,
# they shouldn't been sent partially and then concated as a whole tensor. # they shouldn't been sent partially and then concated as a whole tensor.
self._enable_partial_send_recv = self._strategy.pipeline_configs[ self._enable_partial_send_recv = self._strategy.pipeline_configs[
'enable_partial_send_recv'] 'enable_partial_send_recv'
]
self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape'] self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']
self.num_stages = self._hcg.get_pipe_parallel_world_size() self.num_stages = self._hcg.get_pipe_parallel_world_size()
...@@ -61,16 +68,20 @@ class PipelineParallel(MetaParallelBase): ...@@ -61,16 +68,20 @@ class PipelineParallel(MetaParallelBase):
self._real_pp_world_size = self.num_stages self._real_pp_world_size = self.num_stages
self._real_pp_rank = self.stage_id self._real_pp_rank = self.stage_id
p2p.initialize_p2p_groups(hcg, self._using_cache, p2p.initialize_p2p_groups(
self._enable_partial_send_recv) hcg, self._using_cache, self._enable_partial_send_recv
)
self.global_rank = self._hcg.get_global_rank() self.global_rank = self._hcg.get_global_rank()
self.micro_batch_id = 0 self.micro_batch_id = 0
self._compute_loss = True self._compute_loss = True
logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format( logger.info(
self.num_stages, self.stage_id)) "Pipeline Info -- num_stages: {}, stage_id: {}".format(
self.num_stages, self.stage_id
)
)
if self.use_model_parallel: if self.use_model_parallel:
logger.info("start broadcast mp parameters") logger.info("start broadcast mp parameters")
...@@ -122,7 +133,7 @@ class PipelineParallel(MetaParallelBase): ...@@ -122,7 +133,7 @@ class PipelineParallel(MetaParallelBase):
# store data id for micro_batch # store data id for micro_batch
self.micro_batch_id = 0 self.micro_batch_id = 0
startup_steps = (self.num_stages - self.stage_id - 1) startup_steps = self.num_stages - self.stage_id - 1
startup_steps = min(startup_steps, self.accumulate_steps) startup_steps = min(startup_steps, self.accumulate_steps)
steady_steps = self.accumulate_steps - startup_steps steady_steps = self.accumulate_steps - startup_steps
...@@ -142,39 +153,46 @@ class PipelineParallel(MetaParallelBase): ...@@ -142,39 +153,46 @@ class PipelineParallel(MetaParallelBase):
input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
for i in range(steady_steps): for i in range(steady_steps):
last_iter = (i == (steady_steps - 1)) last_iter = i == (steady_steps - 1)
output_tensor = self._forward_step(input_tensor) output_tensor = self._forward_step(input_tensor)
output_tensor_grad = p2p.send_forward_recv_backward( output_tensor_grad = p2p.send_forward_recv_backward(
output_tensor, self.is_pipeline_last_stage()) output_tensor, self.is_pipeline_last_stage()
)
input_buffers.append(input_tensor) input_buffers.append(input_tensor)
output_buffers.append(output_tensor) output_buffers.append(output_tensor)
input_tensor, output_tensor = input_buffers.pop( input_tensor, output_tensor = input_buffers.pop(
0), output_buffers.pop(0) 0
), output_buffers.pop(0)
input_tensor_grad = self._backward_step(input_tensor, output_tensor, input_tensor_grad = self._backward_step(
output_tensor_grad) input_tensor, output_tensor, output_tensor_grad
)
if last_iter: if last_iter:
input_tensor = None input_tensor = None
p2p.send_backward(input_tensor_grad, p2p.send_backward(
self.is_pipeline_first_stage()) input_tensor_grad, self.is_pipeline_first_stage()
)
else: else:
input_tensor = p2p.send_backward_recv_forward( input_tensor = p2p.send_backward_recv_forward(
input_tensor_grad, self.is_pipeline_first_stage()) input_tensor_grad, self.is_pipeline_first_stage()
)
for i in range(startup_steps): for i in range(startup_steps):
input_tensor = input_buffers.pop(0) input_tensor = input_buffers.pop(0)
output_tensor = output_buffers.pop(0) output_tensor = output_buffers.pop(0)
output_tensor_grad = p2p.recv_backward( output_tensor_grad = p2p.recv_backward(
self.is_pipeline_last_stage()) self.is_pipeline_last_stage()
)
input_tensor_grad = self._backward_step(input_tensor, output_tensor, input_tensor_grad = self._backward_step(
output_tensor_grad) input_tensor, output_tensor, output_tensor_grad
)
p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage()) p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
self._layers.allreduce_shared_weight_gradients() self._layers.allreduce_shared_weight_gradients()
...@@ -186,17 +204,20 @@ class PipelineParallel(MetaParallelBase): ...@@ -186,17 +204,20 @@ class PipelineParallel(MetaParallelBase):
# reset the virtual pp rank for each run # reset the virtual pp rank for each run
self.set_virtual_pipeline_rank(0) self.set_virtual_pipeline_rank(0)
assert isinstance(optimizer, HybridParallelOptimizer), ( assert isinstance(
'optimizer should be HybridParallelOptimizer subclass.') optimizer, HybridParallelOptimizer
), 'optimizer should be HybridParallelOptimizer subclass.'
assert fluid.framework._dygraph_tracer()._has_grad, ( assert (
'Please enable the generation of gradients.') fluid.framework._dygraph_tracer()._has_grad
), 'Please enable the generation of gradients.'
if self.is_pipeline_first_stage( if self.is_pipeline_first_stage(
ignore_virtual=True) or self.is_pipeline_last_stage( ignore_virtual=True
ignore_virtual=True): ) or self.is_pipeline_last_stage(ignore_virtual=True):
assert data is not None, ( assert (
"For the first and the last stage, the data must be set.") data is not None
), "For the first and the last stage, the data must be set."
else: else:
data = None data = None
...@@ -233,7 +254,7 @@ class PipelineParallel(MetaParallelBase): ...@@ -233,7 +254,7 @@ class PipelineParallel(MetaParallelBase):
# store total loss of entire batch # store total loss of entire batch
self.total_loss = None self.total_loss = None
startup_steps = (self.num_stages - self.stage_id - 1) startup_steps = self.num_stages - self.stage_id - 1
startup_steps = min(startup_steps, self.accumulate_steps) startup_steps = min(startup_steps, self.accumulate_steps)
steady_steps = self.accumulate_steps - startup_steps steady_steps = self.accumulate_steps - startup_steps
...@@ -253,7 +274,7 @@ class PipelineParallel(MetaParallelBase): ...@@ -253,7 +274,7 @@ class PipelineParallel(MetaParallelBase):
input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
for i in range(steady_steps): for i in range(steady_steps):
last_iter = (i == (steady_steps - 1)) last_iter = i == (steady_steps - 1)
output_tensor = self._forward_step(input_tensor) output_tensor = self._forward_step(input_tensor)
p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
...@@ -282,13 +303,14 @@ class PipelineParallel(MetaParallelBase): ...@@ -282,13 +303,14 @@ class PipelineParallel(MetaParallelBase):
if self.is_pipeline_last_stage(): if self.is_pipeline_last_stage():
# train calculate loss for train # train calculate loss for train
if self._compute_loss: if self._compute_loss:
assert self._layers._loss_fn is not None, "loss function should exist to compute loss" assert (
self._layers._loss_fn is not None
), "loss function should exist to compute loss"
labels = self._load_micro_batch(self.micro_batch_id) labels = self._load_micro_batch(self.micro_batch_id)
output_tensor = self._layers._loss_fn(output_tensor, labels) output_tensor = self._layers._loss_fn(output_tensor, labels)
assert isinstance( assert isinstance(
output_tensor, output_tensor, (paddle.Tensor, core.eager.Tensor)
(paddle.Tensor, core.eager.Tensor ), "Currently, loss_fn should obtain Paddle.Tensor dtype"
)), "Currently, loss_fn should obtain Paddle.Tensor dtype"
with paddle.amp.auto_cast(enable=False): with paddle.amp.auto_cast(enable=False):
if self.accumulate_steps > 1: if self.accumulate_steps > 1:
...@@ -318,91 +340,113 @@ class PipelineParallel(MetaParallelBase): ...@@ -318,91 +340,113 @@ class PipelineParallel(MetaParallelBase):
assert len(outputs) == len(output_tensor_grad) assert len(outputs) == len(output_tensor_grad)
paddle.autograd.backward( paddle.autograd.backward(
tensors=outputs, tensors=outputs,
grad_tensors=[t for t in output_tensor_grad]) grad_tensors=[t for t in output_tensor_grad],
)
else: else:
paddle.autograd.backward(tensors=[output_tensor], paddle.autograd.backward(
grad_tensors=[output_tensor_grad]) tensors=[output_tensor],
grad_tensors=[output_tensor_grad],
)
input_tensor_grad = None input_tensor_grad = None
if input_tensor is not None: if input_tensor is not None:
if isinstance(input_tensor, tuple): if isinstance(input_tensor, tuple):
input_tensor_grad = tuple( input_tensor_grad = tuple(
[t.grad for t in input_tensor if not t.stop_gradient]) [t.grad for t in input_tensor if not t.stop_gradient]
)
else: else:
input_tensor_grad = input_tensor.grad input_tensor_grad = input_tensor.grad
return input_tensor_grad return input_tensor_grad
def _load_micro_batch(self, cache_id): def _check_data_vaild(self, data):
inputs = self.data batch_size = data.shape[0]
assert self.micro_batch_size * self.accumulate_steps == batch_size, (
"batch_size needs to be divisible by micro_batch_size. Currently, "
"batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
% (batch_size, self.micro_batch_size, self.accumulate_steps)
)
def _load_micro_batch_impl(self, inputs, cache_id):
begin = cache_id * self.micro_batch_size begin = cache_id * self.micro_batch_size
end = begin + self.micro_batch_size end = begin + self.micro_batch_size
# The virtual first and last pipeline stage need data, all others don't need. if isinstance(inputs, tuple):
output = []
for data in inputs:
if isinstance(data, list):
assert (
len(data) == self.accumulate_steps
), "length of data should be %d, but it is %d" % (
self.accumulate_steps,
len(data),
)
output.append(data[cache_id].detach())
else:
self._check_data_vaild(data)
output.append(data[begin:end, :].detach())
return tuple(output)
elif isinstance(inputs, list):
assert (
len(inputs) == self.accumulate_steps
), "length of data should be %d, but it is %d" % (
self.accumulate_steps,
len(inputs),
)
return inputs[cache_id].detach()
else:
self._check_data_vaild(inputs)
return inputs[begin:end, :].detach()
def _load_micro_batch(self, cache_id):
inputs = self.data
if self.is_pipeline_first_stage(): if self.is_pipeline_first_stage():
assert len(inputs) == 2, "length of input should be 2" assert len(inputs) == 2, "length of input should be 2"
if isinstance(inputs[0], tuple): return self._load_micro_batch_impl(inputs[0], cache_id)
assert len(
inputs[0]
) > 1, "If you use tuple for input data, it should have at least two inputs."
batch_size = inputs[0][0].shape[0]
assert self.micro_batch_size * self.accumulate_steps == batch_size, (
"batch_size needs to be divisible by micro_batch_size. Currently, "
"batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
%
(batch_size, self.micro_batch_size, self.accumulate_steps))
data = [input[begin:end, :].detach() for input in inputs[0]]
return tuple(data)
else:
batch_size = inputs[0].shape[0]
assert self.micro_batch_size * self.accumulate_steps == batch_size
return inputs[0][begin:end, :].detach()
elif self.is_pipeline_last_stage(): elif self.is_pipeline_last_stage():
assert len(inputs) == 2, "length of input should be 2" assert len(inputs) == 2, "length of input should be 2"
if isinstance(inputs[1], tuple): return self._load_micro_batch_impl(inputs[1], cache_id)
batch_size = inputs[1][0].shape[0]
assert self.micro_batch_size * self.accumulate_steps == batch_size
data = [input[begin:end, :].detach() for input in inputs[1]]
return tuple(data)
else: else:
batch_size = inputs[1].shape[0]
assert self.micro_batch_size * self.accumulate_steps == batch_size
return inputs[1][begin:end, :].detach()
else:
# No data input is required for other stages
inputs = None inputs = None
def _broadcast_final_loss(self): def _broadcast_final_loss(self):
# Since the last backward run in interleave will set the virtual rank to 0, # Since the last backward run in interleave will set the virtual rank to 0,
# here we need to check last stage ignoring virtual stage. # here we need to check last stage ignoring virtual stage.
if self.is_pipeline_last_stage(ignore_virtual=True): if self.is_pipeline_last_stage(ignore_virtual=True):
assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss" assert (
self.total_loss is not None
), "train_batch() in last stage should obtain vaild loss"
loss = self.total_loss.detach() loss = self.total_loss.detach()
is_fp32 = paddle.to_tensor( is_fp32 = (
1) if loss.dtype == paddle.float32 else paddle.to_tensor(0) paddle.to_tensor(1)
paddle.distributed.broadcast(is_fp32, if loss.dtype == paddle.float32
src=self.global_rank, else paddle.to_tensor(0)
sync_op=True, )
group=self.pp_group) paddle.distributed.broadcast(
paddle.distributed.broadcast(loss, is_fp32, src=self.global_rank, sync_op=True, group=self.pp_group
src=self.global_rank, )
sync_op=True, paddle.distributed.broadcast(
group=self.pp_group) loss, src=self.global_rank, sync_op=True, group=self.pp_group
)
else: else:
is_fp32 = paddle.to_tensor(1) is_fp32 = paddle.to_tensor(1)
paddle.distributed.broadcast( paddle.distributed.broadcast(
is_fp32, is_fp32,
src=self._hcg.get_rank_from_stage(self.num_stages - 1), src=self._hcg.get_rank_from_stage(self.num_stages - 1),
sync_op=True, sync_op=True,
group=self.pp_group) group=self.pp_group,
loss = paddle.zeros(shape=[ )
1 loss = (
], dtype="float32") if is_fp32.numpy()[0] else paddle.zeros( paddle.zeros(shape=[1], dtype="float32")
shape=[1], dtype="float16") if is_fp32.numpy()[0]
else paddle.zeros(shape=[1], dtype="float16")
)
paddle.distributed.broadcast( paddle.distributed.broadcast(
loss, loss,
src=self._hcg.get_rank_from_stage(self.num_stages - 1), src=self._hcg.get_rank_from_stage(self.num_stages - 1),
sync_op=True, sync_op=True,
group=self.pp_group) group=self.pp_group,
)
return loss return loss
def _optimizer_step(self): def _optimizer_step(self):
...@@ -421,11 +465,12 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -421,11 +465,12 @@ class PipelineParallelWithInterleave(PipelineParallel):
# pipeline parallel with interleave scheduler # pipeline parallel with interleave scheduler
def __init__(self, layers, hcg, strategy): def __init__(self, layers, hcg, strategy):
super(PipelineParallelWithInterleave, self).__init__(layers=layers, super(PipelineParallelWithInterleave, self).__init__(
hcg=hcg, layers=layers, hcg=hcg, strategy=strategy
strategy=strategy) )
assert layers.get_num_virtual_stages() > 1 assert layers.get_num_virtual_stages() > 1
assert framework.in_dygraph_mode( assert (
framework.in_dygraph_mode()
), "virtual pipeline stage with interleave only support eager dygraph mode" ), "virtual pipeline stage with interleave only support eager dygraph mode"
# setup for interleave scheduler # setup for interleave scheduler
self.num_model_chunks = layers.get_num_virtual_stages() self.num_model_chunks = layers.get_num_virtual_stages()
...@@ -436,11 +481,12 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -436,11 +481,12 @@ class PipelineParallelWithInterleave(PipelineParallel):
self._virtual_pp_rank = 0 self._virtual_pp_rank = 0
def _get_virtual_pp_rank(self, micro_step, forward): def _get_virtual_pp_rank(self, micro_step, forward):
virtual_pp_stage = micro_step % (self.num_stages * virtual_pp_stage = micro_step % (
self.num_model_chunks) self.num_stages * self.num_model_chunks
)
virtual_pp_stage = virtual_pp_stage // self.num_stages virtual_pp_stage = virtual_pp_stage // self.num_stages
if not forward: if not forward:
virtual_pp_stage = (self.num_model_chunks - virtual_pp_stage - 1) virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1
return virtual_pp_stage return virtual_pp_stage
def _forward_step_helper(self, micro_step): def _forward_step_helper(self, micro_step):
...@@ -455,7 +501,8 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -455,7 +501,8 @@ class PipelineParallelWithInterleave(PipelineParallel):
if self.is_pipeline_first_stage(): if self.is_pipeline_first_stage():
if len(self.input_tensors[virtual_pp_rank]) == len( if len(self.input_tensors[virtual_pp_rank]) == len(
self.output_tensors[virtual_pp_rank]): self.output_tensors[virtual_pp_rank]
):
self.input_tensors[virtual_pp_rank].append(None) self.input_tensors[virtual_pp_rank].append(None)
input_tensor = self.input_tensors[virtual_pp_rank][-1] input_tensor = self.input_tensors[virtual_pp_rank][-1]
output_tensor = self._forward_step(input_tensor, virtual_pp_rank) output_tensor = self._forward_step(input_tensor, virtual_pp_rank)
...@@ -484,21 +531,22 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -484,21 +531,22 @@ class PipelineParallelWithInterleave(PipelineParallel):
input_tensor = self.input_tensors[virtual_pp_rank].pop(0) input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
output_tensor = self.output_tensors[virtual_pp_rank].pop(0) output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0) output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0)
input_tensor_grad = self._backward_step(input_tensor, output_tensor, input_tensor_grad = self._backward_step(
output_tensor_grad) input_tensor, output_tensor, output_tensor_grad
)
return input_tensor_grad return input_tensor_grad
def interleave_pipeline(self, def interleave_pipeline(
data, self, data, scaler, forward_only=False, compute_loss=True
scaler, ):
forward_only=False,
compute_loss=True):
# use interleave scheduling strategy. # use interleave scheduling strategy.
# this strategy is inspired by: # this strategy is inspired by:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
if not compute_loss: if not compute_loss:
assert not forward_only, "compute_loss can only be set to False when forward_only is set to True" assert (
not forward_only
), "compute_loss can only be set to False when forward_only is set to True"
# init some attributes for this batch run # init some attributes for this batch run
self.scaler = scaler self.scaler = scaler
...@@ -530,15 +578,17 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -530,15 +578,17 @@ class PipelineParallelWithInterleave(PipelineParallel):
self.set_virtual_pipeline_rank(0) self.set_virtual_pipeline_rank(0)
self.input_tensors[0].append( self.input_tensors[0].append(
p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False)) p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False)
)
# run startup steps # run startup steps
for micro_step in range(startup_steps): for micro_step in range(startup_steps):
output_tensor = self._forward_step_helper(micro_step) output_tensor = self._forward_step_helper(micro_step)
# determine whether recv forward tensor or not # determine whether recv forward tensor or not
next_virtual_pp_rank = self._get_virtual_pp_rank(micro_step + 1, next_virtual_pp_rank = self._get_virtual_pp_rank(
forward=True) micro_step + 1, forward=True
)
recv_prev = True recv_prev = True
if self.is_pipeline_first_stage(ignore_virtual=True): if self.is_pipeline_first_stage(ignore_virtual=True):
if next_virtual_pp_rank == 0: if next_virtual_pp_rank == 0:
...@@ -552,24 +602,33 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -552,24 +602,33 @@ class PipelineParallelWithInterleave(PipelineParallel):
if self.is_pipeline_last_stage(): if self.is_pipeline_last_stage():
output_tensor = None output_tensor = None
if micro_step == (startup_steps - if (
1) and not forward_only and not all_startup_steps: micro_step == (startup_steps - 1)
and not forward_only
and not all_startup_steps
):
input_tensor_grad = None input_tensor_grad = None
recv_next = True recv_next = True
if self.is_pipeline_last_stage(ignore_virtual=True): if self.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False recv_next = False
# the last startup step needs on four direction comm to set up for steady 1f1b # the last startup step needs on four direction comm to set up for steady 1f1b
input_tensor, output_tensor_grad = p2p.send_forward_backward_recv_forward_backward( (
input_tensor,
output_tensor_grad,
) = p2p.send_forward_backward_recv_forward_backward(
output_tensor, output_tensor,
input_tensor_grad, input_tensor_grad,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=recv_next) recv_next=recv_next,
self.output_tensor_grads[self.num_model_chunks - )
1].append(output_tensor_grad) self.output_tensor_grads[self.num_model_chunks - 1].append(
output_tensor_grad
)
else: else:
input_tensor = p2p.send_forward_recv_forward( input_tensor = p2p.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev) output_tensor, recv_prev=recv_prev
)
self.input_tensors[next_virtual_pp_rank].append(input_tensor) self.input_tensors[next_virtual_pp_rank].append(input_tensor)
# run 1f1b steady steps # run 1f1b steady steps
...@@ -581,7 +640,8 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -581,7 +640,8 @@ class PipelineParallelWithInterleave(PipelineParallel):
# backward # backward
backward_micro_step_id = micro_step backward_micro_step_id = micro_step
input_tensor_grad = self._backward_step_helper( input_tensor_grad = self._backward_step_helper(
backward_micro_step_id) backward_micro_step_id
)
# four directions comm # four directions comm
# send output tensor to downstream # send output tensor to downstream
...@@ -591,14 +651,16 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -591,14 +651,16 @@ class PipelineParallelWithInterleave(PipelineParallel):
# last stage doesn't send rst to downstream # last stage doesn't send rst to downstream
forward_virtual_pp_rank = self._get_virtual_pp_rank( forward_virtual_pp_rank = self._get_virtual_pp_rank(
forward_micro_step_id, forward=True) forward_micro_step_id, forward=True
)
self.set_virtual_pipeline_rank(forward_virtual_pp_rank) self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
if self.is_pipeline_last_stage(): if self.is_pipeline_last_stage():
output_tensor = None output_tensor = None
# first stage doesn't send grad to upstream # first stage doesn't send grad to upstream
backward_virtual_pp_rank = self._get_virtual_pp_rank( backward_virtual_pp_rank = self._get_virtual_pp_rank(
backward_micro_step_id, forward=False) backward_micro_step_id, forward=False
)
self.set_virtual_pipeline_rank(backward_virtual_pp_rank) self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
if self.is_pipeline_first_stage(): if self.is_pipeline_first_stage():
input_tensor_grad = None input_tensor_grad = None
...@@ -607,14 +669,16 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -607,14 +669,16 @@ class PipelineParallelWithInterleave(PipelineParallel):
recv_prev = True recv_prev = True
if self.is_pipeline_first_stage(ignore_virtual=True): if self.is_pipeline_first_stage(ignore_virtual=True):
next_forward_virtual_pp_rank = self._get_virtual_pp_rank( next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
forward_micro_step_id - (self.num_stages - 1), forward=True) forward_micro_step_id - (self.num_stages - 1), forward=True
)
if next_forward_virtual_pp_rank == (self.num_model_chunks - 1): if next_forward_virtual_pp_rank == (self.num_model_chunks - 1):
# first pp stage and first virtual stage # first pp stage and first virtual stage
recv_prev = False recv_prev = False
next_forward_virtual_pp_rank += 1 next_forward_virtual_pp_rank += 1
else: else:
next_forward_virtual_pp_rank = self._get_virtual_pp_rank( next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
forward_micro_step_id + 1, forward=True) forward_micro_step_id + 1, forward=True
)
# last iteration doesn't need recv from upstream # last iteration doesn't need recv from upstream
if micro_step == (steady_steps - 1): if micro_step == (steady_steps - 1):
...@@ -625,53 +689,67 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -625,53 +689,67 @@ class PipelineParallelWithInterleave(PipelineParallel):
if self.is_pipeline_last_stage(ignore_virtual=True): if self.is_pipeline_last_stage(ignore_virtual=True):
next_backward_virtual_pp_rank = self._get_virtual_pp_rank( next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
backward_micro_step_id - (self.num_stages - 1), backward_micro_step_id - (self.num_stages - 1),
forward=False) forward=False,
)
if next_backward_virtual_pp_rank == 0: if next_backward_virtual_pp_rank == 0:
# last pp stage and last virtual stage # last pp stage and last virtual stage
recv_next = False recv_next = False
next_backward_virtual_pp_rank -= 1 next_backward_virtual_pp_rank -= 1
else: else:
next_backward_virtual_pp_rank = self._get_virtual_pp_rank( next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
backward_micro_step_id + 1, forward=False) backward_micro_step_id + 1, forward=False
)
input_tensor, output_tensor_grad = p2p.send_forward_backward_recv_forward_backward( (
input_tensor,
output_tensor_grad,
) = p2p.send_forward_backward_recv_forward_backward(
output_tensor, output_tensor,
input_tensor_grad, input_tensor_grad,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=recv_next) recv_next=recv_next,
)
if recv_prev: if recv_prev:
self.input_tensors[next_forward_virtual_pp_rank].append( self.input_tensors[next_forward_virtual_pp_rank].append(
input_tensor) input_tensor
)
if recv_next: if recv_next:
self.output_tensor_grads[next_backward_virtual_pp_rank].append( self.output_tensor_grads[next_backward_virtual_pp_rank].append(
output_tensor_grad) output_tensor_grad
)
# remaining backward steps # remaining backward steps
if not forward_only: if not forward_only:
if all_startup_steps: if all_startup_steps:
self.output_tensor_grads[self.num_model_chunks - 1].append( self.output_tensor_grads[self.num_model_chunks - 1].append(
p2p.recv_backward(self.is_pipeline_last_stage(), p2p.recv_backward(
sync_recv=False)) self.is_pipeline_last_stage(), sync_recv=False
)
)
for micro_step in range(steady_steps, num_steps): for micro_step in range(steady_steps, num_steps):
# cooldown loop # cooldown loop
input_tensor_grad = self._backward_step_helper(micro_step) input_tensor_grad = self._backward_step_helper(micro_step)
next_backward_virtual_pp_rank = self._get_virtual_pp_rank( next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
micro_step + 1, forward=False) micro_step + 1, forward=False
)
recv_next = True recv_next = True
if self.is_pipeline_last_stage(ignore_virtual=True): if self.is_pipeline_last_stage(ignore_virtual=True):
if next_backward_virtual_pp_rank == (self.num_model_chunks - if next_backward_virtual_pp_rank == (
1): self.num_model_chunks - 1
):
recv_next = False recv_next = False
if micro_step == (num_steps - 1): if micro_step == (num_steps - 1):
recv_next = False recv_next = False
self.output_tensor_grads[next_backward_virtual_pp_rank].append( self.output_tensor_grads[next_backward_virtual_pp_rank].append(
p2p.send_backward_recv_backward(input_tensor_grad, p2p.send_backward_recv_backward(
recv_next=recv_next)) input_tensor_grad, recv_next=recv_next
)
)
self._layers.allreduce_shared_weight_gradients() self._layers.allreduce_shared_weight_gradients()
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import numpy as np
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from hybrid_parallel_pp_transformer import (
TestDistPPTraning,
set_random_seed,
ModelPipe,
batch_size,
length,
micro_batch_size,
vocab_size,
)
class TestDistPPTraningUnbalancedData(TestDistPPTraning):
def test_pp_model(self):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
dp_id = hcg.get_data_parallel_rank()
pp_id = hcg.get_stage_id()
rank_id = dist.get_rank()
topology = hcg.topology()
set_random_seed(1024, dp_id, rank_id)
model = ModelPipe(topology)
scheduler = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True
)
optimizer = paddle.optimizer.SGD(
learning_rate=scheduler, parameters=model.parameters()
)
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
for step_id in range(5):
x = []
for _ in range(batch_size // micro_batch_size):
size = micro_batch_size
x_data = np.random.randint(0, vocab_size, size=[size, length])
x.append(paddle.to_tensor(x_data))
e_loss = model.eval_batch([x, x], True)
loss = model.train_batch([x, x], optimizer, scheduler)
# TODO(shenliang03) add utest for loss
if pp_id != 0:
np.testing.assert_allclose(loss.numpy(), e_loss.numpy())
if __name__ == "__main__":
unittest.main()
...@@ -22,13 +22,14 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus ...@@ -22,13 +22,14 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestHybridPipeParallel(TestMultipleGpus): class TestHybridPipeParallel(TestMultipleGpus):
def test_hybrid_parallel_pp_layer(self): def test_hybrid_parallel_pp_layer(self):
self.run_mnist_2gpu( self.run_mnist_2gpu(
os.path.abspath('../../hybrid_parallel_pp_layer.py')) os.path.abspath('../../hybrid_parallel_pp_layer.py')
)
self.run_mnist_2gpu( self.run_mnist_2gpu(
os.path.abspath('../../hybrid_parallel_pp_layer.py'), os.path.abspath('../../hybrid_parallel_pp_layer.py'),
eager_mode=False) eager_mode=False,
)
def test_hybrid_parallel_pp_tuple_inputs(self): def test_hybrid_parallel_pp_tuple_inputs(self):
self.run_mnist_2gpu('hybrid_parallel_pp_embedding.py') self.run_mnist_2gpu('hybrid_parallel_pp_embedding.py')
...@@ -36,8 +37,9 @@ class TestHybridPipeParallel(TestMultipleGpus): ...@@ -36,8 +37,9 @@ class TestHybridPipeParallel(TestMultipleGpus):
def test_hybrid_parallel_shared_weight(self): def test_hybrid_parallel_shared_weight(self):
self.run_mnist_2gpu('hybrid_parallel_shared_weight.py') self.run_mnist_2gpu('hybrid_parallel_shared_weight.py')
self.run_mnist_2gpu('hybrid_parallel_shared_weight.py', self.run_mnist_2gpu(
eager_mode=False) 'hybrid_parallel_shared_weight.py', eager_mode=False
)
def test_pipeline_parallel_amp(self): def test_pipeline_parallel_amp(self):
self.run_mnist_2gpu('hybrid_parallel_pp_amp.py') self.run_mnist_2gpu('hybrid_parallel_pp_amp.py')
...@@ -49,8 +51,9 @@ class TestHybridPipeParallel(TestMultipleGpus): ...@@ -49,8 +51,9 @@ class TestHybridPipeParallel(TestMultipleGpus):
def test_hybrid_parallel_transformer(self): def test_hybrid_parallel_transformer(self):
self.run_mnist_2gpu('hybrid_parallel_pp_transformer.py') self.run_mnist_2gpu('hybrid_parallel_pp_transformer.py')
self.run_mnist_2gpu('hybrid_parallel_pp_transformer.py', self.run_mnist_2gpu(
eager_mode=False) 'hybrid_parallel_pp_transformer.py', eager_mode=False
)
def test_hybrid_parallel_save_load(self): def test_hybrid_parallel_save_load(self):
self.run_mnist_2gpu('hybrid_parallel_pp_save_load.py') self.run_mnist_2gpu('hybrid_parallel_pp_save_load.py')
...@@ -64,6 +67,13 @@ class TestHybridPipeParallel(TestMultipleGpus): ...@@ -64,6 +67,13 @@ class TestHybridPipeParallel(TestMultipleGpus):
self.run_mnist_2gpu('hybrid_parallel_pp_clip_grad.py') self.run_mnist_2gpu('hybrid_parallel_pp_clip_grad.py')
self.run_mnist_2gpu('hybrid_parallel_pp_clip_grad.py', eager_mode=False) self.run_mnist_2gpu('hybrid_parallel_pp_clip_grad.py', eager_mode=False)
def test_hybrid_parallel_transformer_unbalanced_data(self):
self.run_mnist_2gpu('hybrid_parallel_pp_transformer_unbalanced_data.py')
self.run_mnist_2gpu(
'hybrid_parallel_pp_transformer_unbalanced_data.py',
eager_mode=False,
)
if __name__ == "__main__": if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1" os.environ["FLAGS_enable_eager_mode"] = "1"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册