未验证 提交 6b86e966 编写于 作者: L lilong12 提交者: GitHub

Fix the bug in pipeline for dygraph mode (#32716) (#32728)

* update, test=develop
上级 4593597d
...@@ -108,7 +108,6 @@ class PipelineLayer(Layer): ...@@ -108,7 +108,6 @@ class PipelineLayer(Layer):
# construct layer # construct layer
self.run_function = [] self.run_function = []
self._build_layer() self._build_layer()
self.to(paddle.CUDAPlace(self.device_id))
def _segment_network(self, seg_method): def _segment_network(self, seg_method):
logger.info("start segment network..") logger.info("start segment network..")
......
...@@ -22,15 +22,11 @@ from numpy import prod ...@@ -22,15 +22,11 @@ from numpy import prod
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase from .meta_parallel_base import MetaParallelBase
from .pp_utils.utils import get_tensor_bytes from .pp_utils.utils import get_tensor_bytes, is_float_tensor
from .pp_utils import utils from .pp_utils import utils
from .parallel_layers.pp_layers import PipelineLayer from .parallel_layers.pp_layers import PipelineLayer
from ..utils.hybrid_parallel_util import *
FLOAT_TYPES = [ from ..utils.log_util import logger
paddle.float16,
paddle.float32,
paddle.float64,
]
class PipelineParallel(MetaParallelBase): class PipelineParallel(MetaParallelBase):
...@@ -46,20 +42,18 @@ class PipelineParallel(MetaParallelBase): ...@@ -46,20 +42,18 @@ class PipelineParallel(MetaParallelBase):
'inputs': [], 'inputs': [],
'labels': [], 'labels': [],
'outputs': [], 'outputs': [],
'backward_tensors': [],
} }
self.recv_cache = None self.recv_cache = None
self.grad_tensors = None self.grad_tensors = None
self.meta_buffer = None
self.send_meta = True self.send_meta = True
self.first_gradient_send = True
self.current_loss = paddle.to_tensor(0.0) self.current_loss = paddle.to_tensor(0.0)
self.total_loss = None self.total_loss = None
def _prepare_for_model(self): self.use_amp = self._strategy.amp
self.init_loss_scaling = self._strategy.amp_configs['init_loss_scaling']
self.micro_batch_size = self._strategy.pipeline_configs[ self.micro_batch_size = self._strategy.pipeline_configs[
'micro_batch_size'] 'micro_batch_size']
self.accumulate_steps = self._strategy.pipeline_configs[ self.accumulate_steps = self._strategy.pipeline_configs[
...@@ -69,9 +63,17 @@ class PipelineParallel(MetaParallelBase): ...@@ -69,9 +63,17 @@ class PipelineParallel(MetaParallelBase):
self.stage_id = self._hcg.get_stage_id() self.stage_id = self._hcg.get_stage_id()
self.prev_stage_id = self.stage_id - 1 self.prev_stage_id = self.stage_id - 1
self.next_stage_id = self.stage_id + 1 self.next_stage_id = self.stage_id + 1
self._layers = PipelineLayer( self.pp_group = self._hcg.get_pipe_parallel_group()
layers=self._layers, num_stages=self.num_stages) logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format(
#TODO: init process group self.num_stages, self.stage_id))
if self.use_model_parallel:
logger.info("start broadcast mp parameters")
broadcast_mp_parameters(self._layers, self._hcg)
if self.use_data_parallel:
logger.info("start broadcast mp parameters")
broadcast_dp_parameters(self._layers, self._hcg)
def _allocate_caches(self, num_caches): def _allocate_caches(self, num_caches):
if self.num_caches >= num_caches: if self.num_caches >= num_caches:
...@@ -82,19 +84,19 @@ class PipelineParallel(MetaParallelBase): ...@@ -82,19 +84,19 @@ class PipelineParallel(MetaParallelBase):
for key in self.caches: for key in self.caches:
self.caches[key].extend([None] * num) self.caches[key].extend([None] * num)
def train_batch(self, data_iter, optimizer): def train_batch(self, data, optimizer):
self.optimizer = optimizer self.optimizer = optimizer
assert fluid.framework._dygraph_tracer()._has_grad, ( assert fluid.framework._dygraph_tracer()._has_grad, (
'Please enable the generation of gradients.') 'Please enable the generation of gradients.')
if self.stage_id == 0 or self.stage_id == self.num_stages - 1: if self.stage_id == 0 or self.stage_id == self.num_stages - 1:
assert data_iter, ( assert data, (
"For the first and the last stage, the data_iter must be set.") "For the first and the last stage, the data_iter must be set.")
else: else:
assert data_iter is None, ( assert data is None, (
"For pipe stages other than the first and the last one, " "For pipe stages other than the first and the last one, "
"the data_iter must be None.") "the data_iter must be None.")
self.data_iter = data_iter self.data = data
self._layers.train() self._layers.train()
self.total_loss = None self.total_loss = None
...@@ -104,39 +106,24 @@ class PipelineParallel(MetaParallelBase): ...@@ -104,39 +106,24 @@ class PipelineParallel(MetaParallelBase):
return self.total_loss return self.total_loss
def _train(self, minibatch_cmds): def _train(self, minibatch_cmds):
self._allocate_caches(self.num_stages) self._allocate_caches(self.accumulate_steps)
for microbatch_cmds in minibatch_cmds: for micro_cmds in minibatch_cmds:
for cmd in microbatch_cmds: for cmd in micro_cmds:
if type(cmd) not in self._COMMAND_MAP: assert type(cmd) in self._COMMAND_MAP, "unknow cmd: {}".format(
#FIXME: type(cmd))
continue
self._apply_cmd = MethodType(self._COMMAND_MAP[type(cmd)], self) self._apply_cmd = MethodType(self._COMMAND_MAP[type(cmd)], self)
self._apply_cmd(**cmd.kwargs) self._apply_cmd(**cmd.kwargs)
def _allreduce_grads(self): def _allreduce_grads(self):
self._modifying_grad = True if not self.use_data_parallel: return
assert self.use_data_parallel <= 1, ("Do not support data parallel " fused_allreduce_gradients(list(self._layers.parameters()), self._hcg)
"with pipeline parallel now.")
self._modifying_grad = False
def _get_data(self):
if self.use_model_parallel:
mp_rank = self._hcg.get_model_parallel_rank()
else:
mp_rank = 0
data = None
# mp rank 0 loads the data and broadcat it to others.
if mp_rank == 0:
data = next(self.data_iter)
if self.use_model_parallel:
data = paddle.distributed.broadcast(
data, group=self._hcg.get_model_parallel_group())
return data
def _forward(self, cache_id): def _forward(self, cache_id):
# load data
self._load_micro_batch(cache_id)
if self.stage_id != 0:
self._recv_activations(cache_id)
if isinstance(self.caches['inputs'][cache_id], tuple): if isinstance(self.caches['inputs'][cache_id], tuple):
inputs = tuple(t.clone() for t in self.caches['inputs'][cache_id]) inputs = tuple(t.clone() for t in self.caches['inputs'][cache_id])
else: else:
...@@ -144,9 +131,13 @@ class PipelineParallel(MetaParallelBase): ...@@ -144,9 +131,13 @@ class PipelineParallel(MetaParallelBase):
self._clear_grads(inputs) self._clear_grads(inputs)
outputs = self._layers.forward(inputs) outputs = self._layers.forward(inputs)
self.caches['outputs'][cache_id] = outputs self.caches['outputs'][cache_id] = outputs
if self.stage_id == self.num_stages - 1:
if self._layers._loss_fn is not None:
labels = self.caches['labels'][cache_id]
outputs = self._layers._loss_fn(outputs, labels)
if self.stage_id == self.num_stages - 1: if self.stage_id == self.num_stages - 1:
self.current_loss = outputs self.current_loss = outputs
if isinstance(self.current_loss, paddle.Tensor): if isinstance(self.current_loss, paddle.Tensor):
...@@ -160,18 +151,28 @@ class PipelineParallel(MetaParallelBase): ...@@ -160,18 +151,28 @@ class PipelineParallel(MetaParallelBase):
] ]
for idx, v in enumerate(self.current_loss): for idx, v in enumerate(self.current_loss):
self.total_loss[idx] += v.detach() self.total_loss[idx] += v.detach()
if self.use_data_parallel:
self.current_loss = self.current_loss / self._hcg.get_data_parallel_world_size(
)
if self.accumulate_steps > 1:
self.current_loss = self.current_loss / self.accumulate_steps
self.caches['outputs'][cache_id] = self.current_loss.clone()
else:
self._send_activations(cache_id)
def _backward(self, cache_id): def _backward(self, cache_id):
assert self.optimizer is not None assert self.optimizer is not None
if self.stage_id == self.num_stages - 1: if self.stage_id == self.num_stages - 1:
paddle.autograd.backward(self.current_loss) paddle.autograd.backward(self.caches['outputs'][cache_id])
self._send_gradients(cache_id)
return return
self._recv_gradients(cache_id)
outputs = self.caches['outputs'][cache_id] outputs = self.caches['outputs'][cache_id]
grad_tensors = self.grad_tensors grad_tensors = self.grad_tensors
if isinstance(outputs, tuple): if isinstance(outputs, tuple):
out_tensors = [t for t in outputs if t.dtype in FLOAT_TYPES] out_tensors = [t for t in outputs if is_float_tensor(t)]
assert len(out_tensors) == len(grad_tensors) assert len(out_tensors) == len(grad_tensors)
paddle.autograd.backward( paddle.autograd.backward(
tensors=out_tensors, grad_tensors=grad_tensors) tensors=out_tensors, grad_tensors=grad_tensors)
...@@ -179,41 +180,76 @@ class PipelineParallel(MetaParallelBase): ...@@ -179,41 +180,76 @@ class PipelineParallel(MetaParallelBase):
paddle.autograd.backward( paddle.autograd.backward(
tensors=[outputs], grad_tensors=[grad_tensors]) tensors=[outputs], grad_tensors=[grad_tensors])
self.caches['outputs'][cache_id] = None
grad_tensors = None grad_tensors = None
if self.stage_id != 0: self._send_gradients(cache_id)
self.caches['outputs'][cache_id] = None
#self.caches['backward_tensors'][cache_id] = None
def _get_data(self):
if self.use_model_parallel:
mp_rank = self._hcg.get_model_parallel_rank()
else:
mp_rank = 0
# mp rank 0 loads the data and broadcat it to others.
data = self.data
if self.use_model_parallel and (self.stage_id == 0 or
self.stage_id == self.num_stages - 1):
assert isinstance(data, (tuple, paddle.Tensor))
if isinstance(data, paddle.Tensor):
paddle.distributed.broadcast(
data,
src=self._hcg.get_model_parallel_group_src_rank(),
group=self._hcg.get_model_parallel_group())
else:
data = []
for d in self.data:
assert isinstance(d, paddle.Tensor)
paddle.distributed.broadcast(
d,
src=self._hcg.get_model_parallel_group_src_rank(),
group=self._hcg.get_model_parallel_group())
data.append(d)
data = tuple(data)
return data
def _load_micro_batch(self, cache_id): def _load_micro_batch(self, cache_id):
inputs = self._get_data() inputs = self._get_data()
if self.stage_id == 0: if self.stage_id == 0:
data = None data = None
if isinstance(inputs[0], paddle.Tensor): #if isinstance(inputs[0], paddle.Tensor):
if len(inputs) == 1:
assert isinstance(inputs[0], paddle.Tensor)
data = inputs[0].clone().detach() data = inputs[0].clone().detach()
data.stop_gradient = data.dtype == paddle.float32 #data.stop_gradient = not is_float_tensor(data)
data.stop_gradient = True
else: else:
assert isinstance(inputs[0], tuple) assert isinstance(inputs, tuple)
# Assume list or tuple
data = [] data = []
for d in inputs[0]: for d in inputs:
assert isinstance(d, paddle.Tensor) assert isinstance(d, paddle.Tensor)
d = d.clone().detach() i = d.clone().detach()
d.stop_gradient = d.dtype == paddle.float32 #i.stop_gradient = not is_float_tensor(i)
loaded.append(d) i.stop_gradient = True
data.append(i)
data = tuple(data) data = tuple(data)
self.caches['inputs'][cache_id] = data self.caches['inputs'][cache_id] = data
if self.stage_id == self.num_stages - 1: if self.stage_id == self.num_stages - 1:
label = None labels = None
if isinstance(inputs[1], paddle.Tensor): #if isinstance(inputs[1], paddle.Tensor):
label = inputs[1] if len(inputs) == 1:
elif isinstance(data[1], tuple): assert isinstance(inputs[0], paddle.Tensor)
label = [] labels = inputs[0]
for l in inputs[1]: elif isinstance(inputs, tuple):
assert isinstance(l, paddle.Tensor) labels = []
l = l.detach() for label in inputs:
label.append(l) assert isinstance(label, paddle.Tensor)
label = tuple(label) label = label.detach()
self.caches['labels'][cache_id] = label labels.append(label)
labels = tuple(labels)
self.caches['labels'][cache_id] = labels
def _send_meta(self, data, peer): def _send_meta(self, data, peer):
""" """
...@@ -225,54 +261,67 @@ class PipelineParallel(MetaParallelBase): ...@@ -225,54 +261,67 @@ class PipelineParallel(MetaParallelBase):
""" """
if isinstance(data, paddle.Tensor): if isinstance(data, paddle.Tensor):
tensor_type = paddle.to_tensor([0]) tensor_type = paddle.to_tensor([0])
paddle.distributed.send(tensor_type, peer) paddle.distributed.send(
tensor_type, peer, use_calc_stream=True, group=self.pp_group)
dims = paddle.to_tensor(len(data.shape)) dims = paddle.to_tensor(len(data.shape))
paddle.distributed.send(dims, peer) paddle.distributed.send(
dims, peer, use_calc_stream=True, group=self.pp_group)
shape = paddle.to_tensor(data.shape) shape = paddle.to_tensor(data.shape)
paddle.distributed.send(shape, peer) paddle.distributed.send(
shape, peer, use_calc_stream=True, group=self.pp_group)
elif isinstance(data, tuple): elif isinstance(data, tuple):
tensor_type = paddle.to_tensor([1]) tensor_type = paddle.to_tensor([1])
paddle.distributed.send(tensor_type, peer) paddle.distributed.send(
tensor_type, peer, use_calc_stream=True, group=self.pp_group)
nums = paddle.to_tensor(len(data)) nums = paddle.to_tensor(len(data))
paddle.distributed.send(nums, peer) paddle.distributed.send(
nums, peer, use_calc_stream=True, group=self.pp_group)
for idx, d in enumerate(data): for idx, d in enumerate(data):
assert isinstance(d, paddle.Tensor) assert isinstance(d, paddle.Tensor)
dims = paddle.to_tensor(len(d.shape)) dims = paddle.to_tensor(len(d.shape))
paddle.distributed.send(dims, peer) paddle.distributed.send(
dims, peer, use_calc_stream=True, group=self.pp_group)
shape = paddle.to_tensor(d.shape) shape = paddle.to_tensor(d.shape)
paddle.distributed.send(shape, peer) paddle.distributed.send(
shape, peer, use_calc_stream=True, group=self.pp_group)
def _recv_meta(self, peer): def _recv_meta(self, peer):
tensor_type = paddle.to_tensor([0]) tensor_type = paddle.to_tensor([0])
paddle.distributed.recv(tensor_type, peer) paddle.distributed.recv(
tensor_type, peer, use_calc_stream=True, group=self.pp_group)
tensor_type = tensor_type.numpy()[0] tensor_type = tensor_type.numpy()[0]
if tensor_type == 0: if tensor_type == 0:
dims = paddle.to_tensor([0]) dims = paddle.to_tensor([0])
paddle.distributed.recv(dims, peer) paddle.distributed.recv(
dims, peer, use_calc_stream=True, group=self.pp_group)
dims = dims.numpy()[0] dims = dims.numpy()[0]
shape = paddle.to_tensor([0] * dims) shape = paddle.to_tensor([0] * dims)
paddle.distributed.recv(shape, peer) paddle.distributed.recv(
shape, peer, use_calc_stream=True, group=self.pp_group)
shape = shape.numpy().tolist() shape = shape.numpy().tolist()
return self._allocate_buffer( return self._allocate_buffer(
shape, dtype="float32", num_caches=1)[0] shape, dtype="float32", num_caches=1)[0]
elif tensor_type == 1: elif tensor_type == 1:
num = paddle.to_tensor([0]) num = paddle.to_tensor([0])
paddle.distributed.recv(num, peer) paddle.distributed.recv(
num, peer, use_calc_stream=True, group=self.pp_group)
num = num.numpy()[0] num = num.numpy()[0]
shapes = [] shapes = []
for i in range(num): for i in range(num):
dims = paddle.to_tensor([0]) dims = paddle.to_tensor([0])
paddle.distributed.recv(dims, peer) paddle.distributed.recv(
dims, peer, use_calc_stream=True, group=self.pp_group)
dims = dims.numpy()[0] dims = dims.numpy()[0]
shape = paddle.to_tensor([0] * dims) shape = paddle.to_tensor([0] * dims)
paddle.distributed.recv(shape, peer) paddle.distributed.recv(
shape, peer, use_calc_stream=True, group=self.pp_group)
shapes.append(shape.numpy().tolist()) shapes.append(shape.numpy().tolist())
dtypes = ["float32"] * len(shapes) dtypes = ["float32"] * len(shapes)
caches = self._allocate_buffers(shapes, dtypes, num_buffers=1)[0] caches = self._allocate_buffers(shapes, dtypes, num_caches=1)[0]
buffers = tuple(buffers) caches = tuple(caches)
return buffers return caches
def _send_activations(self, cache_id): def _send_activations(self, cache_id):
outputs = self.caches['outputs'][cache_id] outputs = self.caches['outputs'][cache_id]
...@@ -282,10 +331,18 @@ class PipelineParallel(MetaParallelBase): ...@@ -282,10 +331,18 @@ class PipelineParallel(MetaParallelBase):
self._send_meta(outputs, self.next_stage_id) self._send_meta(outputs, self.next_stage_id)
if isinstance(outputs, paddle.Tensor): if isinstance(outputs, paddle.Tensor):
paddle.distributed.send(outputs, self.next_stage_id) paddle.distributed.send(
outputs,
self.next_stage_id,
use_calc_stream=True,
group=self.pp_group)
elif isinstance(outputs, tuple): elif isinstance(outputs, tuple):
for output in outputs: for output in outputs:
paddle.distributed.send(output, self.next_stage_id) paddle.distributed.send(
output,
self.next_stage_id,
use_calc_stream=True,
group=self.pp_group)
def _send_gradients(self, cache_id): def _send_gradients(self, cache_id):
inputs = self.caches['inputs'][cache_id] inputs = self.caches['inputs'][cache_id]
...@@ -293,15 +350,22 @@ class PipelineParallel(MetaParallelBase): ...@@ -293,15 +350,22 @@ class PipelineParallel(MetaParallelBase):
if isinstance(inputs, paddle.Tensor): if isinstance(inputs, paddle.Tensor):
assert inputs.grad is not None assert inputs.grad is not None
paddle.distributed.send( paddle.distributed.send(
paddle.to_tensor(inputs.grad), self.prev_stage_id) paddle.to_tensor(inputs.grad),
self.prev_stage_id,
use_calc_stream=True,
group=self.pp_group)
else: else:
for idx, d in enumerate(inputs): for idx, d in enumerate(inputs):
# Skip tensors that will not produce a grad # Skip tensors that will not produce a grad
if not d.dtype in FLOAT_TYPES: if not is_float_tensor(d):
assert d.grad is None assert d.grad is None
continue continue
assert d.grad is not None assert d.grad is not None
paddle.distributed.send(d.grad, self.prev_stage_id) paddle.distributed.send(
d.grad,
self.prev_stage_id,
use_calc_stream=True,
group=self.pp_group)
self.caches['inputs'][cache_id] = None self.caches['inputs'][cache_id] = None
def _recv_activations(self, cache_id): def _recv_activations(self, cache_id):
...@@ -312,22 +376,30 @@ class PipelineParallel(MetaParallelBase): ...@@ -312,22 +376,30 @@ class PipelineParallel(MetaParallelBase):
self.recv_cache = self._recv_meta(self.prev_stage_id) self.recv_cache = self._recv_meta(self.prev_stage_id)
if isinstance(self.recv_cache, paddle.Tensor): if isinstance(self.recv_cache, paddle.Tensor):
paddle.distributed.recv(self.recv_cache, self.prev_stage_id) paddle.distributed.recv(
self.recv_cache,
self.prev_stage_id,
use_calc_stream=True,
group=self.pp_group)
inputs = self.recv_cache.clone().detach() inputs = self.recv_cache.clone().detach()
inputs.stop_gradient = inputs.dtype not in FLOAT_TYPES inputs.stop_gradient = not is_float_tensor(inputs)
else: else:
assert isinstance(self.recv_cache, tuple) assert isinstance(self.recv_cache, tuple)
inputs = [None] * len(self.recv_cache) inputs = [None] * len(self.recv_cache)
for idx, d in enumerate(self.recv_cache): for idx, d in enumerate(self.recv_cache):
assert isinstance(d, paddle.Tensor) assert isinstance(d, paddle.Tensor)
paddle.distributed.recv(d, self.prev_stage_id) paddle.distributed.recv(
d,
self.prev_stage_id,
use_calc_stream=True,
group=self.pp_group)
inputs[idx] = d.clone().detach() inputs[idx] = d.clone().detach()
inputs = tuple(inputs) inputs = tuple(inputs)
for d in inputs: for d in inputs:
d.stop_gradient = d.dtype not in FLOAT_TYPES d.stop_gradient = not is_float_tensor(d)
self.caches['inputs'][cache_id] = inputs self.caches['inputs'][cache_id] = inputs
...@@ -336,29 +408,35 @@ class PipelineParallel(MetaParallelBase): ...@@ -336,29 +408,35 @@ class PipelineParallel(MetaParallelBase):
if self.grad_tensors is None: if self.grad_tensors is None:
if isinstance(outputs, paddle.Tensor): if isinstance(outputs, paddle.Tensor):
s = list(outputs.shape) s = list(outputs.shape)
dtype = 'float32' dtype = 'float16' if self.use_amp else "float32"
self.grad_tensors = self._allocate_buffer( self.grad_tensors = self._allocate_buffer(
s, dtype, num_buffers=1)[0] s, dtype, num_buffers=1)[0]
else: else:
sizes = [ sizes = [list(d.shape) for d in outputs if is_float_tensor(d)]
list(d.shape) for d in outputs if d.dtype in FLOAT_TYPES dtypes = ['float16'] * len(
] sizes) if self.use_amp else ['float32'] * len(sizes)
dtypes = ['float32'] * len(sizes)
self.grad_tensors = self._allocate_buffers( self.grad_tensors = self._allocate_buffers(
sizes, dtypes, num_buffers=1)[0] sizes, dtypes, num_caches=1)[0]
if isinstance(self.grad_tensors, paddle.Tensor): if isinstance(self.grad_tensors, paddle.Tensor):
paddle.distributed.recv(self.grad_tensors, self.next_stage_id) paddle.distributed.recv(
self.grad_tensors,
self.next_stage_id,
use_calc_stream=True,
group=self.pp_group)
else: else:
assert isinstance(outputs, tuple) assert isinstance(outputs, tuple)
for d in self.grad_tensors: for d in self.grad_tensors:
paddle.distributed.recv(d, self.next_stage_id) paddle.distributed.recv(
d,
def _step(self, lr_kwargs=None): self.next_stage_id,
self._modifying_grad = True use_calc_stream=True,
group=self.pp_group)
def _step(self):
self._allreduce_grads()
self.optimizer.step() self.optimizer.step()
self.optimizer.clear_gradients() self.optimizer.clear_gradients()
self._modifying_grad = False
def _clear_grads(self, inputs): def _clear_grads(self, inputs):
if isinstance(inputs, paddle.Tensor): if isinstance(inputs, paddle.Tensor):
...@@ -372,26 +450,24 @@ class PipelineParallel(MetaParallelBase): ...@@ -372,26 +450,24 @@ class PipelineParallel(MetaParallelBase):
def _allocate_zeros(self, shape, dtype): def _allocate_zeros(self, shape, dtype):
return paddle.zeros(shape, dtype) return paddle.zeros(shape, dtype)
def _allocate_buffer(self, shape, dtype, num_buffers=-1, **kwargs): def _allocate_buffer(self, shape, dtype, num_caches=-1):
buffers = [] caches = []
if num_buffers == -1: if num_caches == -1:
num_buffers = self.num_caches num_caches = self.num_caches
for count in range(num_buffers): for count in range(num_caches):
buffers.append(self._allocate_zeros(shape, dtype)) caches.append(self._allocate_zeros(shape, dtype))
return buffers return caches
def _allocate_buffers(self, shapes, dtypes, num_buffers=-1): def _allocate_buffers(self, shapes, dtypes, num_caches=-1):
buffers = [] caches = []
if num_buffers == -1: if num_caches == -1:
num_buffers = self.num_caches num_caches = self.num_caches
for count in range(num_buffers): for count in range(num_caches):
buffer = [] cache = []
for shape, dtype in zip(shapes, dtypes): for shape, dtype in zip(shapes, dtypes):
buffer.append( cache.append(self._allocate_zeros(shape, dtype))
self._allocate_zeros( caches.append(cache)
shape, dtype, requires_grad=requires_grad)) return caches
buffers.append(buffer)
return buffers
def save_state_dict(self, model_path): def save_state_dict(self, model_path):
state_dict = self._layers.state_dict() state_dict = self._layers.state_dict()
...@@ -403,25 +479,9 @@ class PipelineParallel(MetaParallelBase): ...@@ -403,25 +479,9 @@ class PipelineParallel(MetaParallelBase):
_COMMAND_MAP = { _COMMAND_MAP = {
utils.Optimize: _step, utils.Optimize: _step,
#utils.ReduceGrads: _allreduce_grads,
utils.Forward: _forward, utils.Forward: _forward,
utils.Backward: _backward, utils.Backward: _backward,
} }
def _pre_forward(self, *inputs, **kwargs):
pass
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
raise RuntimeError("Call train_batch for pipeline instead of forward.") raise RuntimeError("Call train_batch for pipeline instead of forward.")
def _post_forward(self, output):
pass
def _pre_backward(self, loss):
pass
def backward_impl(self, loss, parameters):
pass
def _post_backward(self, loss):
pass
...@@ -16,7 +16,21 @@ import abc ...@@ -16,7 +16,21 @@ import abc
import paddle import paddle
from ...utils import hybrid_parallel_util as hp_util from ...utils import hybrid_parallel_util as hp_util
__all__ = ['get_tensor_bytes', ] __all__ = [
'get_tensor_bytes',
'is_float_tensor',
]
FLOAT_TYPES = [
paddle.float16,
paddle.float32,
paddle.float64,
]
def is_float_tensor(tensor):
"""Is a float tensor"""
return tensor.dtype in FLOAT_TYPES
def get_tensor_bytes(tensor): def get_tensor_bytes(tensor):
...@@ -48,10 +62,6 @@ class Generator(): ...@@ -48,10 +62,6 @@ class Generator():
self.stage_id = stage_id self.stage_id = stage_id
self.prev_stage = self.stage_id - 1 self.prev_stage = self.stage_id - 1
self.next_stage = self.stage_id + 1 self.next_stage = self.stage_id + 1
assert self.micro_batches >= self.stages, (
"micro_batches {} "
"must be greater than or equal to {}".format(self.micro_batches,
self.stages))
@abc.abstractmethod @abc.abstractmethod
def generate(self): def generate(self):
...@@ -73,18 +83,25 @@ class TrainGenerator(Generator): ...@@ -73,18 +83,25 @@ class TrainGenerator(Generator):
cmds = [] cmds = []
forward_steps = 0 forward_steps = 0
backward_steps = 0 backward_steps = 0
while (forward_steps < startup_steps): #while (forward_steps < startup_steps):
cmds.append(Forward) # cmds.append(Forward(cache_id=forward_steps))
forward_steps += 1 # forward_steps += 1
#while (forward_steps < self.micro_batches):
# cmds.append(Forward(cache_id=forward_steps))
# forward_steps += 1
# cmds.append(Backward(cache_id=backward_steps))
# backward_steps += 1
#while (backward_steps < self.micro_batches):
# cmds.append(Backward(cache_id=backward_steps))
# backward_steps += 1
#cmds.append(Optimize())
while (forward_steps < self.micro_batches): while (forward_steps < self.micro_batches):
cmds.append(Forward) cmds.append(Forward(cache_id=forward_steps))
forward_steps += 1 forward_steps += 1
cmds.append(Backward)
backward_steps += 1
while (backward_steps < self.micro_batches): while (backward_steps < self.micro_batches):
cmds.append(Backward) cmds.append(Backward(cache_id=backward_steps))
backward_steps += 1 backward_steps += 1
cmds.append(Optimize) cmds.append(Optimize())
yield cmds yield cmds
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册