未验证 提交 010aba33 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Add miscellaneous improvements (#43108)

* [Auto Parallel] Add the parallel tuner

* [Auto Parallel] Improve the parallel tuner and fix some bugs

* upodate cost model

* update import Resharder by dist op

* update cost model

* fix comp cost bug

* update cost model

* [Auto Parallel] Amend the dist attr for #processses=1

* update cost model and tuner

* update cost model and tuner

* update cost model and tuner

* update cluster

* update reshard

* [Auto Parallel] Add the estimation from the cost model

* [Auto Parallel] Reimplement the backup and restore functions

* [Auto Parallel] Fix the bugs of the parallel tuner

* [Auto Parallel] Update the engine api and dist context

* [Auto Parallel] Work around the high order grad problem

* [Auto Parallel] Add some miscellaneous improvements

* [Auto Parallel] Add a unittest for DistributedContext
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
上级 5f2c251c
......@@ -20,7 +20,7 @@ from paddle.fluid import core
from paddle.fluid import framework
from .utils import print_program_with_dist_attr
from .operators import find_best_compatible_distributed_operator_impl
from .operators import find_compatible_distributed_operator_impls
from .dist_context import get_default_distributed_context, _node_id
from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator
......@@ -238,13 +238,17 @@ class Completer:
tensor_desc.name())
compatible_dims_mapping = compute_compatible_dims_mapping(
[op_dims_mapping, tensor_dims_mapping])
if not _validate_dims_mapping(
compatible_dims_mapping,
op_dist_attr.process_mesh):
continue
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != op_dims_mapping):
op_dist_attr.set_input_dims_mapping(
tensor_desc.name(), compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impls = find_best_compatible_distributed_operator_impl(
op_dist_impls = find_compatible_distributed_operator_impls(
dist_op, fwd=True)
if op_dist_impls is not None:
not_compatible = True
......@@ -254,7 +258,8 @@ class Completer:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.is_auto_compatible(dist_op) \
and dist_op.validate_dist_attr():
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
......@@ -289,13 +294,17 @@ class Completer:
tensor_desc.name())
compatible_dims_mapping = compute_compatible_dims_mapping(
[op_dims_mapping, tensor_dims_mapping])
if not _validate_dims_mapping(
compatible_dims_mapping,
op_dist_attr.process_mesh):
continue
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != op_dims_mapping):
op_dist_attr.set_output_dims_mapping(
tensor_desc.name(), compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impls = find_best_compatible_distributed_operator_impl(
op_dist_impls = find_compatible_distributed_operator_impls(
dist_op, fwd=False)
if op_dist_impls is not None:
not_compatible = True
......@@ -305,8 +314,8 @@ class Completer:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
not_compatible = False
if op_dist_impl.is_auto_compatible(dist_op) \
and dist_op.validate_dist_attr():
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
......@@ -352,6 +361,23 @@ class Completer:
changed = True
return changed
def _update_dims_mapping_for_special(self):
# Set the dims_mapping of a tensor to the dims_mapping inside the op which produces it
op_nodes = self._dist_context._serial_ordered_op_nodes
for op_node in op_nodes:
op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node)
for tensor_node in op_node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER:
continue
tensor_desc = tensor_node.var()
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh:
op_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_desc.name())
tensor_dist_attr.dims_mapping = op_dims_mapping
def _update_dims_mapping(self):
# Complete dims_mapping for each node
reach_fix_point = False
......@@ -378,6 +404,7 @@ class Completer:
reach_fix_point = False
else:
reach_fix_point = True
self._update_dims_mapping_for_special()
def _update_process_mesh_by_nearest(self, op_node, nearest_op_node):
op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node)
......@@ -685,7 +712,7 @@ class Completer:
# Step 3: adjust the process meshes for special ops
self._update_process_mesh_for_specials()
# Step 4: adjust the process meshes between graphs
# Step 4: adjust the process meshes between graphs
self._update_process_mesh_between_graphs()
def _prepare(self):
......@@ -727,14 +754,14 @@ class Completer:
""" Complete annotation for the partial annotated serial_main_program.
Arguments:
serial_main_program: partial annotated serial_main_program.
Returns:
Returns:e
serial_main_program: completed annotated serial_main_program.
"""
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context.serial_main_program = serial_main_program
self._dist_context._serial_main_program = serial_main_program
self._dist_context.initialize()
......@@ -757,13 +784,18 @@ class Completer:
return serial_main_program
def _complete_high_order_grad_annotation(self, serial_main_program):
def _complete_high_order_grad_annotation(self, serial_main_program=None):
"""
NOTE:
[HighOrderGrad] Complete the annotation of vars and ops only for high order gradient.
This function is temporary to support high order gradient, and will be removed in the future.
"""
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context._serial_main_program = serial_main_program
def _is_grad_var_name(name):
if "@GRAD" in name:
return True
......@@ -917,12 +949,13 @@ class Completer:
self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr)
def complete_backward_annotation(self, serial_main_program):
def complete_backward_annotation(self, serial_main_program=None):
"""Complete the annotation of vars and ops in the backward phase for parallel program."""
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context.serial_main_program = serial_main_program
self._dist_context._serial_main_program = serial_main_program
def _is_grad_var_name(name):
if "@GRAD" in name:
......@@ -1032,6 +1065,9 @@ class Completer:
grad_op_dist_attr.process_mesh = ref_mesh
self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr)
grad_op_dist_attr.impl_type = fwd_op_dist_attr.impl_type
grad_op_dist_attr.impl_idx = fwd_op_dist_attr.impl_idx
continue
fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
......@@ -1078,6 +1114,8 @@ class Completer:
grad_op_dist_attr.set_output_dims_mapping(output_name,
ref_dims_mapping)
grad_op_dist_attr.impl_type = fwd_op_dist_attr.impl_type
grad_op_dist_attr.impl_idx = fwd_op_dist_attr.impl_idx
self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr)
......@@ -1111,6 +1149,8 @@ class Completer:
var_name, ref_fwd_dims_mapping)
grad_op_dist_attr.set_output_dims_mapping(
output_name, ref_fwd_dims_mapping)
grad_op_dist_attr.impl_type = "default"
grad_op_dist_attr.impl_idx = 0
elif grad_op.type == 'fill_zeros_like':
ref_var_name = grad_op.input_arg_names[0]
......@@ -1142,12 +1182,13 @@ class Completer:
self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr)
def complete_update_annotation(self, serial_main_program=None):
def complete_update_annotation(self, serial_main_program):
"""Complete the annotation of vars and ops in the update phase for parallel program."""
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context.serial_main_program = serial_main_program
# Notice: serial_main_program is actually a dist_main_program of current rank,
# and must be passed into this function.
# TODO: We should fix this behavior.
ops = list(serial_main_program.global_block().ops)
vars = serial_main_program.global_block().vars
learning_rate_completed = False
......@@ -1304,7 +1345,7 @@ class Completer:
dist_op.dist_attr.process_mesh = world_ranks
# Find the most compatible implemenetations from the distributed operator
op_dist_impls = find_best_compatible_distributed_operator_impl(
op_dist_impls = find_compatible_distributed_operator_impls(
dist_op, fwd=True)
if op_dist_impls is not None:
backup_op_dist_attr = copy.deepcopy(dist_op.dist_attr)
......
......@@ -132,15 +132,17 @@ class TensorDistributedAttribute:
key, dist_attr)
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
# def reset(self, skip_dist_attr_field_names):
# if skip_dist_attr_field_names is not None \
# and "process_mesh" not in skip_dist_attr_field_names:
# self._process_mesh = None
# if skip_dist_attr_field_names is not None \
# and "dims_mapping" not in skip_dist_attr_field_names:
# for i in enumerate(self._dims_mapping):
# self._dims_mapping[i] = -1
# self._is_annotated = {}
def reset(self, skip_dist_attr_field_names=None):
if skip_dist_attr_field_names is None or \
(skip_dist_attr_field_names is not None \
and "process_mesh" not in skip_dist_attr_field_names):
self._process_mesh = None
if skip_dist_attr_field_names is None or \
(skip_dist_attr_field_names is not None \
and "dims_mapping" not in skip_dist_attr_field_names):
for i, _ in enumerate(self._dims_mapping):
self._dims_mapping[i] = -1
self._is_annotated = {}
def is_annotated(self, dist_attr_field_name):
return self._is_annotated.get(dist_attr_field_name, False)
......@@ -272,6 +274,9 @@ class OperatorDistributedAttribute:
dist_attr_object.init(dist_attr)
self._inputs_dist_attrs[name] = dist_attr_object
# def del_input_dist_attr(self, name):
# del self._inputs_dist_attrs[name]
def get_output_dist_attr(self, name):
return self._outputs_dist_attrs.get(name, None)
......@@ -280,6 +285,9 @@ class OperatorDistributedAttribute:
dist_attr_object.init(dist_attr)
self._outputs_dist_attrs[name] = dist_attr_object
# def del_output_dist_attr(self, name):
# del self._inputs_dist_attrs[name]
def get_input_dims_mapping(self, name):
input_dist_attr = self.get_input_dist_attr(name)
if input_dist_attr:
......@@ -374,17 +382,18 @@ class OperatorDistributedAttribute:
"ProcessMeshes in DistributedOperator must be the same."
self.process_mesh = shared_process_mesh
# def reset(self, skip_dist_attr_field_names):
# for tensor_dist_attr in self.inputs_dist_attrs.values():
# tensor_dist_attr.reset(skip_dist_attr_field_names)
# for tensor_dist_attr in self.outputs_dist_attrs.values():
# tensor_dist_attr.reset(skip_dist_attr_field_names)
# if skip_dist_attr_field_names is not None \
# and "process_mesh" not in skip_dist_attr_field_names:
# self.process_mesh = None
# self.impl_type = "default"
# self.impl_idx = 0
# self._is_annotated = {}
def reset(self, skip_dist_attr_field_names=None):
for tensor_dist_attr in self.inputs_dist_attrs.values():
tensor_dist_attr.reset(skip_dist_attr_field_names)
for tensor_dist_attr in self.outputs_dist_attrs.values():
tensor_dist_attr.reset(skip_dist_attr_field_names)
if skip_dist_attr_field_names is None or \
(skip_dist_attr_field_names is not None \
and "process_mesh" not in skip_dist_attr_field_names):
self._process_mesh = None
self.impl_type = "default"
self.impl_idx = 0
self._is_annotated = {}
def is_annotated(self, attr_name):
return self._is_annotated.get(attr_name, False)
......
......@@ -57,33 +57,30 @@ class DistributedContext:
serial_startup_prog=None,
serial_optimizer=None,
serial_loss=None,
feed_vars=None,
fetch_vars=None,
feed_vars={},
fetch_vars={},
cluster=None,
strategy=None):
# Data members related to original programs (unchanged)
self._original_serial_main_program = serial_main_prog
self._original_serial_startup_program = serial_startup_prog
self._original_serial_optimizer = serial_optimizer
self._original_serial_loss = serial_loss
self._original_serial_feed_vars = feed_vars
self._original_serial_fetch_vars = fetch_vars
self._original_serial_optimizer = serial_optimizer
if self._original_serial_main_program is None:
self._original_serial_main_program = paddle.fluid.default_main_program(
)
if self._original_serial_startup_program is None:
self._original_serial_startup_program = paddle.fluid.default_startup_program(
)
# Data members related to programs (changed)
self._serial_main_program = None
self._serial_startup_program = None
self._serial_loss = serial_loss
self._serial_optimizer = serial_optimizer
self._serial_feed_vars = feed_vars
self._serial_fetch_vars = fetch_vars
self._serial_loss = None
self._serial_optimizer = None
self._serial_feed_vars = {}
self._serial_fetch_vars = {}
# Data members related to the program
self._dist_tensors_for_program = {}
self._dist_ops_for_program = {}
self._block_state = BlockState()
# Data members related to the graph
self._serial_graph = None
......@@ -96,24 +93,30 @@ class DistributedContext:
# Distributed programs
self._dist_main_programs = {}
self._dist_startup_programs = {}
self._dist_op_context = DistributedOperatorContext()
self._process_meshes = []
# Distributed Strategy
self._cluster = cluster
self._strategy = strategy
# Pass Context
self._pass_context = PassContext()
# Distributed Operator Context
self._dist_op_context = DistributedOperatorContext()
self._block_state = BlockState()
# Other data members
self._process_meshes = []
self._serial_ordered_tensor_nodes = []
self._serial_ordered_op_nodes = []
self._serial_ordered_nodes = []
# self._tensor_id_to_tensor_node_ids = {}
self._is_initialized = False
self._need_copy_dist_attr_to_graph = False
self._backup_pass_context_stack = []
self._backup_block_state_stack = []
self._backup_dist_tensors_for_program_stack = []
self._backup_dist_ops_for_program_stack = []
self._backup_serial_main_program_stack = []
self._backup_serial_startup_program_stack = []
# flag whether scale gradient with dp size
self._gradient_scale = True
......@@ -122,13 +125,6 @@ class DistributedContext:
def serial_main_program(self):
return self._serial_main_program
@serial_main_program.setter
def serial_main_program(self, program):
# if self._serial_main_program:
# print("WARNING: The program attached to this distributed context will be replaced by the new one.")
self._original_serial_main_program = program
self._serial_main_program = program
@property
def serial_startup_program(self):
return self._serial_startup_program
......@@ -149,6 +145,18 @@ class DistributedContext:
def serial_fetch_vars(self):
return self._serial_fetch_vars
@property
def dist_main_programs(self):
return self._dist_main_programs
@property
def dist_startup_programs(self):
return self._dist_startup_programs
@property
def cluster(self):
return self._cluster
@property
def strategy(self):
return self._strategy
......@@ -177,14 +185,6 @@ class DistributedContext:
def block_state(self):
return self._block_state
@property
def dist_main_programs(self):
return self._dist_main_programs
@property
def dist_startup_programs(self):
return self._dist_startup_programs
@property
def has_annotation(self):
return len(self._dist_tensors_for_program) or len(
......@@ -198,21 +198,168 @@ class DistributedContext:
def gradient_scale(self, gs):
self._gradient_scale = gs
def initialize(self):
if not self._is_initialized:
def _backup_serial_info(self, mode):
self._backup_serial_main_program_stack.append(
self._serial_main_program.clone())
self._backup_serial_startup_program_stack.append(
self._serial_startup_program.clone())
self._backup_pass_context_stack.append(
copy.deepcopy(self._pass_context))
self._backup_block_state_stack.append(copy.deepcopy(self._block_state))
def _backup_dist_info(self, mode):
self._backup_dist_tensors_for_program_stack.append(
copy.deepcopy(self._dist_tensors_for_program))
self._backup_dist_ops_for_program_stack.append(
copy.deepcopy(self._dist_ops_for_program))
def _backup(self, serial=True, serial_mode=None, dist=True, dist_mode=None):
# Use this function carefully
if serial:
self._backup_serial_info(serial_mode)
if dist:
self._backup_dist_info(dist_mode)
def _restore_serial_info(self, mode="to_backup"):
if mode == "to_backup":
self._serial_main_program = self._backup_serial_main_program_stack.pop(
)
self._serial_startup_program = self._backup_serial_startup_program_stack.pop(
)
elif mode == "to_original":
assert self._original_serial_main_program is not None
assert self._original_serial_startup_program is not None
self._serial_main_program = self._original_serial_main_program.clone(
)
self._serial_startup_program = self._original_serial_startup_program.clone(
)
# self._serial_main_program = self._original_serial_main_program
# self._serial_startup_program = self._original_serial_startup_program
if self._original_serial_loss:
self._serial_loss = self._serial_main_program.global_block(
).vars[self._original_serial_loss[0].name]
self._serial_optimizer = self._original_serial_optimizer
if self._original_serial_loss:
if isinstance(self._original_serial_loss, list):
assert len(self._original_serial_loss) == 1
loss = self._original_serial_loss[0]
block_idx = loss.block.idx
var_name = loss.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
self._serial_loss = var
else:
self._serial_loss = self._original_serial_loss
self._serial_optimizer = self._original_serial_optimizer
block_idx = self._original_serial_loss.block.idx
var_name = self._original_serial_loss.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
self._serial_loss = var
for key, var_list in self._original_serial_feed_vars.items():
new_var_list = []
for var in var_list:
block_idx = var.block.idx
var_name = var.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_var_list.append(var)
self._serial_feed_vars[key] = new_var_list
for key, var_list in self._original_serial_fetch_vars.items():
new_var_list = []
for var in var_list:
block_idx = var.block.idx
var_name = var.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_var_list.append(var)
self._serial_fetch_vars[key] = new_var_list
self._pass_context = self._backup_pass_context_stack.pop()
self._block_state = self._backup_block_state_stack.pop()
def _restore_dist_info(self, mode="to_backup"):
if mode == "to_backup":
self._dist_tensors_for_program = self._backup_dist_tensors_for_program_stack.pop(
)
self._dist_ops_for_program = self._backup_dist_ops_for_program_stack.pop(
)
elif mode == "to_original":
assert self._original_dist_tensors_for_program
assert self._original_dist_ops_for_program
self._dist_tensors_for_program = copy.deepcopy(
self._original_dist_tensors_for_program)
self._dist_ops_for_program = copy.deepcopy(
self._original_dist_ops_for_program)
elif mode == "to_default":
new_tensors_ids = []
for tensor_id, dist_tensor in self._dist_tensors_for_program.items(
):
if tensor_id in self._tensors_ids:
dist_tensor.dist_attr.reset()
else:
new_tensors_ids.append(tensor_id)
for tensor_id in new_tensors_ids:
self._dist_tensors_for_program.pop(tensor_id)
new_ops_ids = []
for op_id, dist_op in self._dist_ops_for_program.items():
if op_id in self._ops_ids:
dist_op.dist_attr.reset()
else:
new_ops_ids.append(op_id)
for op_id in new_ops_ids:
self._dist_ops_for_program.pop(op_id)
else:
new_tensors_ids = []
for tensor_id, dist_tensor in self._dist_tensors_for_program.items(
):
new_tensors_ids.append(tensor_id)
for tensor_id in new_tensors_ids:
self._dist_tensors_for_program.pop(tensor_id)
new_ops_ids = []
for op_id, dist_op in self._dist_ops_for_program.items():
new_ops_ids.append(op_id)
for op_id in new_ops_ids:
self._dist_ops_for_program.pop(op_id)
self._dist_main_programs = {}
self._dist_startup_programs = {}
self._dist_op_context = DistributedOperatorContext()
self._need_copy_dist_attr_to_graph = True
self._process_meshes = []
def _restore(self,
serial=True,
serial_mode="to_backup",
dist=True,
dist_mode="to_backup"):
# Use this function carefully
if serial:
self._restore_serial_info(serial_mode)
if dist:
self._restore_dist_info(dist_mode)
def initialize(self):
if not self._is_initialized:
if not self._serial_main_program:
self._serial_main_program = self._original_serial_main_program
if not self._serial_startup_program:
self._serial_startup_program = self._original_serial_startup_program
if not self._serial_loss:
if isinstance(self._original_serial_loss, list):
assert len(self._original_serial_loss) == 1
self._serial_loss = self._original_serial_loss[0]
else:
self._serial_loss = self._original_serial_loss
if not self._serial_optimizer:
self._serial_optimizer = self._original_serial_optimizer
if not self._serial_feed_vars:
self._serial_feed_vars = self._original_serial_feed_vars
if not self._serial_fetch_vars:
self._serial_fetch_vars = self._original_serial_fetch_vars
self._init_dist_attr_for_program()
# Backup the original distributed information for later restore
self._original_dist_tensors_for_program = copy.deepcopy(
self._dist_tensors_for_program)
self._original_dist_ops_for_program = copy.deepcopy(
self._dist_ops_for_program)
self._tensors_ids = list(self._dist_tensors_for_program.keys())
self._ops_ids = list(self._dist_ops_for_program.keys())
set_flags({"FLAGS_convert_all_blocks": True})
......@@ -220,41 +367,9 @@ class DistributedContext:
core.Graph(self._serial_main_program.desc))
self._init_dist_attr_for_graph()
self._is_initialized = True
# def reset(self,
# skip_dist_tensors=None,
# skip_dist_ops=None,
# skip_tensor_dist_attr_fields=None,
# skip_op_dist_attr_fields=None):
# self._serial_main_program = self._original_serial_main_program.clone()
# self._serial_startup_program = self._original_serial_startup_program.clone()
# new_tensors_ids = []
# for tensor_id, dist_tensor in self._dist_tensors_for_program.items():
# if tensor_id in self._tensors_ids:
# dist_tensor.dist_attr.reset(skip_tensor_dist_attr_fields)
# else:
# new_tensors_ids.append(tensor_id)
# for tensor_id in new_tensors_ids:
# self._dist_tensors_for_program.pop(tensor_id)
# new_ops_ids = []
# for op_id, dist_op in self._dist_ops_for_program.items():
# if op_id in self._ops_ids:
# dist_op.dist_attr.reset(skip_op_dist_attr_fields)
# else:
# new_ops_ids.append(op_id)
# for op_id in new_ops_ids:
# self._dist_ops_for_program.pop(op_id)
# self.copy_dist_attr_from_program_to_graph()
# self._dist_main_programs = {}
# self._dist_startup_programs = {}
# self._pass_context = PassContext()
# self._dist_op_context = DistributedOperatorContext()
# self._process_meshes = []
self._need_copy_dist_attr_to_graph = False
if self._need_copy_dist_attr_to_graph:
self.copy_dist_attr_from_program_to_graph()
def add_process_mesh(self, process_mesh):
assert isinstance(process_mesh, ProcessMesh), \
......@@ -423,6 +538,10 @@ class DistributedContext:
if current_dist_op is None:
dist_op = DistributedOperator(op)
self.add_dist_op_for_program(dist_op)
self._original_dist_tensors_for_program = copy.deepcopy(
self._dist_tensors_for_program)
self._original_dist_ops_for_program = copy.deepcopy(
self._dist_ops_for_program)
def _order_nodes_by_program_order(self):
def _contains(nodes, target_node):
......@@ -592,7 +711,7 @@ class DistributedContext:
op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node)
dist_op_for_program = self._dist_ops_for_program[op_id]
dist_op_for_program.dist_attr = op_dist_attr_for_graph
# TODO: the completion algorithm will skip orphan tensors,
# TODO: the completion algorithm will skipped orphan tensors,
# here we just set there process_mesh to the first one.
for orphan_node in self._serial_orphan_tensor_nodes:
serial_tensor_id = orphan_node.var().id()
......@@ -618,16 +737,21 @@ class DistributedContext:
tensor_shape = serial_tensor.shape
dims_mapping = dist_attr.dims_mapping
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_processes = dist_attr.process_mesh.processes
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
if dims_mapping[i] != -1 and len(process_mesh_processes) == 1:
dims_mapping[i] = -1
for dist_op in self._dist_ops_for_program.values():
serial_op = dist_op.serial_op
dist_attr = dist_op.dist_attr
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_processes = dist_attr.process_mesh.processes
for arg_name in serial_op.input_arg_names:
if dist_op.get_serial_input(arg_name) is None:
tensor_shape = []
......@@ -639,13 +763,15 @@ class DistributedContext:
else:
tensor_shape = dist_op.get_serial_input(arg_name).shape
dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
process_mesh_shape = dist_attr.process_mesh.topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
if dims_mapping[i] != -1 and len(
process_mesh_processes) == 1:
dims_mapping[i] = -1
for arg_name in serial_op.output_arg_names:
if dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.READER \
or dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
......@@ -654,13 +780,18 @@ class DistributedContext:
else:
tensor_shape = dist_op.get_serial_output(arg_name).shape
dims_mapping = dist_attr.get_output_dims_mapping(arg_name)
process_mesh_shape = dist_attr.process_mesh.topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
if dims_mapping[i] != -1 and len(
process_mesh_processes) == 1:
dims_mapping[i] = -1
if len(process_mesh_processes) == 1:
dist_op.dist_attr.impl_type = "default"
dist_op.dist_attr.impl_idx = 0
def validate_dist_attr_for_program(self):
if not self._is_initialized:
......@@ -674,16 +805,20 @@ class DistributedContext:
dist_tensor.serial_tensor.name)
if (dist_tensor is not None) and (
not dist_tensor.validate_dist_attr()):
assert False, "Tensor {} has a wrong distributed attributes {}.".format(
dist_tensor.serial_tensor.name, dist_tensor.dist_attr)
assert False, "Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}.".format(
dist_tensor.serial_tensor.name,
dist_tensor.desc.id(),
dist_tensor.desc.original_id(), dist_tensor.dist_attr)
for op in block.ops:
dist_op = self.get_dist_op_for_program(op)
assert dist_op is not None, \
"Operator {} does not have a distributed attribute.".format(
dist_op.serial_op.type)
if (dist_op is not None) and (not dist_op.validate_dist_attr()):
assert False, "Operator {} has a wrong distributed attributes {}.".format(
dist_op.serial_op.type, dist_op.dist_attr)
assert False, "Operator {} (id: {}, original_id: {}) has a wrong distributed attributes {} .".format(
dist_op.serial_op.type,
dist_op.serial_op.desc.id(),
dist_op.serial_op.desc.original_id(), dist_op.dist_attr)
return True
def __deepcopy__(self, memo):
......
......@@ -41,7 +41,7 @@ class DistributedTensor:
rank=None,
shard_sizes=None):
if not (isinstance(sizes, (list, tuple)) and
all(map(lambda x: isinstance(x, int) and x > 0, sizes))):
all(map(lambda x: isinstance(x, int) and x >= 0, sizes))):
raise ValueError(
"The sizes must be list or tuple and item in sizes must be non-negative integer, but got {}".
format(sizes))
......@@ -79,8 +79,11 @@ class DistributedTensor:
local_sizes = []
# for even sharding, the local sizes of every rank are equal
for idx, item in enumerate(global_sizes):
if dims_mapping[idx] == -1:
# This is a trick to avoid dims_mapping is []
val = dims_mapping[idx] if idx < len(dims_mapping) else -1
if val == -1:
local_sizes.append(item)
else:
local_sizes.append(item // topology[dims_mapping[idx]])
......
......@@ -31,10 +31,11 @@ from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Operator
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import fleet
from paddle.distributed.utils import get_logger
from paddle.distributed.passes import new_pass, PassContext
from .cluster import Cluster
# from .cluster import Cluster, get_default_cluster
from .planner_v2 import Planner
from .parallelizer_v2 import Parallelizer
from .dist_op import DistributedOperator
......@@ -57,7 +58,11 @@ class Engine:
self.inputs_spec = self._validate_spec(inputs_spec)
self.labels_spec = self._validate_spec(labels_spec)
self.cluster = cluster
# if self.cluster is None:
# self.cluster = get_default_cluster()
self.strategy = strategy
if self.strategy is None:
self.strategy = fleet.DistributedStrategy()
self._executor = None
self._cur_rank = paddle.distributed.get_rank()
......@@ -69,11 +74,11 @@ class Engine:
self._orig_main_prog = fluid.default_main_program()
self._orig_startup_prog = fluid.default_startup_program()
self._orig_dist_context = get_default_distributed_context()
self._dist_contexts = {}
self._serial_main_progs = {}
self._serial_startup_progs = {}
self._dist_main_progs = defaultdict(dict) # dist main programs
self._dist_startup_progs = defaultdict(dict) # dist startup programs
self._dist_contexts = {}
self._feed_vars = {}
self._fetch_vars = {}
......@@ -104,11 +109,17 @@ class Engine:
parallelizer.parallel(self._cur_rank)
else:
parallelizer.parallel_all()
# Get the distributed main programs and startup programs
# Get the current content from the distributed context
self._serial_main_progs[mode] = self._dist_contexts[
mode].serial_main_program
self._serial_startup_progs[mode] = self._dist_contexts[
mode].serial_startup_program
self._dist_main_progs[mode] = self._dist_contexts[
mode].dist_main_programs
self._dist_startup_progs[mode] = self._dist_contexts[
mode].dist_startup_programs
self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
# Init comm and startup program
self._initialize(mode)
......@@ -135,20 +146,23 @@ class Engine:
inputs = [self._set_data_parallel(var) for var in inputs]
labels = [self._set_data_parallel(var) for var in labels]
self._feed_vars[mode] = {"inputs": inputs, "labels": labels}
# self._feed_vars[mode] = {"inputs": inputs, "labels": labels}
feed_vars = {"inputs": inputs, "labels": labels}
self._fetch_vars[mode] = {
# self._fetch_vars[mode] = {
# "outputs": flatten(outputs),
# "loss": losses,
# "metrics": metrics
# }
fetch_vars = {
"outputs": flatten(outputs),
"loss": losses,
"metrics": metrics
}
self._serial_main_progs[mode] = serial_main_prog
self._serial_startup_progs[mode] = serial_startup_prog
self._dist_contexts[mode] = DistributedContext(
self._serial_main_progs[mode], self._serial_startup_progs[mode],
self._optimizer, losses, self._feed_vars[mode],
self._fetch_vars[mode], self.strategy)
serial_main_prog, serial_startup_prog, self._optimizer, losses,
feed_vars, fetch_vars, self.cluster, self.strategy)
self._dist_contexts[mode].gradient_scale = self._gradient_scale
def _initialize(self, mode):
......
......@@ -16,7 +16,7 @@ from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import find_best_compatible_distributed_operator_impl
from .common import find_compatible_distributed_operator_impls
from . import dist_embedding
from . import dist_matmul
from . import dist_reshape
......
......@@ -157,9 +157,7 @@ def register_distributed_operator_impl(op_type, dist_impl):
assert False, "Must register distributed operator registry first."
def find_best_compatible_distributed_operator_impl(dist_op,
fwd=True,
partial=True):
def find_compatible_distributed_operator_impls(dist_op, fwd=True, partial=True):
"""
Here just return the first compatible implemention.
This will be improved by cost model in the future.
......
......@@ -187,7 +187,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if serial_tensor.is_parameter:
if serial_tensor is not None and serial_tensor.is_parameter:
for mapping in dims_mapping:
if mapping != -1:
return False
......@@ -217,7 +217,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if serial_tensor.is_parameter:
if serial_tensor is not None and serial_tensor.is_parameter:
for mapping in dims_mapping:
if mapping != -1:
return False
......
......@@ -22,7 +22,6 @@ from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import set_comm_op_dist_attr_for_program
from .dist_default import DistributedDefaultImpl0
from ..reshard import Resharder
from ..process_group import new_process_group
from ..utils import is_dim_shard, is_dim_replicate, _get_corresponding_rank
from ..utils import compute_compatible_dim_mapping, set_dist_op_desc_original_id, _get_comm_group
......@@ -324,6 +323,8 @@ class DistributedPNormImpl(DistributedOperatorImpl):
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
dims_mapping = [0] + [-1 for _ in range(len(new_X_grad.shape) - 1)]
from ..reshard import Resharder
partition_idx = Resharder.compute_partition_index(
rank_id, new_X_grad.shape, dims_mapping, process_mesh_shape,
process_mesh_group)
......
......@@ -35,7 +35,7 @@ class Parallelizer:
self._mode = mode
self._completer = completer
self._dist_context = dist_context
self._dist_context.initialize()
assert self._dist_context._is_initialized
self._pass_context = self._dist_context.pass_context
self._strategy = self._dist_context.strategy
......@@ -43,7 +43,9 @@ class Parallelizer:
world_process_group = get_world_process_group()
all_ranks = world_process_group.ranks
for rank in all_ranks:
# self._dist_context._backup(serial=True, dist=True)
self.parallel(rank)
# self._dist_context._restore(serial=True, dist=True)
def parallel(self, rank):
serial_main_program = self._dist_context.serial_main_program
......@@ -58,6 +60,7 @@ class Parallelizer:
self._apply_pre_optimization(serial_main_program,
serial_startup_program, serial_loss,
serial_optimizer, params_grads)
# Do logical partition
partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
......@@ -85,7 +88,6 @@ class Parallelizer:
resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
self._dist_context, [], 1)
resharder.reshard()
# Clone program for test
if self._mode != 'train':
dist_main_prog = dist_main_prog.clone(for_test=True)
......
......@@ -16,6 +16,8 @@ from .completion import Completer
from .dist_context import get_default_distributed_context
from .utils import print_program_with_dist_attr
# from .tuner.parallel_tuner import ParallelTuner
class Planner:
def __init__(self, mode, dist_context):
......@@ -24,19 +26,28 @@ class Planner:
# NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need
# dependency of backward-forward ops in forward completion.
# TODO: The id mapping will be lost if we clone the original program.
default_ctx = get_default_distributed_context()
self._dist_context._dist_op_context = default_ctx.dist_op_context
self._dist_context.initialize()
self._completer = Completer(self._dist_context)
self._strategy = dist_context.strategy
# if self._strategy.auto_search:
# self._parallel_tuner = ParallelTuner(
# self._dist_context, mode=self._mode)
@property
def completer(self):
return self._completer
def plan(self):
self._completer.complete_forward_annotation()
# if self._strategy.auto_search:
# self._parallel_tuner.tune()
# else:
# self._completer.complete_forward_annotation()
# parse forward sub block
self._dist_context.block_state.parse_forward_blocks(
self._dist_context.serial_main_program)
# TODO: add the auto searcher
......@@ -324,10 +324,13 @@ def _get_corresponding_rank(dist_context, target_mesh, rank):
mesh.processes.index(rank))
break
assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format(
rank)
return target_mesh.processes[_coordinate2linear_idx(mesh.topology,
coordinate)]
# assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format(
# rank)
if coordinate is not None:
return target_mesh.processes[_coordinate2linear_idx(mesh.topology,
coordinate)]
else:
return target_mesh.processes[0]
def _get_unshard_dist_shape(var, dist_attr):
......
......@@ -31,4 +31,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_cluster MODULES test_cluster ENVS ${dist_ENVS})
py_test_modules(test_comm_cost MODULES test_comm_cost ENVS ${dist_ENVS})
py_test_modules(test_comp_cost MODULES test_comp_cost ENVS ${dist_ENVS})
py_test_modules(test_dist_context MODULES test_dist_context ENVS ${dist_ENVS})
endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import json
import paddle
import numpy as np
import paddle.nn as nn
import paddle.utils as utils
import paddle.static as static
import paddle.nn.functional as F
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
batch_size = 4
hidden_size = 1024
sequence_len = 512
_g_process_mesh = [[0, 1], [2, 3]]
def get_random_inputs_and_labels(input_shape, label_shape):
input = np.random.random(size=input_shape).astype('float32')
label = np.random.random(size=label_shape).astype('float32')
return input, label
def batch_generator_creator():
def __reader__():
for _ in range(batch_size):
batch_input, batch_label = get_random_inputs_and_labels(
[batch_size, sequence_len, hidden_size],
[batch_size, sequence_len, 1])
yield batch_input, batch_label
return __reader__
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
param_initializer = nn.initializer.Normal(
mean=0.0, std=initializer_range)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.linear0 = nn.Linear(
d_model,
dim_feedforward,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None)
self.linear1 = nn.Linear(
dim_feedforward,
d_model,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None)
def forward(self, input):
out = self.norm(input)
auto.shard_tensor(
self.linear0.weight,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [-1, 0]
})
out = self.linear0(out)
out = F.gelu(out, approximate=True)
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": _g_process_mesh[1],
"dims_mapping": [0, -1]
})
out = self.linear1(out)
return out
def get_program():
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
# fleet.init(is_collective=True, strategy=dist_strategy)
train_program = static.Program()
start_program = static.Program()
with static.program_guard(train_program, start_program):
# input
input = static.data(
name="input",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
label = static.data(
name="label", shape=[batch_size, sequence_len, 1], dtype='float32')
data_holder = [input, label]
# dataloader
dataloader = paddle.io.DataLoader.from_generator(
feed_list=data_holder, capacity=4 * batch_size, iterable=False)
dataloader.set_batch_generator(
batch_generator_creator(), places=paddle.static.cuda_places())
# data dist_attr
auto.shard_tensor(
input,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [0, -1, -1]
})
auto.shard_tensor(
label,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [0, -1, -1]
})
mlp_start = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
pred = mlp_start(input)
mlp_mid = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
pred = mlp_mid(pred)
mlp_end = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
pred = mlp_end(pred)
error_cost = paddle.nn.functional.square_error_cost(pred, label)
loss = paddle.mean(error_cost)
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
feed_vars = {"inputs": [input], "labels": [label]}
fetch_vars = {"loss": [loss]}
return train_program, start_program, dataloader, loss, optimizer, feed_vars, fetch_vars
class TestDistributedContext(unittest.TestCase):
def test_backup_restore(self):
train_program, start_program, dataloader, loss, optimizer, feed_vars, fetch_vars = get_program(
)
dist_context = DistributedContext(train_program, start_program,
optimizer, loss, feed_vars,
fetch_vars)
dist_context.initialize()
dist_context._backup(serial=True, dist=True)
dist_context._restore(
serial=True,
serial_mode="to_backup",
dist=True,
dist_mode="to_backup")
dist_context._backup(serial=True, dist=True)
dist_context._restore(
serial=True,
serial_mode="to_original",
dist=True,
dist_mode="to_original")
dist_context._backup(serial=True, dist=True)
dist_context._restore(serial=True, dist=True, dist_mode="to_default")
dist_context._backup(serial=True, dist=True)
dist_context._restore(serial=True, dist=True, dist_mode="to_nothing")
if __name__ == "__main__":
unittest.main()
......@@ -94,7 +94,8 @@ class TestDistSlice(unittest.TestCase):
ops = dist_main_prog.global_block().ops
for op in ops:
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_type == "slice"
# We amend this impl_type after completion
assert op_dist_attr.impl_type == "default"
for out in op.output_arg_names:
var_dims_mapping = op_dist_attr.get_output_dims_mapping(out)
ref_dims_mapping = [-1 for i in range(len(var_dims_mapping))]
......
......@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import make_data_unshard
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context
from paddle.distributed.auto_parallel.operators import find_best_compatible_distributed_operator_impl
from paddle.distributed.auto_parallel.operators import find_compatible_distributed_operator_impls
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
......
......@@ -28,7 +28,7 @@ from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import make_data_unshard
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context
from paddle.distributed.auto_parallel.operators import find_best_compatible_distributed_operator_impl
from paddle.distributed.auto_parallel.operators import find_compatible_distributed_operator_impls
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册