From 9db507f1485eddc32f8593173dc214a8db437cb1 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Mon, 7 Nov 2022 10:08:43 +0800 Subject: [PATCH] [AutoParallel] update naive data parallel completion (#47578) * expand op donot use naive data parallel * fix unittest --- .../distributed/auto_parallel/completion.py | 141 ++++-------------- .../distributed/auto_parallel/engine.py | 47 +++--- .../distributed/auto_parallel/planner_v2.py | 4 +- .../paddle/distributed/auto_parallel/utils.py | 31 ++++ 4 files changed, 94 insertions(+), 129 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index ba879e89f8..97023a43cc 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -13,10 +13,11 @@ # limitations under the License. import copy -import time +import logging from paddle.fluid import core +from .utils import is_naive_data_parallel, get_logger from .utils import is_gradient_clip_op, __not_shape_var_type__ from .operators import find_compatible_distributed_operator_impls from .dist_context import _node_id @@ -142,6 +143,7 @@ class Completer: assert dist_context is not None self._dist_context = dist_context self._has_prepared = False + self._logger = get_logger(logging.INFO, "Completer") def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True): changed = False @@ -974,138 +976,60 @@ class Completer: else: self._dist_context._serial_main_program = serial_main_program - start_time = time.time() - # print("start time", start_time, flush=True) - if not self._dist_context.data_parallel: + if not is_naive_data_parallel(self._dist_context): self._dist_context.initialize(with_graph=True) - - # self._dist_context.validate_dist_attr_for_program() - self._prepare() - self._update_process_mesh() - self._update_dims_mapping() - # Copy the corresponding distributed attribute from graph to serial_main_program self._dist_context.copy_dist_attr_from_graph_to_program() else: + self._logger.info("Default data parallel will be set.") self._dist_context.initialize(with_graph=False) - # A fast and special completion for data parallel self._update_dist_attr_for_dp() - # print_program_with_dist_attr(self._dist_context.serial_main_program, - # self._dist_context) - # NOTE:[HighOrderGrad] update vars and ops distributed attribute in high order gradient self._complete_high_order_grad_annotation(serial_main_program) - # Do the validation check and amend some completion self._dist_context.amend_dist_attr_for_program() - self._dist_context.validate_dist_attr_for_program() - - end_time = time.time() - # print("end time", end_time, flush=True) - # print("elapsed time", end_time - start_time, flush=True) - return serial_main_program def _update_dist_attr_for_dp(self): # TODO: we must ensure the world process group contains all ranks ranks = get_world_process_group().ranks process_mesh = ProcessMesh(ranks) - for ( - dist_tensor - ) in self._dist_context._dist_tensors_for_program.values(): - serial_tensor = dist_tensor.serial_tensor - tensor_dist_attr = dist_tensor.dist_attr - tensor_dist_attr.process_mesh = process_mesh - - for dist_op in self._dist_context._dist_ops_for_program.values(): + + dist_tensors = self._dist_context._dist_tensors_for_program + for dist_tensor in dist_tensors.values(): + dist_tensor.dist_attr.process_mesh = process_mesh + + dist_ops = self._dist_context._dist_ops_for_program + for dist_op in dist_ops.values(): serial_op = dist_op.serial_op - op_desc = serial_op.desc op_dist_attr = dist_op.dist_attr op_dist_attr.process_mesh = process_mesh original_op_dist_attr = copy.deepcopy(op_dist_attr) - input_xshape_arg_names = [] - if "XShape" in op_desc.input_names(): - input_xshape_arg_names = op_desc.input("XShape") + for arg_name in serial_op.input_arg_names: serial_tensor = dist_op.get_serial_input(arg_name) if not serial_tensor.is_parameter: - if arg_name not in input_xshape_arg_names: - old_dims_mapping = op_dist_attr.get_input_dims_mapping( - arg_name + dist_tensor = ( + self._dist_context.get_dist_tensor_for_program( + serial_tensor ) - if len(old_dims_mapping) > 0: - new_dims_mapping = [0] + [ - -1 for _ in range(len(old_dims_mapping) - 1) - ] - op_dist_attr.set_input_dims_mapping( - arg_name, new_dims_mapping - ) - else: - old_dims_mapping = op_dist_attr.get_input_dims_mapping( - arg_name - ) - if len(old_dims_mapping) > 1: - new_dims_mapping = [-1, 0] + [ - -1 for _ in range(len(old_dims_mapping) - 2) - ] - op_dist_attr.set_input_dims_mapping( - arg_name, new_dims_mapping - ) - # Set tensor's dims_mapping by the op's - tensor_dist_attr = ( - self._dist_context.get_tensor_dist_attr_for_program( - serial_tensor ) - ) - tensor_dist_attr.dims_mapping = ( - op_dist_attr.get_input_dims_mapping(arg_name) - ) - output_xshape_arg_names = [] - if "XShape" in op_desc.output_names(): - output_xshape_arg_names = op_desc.output("XShape") - for arg_name in serial_op.output_arg_names: - serial_tensor = dist_op.get_serial_output(arg_name) - if not serial_tensor.is_parameter: - if arg_name not in output_xshape_arg_names: - old_dims_mapping = op_dist_attr.get_output_dims_mapping( - arg_name - ) - if len(old_dims_mapping) > 0: - new_dims_mapping = [0] + [ - -1 for _ in range(len(old_dims_mapping) - 1) - ] - op_dist_attr.set_output_dims_mapping( - arg_name, new_dims_mapping - ) - else: - old_dims_mapping = op_dist_attr.get_output_dims_mapping( - arg_name - ) - if len(old_dims_mapping) > 1: - new_dims_mapping = [-1, 0] + [ - -1 for _ in range(len(old_dims_mapping) - 2) - ] - op_dist_attr.set_output_dims_mapping( - arg_name, new_dims_mapping - ) - # Set tensor's dims_mapping by the op's - tensor_dist_attr = ( - self._dist_context.get_tensor_dist_attr_for_program( - serial_tensor + op_dist_attr = dist_op.dist_attr + op_dist_attr.process_mesh = ( + dist_tensor.dist_attr.process_mesh + ) + op_dist_attr.set_input_dims_mapping( + arg_name, dist_tensor.dist_attr.dims_mapping ) - ) - tensor_dist_attr.dims_mapping = ( - op_dist_attr.get_output_dims_mapping(arg_name) - ) op_dist_impls = find_compatible_distributed_operator_impls( - dist_op, partial=False + dist_op, fwd=True ) if op_dist_impls is not None: not_compatible = True @@ -1127,6 +1051,16 @@ class Completer: else: dist_op.dist_attr = original_op_dist_attr + for arg_name in serial_op.output_arg_names: + op_dist_attr = dist_op.dist_attr + serial_tensor = dist_op.get_serial_output(arg_name) + dist_tensor = self._dist_context.get_dist_tensor_for_program( + serial_tensor + ) + dist_tensor.dist_attr.dims_mapping = ( + op_dist_attr.get_output_dims_mapping(arg_name) + ) + def _complete_tensor_dist_attr_by_op(self, serial_main_program=None): if serial_main_program is None: serial_main_program = self._dist_context.serial_main_program @@ -1942,19 +1876,10 @@ class Completer: else: self._dist_context._serial_main_program = serial_main_program - import time - - start_time = time.time() self._dist_context._is_initialized = True - - start_time = time.time() self._dist_context._init_dist_attr_for_program() - - start_time = time.time() self._init_global_mesh_for_program() - # Do the validation check and amend some completion - start_time = time.time() self._dist_context.amend_dist_attr_for_program() self._dist_context.validate_dist_attr_for_program() diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 28e9fc69d7..b8870106ac 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -22,6 +22,7 @@ from collections import defaultdict import paddle import paddle.utils as utils +import paddle.distributed.auto_parallel.utils as auto_utils from paddle import fluid, static from paddle.metric import Metric @@ -47,12 +48,10 @@ from .dist_loader import ( DistributedDataLoaderFromGenerator, DistributedDataLoader, ) +from .strategy import Strategy from .process_group import new_process_group, get_all_process_groups from .dist_context import DistributedContext, get_default_distributed_context -from .strategy import Strategy from .interface import CollectionNames, get_collection -from .utils import to_list, get_dist_attr, get_lr, validate_opt -from .utils import initialize_pg_in_full_mode, get_input_split_info from .cost.estimate_cost import get_cost_from_engine from ..utils.log_utils import get_logger @@ -159,18 +158,18 @@ class Engine: "'optimizer' must be object of class `paddle.optimizer.Optimizer`" " or `paddle.fluid.optimizer.Optimizer`." ) - self._optimizer = validate_opt(optimizer) + self._optimizer = auto_utils.validate_opt(optimizer) self._orig_optimizer = copy.deepcopy(self._optimizer) metrics = metrics or [] - for metric in to_list(metrics): + for metric in auto_utils.to_list(metrics): if metric and not isinstance(metric, Metric): raise TypeError( "{} is not sub class of Metric".format( metric.__class__.__name__ ) ) - self._metrics = to_list(metrics) + self._metrics = auto_utils.to_list(metrics) if cluster and not isinstance(cluster, Cluster): raise TypeError( @@ -253,8 +252,8 @@ class Engine: type(data).__name__ ) ) - inputs = to_list(inputs) - labels = to_list(labels) + inputs = auto_utils.to_list(inputs) + labels = auto_utils.to_list(labels) num_shards = self._strategy.dataset.num_shards @@ -481,7 +480,7 @@ class Engine: if metric_out: metric.update(*metric_out) results = metric.accumulate() - for i, res in enumerate(to_list(results)): + for i, res in enumerate(auto_utils.to_list(results)): logs[metric.name()[i]] = res group_idx += 1 # logging outputs @@ -562,7 +561,7 @@ class Engine: s._create_feed_layer() for s in self._labels_spec ] - outputs = to_list(self._model(*self._inputs)) + outputs = auto_utils.to_list(self._model(*self._inputs)) if mode != "predict" and self._loss: assert isinstance( @@ -570,14 +569,14 @@ class Engine: ) or callable( self._loss ), "the type of `loss` of the Engine arguments should be sub classes of `paddle.nn.Layer` or any callable function." - self._losses = to_list( + self._losses = auto_utils.to_list( self._loss(*(outputs + self._labels)) ) if mode != "predict" and (outputs or self._labels): for metric in self._metrics: metrics.append( - to_list( + auto_utils.to_list( metric.compute(*(outputs + self._labels)) ) ) @@ -585,7 +584,7 @@ class Engine: assert isinstance( self._loss, Variable ), "the type of `loss` of the Engine arguments should be Variable." - self._losses = to_list(self._loss) + self._losses = auto_utils.to_list(self._loss) default_ctx = get_default_distributed_context() if not default_ctx.has_annotation: @@ -593,6 +592,12 @@ class Engine: # needs all ranks by default. new_process_group(list(range(self._nranks))) default_ctx.data_parallel = True + self._inputs = [ + auto_utils.set_data_parallel(var) for var in self._inputs + ] + self._labels = [ + auto_utils.set_data_parallel(var) for var in self._labels + ] feed_vars = {"inputs": self._inputs, "labels": self._labels} @@ -684,7 +689,7 @@ class Engine: self._dp_world_sizes = [] self._dp_ranks = [] for feed_var in feed_list: - dp_world_size, dp_rank = get_input_split_info( + dp_world_size, dp_rank = auto_utils.get_input_split_info( self._cur_rank, feed_var, self._dist_contexts[mode] ) self._dp_world_sizes.append(dp_world_size) @@ -749,7 +754,9 @@ class Engine: cur_rank = self._cur_rank # NOTE: After the implementation of the unified dynamic and static communication group initialization mode in the future, the initialization logic of full mode will be removed because port occupation error may occur. if self._strategy.auto_mode == "full": - initialize_pg_in_full_mode(all_process_groups, cur_rank) + auto_utils.initialize_pg_in_full_mode( + all_process_groups, cur_rank + ) else: for process_group in all_process_groups: if cur_rank not in process_group.ranks: @@ -927,7 +934,7 @@ class Engine: ) except core.EOFException: break - lr = get_lr(self._optimizer) + lr = auto_utils.get_lr(self._optimizer) logs = self._prepare_logger( outs, epoch, @@ -1474,7 +1481,7 @@ class Engine: self._optimization_tuning(self._mode, tune_data, batch_size) def _validate_spec(self, specs): - specs = to_list(specs) + specs = auto_utils.to_list(specs) self._k_steps = self._strategy.gradient_merge.k_steps if specs is not None: for i, spec in enumerate(specs): @@ -1500,7 +1507,7 @@ class Engine: return specs or [] def _validate_vars(self, vars): - vars = to_list(vars) + vars = auto_utils.to_list(vars) if vars is not None: for i, var in enumerate(vars): if not isinstance(var, Variable): @@ -1547,7 +1554,7 @@ class Engine: def _metrics_name(self): metrics_name = ['loss'] if self._loss else [] for m in self._metrics: - metrics_name.extend(to_list(m.name())) + metrics_name.extend(auto_utils.to_list(m.name())) return metrics_name def _switch_mode(self, mode): @@ -1568,7 +1575,7 @@ class Engine: def _set_state_dict(self, mode, strict, state_dict, dist_attr): program = self._dist_main_progs[mode][self._cur_rank] dist_context = self._dist_contexts[mode] - cur_dist_attr = get_dist_attr(program, dist_context) + cur_dist_attr = auto_utils.get_dist_attr(program, dist_context) converter = Converter(state_dict, dist_attr, cur_dist_attr) state_dict = converter.convert(strict=strict) program.set_state_dict(state_dict) diff --git a/python/paddle/distributed/auto_parallel/planner_v2.py b/python/paddle/distributed/auto_parallel/planner_v2.py index 0f9792911d..5c9b7233b8 100755 --- a/python/paddle/distributed/auto_parallel/planner_v2.py +++ b/python/paddle/distributed/auto_parallel/planner_v2.py @@ -15,6 +15,7 @@ from .completion import Completer from .dist_context import get_default_distributed_context from .tuner.parallel_tuner import ParallelTuner +from .utils import is_naive_data_parallel class Planner: @@ -26,7 +27,8 @@ class Planner: # dependency of backward-forward ops in forward completion. default_ctx = get_default_distributed_context() self._dist_context._dist_op_context = default_ctx.dist_op_context - if not default_ctx.data_parallel: + self._dist_context.data_parallel = default_ctx.data_parallel + if not is_naive_data_parallel(self._dist_context): # Use SSA graph for complex parallism self._dist_context.initialize(with_graph=True) else: diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index c22ad2b831..d9357db371 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -37,6 +37,8 @@ __not_shape_var_type__ = [ core.VarDesc.VarType.STEP_SCOPES, ] +__not_naive_data_parallel_op__ = ["expand_v2"] + def get_logger(log_level, name="auto_parallel"): logger = logging.getLogger(name) @@ -1909,6 +1911,35 @@ def validate_opt(optimizer): return optimizer +def set_data_parallel(x): + from .process_group import get_world_process_group + from .interface import shard_tensor, ProcessMesh + + world_ranks = get_world_process_group().ranks + process_mesh = ProcessMesh(world_ranks, ['dp']) + shard_spec = ['dp' if len(world_ranks) > 1 else None] + [ + None for _ in range(len(x.shape) - 1) + ] + + return shard_tensor(x, process_mesh, shard_spec) + + +def is_naive_data_parallel(dist_context): + # Navie data parallel only completes dist_attr once from the front to back. + if not dist_context.data_parallel: + return False + + ops_type = [ + op.type + for op in dist_context._original_serial_main_program.global_block().ops + ] + if ( + not set(ops_type) & set(__not_naive_data_parallel_op__) + ) and dist_context.data_parallel: + return True + return False + + def _copy_tensor_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr): py_process_mesh = py_dist_attr.process_mesh if py_process_mesh is not None: -- GitLab