未验证 提交 9db507f1 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] update naive data parallel completion (#47578)

* expand op donot use naive data parallel

* fix unittest
上级 b0c38568
......@@ -13,10 +13,11 @@
# limitations under the License.
import copy
import time
import logging
from paddle.fluid import core
from .utils import is_naive_data_parallel, get_logger
from .utils import is_gradient_clip_op, __not_shape_var_type__
from .operators import find_compatible_distributed_operator_impls
from .dist_context import _node_id
......@@ -142,6 +143,7 @@ class Completer:
assert dist_context is not None
self._dist_context = dist_context
self._has_prepared = False
self._logger = get_logger(logging.INFO, "Completer")
def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):
changed = False
......@@ -974,138 +976,60 @@ class Completer:
else:
self._dist_context._serial_main_program = serial_main_program
start_time = time.time()
# print("start time", start_time, flush=True)
if not self._dist_context.data_parallel:
if not is_naive_data_parallel(self._dist_context):
self._dist_context.initialize(with_graph=True)
# self._dist_context.validate_dist_attr_for_program()
self._prepare()
self._update_process_mesh()
self._update_dims_mapping()
# Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program()
else:
self._logger.info("Default data parallel will be set.")
self._dist_context.initialize(with_graph=False)
# A fast and special completion for data parallel
self._update_dist_attr_for_dp()
# print_program_with_dist_attr(self._dist_context.serial_main_program,
# self._dist_context)
# NOTE:[HighOrderGrad] update vars and ops distributed attribute in high order gradient
self._complete_high_order_grad_annotation(serial_main_program)
# Do the validation check and amend some completion
self._dist_context.amend_dist_attr_for_program()
self._dist_context.validate_dist_attr_for_program()
end_time = time.time()
# print("end time", end_time, flush=True)
# print("elapsed time", end_time - start_time, flush=True)
return serial_main_program
def _update_dist_attr_for_dp(self):
# TODO: we must ensure the world process group contains all ranks
ranks = get_world_process_group().ranks
process_mesh = ProcessMesh(ranks)
for (
dist_tensor
) in self._dist_context._dist_tensors_for_program.values():
serial_tensor = dist_tensor.serial_tensor
tensor_dist_attr = dist_tensor.dist_attr
tensor_dist_attr.process_mesh = process_mesh
for dist_op in self._dist_context._dist_ops_for_program.values():
dist_tensors = self._dist_context._dist_tensors_for_program
for dist_tensor in dist_tensors.values():
dist_tensor.dist_attr.process_mesh = process_mesh
dist_ops = self._dist_context._dist_ops_for_program
for dist_op in dist_ops.values():
serial_op = dist_op.serial_op
op_desc = serial_op.desc
op_dist_attr = dist_op.dist_attr
op_dist_attr.process_mesh = process_mesh
original_op_dist_attr = copy.deepcopy(op_dist_attr)
input_xshape_arg_names = []
if "XShape" in op_desc.input_names():
input_xshape_arg_names = op_desc.input("XShape")
for arg_name in serial_op.input_arg_names:
serial_tensor = dist_op.get_serial_input(arg_name)
if not serial_tensor.is_parameter:
if arg_name not in input_xshape_arg_names:
old_dims_mapping = op_dist_attr.get_input_dims_mapping(
arg_name
dist_tensor = (
self._dist_context.get_dist_tensor_for_program(
serial_tensor
)
if len(old_dims_mapping) > 0:
new_dims_mapping = [0] + [
-1 for _ in range(len(old_dims_mapping) - 1)
]
op_dist_attr.set_input_dims_mapping(
arg_name, new_dims_mapping
)
else:
old_dims_mapping = op_dist_attr.get_input_dims_mapping(
arg_name
)
if len(old_dims_mapping) > 1:
new_dims_mapping = [-1, 0] + [
-1 for _ in range(len(old_dims_mapping) - 2)
]
op_dist_attr.set_input_dims_mapping(
arg_name, new_dims_mapping
)
# Set tensor's dims_mapping by the op's
tensor_dist_attr = (
self._dist_context.get_tensor_dist_attr_for_program(
serial_tensor
)
)
tensor_dist_attr.dims_mapping = (
op_dist_attr.get_input_dims_mapping(arg_name)
)
output_xshape_arg_names = []
if "XShape" in op_desc.output_names():
output_xshape_arg_names = op_desc.output("XShape")
for arg_name in serial_op.output_arg_names:
serial_tensor = dist_op.get_serial_output(arg_name)
if not serial_tensor.is_parameter:
if arg_name not in output_xshape_arg_names:
old_dims_mapping = op_dist_attr.get_output_dims_mapping(
arg_name
)
if len(old_dims_mapping) > 0:
new_dims_mapping = [0] + [
-1 for _ in range(len(old_dims_mapping) - 1)
]
op_dist_attr.set_output_dims_mapping(
arg_name, new_dims_mapping
)
else:
old_dims_mapping = op_dist_attr.get_output_dims_mapping(
arg_name
)
if len(old_dims_mapping) > 1:
new_dims_mapping = [-1, 0] + [
-1 for _ in range(len(old_dims_mapping) - 2)
]
op_dist_attr.set_output_dims_mapping(
arg_name, new_dims_mapping
)
# Set tensor's dims_mapping by the op's
tensor_dist_attr = (
self._dist_context.get_tensor_dist_attr_for_program(
serial_tensor
op_dist_attr = dist_op.dist_attr
op_dist_attr.process_mesh = (
dist_tensor.dist_attr.process_mesh
)
op_dist_attr.set_input_dims_mapping(
arg_name, dist_tensor.dist_attr.dims_mapping
)
)
tensor_dist_attr.dims_mapping = (
op_dist_attr.get_output_dims_mapping(arg_name)
)
op_dist_impls = find_compatible_distributed_operator_impls(
dist_op, partial=False
dist_op, fwd=True
)
if op_dist_impls is not None:
not_compatible = True
......@@ -1127,6 +1051,16 @@ class Completer:
else:
dist_op.dist_attr = original_op_dist_attr
for arg_name in serial_op.output_arg_names:
op_dist_attr = dist_op.dist_attr
serial_tensor = dist_op.get_serial_output(arg_name)
dist_tensor = self._dist_context.get_dist_tensor_for_program(
serial_tensor
)
dist_tensor.dist_attr.dims_mapping = (
op_dist_attr.get_output_dims_mapping(arg_name)
)
def _complete_tensor_dist_attr_by_op(self, serial_main_program=None):
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
......@@ -1942,19 +1876,10 @@ class Completer:
else:
self._dist_context._serial_main_program = serial_main_program
import time
start_time = time.time()
self._dist_context._is_initialized = True
start_time = time.time()
self._dist_context._init_dist_attr_for_program()
start_time = time.time()
self._init_global_mesh_for_program()
# Do the validation check and amend some completion
start_time = time.time()
self._dist_context.amend_dist_attr_for_program()
self._dist_context.validate_dist_attr_for_program()
......
......@@ -22,6 +22,7 @@ from collections import defaultdict
import paddle
import paddle.utils as utils
import paddle.distributed.auto_parallel.utils as auto_utils
from paddle import fluid, static
from paddle.metric import Metric
......@@ -47,12 +48,10 @@ from .dist_loader import (
DistributedDataLoaderFromGenerator,
DistributedDataLoader,
)
from .strategy import Strategy
from .process_group import new_process_group, get_all_process_groups
from .dist_context import DistributedContext, get_default_distributed_context
from .strategy import Strategy
from .interface import CollectionNames, get_collection
from .utils import to_list, get_dist_attr, get_lr, validate_opt
from .utils import initialize_pg_in_full_mode, get_input_split_info
from .cost.estimate_cost import get_cost_from_engine
from ..utils.log_utils import get_logger
......@@ -159,18 +158,18 @@ class Engine:
"'optimizer' must be object of class `paddle.optimizer.Optimizer`"
" or `paddle.fluid.optimizer.Optimizer`."
)
self._optimizer = validate_opt(optimizer)
self._optimizer = auto_utils.validate_opt(optimizer)
self._orig_optimizer = copy.deepcopy(self._optimizer)
metrics = metrics or []
for metric in to_list(metrics):
for metric in auto_utils.to_list(metrics):
if metric and not isinstance(metric, Metric):
raise TypeError(
"{} is not sub class of Metric".format(
metric.__class__.__name__
)
)
self._metrics = to_list(metrics)
self._metrics = auto_utils.to_list(metrics)
if cluster and not isinstance(cluster, Cluster):
raise TypeError(
......@@ -253,8 +252,8 @@ class Engine:
type(data).__name__
)
)
inputs = to_list(inputs)
labels = to_list(labels)
inputs = auto_utils.to_list(inputs)
labels = auto_utils.to_list(labels)
num_shards = self._strategy.dataset.num_shards
......@@ -481,7 +480,7 @@ class Engine:
if metric_out:
metric.update(*metric_out)
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
for i, res in enumerate(auto_utils.to_list(results)):
logs[metric.name()[i]] = res
group_idx += 1
# logging outputs
......@@ -562,7 +561,7 @@ class Engine:
s._create_feed_layer() for s in self._labels_spec
]
outputs = to_list(self._model(*self._inputs))
outputs = auto_utils.to_list(self._model(*self._inputs))
if mode != "predict" and self._loss:
assert isinstance(
......@@ -570,14 +569,14 @@ class Engine:
) or callable(
self._loss
), "the type of `loss` of the Engine arguments should be sub classes of `paddle.nn.Layer` or any callable function."
self._losses = to_list(
self._losses = auto_utils.to_list(
self._loss(*(outputs + self._labels))
)
if mode != "predict" and (outputs or self._labels):
for metric in self._metrics:
metrics.append(
to_list(
auto_utils.to_list(
metric.compute(*(outputs + self._labels))
)
)
......@@ -585,7 +584,7 @@ class Engine:
assert isinstance(
self._loss, Variable
), "the type of `loss` of the Engine arguments should be Variable."
self._losses = to_list(self._loss)
self._losses = auto_utils.to_list(self._loss)
default_ctx = get_default_distributed_context()
if not default_ctx.has_annotation:
......@@ -593,6 +592,12 @@ class Engine:
# needs all ranks by default.
new_process_group(list(range(self._nranks)))
default_ctx.data_parallel = True
self._inputs = [
auto_utils.set_data_parallel(var) for var in self._inputs
]
self._labels = [
auto_utils.set_data_parallel(var) for var in self._labels
]
feed_vars = {"inputs": self._inputs, "labels": self._labels}
......@@ -684,7 +689,7 @@ class Engine:
self._dp_world_sizes = []
self._dp_ranks = []
for feed_var in feed_list:
dp_world_size, dp_rank = get_input_split_info(
dp_world_size, dp_rank = auto_utils.get_input_split_info(
self._cur_rank, feed_var, self._dist_contexts[mode]
)
self._dp_world_sizes.append(dp_world_size)
......@@ -749,7 +754,9 @@ class Engine:
cur_rank = self._cur_rank
# NOTE: After the implementation of the unified dynamic and static communication group initialization mode in the future, the initialization logic of full mode will be removed because port occupation error may occur.
if self._strategy.auto_mode == "full":
initialize_pg_in_full_mode(all_process_groups, cur_rank)
auto_utils.initialize_pg_in_full_mode(
all_process_groups, cur_rank
)
else:
for process_group in all_process_groups:
if cur_rank not in process_group.ranks:
......@@ -927,7 +934,7 @@ class Engine:
)
except core.EOFException:
break
lr = get_lr(self._optimizer)
lr = auto_utils.get_lr(self._optimizer)
logs = self._prepare_logger(
outs,
epoch,
......@@ -1474,7 +1481,7 @@ class Engine:
self._optimization_tuning(self._mode, tune_data, batch_size)
def _validate_spec(self, specs):
specs = to_list(specs)
specs = auto_utils.to_list(specs)
self._k_steps = self._strategy.gradient_merge.k_steps
if specs is not None:
for i, spec in enumerate(specs):
......@@ -1500,7 +1507,7 @@ class Engine:
return specs or []
def _validate_vars(self, vars):
vars = to_list(vars)
vars = auto_utils.to_list(vars)
if vars is not None:
for i, var in enumerate(vars):
if not isinstance(var, Variable):
......@@ -1547,7 +1554,7 @@ class Engine:
def _metrics_name(self):
metrics_name = ['loss'] if self._loss else []
for m in self._metrics:
metrics_name.extend(to_list(m.name()))
metrics_name.extend(auto_utils.to_list(m.name()))
return metrics_name
def _switch_mode(self, mode):
......@@ -1568,7 +1575,7 @@ class Engine:
def _set_state_dict(self, mode, strict, state_dict, dist_attr):
program = self._dist_main_progs[mode][self._cur_rank]
dist_context = self._dist_contexts[mode]
cur_dist_attr = get_dist_attr(program, dist_context)
cur_dist_attr = auto_utils.get_dist_attr(program, dist_context)
converter = Converter(state_dict, dist_attr, cur_dist_attr)
state_dict = converter.convert(strict=strict)
program.set_state_dict(state_dict)
......
......@@ -15,6 +15,7 @@
from .completion import Completer
from .dist_context import get_default_distributed_context
from .tuner.parallel_tuner import ParallelTuner
from .utils import is_naive_data_parallel
class Planner:
......@@ -26,7 +27,8 @@ class Planner:
# dependency of backward-forward ops in forward completion.
default_ctx = get_default_distributed_context()
self._dist_context._dist_op_context = default_ctx.dist_op_context
if not default_ctx.data_parallel:
self._dist_context.data_parallel = default_ctx.data_parallel
if not is_naive_data_parallel(self._dist_context):
# Use SSA graph for complex parallism
self._dist_context.initialize(with_graph=True)
else:
......
......@@ -37,6 +37,8 @@ __not_shape_var_type__ = [
core.VarDesc.VarType.STEP_SCOPES,
]
__not_naive_data_parallel_op__ = ["expand_v2"]
def get_logger(log_level, name="auto_parallel"):
logger = logging.getLogger(name)
......@@ -1909,6 +1911,35 @@ def validate_opt(optimizer):
return optimizer
def set_data_parallel(x):
from .process_group import get_world_process_group
from .interface import shard_tensor, ProcessMesh
world_ranks = get_world_process_group().ranks
process_mesh = ProcessMesh(world_ranks, ['dp'])
shard_spec = ['dp' if len(world_ranks) > 1 else None] + [
None for _ in range(len(x.shape) - 1)
]
return shard_tensor(x, process_mesh, shard_spec)
def is_naive_data_parallel(dist_context):
# Navie data parallel only completes dist_attr once from the front to back.
if not dist_context.data_parallel:
return False
ops_type = [
op.type
for op in dist_context._original_serial_main_program.global_block().ops
]
if (
not set(ops_type) & set(__not_naive_data_parallel_op__)
) and dist_context.data_parallel:
return True
return False
def _copy_tensor_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr):
py_process_mesh = py_dist_attr.process_mesh
if py_process_mesh is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册