diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 04525977192be2ae4f56223807d419b5caa45b0a..04d8417fdcbf3f1ef23db09caf1cc417672b8358 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -253,3 +253,8 @@ class HybridCommunicateGroup(object): # check parallel group def get_check_parallel_group(self): return self._check_comm_group + + def get_rank_from_stage(self, stage_id): + coord = self._topo.get_coord(self.global_rank) + tf = coord._replace(pipe=stage_id)._asdict() + return self._topo.get_rank(**tf) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index 00ac019c0d18829e4603c1ffc54efb53b9dda809..c2d79a62c7663a01d5cd1e7ca9ac705612e1db03 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -89,12 +89,14 @@ class HybridParallelOptimizer: self._inner_opt = optimizer self._strategy = strategy self._hcg = hcg - self._is_mp = ( - self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL) + + self._use_dp_mode = ( + self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL) + self._need_dp = (self._hcg.get_data_parallel_world_size() > 1) if isinstance(self._inner_opt._grad_clip, - ClipGradByGlobalNorm) and self._is_mp: + ClipGradByGlobalNorm) and not self._use_dp_mode: logger.warning("using ClipGradByGlobalNorm in TensorParallel, the origin " \ "optmizer'grad clip will be changed.") self._inner_opt._grad_clip = HybridParallelClipGrad( @@ -103,7 +105,7 @@ class HybridParallelOptimizer: @imperative_base.no_grad @framework.dygraph_only def step(self): - if self._is_mp and self._need_dp: + if not self._use_dp_mode and self._need_dp: fused_allreduce_gradients( list(self._inner_opt._parameter_list), self._hcg) self._inner_opt.step() @@ -119,7 +121,7 @@ class HybridParallelOptimizer: parameter_list = parameters if parameters \ else self._parameter_list - if self._is_mp and self._need_dp: + if not self._use_dp_mode and self._need_dp: fused_allreduce_gradients(list(parameter_list), self._hcg) return self._inner_opt.minimize(loss, startup_program, parameters, diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 79e5bc2ffeda06d62b24aec2e10ae3ad071d856a..54324b389336d0f423a79e177a04a0381d2779b9 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -11,39 +11,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -import time -import copy -import os - from types import MethodType -from numpy import prod - import paddle import paddle.fluid as fluid from .meta_parallel_base import MetaParallelBase -from .pp_utils.utils import get_tensor_bytes, is_float_tensor +from .pp_utils.utils import is_float_tensor, get_tensor_dtype, paddle_2_number, number_2_dtype from .pp_utils import utils from .parallel_layers.pp_layers import PipelineLayer from ..utils.hybrid_parallel_util import broadcast_mp_parameters from ..utils.hybrid_parallel_util import broadcast_dp_parameters -from ..utils.hybrid_parallel_util import fused_allreduce_gradients from ..utils.log_util import logger +from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer __all__ = [] -FLOAT_TYPES = [ - paddle.float16, - paddle.float32, - paddle.float64, -] - class PipelineParallel(MetaParallelBase): def __init__(self, layers, hcg, strategy): + if not isinstance(layers, PipelineLayer): + raise TypeError( + "The Layer should be a derived class of PipelineLayer.") super(PipelineParallel, self).__init__(layers, hcg, strategy) - self.use_pipe_parallel = self._hcg.get_pipe_parallel_world_size() > 1 self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1 self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1 @@ -63,8 +53,6 @@ class PipelineParallel(MetaParallelBase): self.current_loss = paddle.to_tensor(0.0) self.total_loss = None - self.use_amp = self._strategy.amp - self.init_loss_scaling = self._strategy.amp_configs['init_loss_scaling'] self.micro_batch_size = self._strategy.pipeline_configs[ 'micro_batch_size'] self.accumulate_steps = self._strategy.pipeline_configs[ @@ -75,6 +63,11 @@ class PipelineParallel(MetaParallelBase): self.prev_stage_id = self.stage_id - 1 self.next_stage_id = self.stage_id + 1 self.pp_group = self._hcg.get_pipe_parallel_group() + + self.is_first_stage = self.stage_id == 0 + self.is_last_stage = (self.stage_id == (self.num_stages - 1)) + self.global_rank = self._hcg.get_global_rank() + logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format( self.num_stages, self.stage_id)) @@ -83,51 +76,72 @@ class PipelineParallel(MetaParallelBase): broadcast_mp_parameters(self._layers, self._hcg) if self.use_data_parallel: - logger.info("start broadcast mp parameters") + logger.info("start broadcast dp parameters") broadcast_dp_parameters(self._layers, self._hcg) - def _allocate_caches(self, num_caches): + def _init_caches(self, num_caches): if self.num_caches >= num_caches: return - - num = num_caches - self.num_caches - self.num_caches = num_caches + self.num_caches = num_caches - self.num_caches for key in self.caches: - self.caches[key].extend([None] * num) + self.caches[key].extend([None] * self.num_caches) + + def _reduce_final_loss(self): + if self.is_last_stage: + assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss" + loss = self.total_loss.clone() / self.accumulate_steps + paddle.distributed.broadcast( + loss, + src=self.global_rank, + use_calc_stream=True, + group=self.pp_group) + else: + loss = paddle.to_tensor(0.0) + paddle.distributed.broadcast( + loss, + src=self._hcg.get_rank_from_stage(self.num_stages - 1), + use_calc_stream=True, + group=self.pp_group) + return loss - def train_batch(self, data, optimizer): + def train_batch(self, data, optimizer, lr_scheduler=None): + assert isinstance(optimizer, HybridParallelOptimizer), ( + 'optimizer should be HybridParallelOptimizer subclass.') self.optimizer = optimizer + self.lr_scheduler = lr_scheduler assert fluid.framework._dygraph_tracer()._has_grad, ( 'Please enable the generation of gradients.') - if self.stage_id == 0 or self.stage_id == self.num_stages - 1: - assert data, ( + if self.is_first_stage or self.is_last_stage: + assert data is not None, ( "For the first and the last stage, the data_iter must be set.") else: - assert data is None, ( - "For pipe stages other than the first and the last one, " - "the data_iter must be None.") + data = None + self.data = data self._layers.train() - self.total_loss = None - - minibatch_cmds = utils.TrainGenerator(self.accumulate_steps, - self.num_stages, self.stage_id) - self._train(minibatch_cmds) - return self.total_loss - def _train(self, minibatch_cmds): - self._allocate_caches(self.accumulate_steps) - for micro_cmds in minibatch_cmds: - for cmd in micro_cmds: - assert type(cmd) in self._COMMAND_MAP, "unknow cmd: {}".format( - type(cmd)) - self._apply_cmd = MethodType(self._COMMAND_MAP[type(cmd)], self) - self._apply_cmd(**cmd.kwargs) - - def _allreduce_grads(self): - if not self.use_data_parallel: return - fused_allreduce_gradients(list(self._layers.parameters()), self._hcg) + # store total loss of entire batch + self.total_loss = None + self._init_caches(self.accumulate_steps) + startup_steps = self.num_stages - self.stage_id - 1 + forward_steps = 0 + backward_steps = 0 + + # forward + while (forward_steps < self.accumulate_steps): + self._forward(cache_id=forward_steps) + forward_steps += 1 + + # backward + while (backward_steps < self.accumulate_steps): + self._backward(cache_id=backward_steps) + backward_steps += 1 + + # optimizer + self._step() + self.train_loss = self._reduce_final_loss() + return self.train_loss def _forward(self, cache_id): # load data @@ -140,16 +154,17 @@ class PipelineParallel(MetaParallelBase): else: inputs = self.caches['inputs'][cache_id] - self._clear_grads(inputs) outputs = self._layers.forward(inputs) + self._clear_grads(inputs) + self.caches['outputs'][cache_id] = outputs - if self.stage_id == self.num_stages - 1: + if self.is_last_stage: if self._layers._loss_fn is not None: labels = self.caches['labels'][cache_id] outputs = self._layers._loss_fn(outputs, labels) - if self.stage_id == self.num_stages - 1: + if self.is_last_stage: self.current_loss = outputs if isinstance(self.current_loss, paddle.Tensor): if self.total_loss is None: @@ -162,18 +177,17 @@ class PipelineParallel(MetaParallelBase): ] for idx, v in enumerate(self.current_loss): self.total_loss[idx] += v.detach() - if self.use_data_parallel: - self.current_loss = self.current_loss / self._hcg.get_data_parallel_world_size( - ) + if self.accumulate_steps > 1: self.current_loss = self.current_loss / self.accumulate_steps + self.caches['outputs'][cache_id] = self.current_loss.clone() + else: self._send_activations(cache_id) def _backward(self, cache_id): - assert self.optimizer is not None - if self.stage_id == self.num_stages - 1: + if self.is_last_stage: paddle.autograd.backward(self.caches['outputs'][cache_id]) self._send_gradients(cache_id) return @@ -194,92 +208,89 @@ class PipelineParallel(MetaParallelBase): grad_tensors = None if self.stage_id != 0: self._send_gradients(cache_id) self.caches['outputs'][cache_id] = None - #self.caches['backward_tensors'][cache_id] = None - def _get_data(self): - if self.use_model_parallel: - mp_rank = self._hcg.get_model_parallel_rank() + def _broadcast_data(self, data): + if isinstance(data, paddle.Tensor): + paddle.distributed.broadcast( + data, + src=self._hcg.get_model_parallel_group_src_rank(), + group=self._hcg.get_model_parallel_group()) else: - mp_rank = 0 - - # mp rank 0 loads the data and broadcat it to others. - data = self.data - if self.use_model_parallel and (self.stage_id == 0 or - self.stage_id == self.num_stages - 1): - assert isinstance(data, (tuple, paddle.Tensor)) - if isinstance(data, paddle.Tensor): + for d in data: + assert isinstance(d, paddle.Tensor) paddle.distributed.broadcast( - data, + d, src=self._hcg.get_model_parallel_group_src_rank(), group=self._hcg.get_model_parallel_group()) - else: - data = [] - for d in self.data: - assert isinstance(d, paddle.Tensor) - paddle.distributed.broadcast( - d, - src=self._hcg.get_model_parallel_group_src_rank(), - group=self._hcg.get_model_parallel_group()) - data.append(d) - data = tuple(data) return data def _load_micro_batch(self, cache_id): - inputs = self._get_data() - - if self.stage_id == 0: - data = None - #if isinstance(inputs[0], paddle.Tensor): - if len(inputs) == 1: - assert isinstance(inputs[0], paddle.Tensor) - data = inputs[0].clone().detach() - #data.stop_gradient = not is_float_tensor(data) - data.stop_gradient = True + inputs = self.data + begin = cache_id * self.micro_batch_size + end = begin + self.micro_batch_size + + if self.is_first_stage: + assert len(inputs) == 2, "length of input should be 2" + if self.use_model_parallel: + inputs[0] = self._broadcast_data(inputs[0]) + if isinstance(inputs[0], tuple): + batch_size = inputs[0][0].shape[0] + assert self.micro_batch_size * self.accumulate_steps == batch_size, ( + "batch_size needs to be divisible by micro_batch_size. Currently, " + "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d." + % + (batch_size, self.micro_batch_size, self.accumulate_steps)) + data = [ + input[begin:end, :].clone().detach() for input in inputs[0] + ] + self.caches['inputs'][cache_id] = tuple(data) + else: + batch_size = inputs[0].shape[0] + assert self.micro_batch_size * self.accumulate_steps == batch_size + self.caches['inputs'][cache_id] = inputs[0][begin:end, :].clone( + ).detach() + elif self.is_last_stage: + assert len(inputs) == 2, "length of input should be 2" + if self.use_model_parallel: + inputs[1] = self._broadcast_data(inputs[1]) + if isinstance(inputs[1], tuple): + batch_size = inputs[1][0].shape[0] + assert self.micro_batch_size * self.accumulate_steps == batch_size + data = [ + input[begin:end, :].clone().detach() for input in inputs[1] + ] + self.caches['labels'][cache_id] = tuple(data) else: - assert isinstance(inputs, tuple) - data = [] - for d in inputs: - assert isinstance(d, paddle.Tensor) - i = d.clone().detach() - #i.stop_gradient = not is_float_tensor(i) - i.stop_gradient = True - data.append(i) - data = tuple(data) - self.caches['inputs'][cache_id] = data - - if self.stage_id == self.num_stages - 1: - labels = None - #if isinstance(inputs[1], paddle.Tensor): - if len(inputs) == 1: - assert isinstance(inputs[0], paddle.Tensor) - labels = inputs[0] - elif isinstance(inputs, tuple): - labels = [] - for label in inputs: - assert isinstance(label, paddle.Tensor) - label = label.detach() - labels.append(label) - labels = tuple(labels) - self.caches['labels'][cache_id] = labels + batch_size = inputs[1].shape[0] + assert self.micro_batch_size * self.accumulate_steps == batch_size + self.caches['labels'][cache_id] = inputs[1][begin:end, :].clone( + ).detach() + else: + # No data input is required for other stages + inputs = None def _send_meta(self, data, peer): - """ - % type (0: tensor, 1: tuple) - % num_tensors if type=tuple - foreach tensor: - % ndims - % shape - """ if isinstance(data, paddle.Tensor): tensor_type = paddle.to_tensor([0]) + # send tensor type paddle.distributed.send( tensor_type, peer, use_calc_stream=True, group=self.pp_group) + + # send len(shape) dims = paddle.to_tensor(len(data.shape)) paddle.distributed.send( dims, peer, use_calc_stream=True, group=self.pp_group) + + # send shape shape = paddle.to_tensor(data.shape) paddle.distributed.send( shape, peer, use_calc_stream=True, group=self.pp_group) + + # send dtype + dtype = paddle.to_tensor(paddle_2_number(data.dtype)) + paddle.distributed.send( + dtype, peer, use_calc_stream=True, group=self.pp_group) + elif isinstance(data, tuple): tensor_type = paddle.to_tensor([1]) paddle.distributed.send( @@ -289,48 +300,73 @@ class PipelineParallel(MetaParallelBase): nums, peer, use_calc_stream=True, group=self.pp_group) for idx, d in enumerate(data): assert isinstance(d, paddle.Tensor) + # send len(shape) dims = paddle.to_tensor(len(d.shape)) paddle.distributed.send( dims, peer, use_calc_stream=True, group=self.pp_group) + + # send shape shape = paddle.to_tensor(d.shape) paddle.distributed.send( shape, peer, use_calc_stream=True, group=self.pp_group) + # send dtype + dtype = paddle.to_tensor(paddle_2_number(d.dtype)) + paddle.distributed.send( + dtype, peer, use_calc_stream=True, group=self.pp_group) + def _recv_meta(self, peer): tensor_type = paddle.to_tensor([0]) paddle.distributed.recv( tensor_type, peer, use_calc_stream=True, group=self.pp_group) - tensor_type = tensor_type.numpy()[0] + tensor_type = tensor_type.item() if tensor_type == 0: + # recv len(shape) dims = paddle.to_tensor([0]) paddle.distributed.recv( dims, peer, use_calc_stream=True, group=self.pp_group) - dims = dims.numpy()[0] + dims = dims.item() + + # recv shape shape = paddle.to_tensor([0] * dims) paddle.distributed.recv( shape, peer, use_calc_stream=True, group=self.pp_group) shape = shape.numpy().tolist() - return self._allocate_buffer( - shape, dtype="float32", num_caches=1)[0] + + # recv dtype + dtype = paddle.to_tensor([0]) + paddle.distributed.recv( + dtype, peer, use_calc_stream=True, group=self.pp_group) + return self._allocate_cache( + shape, dtype=number_2_dtype(dtype.item()), num_caches=1)[0] elif tensor_type == 1: num = paddle.to_tensor([0]) paddle.distributed.recv( num, peer, use_calc_stream=True, group=self.pp_group) - num = num.numpy()[0] + num = num.item() shapes = [] + dtypes = [] for i in range(num): + # recv len(shape) dims = paddle.to_tensor([0]) paddle.distributed.recv( dims, peer, use_calc_stream=True, group=self.pp_group) - dims = dims.numpy()[0] + + # recv shape + dims = dims.item() shape = paddle.to_tensor([0] * dims) paddle.distributed.recv( shape, peer, use_calc_stream=True, group=self.pp_group) shapes.append(shape.numpy().tolist()) - dtypes = ["float32"] * len(shapes) - caches = self._allocate_buffers(shapes, dtypes, num_caches=1)[0] + # recv dtype + dtype = paddle.to_tensor([0]) + paddle.distributed.recv( + dtype, peer, use_calc_stream=True, group=self.pp_group) + dtypes.append(number_2_dtype(dtype.item())) + + caches = self._allocate_caches(shapes, dtypes, num_caches=1)[0] caches = tuple(caches) return caches @@ -357,7 +393,6 @@ class PipelineParallel(MetaParallelBase): def _send_gradients(self, cache_id): inputs = self.caches['inputs'][cache_id] - if isinstance(inputs, paddle.Tensor): assert inputs.grad is not None paddle.distributed.send( @@ -371,7 +406,6 @@ class PipelineParallel(MetaParallelBase): if not is_float_tensor(d): assert d.grad is None continue - assert d.grad is not None paddle.distributed.send( d.grad, self.prev_stage_id, @@ -381,8 +415,6 @@ class PipelineParallel(MetaParallelBase): def _recv_activations(self, cache_id): inputs = None - - # Allocate the buffer if necessary if self.recv_cache is None: self.recv_cache = self._recv_meta(self.prev_stage_id) @@ -419,14 +451,16 @@ class PipelineParallel(MetaParallelBase): if self.grad_tensors is None: if isinstance(outputs, paddle.Tensor): s = list(outputs.shape) - dtype = 'float16' if self.use_amp else "float32" - self.grad_tensors = self._allocate_buffer( - s, dtype, num_buffers=1)[0] + dtype = get_tensor_dtype(outputs.dtype) + self.grad_tensors = self._allocate_cache( + s, dtype, num_caches=1)[0] else: sizes = [list(d.shape) for d in outputs if is_float_tensor(d)] - dtypes = ['float16'] * len( - sizes) if self.use_amp else ['float32'] * len(sizes) - self.grad_tensors = self._allocate_buffers( + dtypes = [ + get_tensor_dtype(d.dtype) for d in outputs + if is_float_tensor(d) + ] + self.grad_tensors = self._allocate_caches( sizes, dtypes, num_caches=1)[0] if isinstance(self.grad_tensors, paddle.Tensor): @@ -445,9 +479,10 @@ class PipelineParallel(MetaParallelBase): group=self.pp_group) def _step(self): - self._allreduce_grads() self.optimizer.step() - self.optimizer.clear_gradients() + self.optimizer.clear_grad() + if self.lr_scheduler: + self.lr_scheduler.step() def _clear_grads(self, inputs): if isinstance(inputs, paddle.Tensor): @@ -461,7 +496,7 @@ class PipelineParallel(MetaParallelBase): def _allocate_zeros(self, shape, dtype): return paddle.zeros(shape, dtype) - def _allocate_buffer(self, shape, dtype, num_caches=-1): + def _allocate_cache(self, shape, dtype, num_caches=-1): caches = [] if num_caches == -1: num_caches = self.num_caches @@ -469,7 +504,7 @@ class PipelineParallel(MetaParallelBase): caches.append(self._allocate_zeros(shape, dtype)) return caches - def _allocate_buffers(self, shapes, dtypes, num_caches=-1): + def _allocate_caches(self, shapes, dtypes, num_caches=-1): caches = [] if num_caches == -1: num_caches = self.num_caches @@ -488,11 +523,5 @@ class PipelineParallel(MetaParallelBase): state_dict = paddle.load(self.model_path) self._layers.set_state_dict(state_dict) - _COMMAND_MAP = { - utils.Optimize: _step, - utils.Forward: _forward, - utils.Backward: _backward, - } - def forward(self, *inputs, **kwargs): raise RuntimeError("Call train_batch for pipeline instead of forward.") diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index e5c5709f98d9577d742ee7eabac259459eef79b1..8c204820b16615db10968928da2c7c1867b0e6bf 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -14,20 +14,51 @@ import abc import paddle -from ...utils import hybrid_parallel_util as hp_util +from ...utils import log_util as hp_util __all__ = [] -FLOAT_TYPES = [ - paddle.float16, - paddle.float32, - paddle.float64, -] +FLOAT_TYPE_DICT = { + paddle.float16: "float16", + paddle.float32: "float32", + paddle.float64: "float64", +} + +PADDLE_TO_NUMBER = { + paddle.float16: 0, + paddle.float32: 1, + paddle.float64: 2, + paddle.int32: 3, + paddle.int64: 4 +} + +NUMBER_TO_DTYPE = { + 0: "float16", + 1: "float32", + 2: "float64", + 3: "int32", + 4: "int64" +} def is_float_tensor(tensor): """Is a float tensor""" - return tensor.dtype in FLOAT_TYPES + return tensor.dtype in FLOAT_TYPE_DICT.keys() + + +def get_tensor_dtype(dtype): + assert dtype in FLOAT_TYPE_DICT.keys() + return FLOAT_TYPE_DICT[dtype] + + +def paddle_2_number(dtype): + assert dtype in PADDLE_TO_NUMBER.keys() + return PADDLE_TO_NUMBER[dtype] + + +def number_2_dtype(number): + assert number in NUMBER_TO_DTYPE.keys() + return NUMBER_TO_DTYPE[number] def get_tensor_bytes(tensor): @@ -48,78 +79,3 @@ def get_tensor_bytes(tensor): else: raise ValueError("unknown data type: {}".format(tensor.dtype)) return tensor.numel() * elem_size - - -class Generator(): - def __init__(self, micro_batches, stages, stage_id): - __metaclass__ = abc.ABCMeta - - self.micro_batches = micro_batches - self.stages = stages - self.stage_id = stage_id - self.prev_stage = self.stage_id - 1 - self.next_stage = self.stage_id + 1 - - @abc.abstractmethod - def generate(self): - pass - - def __iter__(self): - self.iter = None - return self - - def __next__(self): - if self.iter is None: - self.iter = self.generate() - return next(self.iter) - - -class TrainGenerator(Generator): - def generate(self): - startup_steps = self.stages - self.stage_id - 1 - cmds = [] - forward_steps = 0 - backward_steps = 0 - #while (forward_steps < startup_steps): - # cmds.append(Forward(cache_id=forward_steps)) - # forward_steps += 1 - #while (forward_steps < self.micro_batches): - # cmds.append(Forward(cache_id=forward_steps)) - # forward_steps += 1 - # cmds.append(Backward(cache_id=backward_steps)) - # backward_steps += 1 - #while (backward_steps < self.micro_batches): - # cmds.append(Backward(cache_id=backward_steps)) - # backward_steps += 1 - #cmds.append(Optimize()) - while (forward_steps < self.micro_batches): - cmds.append(Forward(cache_id=forward_steps)) - forward_steps += 1 - while (backward_steps < self.micro_batches): - cmds.append(Backward(cache_id=backward_steps)) - backward_steps += 1 - cmds.append(Optimize()) - yield cmds - - -class Command: - def __init__(self, **kwargs): - self.name = self.__class__.__name__ - self.kwargs = kwargs - for key, val in kwargs.items(): - setattr(self, key, val) - - def __repr__(self): - return hp_util.call_to_str(self.name, **self.kwargs) - - -class Optimize(Command): - pass - - -class Forward(Command): - pass - - -class Backward(Command): - pass diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 4de369fc1ca80bbb3d8affc30308574128203009..c4a256f0e193d750516baffb6184e273ce1ba246 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -23,7 +23,8 @@ list(APPEND DIST_TEST_OPS test_gen_nccl_id_op) list(APPEND DIST_TEST_OPS test_parallel_dygraph_unused_variables) list(APPEND DIST_TEST_OPS test_parallel_dygraph_control_flow) list(APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel) -list(APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_layer) +list(APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_parallel) +list(APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers) set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS}) #remove distribute unittests. @@ -179,7 +180,8 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sync_batch_norm) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_control_flow) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel) - list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_pipeline_layer) + list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_pipeline_parallel) + list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_tensor_parallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers) LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision) LIST(REMOVE_ITEM TEST_OPS test_fleet_base_single) @@ -558,7 +560,7 @@ if(WITH_DISTRIBUTE) set(dist_ut_port 20001) foreach(TEST_OP ${DIST_TEST_OPS}) bash_test_modules(${TEST_OP} START_BASH dist_test.sh SERIAL LABELS "RUN_TYPE=EXCLUSIVE" ENVS "PADDLE_DIST_UT_PORT=${dist_ut_port}") - MATH(EXPR dist_ut_port "${dist_ut_port}+40") + MATH(EXPR dist_ut_port "${dist_ut_port}+35") if(dist_ut_port GREATER_EQUAL 22998) message(FATAL_ERROR "available ports have been exhausted:${dist_ut_port}") endif() @@ -866,7 +868,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 120) - set_tests_properties(test_parallel_dygraph_pipeline_layer PROPERTIES TIMEOUT 120) + set_tests_properties(test_parallel_dygraph_pipeline_parallel PROPERTIES TIMEOUT 120) + set_tests_properties(test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT 200) set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120) if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212) set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_model.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_model.py index 767bf5d57e74aff64d13170267785c6a8ed4347b..a9f251f3079cef2860c2599f6a8d33abf8da5fb8 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_model.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_model.py @@ -37,6 +37,7 @@ hidden_size = 10 inner_size = 8 output_size = 2 seq_length = 2 +batch_size = 4 class SimpleMPNet(fluid.dygraph.Layer): @@ -130,18 +131,6 @@ class SimpleDPNet(fluid.dygraph.Layer): return x -class TrainDataset(Dataset): - def __init__(self, length): - self.length = length - - def __len__(self): - return self.length - - def __getitem__(self, index): - np_input_data = np.random.randint(0, vocab_size, (seq_length, )) - return np_input_data - - class TestDistMPTraning(unittest.TestCase): def setUp(self): strategy = fleet.DistributedStrategy() @@ -178,20 +167,6 @@ class TestDistMPTraning(unittest.TestCase): np_fc1 = np.random.random_sample((hidden_size, inner_size)) np_fc2 = np.random.random_sample((inner_size, hidden_size)) - train_data = TrainDataset(length=10000) - - train_batch_sampler = paddle.io.DistributedBatchSampler( - train_data, - batch_size=4, - shuffle=False, - num_replicas=self.data_parallel_size, - rank=dp_id) - train_data_loader = DataLoader( - dataset=train_data, - batch_sampler=train_batch_sampler, - num_workers=0, - return_list=True) - model_a = SimpleMPNet(vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2, mp_id) optimizer_a = self.build_optimizer(model_a) @@ -202,16 +177,17 @@ class TestDistMPTraning(unittest.TestCase): np_fc1, np_fc2) optimizer_b = self.build_optimizer(model_b) - return model_a, optimizer_a, model_b, optimizer_b, train_data_loader + return model_a, optimizer_a, model_b, optimizer_b def test_mp_model(self): - model_a, optimizer_a, model_b, optimizer_b, train_data_loader = self.build_model_optimizer( + model_a, optimizer_a, model_b, optimizer_b = self.build_model_optimizer( ) - for step, batch in enumerate(train_data_loader): - if step > 5: - return - + for _ in range(5): + np_data = np.random.randint(0, vocab_size, ( + batch_size, + seq_length, )) + batch = paddle.to_tensor(np_data) loss_a = self.train_batch(batch, model_a, optimizer_a, True) loss_b = self.train_batch(batch, model_b, optimizer_b, False) diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_alexnet.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_alexnet.py new file mode 100644 index 0000000000000000000000000000000000000000..14d7e960f4a68cfcc7101b353eb379fee0739074 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_alexnet.py @@ -0,0 +1,120 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest +import paddle +import numpy as np +import random +import paddle +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +from hybrid_parallel_pp_layer import AlexNetPipeDesc, AlexNet + + +def set_random_seed(seed, dp_id, rank_id): + """Set random seed for reproducability.""" + random.seed(seed) + np.random.seed(seed + dp_id) + paddle.seed(seed + dp_id) + + +batch_size = 4 +micro_batch_size = 2 + + +class TestDistPPTraning(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 1 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + strategy.pipeline_configs = { + "accumulate_steps": batch_size // micro_batch_size, + "micro_batch_size": micro_batch_size + } + fleet.init(is_collective=True, strategy=strategy) + + def test_pp_model(self): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + dp_id = hcg.get_data_parallel_rank() + pp_id = hcg.get_stage_id() + rank_id = dist.get_rank() + set_random_seed(1024, dp_id, rank_id) + + #construct model a + model_a = AlexNet(10) + scheduler_a = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2], values=[0.001, 0.002], verbose=True) + optimizer_a = paddle.optimizer.SGD(learning_rate=scheduler_a, + parameters=model_a.parameters()) + + param_len = len(model_a.parameters()) + + parameters = [] + for param in model_a.parameters(): + parameters.append(param.numpy()) + + # construct model b + model_b = AlexNetPipeDesc(num_stages=self.pipeline_parallel_size) + scheduler_b = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2], values=[0.001, 0.002], verbose=True) + optimizer_b = paddle.optimizer.SGD(learning_rate=scheduler_b, + parameters=model_b.parameters()) + model_b = fleet.distributed_model(model_b) + optimizer_b = fleet.distributed_optimizer(optimizer_b) + + for idx, param in enumerate(model_b.parameters()): + param.set_value(parameters[idx + pp_id * (param_len // 2)]) + + # construct reader + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=batch_size, drop_last=True) + + for step_id, data in enumerate(train_reader()): + x_data = np.array([x[0] for x in data]).astype('float32').reshape( + batch_size, 1, 28, 28) + y_data = np.array([x[1] for x in data]).astype('int64').reshape( + batch_size, 1) + img = paddle.to_tensor(x_data) + label = paddle.to_tensor(y_data) + img.stop_gradient = True + label.stop_gradient = True + + if step_id >= 5: + return True + + loss_a = model_a(img, label) + loss_a.backward() + optimizer_a.step() + optimizer_a.clear_grad() + scheduler_a.step() + + loss_b = model_b.train_batch([img, label], optimizer_b, scheduler_b) + + print("loss: ", loss_a.numpy(), loss_b.numpy()) + np.testing.assert_allclose( + loss_a.numpy(), loss_b.numpy(), rtol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_embedding.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..d2be0cb80722b4c540b10f05027de4514dbb5a4f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_embedding.py @@ -0,0 +1,208 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest +import paddle +import numpy as np +import random +import paddle +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +from paddle.fluid.dygraph.container import Sequential +from paddle.distributed.fleet.meta_parallel import PipelineLayer +from paddle.fluid.dygraph.layers import Layer +import paddle.nn as nn +import paddle.fluid as fluid + + +def set_random_seed(seed, dp_id, rank_id): + """Set random seed for reproducability.""" + random.seed(seed) + np.random.seed(seed + dp_id) + paddle.seed(seed + dp_id) + + +batch_size = 16 +micro_batch_size = 4 +vocab_size = 128 +hidden_size = 8 + + +class SimpleNet(Layer): + def __init__(self): + super(SimpleNet, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size) + + self.softmax_weight = self.create_parameter( + shape=[hidden_size, vocab_size]) + self.softmax_bias = self.create_parameter( + shape=[vocab_size], is_bias=False) + + def forward(self, x1, x2, y1): + x_emb = self.word_embeddings(x1) + fc = fluid.layers.matmul(x_emb, self.softmax_weight) + fc = fluid.layers.elementwise_add(fc, self.softmax_bias) + projection = fluid.layers.reshape(fc, shape=[-1, vocab_size]) + loss = fluid.layers.softmax_with_cross_entropy( + logits=projection, label=y1, soft_label=False) + return loss.mean() + + +class EmbeddingNet(Layer): + def __init__(self): + super(EmbeddingNet, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size) + + def forward(self, args): + x1, x2 = args + x_emb = self.word_embeddings(x1) + return x_emb, x2 + + +class MatmulNet(Layer): + def __init__(self): + super(MatmulNet, self).__init__() + self.softmax_weight = self.create_parameter( + shape=[hidden_size, vocab_size]) + + def forward(self, args): + x1, x2 = args + fc = fluid.layers.matmul(x1, self.softmax_weight) + + return fc, x2 + + +class BiasNet(Layer): + def __init__(self): + super(BiasNet, self).__init__() + self.softmax_bias = self.create_parameter(shape=[vocab_size]) + + def forward(self, args): + fc, x2 = args + fc = fluid.layers.elementwise_add(fc, self.softmax_bias) + projection = fluid.layers.reshape(fc, shape=[-1, vocab_size]) + return projection, x2 + + +class LossNet(Layer): + def __init__(self): + super(LossNet, self).__init__() + + def forward(self, args, y1): + projection, x2 = args + loss = fluid.layers.softmax_with_cross_entropy( + logits=projection, label=y1[0], soft_label=False) + return loss.mean() + + +class SimpleNetPipe(Layer): + def __init__(self): + super(SimpleNetPipe, self).__init__() + self.features = Sequential(EmbeddingNet(), MatmulNet(), BiasNet()) + + def to_layers(self): + feat = [self.features[i] for i in range(len(self.features))] + return feat + + +class TestDistEmbeddingTraning(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 1 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + strategy.pipeline_configs = { + "accumulate_steps": batch_size // micro_batch_size, + "micro_batch_size": micro_batch_size + } + fleet.init(is_collective=True, strategy=strategy) + + def test_pp_model(self): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + dp_id = hcg.get_data_parallel_rank() + pp_id = hcg.get_stage_id() + rank_id = dist.get_rank() + set_random_seed(1024, dp_id, rank_id) + + #construct model a + model_a = SimpleNet() + scheduler_a = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2, 3, 4], values=[0.01, 0.02, 0.03, 0.04], verbose=True) + optimizer_a = paddle.optimizer.SGD(learning_rate=scheduler_a, + parameters=model_a.parameters()) + + init_net = SimpleNetPipe() + model_b = PipelineLayer( + layers=init_net.to_layers(), + num_stages=self.pipeline_parallel_size, + loss_fn=LossNet()) + + scheduler_b = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2, 3, 4], values=[0.01, 0.02, 0.03, 0.04], verbose=True) + optimizer_b = paddle.optimizer.SGD(learning_rate=scheduler_b, + parameters=model_b.parameters()) + model_b = fleet.distributed_model(model_b) + optimizer_b = fleet.distributed_optimizer(optimizer_b) + + param_len = len(model_a.parameters()) + + parameters = [] + for param in model_a.parameters(): + print(param.name, param.shape) + parameters.append(param.numpy()) + + model_b_params = model_b.parameters() + if pp_id == 0: + model_b_params[0].set_value(parameters[2]) + else: + model_b_params[0].set_value(parameters[0]) + model_b_params[1].set_value(parameters[1]) + + for step in range(5): + x1_data = np.random.randint(0, vocab_size, size=[batch_size, 1]) + x2_data = np.random.randint(0, vocab_size, size=[batch_size, 1]) + y1_data = np.random.randint(0, 10, size=[batch_size, 1]) + + x1 = paddle.to_tensor(x1_data) + x2 = paddle.to_tensor(x2_data) + y1 = paddle.to_tensor(y1_data) + + x1.stop_gradient = True + x2.stop_gradient = True + y1.stop_gradient = True + + loss_a = model_a(x1, x2, y1) + loss_a.backward() + optimizer_a.step() + optimizer_a.clear_grad() + scheduler_a.step() + + loss_b = model_b.train_batch([(x1, x2), (y1, )], optimizer_b, + scheduler_b) + + print("loss", loss_a.numpy(), loss_b.numpy()) + np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy()) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer.py index 3130cbf458467acfc70d38a438aa845c40584469..b30df0e9a2f21ba6d9dea0624d3129eae9b32d74 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer.py @@ -12,17 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest import numpy as np import os import paddle from paddle.distributed import fleet -import copy from paddle.fluid.dygraph.container import Sequential import paddle.nn as nn from paddle.fluid.dygraph.layers import Layer from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer import paddle.nn.functional as F -import unittest + + +class ReshapeHelp(Layer): + def __init__(self, shape): + super(ReshapeHelp, self).__init__() + self.shape = shape + + def forward(self, x): + return x.reshape(shape=self.shape) class AlexNet(Layer): @@ -30,7 +38,7 @@ class AlexNet(Layer): super(AlexNet, self).__init__() self.features = Sequential( nn.Conv2D( - 3, 64, kernel_size=11, stride=4, padding=5), + 1, 64, kernel_size=11, stride=4, padding=5), nn.ReLU(), nn.MaxPool2D( kernel_size=2, stride=2), @@ -50,13 +58,14 @@ class AlexNet(Layer): nn.ReLU(), nn.MaxPool2D( kernel_size=2, stride=2), ) + + self.reshape_layer = ReshapeHelp(shape=[-1, 256]) self.classifier = nn.Linear(256, num_classes) self.loss_fn = nn.loss.CrossEntropyLoss() def forward(self, x, y): x = self.features(x) - x.flatten() - + x = self.reshape_layer(x) x = self.classifier(x) return self.loss_fn(x, y) @@ -64,7 +73,7 @@ class AlexNet(Layer): class AlexNetPipe(AlexNet): def to_layers(self): feat = [self.features[i] for i in range(len(self.features))] - loss_fn = [lambda x: x.flatten(), self.classifier] + loss_fn = [self.reshape_layer, self.classifier] feat.extend(loss_fn) return feat @@ -74,7 +83,7 @@ class AlexNetPipeDesc(PipelineLayer): self.num_classes = num_classes decs = [ LayerDesc( - nn.Conv2D, 3, 64, kernel_size=11, stride=4, padding=5), + nn.Conv2D, 1, 64, kernel_size=11, stride=4, padding=5), LayerDesc(nn.ReLU), LayerDesc( nn.MaxPool2D, kernel_size=2, stride=2), @@ -94,7 +103,8 @@ class AlexNetPipeDesc(PipelineLayer): F.relu, LayerDesc( nn.MaxPool2D, kernel_size=2, stride=2), - lambda x: x.flatten(), + LayerDesc( + ReshapeHelp, shape=[-1, 256]), LayerDesc(nn.Linear, 256, self.num_classes), # classifier ] super(AlexNetPipeDesc, self).__init__( @@ -104,24 +114,24 @@ class AlexNetPipeDesc(PipelineLayer): class TestPipeLayerAPI(unittest.TestCase): def setUp(self): strategy = fleet.DistributedStrategy() - self.model_parallel_size = 2 + self.pipeline_parallel_size = 2 strategy.hybrid_configs = { "dp_degree": 1, "mp_degree": 1, - "pp_degree": self.model_parallel_size + "pp_degree": self.pipeline_parallel_size } fleet.init(is_collective=True, strategy=strategy) self.hcg = fleet.get_hybrid_communicate_group() def test_pipelayer_desc(self): - pipe_model = AlexNetPipeDesc(num_stages=self.model_parallel_size) + pipe_model = AlexNetPipeDesc(num_stages=self.pipeline_parallel_size) np.testing.assert_array_equal(len(pipe_model.parameters()), 6) def test_pipelayer_sequential(self): init_net = AlexNetPipe() pipe_model = PipelineLayer( layers=init_net.to_layers(), - num_stages=self.model_parallel_size, + num_stages=self.pipeline_parallel_size, loss_fn=nn.CrossEntropyLoss()) stage_id = self.hcg.get_stage_id() init_parameters = init_net.parameters() diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_model.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_model.py deleted file mode 100644 index 9b9283a1a9b6ea9e92246db974f501170fc4cb50..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_model.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import division -from __future__ import print_function - -import paddle -import numpy as np -import random -import paddle.distributed as dist -import paddle.fluid as fluid -import paddle.distributed.fleet as fleet -from paddle.io import DataLoader, Dataset -import unittest - - -def set_random_seed(seed, dp_id, rank_id): - """Set random seed for reproducability.""" - random.seed(seed) - np.random.seed(seed + dp_id) - paddle.seed(seed + rank_id) - - -HIDDEN_DIM = 32 -LAYERS = 8 - - -def sequential_model(): - model = paddle.nn.Sequential( - paddle.nn.Linear(HIDDEN_DIM, HIDDEN_DIM), - paddle.nn.Linear(HIDDEN_DIM, HIDDEN_DIM), - paddle.nn.Linear(HIDDEN_DIM, HIDDEN_DIM), - paddle.nn.Linear(HIDDEN_DIM, HIDDEN_DIM), - paddle.nn.Linear(HIDDEN_DIM, HIDDEN_DIM), - paddle.nn.Linear(HIDDEN_DIM, HIDDEN_DIM), - paddle.nn.Linear(HIDDEN_DIM, HIDDEN_DIM), - paddle.nn.Linear(HIDDEN_DIM, HIDDEN_DIM), - paddle.nn.Linear(HIDDEN_DIM, 1), ) - return model - - -class TestDistPPTraning(unittest.TestCase): - def setUp(self): - strategy = fleet.DistributedStrategy() - self.model_parallel_size = 1 - self.data_parallel_size = 1 - self.pipeline_parallel_size = 2 - strategy.hybrid_configs = { - "dp_degree": self.data_parallel_size, - "mp_degree": self.model_parallel_size, - "pp_degree": self.pipeline_parallel_size, - } - strategy.pipeline_configs = {"accumulate_steps": 2} - paddle.distributed.init_parallel_env() - fleet.init(is_collective=True, strategy=strategy) - - def test_mp_model(self): - batch_input = paddle.randn(shape=(1, HIDDEN_DIM), dtype="float32") - pipe_model = sequential_model() - sgd = paddle.optimizer.SGD(learning_rate=0.0003, parameters=[]) - pipe_model = paddle.distributed.fleet.distributed_model(pipe_model) - - if pipe_model.stage_id == 0 or pipe_model.stage_id == 1: - pipe_input = batch_input.clone().detach() - pipe_input = paddle.cast(pipe_input, 'float32') - - def data_gen(): - gen = True - while gen: - yield [pipe_input, 0] - gen = False - - loader = paddle.io.DataLoader.from_generator(capacity=5) - loader.set_batch_generator(data_gen) - data_iter = iter(loader) - else: - data_iter = None - return True - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py index 5491b451368c825c10f1e957d85e30ccacdd1dc7..f3cd97ee1ec86916ecebb8ddf0895443c6e14567 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py @@ -17,8 +17,11 @@ from __future__ import print_function import unittest import time import paddle.fluid as fluid +import copy +import os +import subprocess -from paddle.distributed.utils import find_free_ports, watch_local_trainers, get_cluster, start_local_trainers +from paddle.distributed.utils import find_free_ports, watch_local_trainers, get_cluster, TrainerProc def get_cluster_from_args(selected_gpus): @@ -46,6 +49,55 @@ def get_gpus(selected_gpus): return selected_gpus +def start_local_trainers(cluster, + pod, + training_script, + training_script_args, + log_dir=None): + current_env = copy.copy(os.environ.copy()) + #paddle broadcast ncclUniqueId use socket, and + #proxy maybe make trainers unreachable, so delete them. + #if we set them to "", grpc will log error message "bad uri" + #so just delete them. + current_env.pop("http_proxy", None) + current_env.pop("https_proxy", None) + + procs = [] + for t in pod.trainers: + proc_env = { + "FLAGS_selected_gpus": "%s" % ",".join([str(g) for g in t.gpus]), + "PADDLE_TRAINER_ID": "%d" % t.rank, + "PADDLE_CURRENT_ENDPOINT": "%s" % t.endpoint, + "PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(), + "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()) + } + + current_env.update(proc_env) + + print("trainer proc env:{}".format(current_env)) + + if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': + cmd = "python -m coverage run --branch -p " + training_script + else: + cmd = "python -u " + training_script + + print("start trainer proc:{} env:{}".format(cmd, proc_env)) + + fn = None + + proc = subprocess.Popen(cmd.split(" "), env=current_env) + + tp = TrainerProc() + tp.proc = proc + tp.rank = t.rank + tp.log_fn = fn + tp.cmd = cmd + + procs.append(tp) + + return procs + + class TestMultipleGpus(unittest.TestCase): def run_mnist_2gpu(self, target_file_name): if not fluid.core.is_compiled_with_cuda( diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_layer.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py similarity index 89% rename from python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_layer.py rename to python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py index f3b89d694f70b96df70f4923b5af3433c7e2e26c..1d06e168208b279ccfba753452b1a82e5034975f 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_layer.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py @@ -24,6 +24,9 @@ class TestHybridPipeParallel(TestMultipleGpus): def test_hybrid_parallel_pp_layer(self): self.run_mnist_2gpu('hybrid_parallel_pp_layer.py') + def test_hybrid_parallel_pp_tuple_inputs(self): + self.run_mnist_2gpu('hybrid_parallel_pp_embedding.py') + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_hybrid_parallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_tensor_parallel.py similarity index 100% rename from python/paddle/fluid/tests/unittests/test_parallel_dygraph_hybrid_parallel.py rename to python/paddle/fluid/tests/unittests/test_parallel_dygraph_tensor_parallel.py diff --git a/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py b/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py index 7f8294ad0efe7536a27024fd30dbcdda15220efd..f62e160673f8d22ee895fe357d25c665859130c1 100644 --- a/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py +++ b/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py @@ -22,7 +22,7 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus class TestPipelineParallel(TestMultipleGpus): def test_pipeline_parallel(self): - self.run_mnist_2gpu('hybrid_parallel_pp_model.py') + self.run_mnist_2gpu('hybrid_parallel_pp_alexnet.py') if __name__ == "__main__":