未验证 提交 a0f4ac54 编写于 作者: L lilong12 提交者: GitHub

Fix the bug in pipeline for dygraph mode (#32716)

* update, test=develop
上级 f4a3f85b
......@@ -108,7 +108,6 @@ class PipelineLayer(Layer):
# construct layer
self.run_function = []
self._build_layer()
self.to(paddle.CUDAPlace(self.device_id))
def _segment_network(self, seg_method):
logger.info("start segment network..")
......
......@@ -22,15 +22,11 @@ from numpy import prod
import paddle
import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase
from .pp_utils.utils import get_tensor_bytes
from .pp_utils.utils import get_tensor_bytes, is_float_tensor
from .pp_utils import utils
from .parallel_layers.pp_layers import PipelineLayer
FLOAT_TYPES = [
paddle.float16,
paddle.float32,
paddle.float64,
]
from ..utils.hybrid_parallel_util import *
from ..utils.log_util import logger
class PipelineParallel(MetaParallelBase):
......@@ -46,20 +42,18 @@ class PipelineParallel(MetaParallelBase):
'inputs': [],
'labels': [],
'outputs': [],
'backward_tensors': [],
}
self.recv_cache = None
self.grad_tensors = None
self.meta_buffer = None
self.send_meta = True
self.first_gradient_send = True
self.current_loss = paddle.to_tensor(0.0)
self.total_loss = None
def _prepare_for_model(self):
self.use_amp = self._strategy.amp
self.init_loss_scaling = self._strategy.amp_configs['init_loss_scaling']
self.micro_batch_size = self._strategy.pipeline_configs[
'micro_batch_size']
self.accumulate_steps = self._strategy.pipeline_configs[
......@@ -69,9 +63,17 @@ class PipelineParallel(MetaParallelBase):
self.stage_id = self._hcg.get_stage_id()
self.prev_stage_id = self.stage_id - 1
self.next_stage_id = self.stage_id + 1
self._layers = PipelineLayer(
layers=self._layers, num_stages=self.num_stages)
#TODO: init process group
self.pp_group = self._hcg.get_pipe_parallel_group()
logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format(
self.num_stages, self.stage_id))
if self.use_model_parallel:
logger.info("start broadcast mp parameters")
broadcast_mp_parameters(self._layers, self._hcg)
if self.use_data_parallel:
logger.info("start broadcast mp parameters")
broadcast_dp_parameters(self._layers, self._hcg)
def _allocate_caches(self, num_caches):
if self.num_caches >= num_caches:
......@@ -82,19 +84,19 @@ class PipelineParallel(MetaParallelBase):
for key in self.caches:
self.caches[key].extend([None] * num)
def train_batch(self, data_iter, optimizer):
def train_batch(self, data, optimizer):
self.optimizer = optimizer
assert fluid.framework._dygraph_tracer()._has_grad, (
'Please enable the generation of gradients.')
if self.stage_id == 0 or self.stage_id == self.num_stages - 1:
assert data_iter, (
assert data, (
"For the first and the last stage, the data_iter must be set.")
else:
assert data_iter is None, (
assert data is None, (
"For pipe stages other than the first and the last one, "
"the data_iter must be None.")
self.data_iter = data_iter
self.data = data
self._layers.train()
self.total_loss = None
......@@ -104,39 +106,24 @@ class PipelineParallel(MetaParallelBase):
return self.total_loss
def _train(self, minibatch_cmds):
self._allocate_caches(self.num_stages)
for microbatch_cmds in minibatch_cmds:
for cmd in microbatch_cmds:
if type(cmd) not in self._COMMAND_MAP:
#FIXME:
continue
self._allocate_caches(self.accumulate_steps)
for micro_cmds in minibatch_cmds:
for cmd in micro_cmds:
assert type(cmd) in self._COMMAND_MAP, "unknow cmd: {}".format(
type(cmd))
self._apply_cmd = MethodType(self._COMMAND_MAP[type(cmd)], self)
self._apply_cmd(**cmd.kwargs)
def _allreduce_grads(self):
self._modifying_grad = True
assert self.use_data_parallel <= 1, ("Do not support data parallel "
"with pipeline parallel now.")
self._modifying_grad = False
def _get_data(self):
if self.use_model_parallel:
mp_rank = self._hcg.get_model_parallel_rank()
else:
mp_rank = 0
data = None
# mp rank 0 loads the data and broadcat it to others.
if mp_rank == 0:
data = next(self.data_iter)
if self.use_model_parallel:
data = paddle.distributed.broadcast(
data, group=self._hcg.get_model_parallel_group())
return data
if not self.use_data_parallel: return
fused_allreduce_gradients(list(self._layers.parameters()), self._hcg)
def _forward(self, cache_id):
# load data
self._load_micro_batch(cache_id)
if self.stage_id != 0:
self._recv_activations(cache_id)
if isinstance(self.caches['inputs'][cache_id], tuple):
inputs = tuple(t.clone() for t in self.caches['inputs'][cache_id])
else:
......@@ -144,9 +131,13 @@ class PipelineParallel(MetaParallelBase):
self._clear_grads(inputs)
outputs = self._layers.forward(inputs)
self.caches['outputs'][cache_id] = outputs
if self.stage_id == self.num_stages - 1:
if self._layers._loss_fn is not None:
labels = self.caches['labels'][cache_id]
outputs = self._layers._loss_fn(outputs, labels)
if self.stage_id == self.num_stages - 1:
self.current_loss = outputs
if isinstance(self.current_loss, paddle.Tensor):
......@@ -160,18 +151,28 @@ class PipelineParallel(MetaParallelBase):
]
for idx, v in enumerate(self.current_loss):
self.total_loss[idx] += v.detach()
if self.use_data_parallel:
self.current_loss = self.current_loss / self._hcg.get_data_parallel_world_size(
)
if self.accumulate_steps > 1:
self.current_loss = self.current_loss / self.accumulate_steps
self.caches['outputs'][cache_id] = self.current_loss.clone()
else:
self._send_activations(cache_id)
def _backward(self, cache_id):
assert self.optimizer is not None
if self.stage_id == self.num_stages - 1:
paddle.autograd.backward(self.current_loss)
paddle.autograd.backward(self.caches['outputs'][cache_id])
self._send_gradients(cache_id)
return
self._recv_gradients(cache_id)
outputs = self.caches['outputs'][cache_id]
grad_tensors = self.grad_tensors
if isinstance(outputs, tuple):
out_tensors = [t for t in outputs if t.dtype in FLOAT_TYPES]
out_tensors = [t for t in outputs if is_float_tensor(t)]
assert len(out_tensors) == len(grad_tensors)
paddle.autograd.backward(
tensors=out_tensors, grad_tensors=grad_tensors)
......@@ -179,41 +180,76 @@ class PipelineParallel(MetaParallelBase):
paddle.autograd.backward(
tensors=[outputs], grad_tensors=[grad_tensors])
self.caches['outputs'][cache_id] = None
grad_tensors = None
if self.stage_id != 0: self._send_gradients(cache_id)
self.caches['outputs'][cache_id] = None
#self.caches['backward_tensors'][cache_id] = None
def _get_data(self):
if self.use_model_parallel:
mp_rank = self._hcg.get_model_parallel_rank()
else:
mp_rank = 0
# mp rank 0 loads the data and broadcat it to others.
data = self.data
if self.use_model_parallel and (self.stage_id == 0 or
self.stage_id == self.num_stages - 1):
assert isinstance(data, (tuple, paddle.Tensor))
if isinstance(data, paddle.Tensor):
paddle.distributed.broadcast(
data,
src=self._hcg.get_model_parallel_group_src_rank(),
group=self._hcg.get_model_parallel_group())
else:
data = []
for d in self.data:
assert isinstance(d, paddle.Tensor)
paddle.distributed.broadcast(
d,
src=self._hcg.get_model_parallel_group_src_rank(),
group=self._hcg.get_model_parallel_group())
data.append(d)
data = tuple(data)
return data
def _load_micro_batch(self, cache_id):
inputs = self._get_data()
if self.stage_id == 0:
data = None
if isinstance(inputs[0], paddle.Tensor):
#if isinstance(inputs[0], paddle.Tensor):
if len(inputs) == 1:
assert isinstance(inputs[0], paddle.Tensor)
data = inputs[0].clone().detach()
data.stop_gradient = data.dtype == paddle.float32
#data.stop_gradient = not is_float_tensor(data)
data.stop_gradient = True
else:
assert isinstance(inputs[0], tuple)
# Assume list or tuple
assert isinstance(inputs, tuple)
data = []
for d in inputs[0]:
for d in inputs:
assert isinstance(d, paddle.Tensor)
d = d.clone().detach()
d.stop_gradient = d.dtype == paddle.float32
loaded.append(d)
i = d.clone().detach()
#i.stop_gradient = not is_float_tensor(i)
i.stop_gradient = True
data.append(i)
data = tuple(data)
self.caches['inputs'][cache_id] = data
if self.stage_id == self.num_stages - 1:
label = None
if isinstance(inputs[1], paddle.Tensor):
label = inputs[1]
elif isinstance(data[1], tuple):
label = []
for l in inputs[1]:
assert isinstance(l, paddle.Tensor)
l = l.detach()
label.append(l)
label = tuple(label)
self.caches['labels'][cache_id] = label
labels = None
#if isinstance(inputs[1], paddle.Tensor):
if len(inputs) == 1:
assert isinstance(inputs[0], paddle.Tensor)
labels = inputs[0]
elif isinstance(inputs, tuple):
labels = []
for label in inputs:
assert isinstance(label, paddle.Tensor)
label = label.detach()
labels.append(label)
labels = tuple(labels)
self.caches['labels'][cache_id] = labels
def _send_meta(self, data, peer):
"""
......@@ -225,54 +261,67 @@ class PipelineParallel(MetaParallelBase):
"""
if isinstance(data, paddle.Tensor):
tensor_type = paddle.to_tensor([0])
paddle.distributed.send(tensor_type, peer)
paddle.distributed.send(
tensor_type, peer, use_calc_stream=True, group=self.pp_group)
dims = paddle.to_tensor(len(data.shape))
paddle.distributed.send(dims, peer)
paddle.distributed.send(
dims, peer, use_calc_stream=True, group=self.pp_group)
shape = paddle.to_tensor(data.shape)
paddle.distributed.send(shape, peer)
paddle.distributed.send(
shape, peer, use_calc_stream=True, group=self.pp_group)
elif isinstance(data, tuple):
tensor_type = paddle.to_tensor([1])
paddle.distributed.send(tensor_type, peer)
paddle.distributed.send(
tensor_type, peer, use_calc_stream=True, group=self.pp_group)
nums = paddle.to_tensor(len(data))
paddle.distributed.send(nums, peer)
paddle.distributed.send(
nums, peer, use_calc_stream=True, group=self.pp_group)
for idx, d in enumerate(data):
assert isinstance(d, paddle.Tensor)
dims = paddle.to_tensor(len(d.shape))
paddle.distributed.send(dims, peer)
paddle.distributed.send(
dims, peer, use_calc_stream=True, group=self.pp_group)
shape = paddle.to_tensor(d.shape)
paddle.distributed.send(shape, peer)
paddle.distributed.send(
shape, peer, use_calc_stream=True, group=self.pp_group)
def _recv_meta(self, peer):
tensor_type = paddle.to_tensor([0])
paddle.distributed.recv(tensor_type, peer)
paddle.distributed.recv(
tensor_type, peer, use_calc_stream=True, group=self.pp_group)
tensor_type = tensor_type.numpy()[0]
if tensor_type == 0:
dims = paddle.to_tensor([0])
paddle.distributed.recv(dims, peer)
paddle.distributed.recv(
dims, peer, use_calc_stream=True, group=self.pp_group)
dims = dims.numpy()[0]
shape = paddle.to_tensor([0] * dims)
paddle.distributed.recv(shape, peer)
paddle.distributed.recv(
shape, peer, use_calc_stream=True, group=self.pp_group)
shape = shape.numpy().tolist()
return self._allocate_buffer(
shape, dtype="float32", num_caches=1)[0]
elif tensor_type == 1:
num = paddle.to_tensor([0])
paddle.distributed.recv(num, peer)
paddle.distributed.recv(
num, peer, use_calc_stream=True, group=self.pp_group)
num = num.numpy()[0]
shapes = []
for i in range(num):
dims = paddle.to_tensor([0])
paddle.distributed.recv(dims, peer)
paddle.distributed.recv(
dims, peer, use_calc_stream=True, group=self.pp_group)
dims = dims.numpy()[0]
shape = paddle.to_tensor([0] * dims)
paddle.distributed.recv(shape, peer)
paddle.distributed.recv(
shape, peer, use_calc_stream=True, group=self.pp_group)
shapes.append(shape.numpy().tolist())
dtypes = ["float32"] * len(shapes)
caches = self._allocate_buffers(shapes, dtypes, num_buffers=1)[0]
buffers = tuple(buffers)
return buffers
caches = self._allocate_buffers(shapes, dtypes, num_caches=1)[0]
caches = tuple(caches)
return caches
def _send_activations(self, cache_id):
outputs = self.caches['outputs'][cache_id]
......@@ -282,10 +331,18 @@ class PipelineParallel(MetaParallelBase):
self._send_meta(outputs, self.next_stage_id)
if isinstance(outputs, paddle.Tensor):
paddle.distributed.send(outputs, self.next_stage_id)
paddle.distributed.send(
outputs,
self.next_stage_id,
use_calc_stream=True,
group=self.pp_group)
elif isinstance(outputs, tuple):
for output in outputs:
paddle.distributed.send(output, self.next_stage_id)
paddle.distributed.send(
output,
self.next_stage_id,
use_calc_stream=True,
group=self.pp_group)
def _send_gradients(self, cache_id):
inputs = self.caches['inputs'][cache_id]
......@@ -293,15 +350,22 @@ class PipelineParallel(MetaParallelBase):
if isinstance(inputs, paddle.Tensor):
assert inputs.grad is not None
paddle.distributed.send(
paddle.to_tensor(inputs.grad), self.prev_stage_id)
paddle.to_tensor(inputs.grad),
self.prev_stage_id,
use_calc_stream=True,
group=self.pp_group)
else:
for idx, d in enumerate(inputs):
# Skip tensors that will not produce a grad
if not d.dtype in FLOAT_TYPES:
if not is_float_tensor(d):
assert d.grad is None
continue
assert d.grad is not None
paddle.distributed.send(d.grad, self.prev_stage_id)
paddle.distributed.send(
d.grad,
self.prev_stage_id,
use_calc_stream=True,
group=self.pp_group)
self.caches['inputs'][cache_id] = None
def _recv_activations(self, cache_id):
......@@ -312,22 +376,30 @@ class PipelineParallel(MetaParallelBase):
self.recv_cache = self._recv_meta(self.prev_stage_id)
if isinstance(self.recv_cache, paddle.Tensor):
paddle.distributed.recv(self.recv_cache, self.prev_stage_id)
paddle.distributed.recv(
self.recv_cache,
self.prev_stage_id,
use_calc_stream=True,
group=self.pp_group)
inputs = self.recv_cache.clone().detach()
inputs.stop_gradient = inputs.dtype not in FLOAT_TYPES
inputs.stop_gradient = not is_float_tensor(inputs)
else:
assert isinstance(self.recv_cache, tuple)
inputs = [None] * len(self.recv_cache)
for idx, d in enumerate(self.recv_cache):
assert isinstance(d, paddle.Tensor)
paddle.distributed.recv(d, self.prev_stage_id)
paddle.distributed.recv(
d,
self.prev_stage_id,
use_calc_stream=True,
group=self.pp_group)
inputs[idx] = d.clone().detach()
inputs = tuple(inputs)
for d in inputs:
d.stop_gradient = d.dtype not in FLOAT_TYPES
d.stop_gradient = not is_float_tensor(d)
self.caches['inputs'][cache_id] = inputs
......@@ -336,29 +408,35 @@ class PipelineParallel(MetaParallelBase):
if self.grad_tensors is None:
if isinstance(outputs, paddle.Tensor):
s = list(outputs.shape)
dtype = 'float32'
dtype = 'float16' if self.use_amp else "float32"
self.grad_tensors = self._allocate_buffer(
s, dtype, num_buffers=1)[0]
else:
sizes = [
list(d.shape) for d in outputs if d.dtype in FLOAT_TYPES
]
dtypes = ['float32'] * len(sizes)
sizes = [list(d.shape) for d in outputs if is_float_tensor(d)]
dtypes = ['float16'] * len(
sizes) if self.use_amp else ['float32'] * len(sizes)
self.grad_tensors = self._allocate_buffers(
sizes, dtypes, num_buffers=1)[0]
sizes, dtypes, num_caches=1)[0]
if isinstance(self.grad_tensors, paddle.Tensor):
paddle.distributed.recv(self.grad_tensors, self.next_stage_id)
paddle.distributed.recv(
self.grad_tensors,
self.next_stage_id,
use_calc_stream=True,
group=self.pp_group)
else:
assert isinstance(outputs, tuple)
for d in self.grad_tensors:
paddle.distributed.recv(d, self.next_stage_id)
def _step(self, lr_kwargs=None):
self._modifying_grad = True
paddle.distributed.recv(
d,
self.next_stage_id,
use_calc_stream=True,
group=self.pp_group)
def _step(self):
self._allreduce_grads()
self.optimizer.step()
self.optimizer.clear_gradients()
self._modifying_grad = False
def _clear_grads(self, inputs):
if isinstance(inputs, paddle.Tensor):
......@@ -372,26 +450,24 @@ class PipelineParallel(MetaParallelBase):
def _allocate_zeros(self, shape, dtype):
return paddle.zeros(shape, dtype)
def _allocate_buffer(self, shape, dtype, num_buffers=-1, **kwargs):
buffers = []
if num_buffers == -1:
num_buffers = self.num_caches
for count in range(num_buffers):
buffers.append(self._allocate_zeros(shape, dtype))
return buffers
def _allocate_buffers(self, shapes, dtypes, num_buffers=-1):
buffers = []
if num_buffers == -1:
num_buffers = self.num_caches
for count in range(num_buffers):
buffer = []
def _allocate_buffer(self, shape, dtype, num_caches=-1):
caches = []
if num_caches == -1:
num_caches = self.num_caches
for count in range(num_caches):
caches.append(self._allocate_zeros(shape, dtype))
return caches
def _allocate_buffers(self, shapes, dtypes, num_caches=-1):
caches = []
if num_caches == -1:
num_caches = self.num_caches
for count in range(num_caches):
cache = []
for shape, dtype in zip(shapes, dtypes):
buffer.append(
self._allocate_zeros(
shape, dtype, requires_grad=requires_grad))
buffers.append(buffer)
return buffers
cache.append(self._allocate_zeros(shape, dtype))
caches.append(cache)
return caches
def save_state_dict(self, model_path):
state_dict = self._layers.state_dict()
......@@ -403,25 +479,9 @@ class PipelineParallel(MetaParallelBase):
_COMMAND_MAP = {
utils.Optimize: _step,
#utils.ReduceGrads: _allreduce_grads,
utils.Forward: _forward,
utils.Backward: _backward,
}
def _pre_forward(self, *inputs, **kwargs):
pass
def forward(self, *inputs, **kwargs):
raise RuntimeError("Call train_batch for pipeline instead of forward.")
def _post_forward(self, output):
pass
def _pre_backward(self, loss):
pass
def backward_impl(self, loss, parameters):
pass
def _post_backward(self, loss):
pass
......@@ -16,7 +16,21 @@ import abc
import paddle
from ...utils import hybrid_parallel_util as hp_util
__all__ = ['get_tensor_bytes', ]
__all__ = [
'get_tensor_bytes',
'is_float_tensor',
]
FLOAT_TYPES = [
paddle.float16,
paddle.float32,
paddle.float64,
]
def is_float_tensor(tensor):
"""Is a float tensor"""
return tensor.dtype in FLOAT_TYPES
def get_tensor_bytes(tensor):
......@@ -48,10 +62,6 @@ class Generator():
self.stage_id = stage_id
self.prev_stage = self.stage_id - 1
self.next_stage = self.stage_id + 1
assert self.micro_batches >= self.stages, (
"micro_batches {} "
"must be greater than or equal to {}".format(self.micro_batches,
self.stages))
@abc.abstractmethod
def generate(self):
......@@ -73,18 +83,25 @@ class TrainGenerator(Generator):
cmds = []
forward_steps = 0
backward_steps = 0
while (forward_steps < startup_steps):
cmds.append(Forward)
forward_steps += 1
#while (forward_steps < startup_steps):
# cmds.append(Forward(cache_id=forward_steps))
# forward_steps += 1
#while (forward_steps < self.micro_batches):
# cmds.append(Forward(cache_id=forward_steps))
# forward_steps += 1
# cmds.append(Backward(cache_id=backward_steps))
# backward_steps += 1
#while (backward_steps < self.micro_batches):
# cmds.append(Backward(cache_id=backward_steps))
# backward_steps += 1
#cmds.append(Optimize())
while (forward_steps < self.micro_batches):
cmds.append(Forward)
cmds.append(Forward(cache_id=forward_steps))
forward_steps += 1
cmds.append(Backward)
backward_steps += 1
while (backward_steps < self.micro_batches):
cmds.append(Backward)
cmds.append(Backward(cache_id=backward_steps))
backward_steps += 1
cmds.append(Optimize)
cmds.append(Optimize())
yield cmds
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册