未验证 提交 9e0bb91c 编写于 作者: S ShenLiang 提交者: GitHub

[HybridParallel]Support 1f1b for PipelineParallel (#34483)

* support 1f1b for pipeline

* add utest

* add send_partial/recv_partial

* support amp for pp

* fix logger
上级 3b5fc2ad
...@@ -156,6 +156,10 @@ class HybridCommunicateGroup(object): ...@@ -156,6 +156,10 @@ class HybridCommunicateGroup(object):
self.is_first_stage = (self.stage_id == 0) self.is_first_stage = (self.stage_id == 0)
self.is_last_stage = (self.stage_id == (self._pp_degree - 1)) self.is_last_stage = (self.stage_id == (self._pp_degree - 1))
# create p2p_groups
if self._pp_degree > 1:
self._set_p2p_group()
debug_str = "HybridParallelInfo: rank_id: %d, mp_degree: %d, " \ debug_str = "HybridParallelInfo: rank_id: %d, mp_degree: %d, " \
"sharding_degree: %d, pp_degree: %d, dp_degree: %d" % (self.global_rank, self._mp_degree, "sharding_degree: %d, pp_degree: %d, dp_degree: %d" % (self.global_rank, self._mp_degree,
self._sharding_degree, self._pp_degree, self._dp_degree) self._sharding_degree, self._pp_degree, self._dp_degree)
...@@ -164,27 +168,9 @@ class HybridCommunicateGroup(object): ...@@ -164,27 +168,9 @@ class HybridCommunicateGroup(object):
self._dp_group, self._check_group) self._dp_group, self._check_group)
logger.info(debug_str) logger.info(debug_str)
# create p2p_groups and no new group
self._p2p_groups = self._build_p2p_lists()
global _HYBRID_PARALLEL_GROUP global _HYBRID_PARALLEL_GROUP
_HYBRID_PARALLEL_GROUP = self _HYBRID_PARALLEL_GROUP = self
def _build_p2p_lists(self):
comm_lists = self._topo.get_comm_list('pipe')
p2p_lists = []
for rank in range(self.nranks):
for comm_ranks in comm_lists:
assert len(comm_ranks) == self._pp_degree
if rank in comm_ranks:
idx = comm_ranks.index(rank)
next_rank = comm_ranks[(idx + 1) % self._pp_degree]
p2p_lists.append([rank, next_rank])
break
assert len(
p2p_lists) == self.nranks, "len(p2p_lists) should be equal nranks"
return p2p_lists
def get_parallel_mode(self): def get_parallel_mode(self):
# there are four modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel # there are four modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel
# NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and # NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and
...@@ -236,6 +222,41 @@ class HybridCommunicateGroup(object): ...@@ -236,6 +222,41 @@ class HybridCommunicateGroup(object):
return parallel_group, parallel_comm_group return parallel_group, parallel_comm_group
def _set_p2p_group(self):
comm_lists = self._topo.get_comm_list('pipe')
self.send_next_group = None
self.send_prev_group = None
self.recv_next_group = None
self.recv_prev_group = None
for comm_ranks in comm_lists:
assert len(comm_ranks) == self._pp_degree
for idx, rank in enumerate(comm_ranks):
curr_rank = rank
next_rank = comm_ranks[(idx + 1) % self._pp_degree]
prev_rank = comm_ranks[(idx - 1) % self._pp_degree]
next_group = paddle.distributed.new_group(
ranks=[curr_rank, next_rank])
if self.global_rank == curr_rank:
self.send_next_group = next_group
elif self.global_rank == next_rank:
self.recv_prev_group = next_group
prev_group = paddle.distributed.new_group(
ranks=[prev_rank, curr_rank])
if self.global_rank == curr_rank:
self.send_prev_group = prev_group
elif self.global_rank == prev_rank:
self.recv_next_group = prev_group
assert self.send_next_group is not None
assert self.send_prev_group is not None
assert self.recv_next_group is not None
assert self.recv_prev_group is not None
def topology(self): def topology(self):
return self._topo return self._topo
...@@ -287,6 +308,9 @@ class HybridCommunicateGroup(object): ...@@ -287,6 +308,9 @@ class HybridCommunicateGroup(object):
def get_pipe_parallel_group(self): def get_pipe_parallel_group(self):
return self._pp_comm_group return self._pp_comm_group
def get_p2p_groups(self):
return self.send_next_group, self.send_prev_group, self.recv_next_group, self.recv_prev_group
# sharding parallel message: # sharding parallel message:
def _get_sharding_parallel_id(self): def _get_sharding_parallel_id(self):
return self._topo.get_coord(self.global_rank).sharding return self._topo.get_coord(self.global_rank).sharding
...@@ -304,9 +328,6 @@ class HybridCommunicateGroup(object): ...@@ -304,9 +328,6 @@ class HybridCommunicateGroup(object):
# TODO should the src rank related to the shard rank for each parameter ? # TODO should the src rank related to the shard rank for each parameter ?
return self._sharding_comm_group.ranks[0] return self._sharding_comm_group.ranks[0]
def get_p2p_groups(self):
return self._p2p_groups
# check parallel group # check parallel group
def get_check_parallel_group(self): def get_check_parallel_group(self):
return self._check_comm_group return self._check_comm_group
......
...@@ -11,19 +11,16 @@ ...@@ -11,19 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase from .meta_parallel_base import MetaParallelBase
from .pp_utils.utils import is_float_tensor, get_tensor_dtype, paddle_2_number, number_2_dtype from .pp_utils.utils import is_float_tensor
from .pp_utils import utils
from .parallel_layers.pp_layers import PipelineLayer from .parallel_layers.pp_layers import PipelineLayer
from ..utils.hybrid_parallel_util import broadcast_mp_parameters from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from ..utils.hybrid_parallel_util import broadcast_dp_parameters from ..utils.hybrid_parallel_util import broadcast_dp_parameters
from ..utils.log_util import logger from ..utils.log_util import logger
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer, HybridParallelGradScaler
from .pp_utils import p2p_communication as p2p from .pp_utils import p2p_communication as p2p
__all__ = [] __all__ = []
...@@ -35,25 +32,9 @@ class PipelineParallel(MetaParallelBase): ...@@ -35,25 +32,9 @@ class PipelineParallel(MetaParallelBase):
raise TypeError( raise TypeError(
"The Layer should be a derived class of PipelineLayer.") "The Layer should be a derived class of PipelineLayer.")
super(PipelineParallel, self).__init__(layers, hcg, strategy) super(PipelineParallel, self).__init__(layers, hcg, strategy)
self.use_pipe_parallel = self._hcg.get_pipe_parallel_world_size() > 1
self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1 self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1 self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
self.is_pipe_partitioned = self.use_model_parallel
self.num_caches = 0
self.caches = {
'inputs': [],
'labels': [],
'outputs': [],
}
self.recv_cache = None
self.grad_tensors = None
self.send_meta = True
self.current_loss = paddle.to_tensor(0.0)
self.total_loss = None self.total_loss = None
self.micro_batch_size = self._strategy.pipeline_configs[ self.micro_batch_size = self._strategy.pipeline_configs[
...@@ -63,17 +44,14 @@ class PipelineParallel(MetaParallelBase): ...@@ -63,17 +44,14 @@ class PipelineParallel(MetaParallelBase):
self.num_stages = self._hcg.get_pipe_parallel_world_size() self.num_stages = self._hcg.get_pipe_parallel_world_size()
self.stage_id = self._hcg.get_stage_id() self.stage_id = self._hcg.get_stage_id()
self.prev_stage_id = self.stage_id - 1
self.next_stage_id = self.stage_id + 1
self.pp_group = self._hcg.get_pipe_parallel_group() self.pp_group = self._hcg.get_pipe_parallel_group()
p2p.initialize_p2p_groups(hcg) p2p.initialize_p2p_groups(hcg)
self.is_first_stage = self.stage_id == 0 self.is_first_stage = self.stage_id == 0
self.is_last_stage = (self.stage_id == (self.num_stages - 1)) self.is_last_stage = (self.stage_id == (self.num_stages - 1))
self.global_rank = self._hcg.get_global_rank() self.global_rank = self._hcg.get_global_rank()
self.micro_batch_id = 0
self.mp_degree = self._hcg.get_model_parallel_world_size()
self.mp_rank = self._hcg.get_model_parallel_rank()
logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format( logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format(
self.num_stages, self.stage_id)) self.num_stages, self.stage_id))
...@@ -86,158 +64,160 @@ class PipelineParallel(MetaParallelBase): ...@@ -86,158 +64,160 @@ class PipelineParallel(MetaParallelBase):
logger.info("start broadcast dp parameters") logger.info("start broadcast dp parameters")
broadcast_dp_parameters(self._layers, self._hcg) broadcast_dp_parameters(self._layers, self._hcg)
def _init_caches(self, num_caches): def _set_tensor_trainable(self, tensor):
if self.num_caches >= num_caches: if tensor is None:
return return
self.num_caches = num_caches - self.num_caches
for key in self.caches:
self.caches[key].extend([None] * self.num_caches)
def _reduce_final_loss(self): if isinstance(tensor, tuple):
if self.is_last_stage: for t in tensor:
assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss" if is_float_tensor(t):
loss = self.total_loss.clone() / self.accumulate_steps t.stop_gradient = False
paddle.distributed.broadcast(
loss,
src=self.global_rank,
use_calc_stream=True,
group=self.pp_group)
else: else:
loss = paddle.to_tensor(0.0) if is_float_tensor(tensor):
paddle.distributed.broadcast( tensor.stop_gradient = False
loss,
src=self._hcg.get_rank_from_stage(self.num_stages - 1),
use_calc_stream=True,
group=self.pp_group)
return loss
def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
assert isinstance(optimizer, HybridParallelOptimizer), ( assert isinstance(optimizer, HybridParallelOptimizer), (
'optimizer should be HybridParallelOptimizer subclass.') 'optimizer should be HybridParallelOptimizer subclass.')
self.optimizer = optimizer if scaler is not None:
self.lr_scheduler = lr_scheduler assert isinstance(scaler, HybridParallelGradScaler), (
self.scaler = scaler 'scaler should be HybridParallelGradScaler subclass or None.')
assert fluid.framework._dygraph_tracer()._has_grad, ( assert fluid.framework._dygraph_tracer()._has_grad, (
'Please enable the generation of gradients.') 'Please enable the generation of gradients.')
if self.is_first_stage or self.is_last_stage: if self.is_first_stage or self.is_last_stage:
assert data is not None, ( assert data is not None, (
"For the first and the last stage, the data_iter must be set.") "For the first and the last stage, the data must be set.")
else: else:
data = None data = None
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.scaler = scaler
self.data = data self.data = data
self._layers.train() self._layers.train()
# store total loss of entire batch # store total loss of entire batch
self.total_loss = None self.total_loss = None
self._init_caches(self.accumulate_steps)
startup_steps = self.num_stages - self.stage_id - 1
forward_steps = 0
backward_steps = 0
# forward # store data id for micro_batch
while (forward_steps < self.accumulate_steps): self.micro_batch_id = 0
self._forward(cache_id=forward_steps)
forward_steps += 1
# backward # Next, use the 1f1b scheduling strategy.
while (backward_steps < self.accumulate_steps): # this strategy is inspired by:
self._backward(cache_id=backward_steps) # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
backward_steps += 1
self._layers.allreduce_shared_weight_gradients() startup_steps = (self.num_stages - self.stage_id - 1)
startup_steps = min(startup_steps, self.accumulate_steps)
steady_steps = self.accumulate_steps - startup_steps
# optimizer input_buffers = []
self.train_loss = self._reduce_final_loss() output_buffers = []
self._step()
return self.train_loss
def _forward(self, cache_id): for step_id in range(startup_steps):
# load data input_tensor = p2p.recv_forward()
self._load_micro_batch(cache_id) self._set_tensor_trainable(input_tensor)
if self.stage_id != 0:
self._recv_activations(cache_id)
if isinstance(self.caches['inputs'][cache_id], tuple): output_tensor = self._forward_step(input_tensor)
inputs = tuple(t for t in self.caches['inputs'][cache_id]) p2p.send_forward(output_tensor)
else:
inputs = self.caches['inputs'][cache_id]
self._clear_grads(inputs) input_buffers.append(input_tensor)
outputs = self._layers.forward(inputs) output_buffers.append(output_tensor)
self.caches['outputs'][cache_id] = outputs if steady_steps > 0:
input_tensor = p2p.recv_forward()
if self.is_last_stage: for i in range(steady_steps):
if self._layers._loss_fn is not None: last_iter = (i == (steady_steps - 1))
labels = self.caches['labels'][cache_id]
outputs = self._layers._loss_fn(outputs, labels)
if self.is_last_stage: self._set_tensor_trainable(input_tensor)
self.current_loss = outputs output_tensor = self._forward_step(input_tensor)
if isinstance(self.current_loss, paddle.Tensor):
if self.total_loss is None:
self.total_loss = paddle.zeros_like(self.current_loss)
self.total_loss += self.current_loss.detach()
else:
if self.total_loss is None:
self.total_loss = [
paddle.zeros_like(v) for v in self.current_loss
]
for idx, v in enumerate(self.current_loss):
self.total_loss[idx] += v.detach()
if self.accumulate_steps > 1: output_tensor_grad = p2p.send_forward_recv_backward(output_tensor)
self.current_loss = self.current_loss / self.accumulate_steps
self.caches['outputs'][cache_id] = self.current_loss.clone() input_buffers.append(input_tensor)
output_buffers.append(output_tensor)
else: input_tensor, output_tensor = input_buffers.pop(
self._send_activations(cache_id) 0), output_buffers.pop(0)
def _backward(self, cache_id): input_tensor_grad = self._backward_step(input_tensor, output_tensor,
if self.is_last_stage: output_tensor_grad)
if self.scaler:
paddle.autograd.backward( if last_iter:
self.scaler.scale(self.caches['outputs'][cache_id])) input_tensor = None
p2p.send_backward(input_tensor_grad)
else: else:
paddle.autograd.backward(self.caches['outputs'][cache_id]) input_tensor = p2p.send_backward_recv_forward(input_tensor_grad)
self._send_gradients(cache_id) for i in range(startup_steps):
return input_tensor = input_buffers.pop(0)
self._recv_gradients(cache_id) output_tensor = output_buffers.pop(0)
outputs = self.caches['outputs'][cache_id] output_tensor_grad = p2p.recv_backward()
grad_tensors = self.grad_tensors input_tensor_grad = self._backward_step(input_tensor, output_tensor,
if isinstance(outputs, tuple): output_tensor_grad)
out_tensors = [t for t in outputs if is_float_tensor(t)] p2p.send_backward(input_tensor_grad)
assert len(out_tensors) == len(grad_tensors)
paddle.autograd.backward(
tensors=out_tensors, grad_tensors=grad_tensors)
else:
paddle.autograd.backward(
tensors=[outputs], grad_tensors=[grad_tensors])
grad_tensors = None self._layers.allreduce_shared_weight_gradients()
if self.stage_id != 0: self._send_gradients(cache_id)
self.caches['outputs'][cache_id] = None
def _broadcast_data(self, data): self.train_loss = self._reduce_final_loss()
if isinstance(data, paddle.Tensor):
paddle.distributed.broadcast( # optimizer
data, self._optimizer_step()
src=self._hcg.get_model_parallel_group_src_rank(), return self.train_loss
group=self._hcg.get_model_parallel_group())
def _forward_step(self, input_tensor):
if self.stage_id == 0:
input_tensor = self._load_micro_batch(self.micro_batch_id)
output_tensor = self._layers.forward(input_tensor)
if self.is_last_stage:
labels = self._load_micro_batch(self.micro_batch_id)
output_tensor = self._layers._loss_fn(output_tensor, labels)
assert isinstance(
output_tensor, paddle.
Tensor), "Currently, loss_fn should obtain Paddle.Tensor dtype"
if self.accumulate_steps > 1:
output_tensor = output_tensor / self.accumulate_steps
if self.total_loss is None:
self.total_loss = paddle.zeros_like(output_tensor)
self.total_loss += output_tensor.detach()
self.micro_batch_id += 1
return output_tensor
def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
if self.is_last_stage:
assert output_tensor_grad is None
if self.scaler:
paddle.autograd.backward(self.scaler.scale(output_tensor))
else:
paddle.autograd.backward(output_tensor)
else: else:
for d in data: if isinstance(output_tensor, tuple):
assert isinstance(d, paddle.Tensor) outputs = [t for t in output_tensor if not t.stop_gradient]
paddle.distributed.broadcast( assert len(outputs) == len(output_tensor_grad)
d, paddle.autograd.backward(
src=self._hcg.get_model_parallel_group_src_rank(), tensors=outputs,
group=self._hcg.get_model_parallel_group()) grad_tensors=[t for t in output_tensor_grad])
return data else:
paddle.autograd.backward(
tensors=[output_tensor], grad_tensors=[output_tensor_grad])
input_tensor_grad = None
if input_tensor is not None:
if isinstance(input_tensor, tuple):
input_tensor_grad = tuple(
[t.grad for t in input_tensor if not t.stop_gradient])
else:
input_tensor_grad = input_tensor.grad
return input_tensor_grad
def _load_micro_batch(self, cache_id): def _load_micro_batch(self, cache_id):
inputs = self.data inputs = self.data
...@@ -246,8 +226,6 @@ class PipelineParallel(MetaParallelBase): ...@@ -246,8 +226,6 @@ class PipelineParallel(MetaParallelBase):
if self.is_first_stage: if self.is_first_stage:
assert len(inputs) == 2, "length of input should be 2" assert len(inputs) == 2, "length of input should be 2"
if self.use_model_parallel:
inputs[0] = self._broadcast_data(inputs[0])
if isinstance(inputs[0], tuple): if isinstance(inputs[0], tuple):
batch_size = inputs[0][0].shape[0] batch_size = inputs[0][0].shape[0]
assert self.micro_batch_size * self.accumulate_steps == batch_size, ( assert self.micro_batch_size * self.accumulate_steps == batch_size, (
...@@ -255,332 +233,51 @@ class PipelineParallel(MetaParallelBase): ...@@ -255,332 +233,51 @@ class PipelineParallel(MetaParallelBase):
"batch_size = %d, micro_batch_size = %d, accumulate_steps = %d." "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
% %
(batch_size, self.micro_batch_size, self.accumulate_steps)) (batch_size, self.micro_batch_size, self.accumulate_steps))
data = [ data = [input[begin:end, :].detach() for input in inputs[0]]
input[begin:end, :].clone().detach() for input in inputs[0] return tuple(data)
]
self.caches['inputs'][cache_id] = tuple(data)
else: else:
batch_size = inputs[0].shape[0] batch_size = inputs[0].shape[0]
assert self.micro_batch_size * self.accumulate_steps == batch_size assert self.micro_batch_size * self.accumulate_steps == batch_size
self.caches['inputs'][cache_id] = inputs[0][begin:end, :].clone( return inputs[0][begin:end, :].detach()
).detach()
elif self.is_last_stage: elif self.is_last_stage:
assert len(inputs) == 2, "length of input should be 2" assert len(inputs) == 2, "length of input should be 2"
if self.use_model_parallel:
inputs[1] = self._broadcast_data(inputs[1])
if isinstance(inputs[1], tuple): if isinstance(inputs[1], tuple):
batch_size = inputs[1][0].shape[0] batch_size = inputs[1][0].shape[0]
assert self.micro_batch_size * self.accumulate_steps == batch_size assert self.micro_batch_size * self.accumulate_steps == batch_size
data = [ data = [input[begin:end, :].detach() for input in inputs[1]]
input[begin:end, :].clone().detach() for input in inputs[1] return tuple(data)
]
self.caches['labels'][cache_id] = tuple(data)
else: else:
batch_size = inputs[1].shape[0] batch_size = inputs[1].shape[0]
assert self.micro_batch_size * self.accumulate_steps == batch_size assert self.micro_batch_size * self.accumulate_steps == batch_size
self.caches['labels'][cache_id] = inputs[1][begin:end, :].clone( return inputs[1][begin:end, :].detach()
).detach()
else: else:
# No data input is required for other stages # No data input is required for other stages
inputs = None inputs = None
def _send_meta(self, data, peer): def _reduce_final_loss(self):
if isinstance(data, paddle.Tensor): if self.is_last_stage:
tensor_type = paddle.to_tensor([0]) assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss"
# send tensor type loss = self.total_loss.detach()
p2p.send(tensor_type, self.next_stage_id) paddle.distributed.broadcast(
loss,
# send len(shape) src=self.global_rank,
dims = paddle.to_tensor(len(data.shape)) use_calc_stream=True,
p2p.send(dims, self.next_stage_id) group=self.pp_group)
# send shape
shape = paddle.to_tensor(data.shape)
p2p.send(shape, self.next_stage_id)
# send dtype
dtype = paddle.to_tensor(paddle_2_number(data.dtype))
p2p.send(dtype, self.next_stage_id)
elif isinstance(data, tuple):
tensor_type = paddle.to_tensor([1])
p2p.send(tensor_type, self.next_stage_id)
nums = paddle.to_tensor(len(data))
p2p.send(nums, self.next_stage_id)
for idx, d in enumerate(data):
assert isinstance(d, paddle.Tensor)
# send len(shape)
dims = paddle.to_tensor(len(d.shape))
p2p.send(dims, self.next_stage_id)
# send shape
shape = paddle.to_tensor(d.shape)
p2p.send(shape, self.next_stage_id)
# send dtype
dtype = paddle.to_tensor(paddle_2_number(d.dtype))
p2p.send(dtype, self.next_stage_id)
def _recv_meta(self, peer):
tensor_type = paddle.to_tensor([0])
p2p.recv(tensor_type, self.prev_stage_id)
tensor_type = tensor_type.item()
if tensor_type == 0:
# recv len(shape)
dims = paddle.to_tensor([0])
p2p.recv(dims, self.prev_stage_id)
dims = dims.item()
# recv shape
shape = paddle.to_tensor([0] * dims)
p2p.recv(shape, self.prev_stage_id)
shape = shape.numpy().tolist()
# recv dtype
dtype = paddle.to_tensor([0])
p2p.recv(dtype, self.prev_stage_id)
return self._allocate_cache(
shape, dtype=number_2_dtype(dtype.item()), num_caches=1)[0]
elif tensor_type == 1:
num = paddle.to_tensor([0])
p2p.recv(num, self.prev_stage_id)
num = num.item()
shapes = []
dtypes = []
for i in range(num):
# recv len(shape)
dims = paddle.to_tensor([0])
p2p.recv(dims, self.prev_stage_id)
# recv shape
dims = dims.item()
shape = paddle.to_tensor([0] * dims)
p2p.recv(shape, self.prev_stage_id)
shapes.append(shape.numpy().tolist())
# recv dtype
dtype = paddle.to_tensor([0])
p2p.recv(dtype, self.prev_stage_id)
dtypes.append(number_2_dtype(dtype.item()))
caches = self._allocate_caches(shapes, dtypes, num_caches=1)[0]
caches = tuple(caches)
return caches
def _is_valid_send_recv(self, tensor):
tensor_numel = np.prod(tensor.shape)
assert tensor_numel != 0, "can't send/recv zero element"
return tensor_numel % self.mp_degree == 0
def _send_activations(self, cache_id):
outputs = self.caches['outputs'][cache_id]
if self.send_meta:
self.send_meta = False
self._send_meta(outputs, self.next_stage_id)
if isinstance(outputs, paddle.Tensor):
if self.is_pipe_partitioned and self._is_valid_send_recv(outputs):
p2p.send_partial(
outputs.detach(),
self.next_stage_id,
mp_degree=self.mp_degree,
mp_rank=self.mp_rank)
else:
p2p.send(outputs.detach(), self.next_stage_id)
elif isinstance(outputs, tuple):
for output in outputs:
if self.is_pipe_partitioned and self._is_valid_send_recv(
output):
p2p.send_partial(
output.detach(),
self.next_stage_id,
mp_degree=self.mp_degree,
mp_rank=self.mp_rank)
else:
p2p.send(output.detach(), self.next_stage_id)
def _send_gradients(self, cache_id):
inputs = self.caches['inputs'][cache_id]
if isinstance(inputs, paddle.Tensor):
assert inputs.grad is not None
if self.is_pipe_partitioned and self._is_valid_send_recv(
inputs.grad):
grad = p2p.send_partial(
inputs.grad,
self.prev_stage_id,
mp_degree=self.mp_degree,
mp_rank=self.mp_rank)
else:
p2p.send(inputs.grad, self.prev_stage_id)
else:
for idx, d in enumerate(inputs):
# Skip tensors that will not produce a grad
if not is_float_tensor(d):
assert d.grad is None
continue
if self.is_pipe_partitioned and self._is_valid_send_recv(
d.grad):
grad = p2p.send_partial(
d.grad,
self.prev_stage_id,
mp_degree=self.mp_degree,
mp_rank=self.mp_rank)
else:
p2p.send(d.grad, self.prev_stage_id)
self.caches['inputs'][cache_id] = None
def _recv_activations(self, cache_id):
inputs = None
if self.recv_cache is None:
self.recv_cache = self._recv_meta(self.prev_stage_id)
if isinstance(self.recv_cache, paddle.Tensor):
if self.is_pipe_partitioned and self._is_valid_send_recv(
self.recv_cache):
p2p.recv_partial(self.recv_cache, self.prev_stage_id,
self.mp_degree, self.mp_rank)
p2p.partial_allgather_operator(
self.recv_cache,
mp_ranks=self.mp_degree,
mp_rank_id=self.mp_rank,
group=self._hcg.get_model_parallel_group(),
use_calc_stream=True)
else:
p2p.recv(self.recv_cache, self.prev_stage_id)
inputs = self.recv_cache.clone().detach()
inputs.stop_gradient = not is_float_tensor(inputs)
else: else:
assert isinstance(self.recv_cache, tuple) loss = paddle.zeros(shape=[1], dtype="float32")
inputs = [None] * len(self.recv_cache) paddle.distributed.broadcast(
for idx, d in enumerate(self.recv_cache): loss,
if self.is_pipe_partitioned and self._is_valid_send_recv(d): src=self._hcg.get_rank_from_stage(self.num_stages - 1),
assert isinstance(d, paddle.Tensor) use_calc_stream=True,
p2p.recv_partial(d, self.prev_stage_id, self.mp_degree, group=self.pp_group)
self.mp_rank) return loss
p2p.partial_allgather_operator(
d,
mp_ranks=self.mp_degree,
mp_rank_id=self.mp_rank,
group=self._hcg.get_model_parallel_group(),
use_calc_stream=True)
else:
assert isinstance(d, paddle.Tensor)
p2p.recv(d, self.prev_stage_id)
inputs[idx] = d.clone().detach()
inputs = tuple(inputs)
for d in inputs:
d.stop_gradient = not is_float_tensor(d)
self.caches['inputs'][cache_id] = inputs
def _recv_gradients(self, cache_id):
outputs = self.caches['outputs'][cache_id]
if self.grad_tensors is None:
if isinstance(outputs, paddle.Tensor):
s = list(outputs.shape)
dtype = get_tensor_dtype(outputs.dtype)
self.grad_tensors = self._allocate_cache(
s, dtype, num_caches=1)[0]
else:
sizes = [list(d.shape) for d in outputs if is_float_tensor(d)]
dtypes = [
get_tensor_dtype(d.dtype) for d in outputs
if is_float_tensor(d)
]
self.grad_tensors = self._allocate_caches(
sizes, dtypes, num_caches=1)[0]
if isinstance(self.grad_tensors, paddle.Tensor):
if self.is_pipe_partitioned and self._is_valid_send_recv(
self.grad_tensors):
p2p.recv_partial(self.grad_tensors, self.next_stage_id,
self.mp_degree, self.mp_rank)
p2p.partial_allgather_operator(
self.grad_tensors,
mp_ranks=self.mp_degree,
mp_rank_id=self.mp_rank,
group=self._hcg.get_model_parallel_group(),
use_calc_stream=True)
else:
p2p.recv(self.grad_tensors, self.next_stage_id)
else: def _optimizer_step(self):
assert isinstance(outputs, tuple)
for d in self.grad_tensors:
if self.is_pipe_partitioned and self._is_valid_send_recv(d):
p2p.recv_partial(d, self.next_stage_id, self.mp_degree,
self.mp_rank)
p2p.partial_allgather_operator(
d,
mp_ranks=self.mp_degree,
mp_rank_id=self.mp_rank,
group=self._hcg.get_model_parallel_group(),
use_calc_stream=True)
else:
p2p.recv(d, self.next_stage_id)
def _step(self):
if self.scaler: if self.scaler:
self.scaler.minimize(self.optimizer, self.train_loss) self.scaler.minimize(self.optimizer, self.train_loss)
else: else:
self.optimizer.step() self.optimizer.step()
self.optimizer.clear_grad() self.optimizer.clear_grad()
if self.lr_scheduler: if self.lr_scheduler:
self.lr_scheduler.step() self.lr_scheduler.step()
def _clear_grads(self, inputs):
if isinstance(inputs, paddle.Tensor):
if inputs.grad is not None:
inputs.clear_gradient()
else:
for d in inputs:
if d.grad is not None:
d.clear_gradient()
def _allocate_zeros(self, shape, dtype):
return paddle.zeros(shape, dtype)
def _allocate_cache(self, shape, dtype, num_caches=-1):
caches = []
if num_caches == -1:
num_caches = self.num_caches
for count in range(num_caches):
caches.append(self._allocate_zeros(shape, dtype))
return caches
def _allocate_caches(self, shapes, dtypes, num_caches=-1):
caches = []
if num_caches == -1:
num_caches = self.num_caches
for count in range(num_caches):
cache = []
for shape, dtype in zip(shapes, dtypes):
cache.append(self._allocate_zeros(shape, dtype))
caches.append(cache)
return caches
def save_state_dict(self, model_path):
state_dict = self._layers.state_dict()
paddle.save(state_dict, model_path)
def load_state_dict(self, model_path):
state_dict = paddle.load(self.model_path)
self._layers.set_state_dict(state_dict)
def forward(self, *inputs, **kwargs):
raise RuntimeError("Call train_batch for pipeline instead of forward.")
...@@ -13,131 +13,388 @@ ...@@ -13,131 +13,388 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
from .utils import paddle_2_number, number_2_dtype
from ...utils.log_util import logger
_groups = None
_hcg = None _hcg = None
def initialize_p2p_groups(hcg): def initialize_p2p_groups(hcg):
global _groups, _hcg global _hcg
_groups = [
paddle.distributed.new_group(ranks=group)
for group in hcg.get_p2p_groups()
]
_hcg = hcg _hcg = hcg
send_next_group, send_prev_group, recv_next_group, recv_prev_group = _hcg.get_p2p_groups(
)
debug_str = "P2pInfo: send_next_group: %s, send_prev_group: %s, " \
"recv_next_group: %s, recv_prev_group: %s" % (repr(send_next_group),
repr(send_prev_group),repr(recv_next_group), repr(recv_prev_group))
logger.info(debug_str)
def _is_valid_communciate(src_stage, dest_stage):
first_stage = 0
last_stage = _hcg.get_pipe_parallel_world_size() - 1
assert abs(src_stage-dest_stage) == 1 or \
(src_stage == first_stage and dest_stage == last_stage) or \
(src_stage == last_stage and dest_stage == first_stage)
class SendRecvMeta:
"""Mainly used to help p2p communication context information"""
def partial_send_operator(tensor, def __init__(self):
dst=0, self.send_shape_message = None
mp_ranks=1, self.send_dtype_message = None
mp_rank_id=0,
group=None,
use_calc_stream=True):
self.recv_shape_message = None
self.recv_dtype_message = None
self.has_send_meta = False
self.has_recv_meta = False
def _recv_shape_dtype(self, group):
# recv len(shape)
dims = paddle.to_tensor([0])
paddle.distributed.recv(dims, src=0, group=group)
dims = dims.item()
# recv shape
shape = paddle.to_tensor([0] * dims)
paddle.distributed.recv(shape, src=0, group=group)
# recv dtype
dtype = paddle.to_tensor([0])
paddle.distributed.recv(dtype, src=0, group=group)
return shape.numpy().tolist(), dtype.item()
def recv_meta(self, group):
tensor_type = paddle.to_tensor([0])
paddle.distributed.recv(tensor_type, src=0, group=group)
tensor_type = tensor_type.item()
if tensor_type == 0:
shape, dtype = self._recv_shape_dtype(group)
self.recv_shape_message = shape
self.recv_dtype_message = dtype
elif tensor_type == 1:
num = paddle.to_tensor([0])
paddle.distributed.recv(num, src=0, group=group)
num = num.item()
shapes = []
dtypes = []
for i in range(num):
shape, dtype = self._recv_shape_dtype(group)
shapes.append(shape)
dtypes.append(dtype)
self.recv_shape_message = tuple(shapes)
self.recv_dtype_message = tuple(dtypes)
def _send_dims_shape_dtype(self, tensor, group):
# send len(shape)
dims = paddle.to_tensor(len(tensor.shape))
paddle.distributed.send(dims, dst=1, group=group)
# send shape
shape = paddle.to_tensor(tensor.shape)
paddle.distributed.send(shape, dst=1, group=group)
# send dtype
dtype = paddle.to_tensor(paddle_2_number(tensor.dtype))
paddle.distributed.send(dtype, dst=1, group=group)
def send_meta(self, tensor, group):
if isinstance(tensor, paddle.Tensor):
tensor_type = paddle.to_tensor([0])
# send tensor type
paddle.distributed.send(tensor_type, dst=1, group=group)
self._send_dims_shape_dtype(tensor, group)
elif isinstance(tensor, tuple):
tensor_type = paddle.to_tensor([1])
# send tensor type
paddle.distributed.send(tensor_type, dst=1, group=group)
nums = paddle.to_tensor(len(tensor))
paddle.distributed.send(nums, dst=1, group=group)
for d in tensor:
assert isinstance(d, paddle.Tensor)
self._send_dims_shape_dtype(d, group=group)
def set_send_message(self, tensor):
if isinstance(tensor, paddle.Tensor):
self.send_shape_message = tensor.shape
self.send_dtype_message = paddle_2_number(tensor.dtype)
elif isinstance(tensor, tuple):
self.send_shape_message = tuple(
[d.shape for d in tensor if not d.stop_gradient])
self.send_dtype_message = tuple(
[paddle_2_number(d.dtype) for d in tensor])
_send_recv_meta = SendRecvMeta()
def send_partial(tensor,
dst=0,
nranks=1,
rank_id=0,
group=None,
use_calc_stream=True):
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
return paddle.fluid.core.ops.partial_send( return paddle.fluid.core.ops.partial_send(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer', tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer',
dst, 'num', mp_ranks, 'id', mp_rank_id) dst, 'num', nranks, 'id', rank_id)
def partial_recv_operator(tensor, def recv_partial(tensor,
src=0, src=0,
mp_ranks=1, nranks=1,
mp_rank_id=0, rank_id=0,
group=None, group=None,
use_calc_stream=True): use_calc_stream=True):
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
return paddle.fluid.core.ops.partial_recv( paddle.fluid.core.ops.partial_recv(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer', tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer',
src, 'num', mp_ranks, 'id', mp_rank_id, 'dtype', tensor.dtype, src, 'num', nranks, 'id', rank_id, 'dtype', tensor.dtype, 'out_shape',
'out_shape', tensor.shape) tensor.shape)
def partial_allgather_operator(tensor, def allgather_partial(tensor,
mp_ranks=1, nranks=1,
mp_rank_id=0, rank_id=0,
group=None, group=None,
use_calc_stream=True): use_calc_stream=True):
if nranks == 1:
return tensor
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
return paddle.fluid.core.ops.partial_allgather_( return paddle.fluid.core.ops.partial_allgather_(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id,
'nranks', mp_ranks, 'rank', mp_rank_id) 'nranks', nranks, 'rank', rank_id)
def send(tensor, dest_stage): def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
global _groups, _hcg global _hcg
src_stage = _hcg.get_stage_id()
_is_valid_communciate(src_stage, dest_stage) tensor_recv_prev = None
group = _get_send_recv_group(src_stage, dest_stage) tensor_recv_next = None
return paddle.distributed.send(
tensor, dst=1 if dest_stage > src_stage else 0, group=group) # send / recv message
recv_shape_msg = _send_recv_meta.recv_shape_message
recv_dtype_msg = _send_recv_meta.recv_dtype_message
def recv(tensor, src_stage): send_shape_msg = _send_recv_meta.send_shape_message
global _groups, _hcg send_dtype_msg = _send_recv_meta.send_dtype_message
dest_stage = _hcg.get_stage_id()
# model parallel message
_is_valid_communciate(src_stage, dest_stage) mp_group = _hcg.get_model_parallel_group()
group = _get_send_recv_group(src_stage, dest_stage) mp_degree = _hcg.get_model_parallel_world_size()
return paddle.distributed.recv( mp_rank = _hcg.get_model_parallel_rank()
tensor, src=0 if dest_stage > src_stage else 1, group=group)
if recv_prev:
if isinstance(recv_shape_msg, tuple):
def send_partial(tensor, dest_stage, mp_degree, mp_rank): tensor_recv_prev = []
global _groups, _hcg for idx, shape in enumerate(recv_shape_msg):
src_stage = _hcg.get_stage_id() tensor_recv_prev.append(
_is_valid_communciate(src_stage, dest_stage) paddle.empty(
group = _get_send_recv_group(src_stage, dest_stage) shape=shape, dtype=number_2_dtype(recv_dtype_msg[idx])))
return partial_send_operator( tensor_recv_prev = tuple(tensor_recv_prev)
tensor, else:
dst=1 if dest_stage > src_stage else 0, tensor_recv_prev = paddle.empty(
mp_ranks=mp_degree, shape=recv_shape_msg, dtype=number_2_dtype(recv_dtype_msg))
mp_rank_id=mp_rank,
group=group) if recv_next:
if isinstance(send_shape_msg, tuple):
tensor_recv_next = []
def recv_partial(tensor, src_stage, mp_degree, mp_rank): for idx, shape in enumerate(send_shape_msg):
global _groups, _hcg tensor_recv_next.append(
dest_stage = _hcg.get_stage_id() paddle.empty(
shape=shape, dtype=number_2_dtype(send_dtype_msg[idx])))
_is_valid_communciate(src_stage, dest_stage) tensor_recv_next = tuple(tensor_recv_next)
group = _get_send_recv_group(src_stage, dest_stage) else:
return partial_recv_operator( tensor_recv_next = paddle.empty(
tensor, shape=send_shape_msg, dtype=number_2_dtype(send_dtype_msg))
src=0 if dest_stage > src_stage else 1,
mp_ranks=mp_degree, # start to p2p communicate
mp_rank_id=mp_rank, if tensor_send_prev is not None:
group=group) if isinstance(tensor_send_prev, tuple):
for d in tensor_send_prev:
paddle.distributed.wait(d, use_calc_stream=True)
def _get_send_recv_group(src_stage, dest_stage): send_partial(
global _groups, _hcg d,
stage_id = None dst=0,
first_stage = 0 nranks=mp_degree,
last_stage = _hcg.get_pipe_parallel_world_size() - 1 rank_id=mp_rank,
if (src_stage == first_stage and dest_stage == last_stage) or \ group=_hcg.send_prev_group,
(dest_stage == first_stage and src_stage == last_stage): use_calc_stream=False)
stage_id = last_stage else:
elif src_stage > dest_stage: paddle.distributed.wait(tensor_send_prev, use_calc_stream=True)
stage_id = dest_stage send_partial(
tensor_send_prev,
dst=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_prev_group,
use_calc_stream=False)
if tensor_recv_prev is not None:
if isinstance(tensor_recv_prev, tuple):
for d in tensor_recv_prev:
recv_partial(
d,
src=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_prev_group,
use_calc_stream=True)
allgather_partial(
d,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
else:
recv_partial(
tensor_recv_prev,
src=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_prev_group,
use_calc_stream=True)
allgather_partial(
tensor_recv_prev,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
if tensor_send_next is not None:
if isinstance(tensor_send_next, tuple):
for d in tensor_send_next:
paddle.distributed.wait(d, use_calc_stream=True)
send_partial(
d,
dst=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_next_group,
use_calc_stream=False)
else:
paddle.distributed.wait(tensor_send_next, use_calc_stream=True)
send_partial(
tensor_send_next,
dst=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.send_next_group,
use_calc_stream=False)
if tensor_recv_next is not None:
if isinstance(tensor_recv_next, tuple):
for d in tensor_recv_next:
recv_partial(
d,
src=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_next_group,
use_calc_stream=True)
allgather_partial(
d,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
else:
recv_partial(
tensor_recv_next,
src=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_next_group,
use_calc_stream=True)
allgather_partial(
tensor_recv_next,
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
return tensor_recv_prev, tensor_recv_next
def recv_forward():
if _hcg.is_first_stage:
input_tensor = None
else:
if not _send_recv_meta.has_recv_meta:
_send_recv_meta.recv_meta(_hcg.recv_prev_group)
_send_recv_meta.has_recv_meta = True
input_tensor, _ = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False)
return input_tensor
def recv_backward():
if _hcg.is_last_stage:
output_tensor_grad = None
else:
_, output_tensor_grad = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True)
return output_tensor_grad
def send_forward(output_tensor):
if not _hcg.is_last_stage:
if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor)
_send_recv_meta.send_meta(output_tensor, _hcg.send_next_group)
_send_recv_meta.has_send_meta = True
_p2p_helper(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False)
def send_backward(input_tensor_grad):
if not _hcg.is_first_stage:
_p2p_helper(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False)
def send_forward_recv_backward(output_tensor):
if _hcg.is_last_stage:
output_tensor_grad = None
else:
_, output_tensor_grad = _p2p_helper(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=True)
return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad):
if _hcg.is_first_stage:
input_tensor = None
else: else:
stage_id = src_stage input_tensor, _ = _p2p_helper(
group_id = _hcg.get_rank_from_stage(stage_id=stage_id) tensor_send_next=None,
return _groups[group_id] tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False)
return input_tensor
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import unittest
import paddle
import numpy as np
import random
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from paddle.fluid import layers
import paddle.nn.functional as F
from paddle.distributed.fleet.meta_parallel import PipelineLayer, LayerDesc
from paddle.fluid.dygraph.layers import Layer
import paddle.nn as nn
def set_random_seed(seed, dp_id, rank_id):
"""Set random seed for reproducability."""
random.seed(seed)
np.random.seed(seed + dp_id)
paddle.seed(seed + dp_id)
batch_size = 8
length = 8
micro_batch_size = 2
vocab_size = 128
hidden_size = 16
d_model = hidden_size
dim_feedforward = 4 * d_model
class EmbeddingNet(Layer):
def __init__(self):
super(EmbeddingNet, self).__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.position_embeddings = nn.Embedding(vocab_size, hidden_size)
def forward(self, x):
attention_mask = paddle.tensor.triu(
(paddle.ones(
(length, length), dtype="float32") * -1e9), 1)
attention_mask.stop_gradient = True
w_emb = self.word_embeddings(x)
p_emb = self.position_embeddings(x)
w_emb = w_emb + p_emb
# need to fix bug of backward()
return w_emb, attention_mask
class TransformerNet(Layer):
def __init__(self):
super(TransformerNet, self).__init__()
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, x, mask):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
product = layers.matmul(x=q, y=k, transpose_y=True, alpha=d_model**-0.5)
weights = F.softmax(product + mask)
weights = F.dropout(weights, 0.2)
tgt = layers.matmul(weights, v)
residual = tgt
tgt = self.norm1(tgt)
tgt = residual + tgt
out = self.linear2(F.gelu(self.linear1(tgt), approximate=True))
return out
class EmbeddingPipe(EmbeddingNet):
def forward(self, x):
return super().forward(x)
class TransformerNetPipe(TransformerNet):
def forward(self, args):
x, mask = args[0], args[1]
output = super().forward(x, mask)
output = output
mask.stop_gradient = True
return output, mask
class CriterionPipe(Layer):
def __init__(self):
super(CriterionPipe, self).__init__()
def forward(self, out, label):
loss = out.mean()
return loss
class ModelPipe(PipelineLayer):
def __init__(self, topology):
self.descs = []
self.descs.append(LayerDesc(EmbeddingPipe))
for x in range(5):
self.descs.append(LayerDesc(TransformerNetPipe))
self.descs.append(lambda x: x[0])
super().__init__(
layers=self.descs, loss_fn=CriterionPipe(), topology=topology)
class TestDistPPTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 1
self.data_parallel_size = 1
self.pipeline_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {
"accumulate_steps": batch_size // micro_batch_size,
"micro_batch_size": micro_batch_size
}
fleet.init(is_collective=True, strategy=strategy)
def test_pp_model(self):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
dp_id = hcg.get_data_parallel_rank()
pp_id = hcg.get_stage_id()
rank_id = dist.get_rank()
topology = hcg.topology()
set_random_seed(1024, dp_id, rank_id)
model = ModelPipe(topology)
scheduler = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer = paddle.optimizer.SGD(learning_rate=scheduler,
parameters=model.parameters())
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
for step_id in range(5):
x_data = np.random.randint(0, vocab_size, size=[batch_size, length])
x = paddle.to_tensor(x_data)
x.stop_gradient = True
loss = model.train_batch([x, x], optimizer, scheduler)
# TODO(shenliang03) add utest for loss
if __name__ == "__main__":
unittest.main()
...@@ -33,6 +33,9 @@ class TestHybridPipeParallel(TestMultipleGpus): ...@@ -33,6 +33,9 @@ class TestHybridPipeParallel(TestMultipleGpus):
def test_pipeline_parallel(self): def test_pipeline_parallel(self):
self.run_mnist_2gpu('hybrid_parallel_pp_amp.py') self.run_mnist_2gpu('hybrid_parallel_pp_amp.py')
def test_hybrid_parallel_transformer(self):
self.run_mnist_2gpu('hybrid_parallel_pp_transformer.py')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册