diff --git a/python/paddle/distributed/auto_parallel/dist_loader.py b/python/paddle/distributed/auto_parallel/dist_loader.py index 5645235cb71f62d3da54566a8ec04f5314764ac8..44f720ade7f80c9771869a91dcaa0a0c6131d396 100644 --- a/python/paddle/distributed/auto_parallel/dist_loader.py +++ b/python/paddle/distributed/auto_parallel/dist_loader.py @@ -26,14 +26,7 @@ from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collat class DistributedDataLoader(metaclass=abc.ABCMeta): - def __init__(self, - dataset, - batch_size=1, - epochs=1, - data_parallel_world_size=None, - data_parallel_rank=None, - drop_last=False, - split_data=True): + def __init__(self, dataset, batch_size=1, epochs=1, drop_last=False): if isinstance(dataset, IterableDataset): self.dataset_kind = _DatasetKind.ITER else: @@ -42,19 +35,11 @@ class DistributedDataLoader(metaclass=abc.ABCMeta): self.dataset = dataset self.epochs = epochs self.drop_lost = drop_last - self.data_parallel_world_size = data_parallel_world_size - self.data_parallel_rank = data_parallel_rank - self.split_data = split_data if batch_size is None: self.batch_size = None self.batch_sampler = None else: - if data_parallel_world_size is not None: - for dp_world_size in data_parallel_world_size: - if dp_world_size is not None: - assert batch_size % dp_world_size == 0, \ - "batch_size must be divisible by dp_world_size value {}".format(str(dp_world_size)) self.batch_size = batch_size if isinstance(dataset, IterableDataset): self.batch_sampler = _InfiniteIterableSampler( @@ -97,18 +82,22 @@ class NonIterableGeneratorLoader(DistributedDataLoader): epochs=1, steps_per_epoch=None, collate_fn=None, - data_parallel_world_size=None, - data_parallel_rank=None, + data_parallel_world_size=[], + data_parallel_rank=[], drop_last=False, split_data=True): self.feed_list = feed_list self.places = places self.steps_per_epoch = steps_per_epoch + assert len(data_parallel_world_size) == len(feed_list) + assert len(data_parallel_rank) == len(feed_list) + self.dp_world_sizes = data_parallel_world_size + self.dp_ranks = data_parallel_rank + self.split_data = split_data + super(NonIterableGeneratorLoader, - self).__init__(dataset, batch_size, epochs, - data_parallel_world_size, data_parallel_rank, - drop_last, split_data) + self).__init__(dataset, batch_size, epochs, drop_last) if self.auto_collate_batch: self.collate_fn = collate_fn or default_collate_fn @@ -154,13 +143,12 @@ class NonIterableGeneratorLoader(DistributedDataLoader): def _create_inner_dataloader(self): - def sample_data_generator(): + def data_generator(): while True: try: indices = next(self.sampler_iter) batch = self.dataset_fetcher.fetch(indices) if batch is None: break - except StopIteration: self.dataset_fetcher = _DatasetKind.create_fetcher( self.dataset_kind, self.dataset, @@ -169,53 +157,23 @@ class NonIterableGeneratorLoader(DistributedDataLoader): break partial_data = [] - for i, d in enumerate(batch[:len(self.feed_list)]): + for i, d in enumerate(batch): array = np.array(d) if not self.split_data: partial_data.append(array) - elif self.dp_world_sizes[i] is not None: - partial_data.append( - np.split(array, - self.dp_world_sizes[i])[self.dp_ranks[i]]) - else: - partial_data.append(array) - yield partial_data + continue - def batch_data_generator(): - while True: - try: - indices = next(self.sampler_iter) + batch_size = array.shape[0] + assert batch_size % self.dp_world_sizes[i] == 0, \ + "batch_size [{}] is not divisible by dp_world_size [{}]".format(str(batch_size), str(self.dp_world_sizes[i])) + partial_data.append( + np.split(array, + self.dp_world_sizes[i])[self.dp_ranks[i]]) - batch = self.dataset_fetcher.fetch(indices) - if batch is None: break - except StopIteration: - break - - partial_data = [] - for i, d in enumerate(batch[:len(self.feed_list)]): - array = np.array(d) - if not self.split_data: - partial_data.append(array) - elif self.dp_world_sizes[i] is not None: - partial_data.append( - np.split(array, - self.dp_world_sizes[i])[self.dp_ranks[i]]) - else: - partial_data.append(array) yield partial_data - self.dp_world_sizes = [ - 1 for _ in range(len(self.feed_list)) - ] if self.data_parallel_world_size is None else self.data_parallel_world_size - self.dp_ranks = [ - 0 for _ in range(len(self.feed_list)) - ] if self.data_parallel_rank is None else self.data_parallel_rank - dataloader = paddle.fluid.io.DataLoader.from_generator( feed_list=self.feed_list, capacity=70, iterable=False) - if self.batch_size is not None: - dataloader.set_batch_generator(sample_data_generator, self.places) - else: - dataloader.set_batch_generator(batch_data_generator, self.places) + dataloader.set_batch_generator(data_generator, self.places) return dataloader diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 8d1a1488ac790ffbe5e204022398e2fe835001f0..a383780f7740298c37c74c5d4897e9122f7a0d5a 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -36,7 +36,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.distributed import fleet from paddle.distributed.passes import new_pass, PassContext -from .hepler import ProgramHelper +from .helper import ProgramHelper from ..collective import _get_global_env from .cluster import Cluster, get_default_cluster from .planner_v2 import Planner @@ -118,8 +118,7 @@ class Engine: "'optimizer' must be object of class `paddle.optimizer.Optimizer`" \ " or `paddle.fluid.optimizer.Optimizer`." ) - self._optimizer = optimizer - self._all_ranks = all_ranks + self._optimizer = self._validate_opt(optimizer) if loss and not isinstance(loss, paddle.nn.Layer) and not callable(loss): @@ -136,6 +135,7 @@ class Engine: self._metrics = to_list(metrics) self._gradient_scale = gradient_scale self._planned_mode = None + self._all_ranks = all_ranks self._prepare_single_mode("train") def _prepare_single_mode(self, mode): @@ -161,21 +161,23 @@ class Engine: self._dygraph_mode = True self._logger.info("Building model with 'to_static' method.") - program_helper = ProgramHelper(self.model, self._loss, - self._metrics, self.inputs_spec, - self.labels_spec) + inputs_spec = self.inputs_spec + labels_spec = self.labels_spec if self.labels_spec else [] + self.program_helper = ProgramHelper(self.model, self._loss, + self._metrics, inputs_spec, + labels_spec) # build forward main program - program_helper.build_program(mode) + self.program_helper.build_program(mode) - self.concrete_program = program_helper.concrete_program - serial_main_prog = program_helper.main_program - serial_startup_prog = program_helper.startup_program + self.concrete_program = self.program_helper.concrete_program + serial_main_prog = self.program_helper.main_program + serial_startup_prog = self.program_helper.startup_program - inputs = program_helper.input_vars - outputs = program_helper.output_vars - labels = program_helper.label_vars - losses = program_helper.loss_vars - metrics = program_helper.metric_vars + inputs = self.program_helper.input_vars + outputs = self.program_helper.output_vars + labels = self.program_helper.label_vars + losses = self.program_helper.loss_vars + metrics = self.program_helper.metric_vars paddle.enable_static() else: @@ -334,40 +336,17 @@ class Engine: continue process_group.instantiate() - self._place = _get_device() - if isinstance(self._place, fluid.CUDAPlace): - self._place = fluid.CUDAPlace(ParallelEnv().dev_id) + place = _get_device() + if isinstance(place, fluid.CUDAPlace): + place = fluid.CUDAPlace(ParallelEnv().dev_id) if self._dygraph_mode: - paddle.disable_static() - main_program = self._dist_main_progs[mode][self._cur_rank] - for param in self.concrete_program.parameters: - # create var in scope and share parameters to scope - if param.name not in main_program.global_block().vars: - continue - # get param_var's dist_attr - var = main_program.global_block().vars[param.name] - var_dist_attr = self._dist_contexts[ - mode].get_tensor_dist_attr_for_program(var) - dist_attr = { - "dims_mapping": var_dist_attr.dims_mapping, - "process_shape": var_dist_attr.process_mesh.topology, - "process_group": var_dist_attr.process_mesh.processes - } - # slice param_value with dist_attr - # share sliced_param_value with param_tensor in global_scope - from .converter import Converter - param_tensor = global_scope().var(param.name).get_tensor() - sliced_param = Converter.slice_with_dist_attr( - param.numpy(), dist_attr) - shared_tensor = paddle.to_tensor(sliced_param, - place=self._place) - param_tensor._share_data_with( - shared_tensor.value().get_tensor()) - paddle.enable_static() + dist_context = self._dist_contexts[mode] + dist_main_program = self._dist_main_progs[mode][self._cur_rank] + self.program_helper.init(dist_main_program, place, dist_context) if self._executor is None: - self._executor = paddle.static.Executor(self._place) + self._executor = paddle.static.Executor(place) uninitialized = [] dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank] for var in dist_startup_prog.list_vars(): @@ -411,7 +390,7 @@ class Engine: data = np.array(param_t) param_t.set(np.float16(data), place) - cast_parameters_to_fp16(self._place, prune_startup_prog) + cast_parameters_to_fp16(place, prune_startup_prog) def fit(self, train_data, @@ -577,15 +556,20 @@ class Engine: dist_context = self._dist_contexts[self.mode] dist_main_block = dist_main_prog.global_block() - # NOTE: Get feed_list from dist_program, then insert dataloader op - # with sharded var shape. Because predict_program does not contain - # labels var, so we will filter dataset's value with length of feed_list. + # NOTE: Get feed_list, then insert dataloader op with sharded var shape. + # Cause predict_program does not contain labels var, + # then we will add labels var from serial_program to dist_program, + # that maintains the length of feed_list equal to the length of dataset's values. inputs_var = self._feed_vars[self.mode]["inputs"] labels_var = self._feed_vars[self.mode]["labels"] feed_list = [] for var in inputs_var + labels_var: if var.name in dist_main_block.vars: feed_list.append(dist_main_block.vars[var.name]) + else: + copy_var = dist_main_block._clone_variable(var, var.persistable) + copy_var.desc.set_original_id(var.desc.original_id()) + feed_list.append(copy_var) # remove the first three ops if multi run fit/evaluate/predict op_size = len(dist_main_block.ops) @@ -688,7 +672,7 @@ class Engine: batch_size_axis, rank_id) return len(group_ranks), group_ranks.index(rank_id) - return None, None + return 1, 0 def _set_recompute_ckpts(self): # NOTE hack to enable recompute in engine api for GPT-3 @@ -717,6 +701,11 @@ class Engine: } self._logger.info(logs) + def _validate_opt(self, optimizer): + optimizer._parameter_list = None + optimizer._param_groups = None + return optimizer + def save(self, path, training=True, mode=None): if not mode: mode = self.mode diff --git a/python/paddle/distributed/auto_parallel/hepler.py b/python/paddle/distributed/auto_parallel/helper.py similarity index 85% rename from python/paddle/distributed/auto_parallel/hepler.py rename to python/paddle/distributed/auto_parallel/helper.py index 077b769116060c3973931c9e1f91d621994e3791..7a17ba65414cec96c4f5b2c33e7834b77bfbcbd7 100644 --- a/python/paddle/distributed/auto_parallel/hepler.py +++ b/python/paddle/distributed/auto_parallel/helper.py @@ -15,14 +15,18 @@ import logging from collections import defaultdict +import paddle + from paddle.nn import Layer from paddle.jit import to_static, not_to_static from paddle.distributed.utils import get_logger from paddle.fluid.framework import Operator, Parameter, _non_static_mode from paddle.fluid.framework import program_guard +from paddle.fluid.executor import global_scope from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction from .utils import to_list +from .converter import Converter class ProxyLayer(Layer): @@ -89,13 +93,14 @@ class ProxyLayer(Layer): # step 4. calculate metrics if needed self._metric_vars[mode] = self.call_metrics(new_inputs) - def _predict(self, inputs): + def _predict(self, inputs, labels): """ Predict process of inner_layer with forward logic. """ # step 1. save feed variables of Program mode = 'predict' self._input_vars[mode] = inputs + self._label_vars[mode] = labels # step 2. call inner_layer.forward self._output_vars[mode] = self.inner_layer(*inputs) @@ -165,6 +170,10 @@ class ProxyLayer(Layer): def metric_vars(self): return self._metric_vars[self.mode] + @property + def startup_program(self): + return self.inner_layer._startup_program() + class BuildInfo: @@ -199,6 +208,7 @@ class ProgramHelper(object): self.build_info = BuildInfo() self._logger = get_logger(logging.INFO) + self.lazy_init = False def reset(self): """ @@ -221,8 +231,7 @@ class ProgramHelper(object): return self._logger.info("start to build program for mode = %s." % mode) - input_spec = [self.inputs_spec, self.labels_spec - ] if mode != 'predict' else [self.inputs_spec] + input_spec = [self.inputs_spec, self.labels_spec] static_func = to_static(self.static_func(), input_spec=input_spec) func_name = '_' + mode @@ -238,6 +247,9 @@ class ProgramHelper(object): """ Create and Sync parameters into startup program. """ + if len(self.startup_program.global_block().ops) > 1: + self.lazy_init = True + return for param in self.concrete_program.parameters: Parameter(name=param.name, desc=param, @@ -294,6 +306,28 @@ class ProgramHelper(object): func_name = '_' + self.proxy_layer.mode return getattr(self.proxy_layer, func_name) + def init(self, main_program, place, dist_context): + if self.lazy_init: + return + for param in self.concrete_program.parameters: + # create var in scope and share parameters to scope + if param.name not in main_program.global_block().vars: + continue + # get param_var's dist_attr + var = main_program.global_block().vars[param.name] + var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var) + dist_attr = { + "dims_mapping": var_dist_attr.dims_mapping, + "process_shape": var_dist_attr.process_mesh.topology, + "process_group": var_dist_attr.process_mesh.processes + } + # slice param_value with dist_attr + # share sliced_param_value with param_tensor in global_scope + param_tensor = global_scope().var(param.name).get_tensor() + sliced_param = Converter.slice_with_dist_attr( + param.numpy(), dist_attr) + param_tensor.set(sliced_param, place) + @property def concrete_program(self): return self.static_func().concrete_program @@ -304,7 +338,12 @@ class ProgramHelper(object): @property def startup_program(self): - return self.concrete_program.startup_program + try: + return self.proxy_layer.startup_program + except Exception as err: + if isinstance(err, AssertionError): + return self.concrete_program.startup_program + raise err @property def input_vars(self): diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index 51eede57638ff8d3c4dc662619ee51f804a861da..d4bb81e6b222c28a129950e0384fde9287486c9d 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -145,12 +145,7 @@ class Parallelizer: params_grads): # NOTE: `apply_gradients` will add an Accumulator for a parameter only once, # but optimizer will be called repeatedly in re-launch, so optimizer need to be copied. - if self._dist_context._dygraph_mode: - paddle.disable_static() - optimizer = copy.deepcopy(optimizer) - paddle.enable_static() - else: - optimizer = copy.deepcopy(optimizer) + optimizer = copy.deepcopy(optimizer) self._dist_context._lr_optimizer = optimizer with program_guard(main_program, startup_program): with unique_name.guard("opt_"): @@ -222,6 +217,7 @@ class Parallelizer: config = {} config["dist_context"] = self._dist_context config["global_rank"] = rank + config["use_sharding"] = self._strategy.sharding dp_pass = new_pass("auto_parallel_data_parallel_optimization", config) dp_pass.apply([main_program], [startup_program], self._pass_context) diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index 6da39b063efa77f0efba1950108a15fb4432f49f..52d5c607bbc57e5ae65e3a4593585c0294d2b74f 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -270,15 +270,16 @@ class Inserter: dtype=tensor_type, type=tensor.type, lod_level=tensor.lod_level) - block._insert_op(idx, - type='cast', - inputs={'X': [tensor]}, - outputs={'Out': [out]}, - attrs={ - 'in_dtype': tensor.dtype, - 'out_dtype': out.dtype, - 'op_role': op_role - }) + cast_op = block._insert_op(idx, + type='cast', + inputs={'X': [tensor]}, + outputs={'Out': [out]}, + attrs={ + 'in_dtype': tensor.dtype, + 'out_dtype': out.dtype, + 'op_role': op_role + }) + cast_op._set_attr('op_namescope', "/auto_parallel/reshard") return out @staticmethod @@ -287,16 +288,17 @@ class Inserter: op_type = 'send_v2' # use pair comm group process_group = new_process_group([src, dst]) - block._insert_op(idx, - type=op_type, - inputs={'X': [tensor]}, - attrs={ - 'ring_id': process_group.id, - 'peer': process_group.ranks.index(dst), - 'use_calc_stream': True, - 'op_role': op_role, - 'dynamic_shape': True - }) + send_op = block._insert_op(idx, + type=op_type, + inputs={'X': [tensor]}, + attrs={ + 'ring_id': process_group.id, + 'peer': process_group.ranks.index(dst), + 'use_calc_stream': True, + 'op_role': op_role, + 'dynamic_shape': True + }) + send_op._set_attr('op_namescope', "/auto_parallel/reshard") @staticmethod def insert_recv_op(block, idx, tensor, src, dst, op_role): @@ -304,19 +306,20 @@ class Inserter: op_type = 'recv_v2' # use pair group process_group = new_process_group([src, dst]) - block._insert_op(idx, - type=op_type, - inputs={'X': [tensor]}, - outputs={'Out': [tensor]}, - attrs={ - 'ring_id': process_group.id, - 'peer': process_group.ranks.index(src), - 'out_shape': tensor.shape, - 'dtype': tensor.dtype, - 'use_calc_stream': True, - 'op_role': op_role, - 'dynamic_shape': True - }) + recv_op = block._insert_op(idx, + type=op_type, + inputs={'X': [tensor]}, + outputs={'Out': [tensor]}, + attrs={ + 'ring_id': process_group.id, + 'peer': process_group.ranks.index(src), + 'out_shape': tensor.shape, + 'dtype': tensor.dtype, + 'use_calc_stream': True, + 'op_role': op_role, + 'dynamic_shape': True + }) + recv_op._set_attr('op_namescope', "/auto_parallel/reshard") @staticmethod def insert_reset_lod_op(block, idx, X, Y, op_role): @@ -330,14 +333,15 @@ class Inserter: dtype=X.dtype, lod_level=X.lod_level) - block._insert_op(idx, - type="lod_reset", - inputs={ - 'X': X, - 'Y': Y - }, - outputs={'Out': reset_lod_out}, - attrs={'op_role': op_role}) + reset_op = block._insert_op(idx, + type="lod_reset", + inputs={ + 'X': X, + 'Y': Y + }, + outputs={'Out': reset_lod_out}, + attrs={'op_role': op_role}) + reset_op._set_attr('op_namescope', "/auto_parallel/reshard") return reset_lod_out @staticmethod @@ -359,11 +363,12 @@ class Inserter: type=tensors[0].type, persistable=False, stop_gradient=False) - block._insert_op(idx, - type='concat', - inputs=inputs, - outputs={'Out': [out]}, - attrs=attrs) + concat_op = block._insert_op(idx, + type='concat', + inputs=inputs, + outputs={'Out': [out]}, + attrs=attrs) + concat_op._set_attr('op_namescope', "/auto_parallel/reshard") return out @staticmethod @@ -391,11 +396,12 @@ class Inserter: inputs = {'X': [tensor]} outputs = {"Out": [out]} attrs = {"in_place": False} - block._insert_op(idx, - type="assign", - inputs=inputs, - outputs=outputs, - attrs=attrs) + slice_op = block._insert_op(idx, + type="assign", + inputs=inputs, + outputs=outputs, + attrs=attrs) + slice_op._set_attr('op_namescope', "/auto_parallel/reshard") return out # use split once @@ -427,11 +433,12 @@ class Inserter: for i in range(num_or_sections) ] out = outs[cur_idx] - op = block._insert_op(idx, - type="split", - inputs=inputs, - outputs={'Out': outs}, - attrs=attrs) + split_op = block._insert_op(idx, + type="split", + inputs=inputs, + outputs={'Out': outs}, + attrs=attrs) + split_op._set_attr('op_namescope', "/auto_parallel/reshard") return out # use slice @@ -449,12 +456,12 @@ class Inserter: dtype=tensor.dtype, type=tensor.type, lod_level=tensor.lod_level) - block._insert_op(idx, - type="slice", - inputs=inputs, - outputs={'Out': [out]}, - attrs=attrs) - + slice_op = block._insert_op(idx, + type="slice", + inputs=inputs, + outputs={'Out': [out]}, + attrs=attrs) + slice_op._set_attr('op_namescope', "/auto_parallel/reshard") return out @staticmethod @@ -482,11 +489,12 @@ class Inserter: persistable=False, stop_gradient=False) for i in range(num_or_sections) ] - block._insert_op(idx, - type="split", - inputs=inputs, - outputs={'Out': outs}, - attrs=attrs) + split_op = block._insert_op(idx, + type="split", + inputs=inputs, + outputs={'Out': outs}, + attrs=attrs) + split_op._set_attr('op_namescope', "/auto_parallel/reshard") return outs @staticmethod @@ -514,12 +522,13 @@ class Inserter: attrs=attrs, shape=[0], op_type='fill_constant') - block._insert_op(idx, - type='fill_constant', - inputs=inputs, - outputs={'Out': [out]}, - attrs=attrs) + fillconstant_op = block._insert_op(idx, + type='fill_constant', + inputs=inputs, + outputs={'Out': [out]}, + attrs=attrs) out.stop_gradient = True + fillconstant_op._set_attr('op_namescope', "/auto_parallel/reshard") return out @staticmethod @@ -537,22 +546,25 @@ class Inserter: fill_constant_out.stop_gradient = True # insert c_allreduce_sum op - block._insert_op(idx + 1, - type="c_allreduce_sum", - inputs={'X': [fill_constant_out]}, - outputs={'Out': [fill_constant_out]}, - attrs={ - 'ring_id': 0, - 'use_calc_stream': True, - 'op_role': op_role - }) - + allreduce_op = block._insert_op( + idx + 1, + type="c_allreduce_sum", + inputs={'X': [fill_constant_out]}, + outputs={'Out': [fill_constant_out]}, + attrs={ + 'ring_id': 0, + 'use_calc_stream': True, + 'op_role': op_role + }) + allreduce_op._set_attr('op_namescope', "/auto_parallel/reshard") # insert c_sync_calc_stream op - block._insert_op(idx + 2, - type="c_sync_calc_stream", - inputs={'X': [fill_constant_out]}, - outputs={'Out': [fill_constant_out]}, - attrs={'op_role': op_role}) + sync_calc_op = block._insert_op( + idx + 2, + type="c_sync_calc_stream", + inputs={'X': [fill_constant_out]}, + outputs={'Out': [fill_constant_out]}, + attrs={'op_role': op_role}) + sync_calc_op._set_attr('op_namescope', "/auto_parallel/reshard") idx_offset = 3 # insert c_allgather op @@ -569,16 +581,17 @@ class Inserter: type=tensor.type, persistable=False, stop_gradient=False) - block._insert_op(idx + idx_offset, - type=op_type, - inputs={'X': [tensor]}, - outputs={'Out': [allgather_out]}, - attrs={ - 'ring_id': group.id, - 'use_calc_stream': True, - 'nranks': group.nranks, - 'op_role': op_role - }) + allgather_op = block._insert_op(idx + idx_offset, + type=op_type, + inputs={'X': [tensor]}, + outputs={'Out': [allgather_out]}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'nranks': group.nranks, + 'op_role': op_role + }) + allgather_op._set_attr('op_namescope', "/auto_parallel/reshard") idx_offset += 1 # insert split op diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index d97209f7fe5c5124bf7501a9d554d63445bb7258..3f3448b5008e6323d51086bbfbaa465215b3dda7 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -26,6 +26,7 @@ from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_fp32_input, _k from paddle.fluid.contrib.mixed_precision.fp16_utils import _valid_types, find_true_post_op, find_true_prev_op from paddle.fluid.contrib.mixed_precision.fp16_utils import _is_in_black_varnames, _dtype_to_str, _rename_arg from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute +from ..auto_parallel.utils import is_forward_op, is_backward_op, is_loss_op world_process_group = get_world_process_group() @@ -222,21 +223,33 @@ class AMPState(object): loss_op = get_loss_op(self._block) loss_op_index = find_op_index(self._block.desc, loss_op.desc) + appended_grad_times = 0 idx = loss_op_index + 1 while idx < len(ops): num_cast_ops = 0 grad_op = ops[idx] + + # NOTE: the map in `grad_var_to_var` may be changed when the var is casted, + # which will affect the dist_op to insert allreduce_sum op. + op_dist_attr = dist_context.get_op_dist_attr_for_program(grad_op) + if is_backward_op(grad_op) and (is_forward_op(ops[idx - 1]) + or is_loss_op(ops[idx - 1])): + if not op_dist_attr.is_recompute: + appended_grad_times += 1 + grad_op_orig_id = grad_op.desc.original_id() dist_op_context = dist_context.dist_op_context if grad_op_orig_id in dist_op_context.grad_op_id_to_op_id: if self._is_fp16_op(grad_op_orig_id) == False: # fp32 num_cast_ops = self._insert_cast_op_backward( grad_op, idx, core.VarDesc.VarType.FP16, - core.VarDesc.VarType.FP32, dist_context) + core.VarDesc.VarType.FP32, dist_context, + appended_grad_times) elif self._is_fp16_op(grad_op_orig_id) == True: # fp16 num_cast_ops = self._insert_cast_op_backward( grad_op, idx, core.VarDesc.VarType.FP32, - core.VarDesc.VarType.FP16, dist_context) + core.VarDesc.VarType.FP16, dist_context, + appended_grad_times) elif grad_op.type == "sum": in_var_name = grad_op.desc.input_arg_names()[0] src_dtype = self._block.var(in_var_name).dtype @@ -258,7 +271,7 @@ class AMPState(object): _update_backward_cast_ops(params_grads, dist_context) def _insert_cast_op_backward(self, grad_op, idx, src_dtype, dst_dtype, - dist_context): + dist_context, appended_grad_times): """ only for backward cast """ def _keep_fp32_input(op, in_name): @@ -328,7 +341,10 @@ class AMPState(object): grad_op) fwd_cast_name = self._var_name_dict[fwd_op_id][ out_var_name_prefix] - cast_name = fwd_cast_name + "@GRAD" + suffix = "" + if "@RENAME" in out_var_name: + suffix = out_var_name[out_var_name.find("@RENAME"):] + cast_name = fwd_cast_name + "@GRAD" + suffix cast_var = self._block.vars.get(cast_name) if cast_var is None or cast_var.dtype != dst_dtype: grad_op.desc._rename_output(out_var_name, cast_name) @@ -347,6 +363,8 @@ class AMPState(object): stop_gradient=out_var.stop_gradient) set_var_dist_attr(dist_context, cast_var, ref_mapping, ref_mesh) + dist_op_context.grad_var_to_var[ + appended_grad_times][cast_name] = fwd_cast_name cast_op = self._block._insert_op( idx + 1, diff --git a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py index 3c5403c8254b9a5123ee97c537168efe95daa774..586ad235fd15a3395e6e27f03b3fd5a5e40a2c8e 100644 --- a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py +++ b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py @@ -45,6 +45,7 @@ class DataParallelOptimizationPass(PassBase): # NOTE not use depence on loss and param_grads self.set_attr("dist_context", None) self.set_attr("global_rank", -1) + self.set_attr("use_sharding", False) # {grad1: group1, grad2: group1, grad3: group2} # record the order for fuse grad data memory self._grad_name_to_group_map = OrderedDict() @@ -71,6 +72,7 @@ class DataParallelOptimizationPass(PassBase): self.dist_context = self.get_attr("dist_context") self.global_rank = int(self.get_attr("global_rank")) + self.use_sharding = self.get_attr("use_sharding") with paddle.static.program_guard(main_program, startup_program): self._analyze_program() @@ -224,7 +226,8 @@ class DataParallelOptimizationPass(PassBase): num_dp_comm_stream = len(set(self._group_to_grad_name_map.keys())) if num_dp_comm_stream > __max_stream_num_allow__: return False - + if self.use_sharding: + return False return True def _comms_overlap_calc(self): diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index 80edec82fd7de231b50a16395fe1d16737ab43ee..a9c83a98c19fcb3283011603d7fc65101e79022a 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -151,13 +151,14 @@ class RecomputeState(ProgramStats): # modify dropout op's desc self._ops.insert(op_idx, seed_op) cur_op.desc.set_input("Seed", [var_unique_name]) - cur_op.desc.remove_attr("fix_seed") - cur_op.desc.remove_attr("seed") + cur_op._remove_attr("fix_seed") + cur_op._remove_attr("seed") cur_op_dist_attr.set_input_dist_attr(seed_var.name, seed_var_dist_attr) - self._block._sync_with_cpp() op_idx += 2 + self._block._sync_with_cpp() + def _find_op_index(block, cur_op): for idx in range(block.desc.op_size()): @@ -339,12 +340,13 @@ class RecomputePass(PassBase): grad_op = ops[i] # remove some attrs of dropout_grad op's desc if grad_op.type == "dropout_grad": - grad_op.desc.remove_attr("fix_seed") - grad_op.desc.remove_attr("seed") - main_block._sync_with_cpp() + grad_op._remove_attr("fix_seed") + grad_op._remove_attr("seed") # rename grad op's var_name which is not in 'vars_in_memory' for key in var_name_dict: + if key not in grad_op.input_arg_names + grad_op.output_arg_names: + continue self.reset_op_dist_attr(grad_op, var_name_dict) _rename_arg_([grad_op.desc], key, var_name_dict[key]) @@ -358,11 +360,11 @@ class RecomputePass(PassBase): idx -= 1 segment_descs = ckpt_ops_dict[fwd_op_id][1] for _, op_desc in reversed(list(enumerate(segment_descs))): - rc_desc = main_block.desc._insert_op(idx) + rc_op = main_block._insert_op_without_sync(idx, + type='nop') + rc_desc = rc_op.desc rc_desc.copy_from(op_desc) rc_desc.set_original_id(rc_desc.id()) - rc_op = Operator(main_block, rc_desc) - main_block.ops.insert(idx, rc_op) # set recomputed ops' dist attr fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id( op_desc.original_id()) @@ -371,7 +373,6 @@ class RecomputePass(PassBase): var_name_dict) ckpt_ops_dict[fwd_op_id][0] = False - main_block._sync_with_cpp() main_program._sync_with_cpp() diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index e414a235b5956535ba084918e877bda65c87088b..fa07915bf60dd5c782727c351decffcb205581fb 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import reduce -from collections import OrderedDict +from collections import OrderedDict, defaultdict import numpy as np import paddle @@ -27,10 +27,7 @@ from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_di OpRole = core.op_proto_and_checker_maker.OpRole OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() -_skip_ops = [ - 'create_py_reader', 'create_double_buffer_reader', 'read', 'slice', 'split', - 'assign', "send_v2" -] +_skip_ops = ['create_py_reader', 'create_double_buffer_reader', 'read'] # update here to support new optimizers _supported_optimizer_type = [ "adam", "adamax", "adamw", "decayed_adagrad", "momentum", "dgc_momentum", @@ -38,6 +35,11 @@ _supported_optimizer_type = [ ] +def _is_reshard_op(op): + return op.desc.has_attr("op_namescope") and \ + "/auto_parallel/reshard" in op.desc.attr('op_namescope') + + # NOTE we add the "auto_parallel" prefix to the pass in order to # indicate that this pass should obey some constrains by auto_parallel # for example all ops and vars should has dist attr before and after pass @@ -100,6 +102,10 @@ class ShardingPass(PassBase): for op in main_block.ops: if not _is_forward_op(op) or op.type in _skip_ops: continue + # NOTE: there aren't dist_attr in the ops which reshard insert, + # and should be skip in sharding. + if _is_reshard_op(op): + continue group = _inference_data_parallel_group_for_operator( self.global_rank, op, self._dist_context) if group is not None: @@ -187,8 +193,28 @@ class ShardingPass(PassBase): if self._is_parameter_in_local_shard(param_name): reversed_x.append(input_name) - op.desc.set_input('X', reversed_x) - op.desc.set_output('Out', reversed_x) + + # NOTE: When `reversed_x` is [], check_finite_and_unscale will be replaced by `fill_constant` op. + # The output of check_finite_and_unscale is be set False + if reversed_x: + op.desc.set_input('X', reversed_x) + op.desc.set_output('Out', reversed_x) + else: + if op.type == "check_finite_and_unscale": + out_name = op.output_arg_names[0] + out_var = main_block.vars[out_name] + main_block._remove_op(idx, sync=False) + main_block._insert_op_without_sync( + idx, + type="fill_constant", + outputs={"Out": out_var}, + attrs={ + "shape": out_var.shape, + "dtype": out_var.dtype, + "value": 0, + }) + else: + main_block._remove_op(idx, sync=False) main_block._sync_with_cpp() @@ -359,6 +385,17 @@ class ShardingPass(PassBase): else: op._set_attr("ring_id", self.outer_dp_group.id) + # NOTE: + # var@GRAD = sum(var@GRAD@RENAME@0, var@GRAD@RENAME@1) + # If the var is not in local rank and it is output of many ops, or the var is renamed in another words, + # the sum op should be removed. + if _is_param_grad_sum_op(op, main_block): + out_name = op.output_arg_names[0] + base_name = _get_base_name_from_grad_name(out_name) + sharding_info = self.varname_to_sharding_info[base_name] + if not sharding_info.is_in_local_shard(base_name): + main_block._remove_op(idx, sync=False) + main_block._sync_with_cpp() def _shard_parameter(self, main_block, startup_block): @@ -606,6 +643,22 @@ def _is_param_grad_allreduce_op(op, block, dp_ring_ids): return block.var(base_name).is_parameter +def _is_param_grad_sum_op(op, block): + + if not is_backward_op(op): + return False + if op.type != "sum": + return False + + output_name = op.output_arg_names[0] + base_name = _get_base_name_from_grad_name(output_name) + + if not block.has_var(base_name): + return False + + return block.var(base_name).is_parameter + + def _is_forward_op(op): return op.attr("op_role") == 0 diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api_dp.py b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api_dp.py index 6a4c8a9986cef26a474a25839c24f42cf186c5d8..76a4772290db9dfb5e1b899c1a06d10cee32d3f1 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api_dp.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api_dp.py @@ -33,7 +33,7 @@ import paddle.distributed.auto_parallel as auto from paddle.distributed.auto_parallel.engine import Engine paddle.enable_static() -batch_size = 1 +batch_size = 2 batch_num = 10 hidden_size = 1024 sequence_len = 512 @@ -133,10 +133,7 @@ def train(fetch): # train train_dataset = MyDataset(batch_num * batch_size) - engine.fit(train_dataset, - batch_size=batch_size, - steps_per_epoch=batch_num * batch_size, - fetches=fetches) + engine.fit(train_dataset, batch_size=batch_size, fetches=fetches) # eval eval_dataset = MyDataset(batch_size) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py b/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py index 0e5c6b387f987e07890e93f892312b0ea1820f41..2884a03a023e541e40d04dfdfb9d3f377ca0f8a9 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py @@ -67,15 +67,14 @@ def create_data_holder(batch_size): def generate_model(strategy): modeling.init_global() + modeling._global_process_mesh = list( + range(paddle.distributed.get_world_size())) if strategy == "serial": modeling._global_parallel_strategy = "serial" - modeling._global_process_mesh = [0] elif strategy == "mp": modeling._global_parallel_strategy = "mp" - modeling._global_process_mesh = [0, 1] elif strategy == "dp": modeling._global_parallel_strategy = "dp" - modeling._global_process_mesh = [0, 1] else: raise ValueError("Only support serial, mp2 and dp2.") diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_lr_grad_clip.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_lr_grad_clip.py index e7d73921eb34f843eabfcdb27a89f895abba6091..35301e448959873c1c1c3cc3f59f6698560ec251 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_lr_grad_clip.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_lr_grad_clip.py @@ -27,7 +27,6 @@ from paddle.io import Dataset from paddle.static import InputSpec from paddle.fluid.framework import _non_static_mode from paddle.distributed.auto_parallel.engine import Engine -from paddle.distributed.auto_parallel.hepler import ProgramHelper from test_to_static import MLPLayer, MyDataset diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_to_static.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_to_static.py index a3ab87160da6825c3d85bf8023a98e284c907a32..86832f485c162a1cbb189e8cfdcbd64cb527e183 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_to_static.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_to_static.py @@ -23,11 +23,12 @@ import paddle.nn.functional as F import paddle.distributed.auto_parallel as auto import paddle.distributed.fleet as fleet +from paddle import LazyGuard from paddle.io import Dataset from paddle.static import InputSpec from paddle.fluid.framework import _non_static_mode from paddle.distributed.auto_parallel.engine import Engine -from paddle.distributed.auto_parallel.hepler import ProgramHelper +from paddle.distributed.auto_parallel.helper import ProgramHelper batch_size = 4 batch_num = 30 @@ -158,5 +159,29 @@ class TestToStatic(unittest.TestCase): engine.predict(dataset, batch_size=batch_size) +class TestLazyInit(unittest.TestCase): + + def test_lazy_init(self): + + with LazyGuard(): + mlp = MLPLayer(hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + loss = paddle.nn.CrossEntropyLoss() + + metrics = paddle.metric.Accuracy() + loss = paddle.nn.CrossEntropyLoss() + inputs = InputSpec([batch_size, hidden_size], 'float32', 'x') + labels = InputSpec([batch_size], 'int64', 'label') + + program_helper = ProgramHelper(mlp, loss, [metrics], [inputs], [labels]) + program_helper.build_program(mode='train') + ops = program_helper.startup_program.block(0).ops + vars = program_helper.startup_program.block(0).vars + assert len(vars.keys()) == len(ops) + program_helper.reset() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py b/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py index 87c746ab5d3b506ba865904d15bf04ac0310f85d..8aef4d1086066a7a23548e00bbef3a8168e322e3 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py @@ -914,12 +914,6 @@ class GPTForPretraining(nn.Layer): initializer_range=0.02, ): super(GPTForPretraining, self).__init__() - self.output_embeddings = nn.Embedding( - vocab_size, - hidden_size, - weight_attr=paddle.ParamAttr(name="output_embeddings", - initializer=nn.initializer.Normal( - mean=0.0, std=initializer_range))) self.gpt = gpt def forward(self, @@ -938,9 +932,45 @@ class GPTForPretraining(nn.Layer): encoder_outputs, cached_kvs = outputs[:2] else: encoder_outputs = outputs - logits = paddle.matmul(encoder_outputs, - self.output_embeddings.weight, - transpose_y=True) + + x = encoder_outputs + w = self.gpt.embeddings.word_embeddings.weight + + mesh = _global_process_mesh + x_dims_mapping = [-1 for i in range(len(x.shape))] + w_dims_mapping = [-1 for i in range(len(w.shape))] + if _global_parallel_strategy == "pp": + mesh = PP_MESH_LIST[-1] + elif _global_parallel_strategy == "dp": + x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)] + elif _global_parallel_strategy == "mp": + w_dims_mapping = [0] + [-1 for i in range(len(w.shape) - 1)] + elif _global_parallel_strategy == "dp_mp": + x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)] + w_dims_mapping = [1] + [-1 for i in range(len(w.shape) - 1)] + elif _global_parallel_strategy == "dp_pp": + mesh = DPPP_MESH_LIST[-1] + x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)] + elif _global_parallel_strategy == "mp_pp": + mesh = MPPP_MESH_LIST[-1] + w_dims_mapping = [0] + [-1 for i in range(len(w.shape) - 1)] + elif _global_parallel_strategy == "dp_mp_pp": + mesh = DPMPPP_MESH_LIST[-1] + x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)] + w_dims_mapping = [1] + [-1 for i in range(len(w.shape) - 1)] + + matmul = auto.shard_op(paddle.matmul, + dist_attr={ + 'process_mesh': mesh, + x: { + "dims_mapping": x_dims_mapping + }, + w: { + "dims_mapping": w_dims_mapping + } + }) + logits = matmul(x, w, transpose_y=True) + if use_cache: return logits, cached_kvs else: @@ -958,6 +988,26 @@ class GPTPretrainingCriterion(nn.Layer): self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none") def forward(self, prediction_scores, masked_lm_labels, loss_mask): + + mesh = _global_process_mesh + dims_mapping = [-1 for i in range(len(loss_mask.shape))] + if _global_parallel_strategy == "dp": + dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)] + elif _global_parallel_strategy == "dp_mp": + dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)] + elif _global_parallel_strategy == "dp_pp": + mesh = DPPP_MESH_LIST[-1] + dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)] + elif _global_parallel_strategy == "dp_mp_pp": + mesh = DPMPPP_MESH_LIST[-1] + dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)] + + auto.shard_tensor(loss_mask, + dist_attr={ + "process_mesh": mesh, + "dims_mapping": dims_mapping + }) + masked_lm_loss = self.loss_func(prediction_scores, masked_lm_labels.unsqueeze(2)) loss_mask = loss_mask.reshape([-1]) diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py b/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py index ec879e77611cd491a7585d8b673d145577303b79..7708767609e66c7cdd03e62f32ac32474a60f35a 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py @@ -178,7 +178,6 @@ class AutoPallelPassTestBase(DistPassTestBase): preds = model(tokens, position_ids, attention_mask) criterion = GPTPretrainingCriterion() loss = criterion(preds, labels, loss_mask) - clip = paddle.nn.ClipGradByNorm(clip_norm=1.0) if kwargs.get('optimizer', None) == "LarsMomentum": optimizer = paddle.fluid.optimizer.LarsMomentumOptimizer( @@ -189,7 +188,7 @@ class AutoPallelPassTestBase(DistPassTestBase): beta1=0.9, beta2=0.999, epsilon=1e-08, - grad_clip=clip) + grad_clip=None) optimizer = fleet.distributed_optimizer(optimizer) startup_program = paddle.static.default_startup_program() _, _, dist_startup_prog, dist_main_prog = optimizer.minimize(