未验证 提交 bed9aaea 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Improve the codes of the completion and distributed context (#40671)

* [Auto Parallel] Replace the old planner by the new partition tuner

* [Auto Parallel] Improve the completion and distributed context

* [Auto Parallel] Fix some bugs of the compatible check of some dist ops

* [Auto Parallel] Fix some bugs
上级 afcf6bd0
...@@ -123,6 +123,19 @@ def merge_process_mesh_two(pm1, pm2): ...@@ -123,6 +123,19 @@ def merge_process_mesh_two(pm1, pm2):
return merged_process_mesh return merged_process_mesh
def _validate_dims_mapping(dims_mapping, process_mesh):
if dims_mapping is None:
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
process_mesh.topology):
return False
for i in range(len(process_mesh.topology)):
if dims_mapping.count(i) > 1:
return False
return True
class Completer: class Completer:
def __init__(self, dist_context): def __init__(self, dist_context):
assert dist_context is not None assert dist_context is not None
...@@ -161,6 +174,9 @@ class Completer: ...@@ -161,6 +174,9 @@ class Completer:
dims_mapping_list.append(tensor_dims_mapping) dims_mapping_list.append(tensor_dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping( compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list) dims_mapping_list)
if not _validate_dims_mapping(compatible_dims_mapping,
tensor_dist_attr.process_mesh):
return False
if (compatible_dims_mapping is not None) and \ if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != tensor_dims_mapping): (compatible_dims_mapping != tensor_dims_mapping):
tensor_dist_attr.dims_mapping = compatible_dims_mapping tensor_dist_attr.dims_mapping = compatible_dims_mapping
...@@ -182,6 +198,9 @@ class Completer: ...@@ -182,6 +198,9 @@ class Completer:
dims_mapping_list.append(tensor_dims_mapping) dims_mapping_list.append(tensor_dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping( compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list) dims_mapping_list)
if not _validate_dims_mapping(compatible_dims_mapping,
tensor_dist_attr.process_mesh):
return False
if (compatible_dims_mapping is not None) and \ if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != tensor_dims_mapping): (compatible_dims_mapping != tensor_dims_mapping):
tensor_dist_attr.dims_mapping = compatible_dims_mapping tensor_dist_attr.dims_mapping = compatible_dims_mapping
...@@ -196,10 +215,12 @@ class Completer: ...@@ -196,10 +215,12 @@ class Completer:
op_desc = op_node.op() op_desc = op_node.op()
if op_desc.type() == "create_py_reader" \ if op_desc.type() == "create_py_reader" \
or op_desc.type() == "create_double_buffer_reader" \ or op_desc.type() == "create_double_buffer_reader" \
or op_desc.type() == "while" \
or op_desc.type() == "read": or op_desc.type() == "read":
return False return False
dist_op = self._dist_context.get_dist_op_for_graph(op_node) dist_op = self._dist_context.get_dist_op_for_graph(op_node)
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
original_op_dist_attr = copy.deepcopy(op_dist_attr)
if fwd: if fwd:
for tensor_node in op_node.inputs: for tensor_node in op_node.inputs:
if tensor_node.is_var() and tensor_node.var() is not None: if tensor_node.is_var() and tensor_node.var() is not None:
...@@ -223,18 +244,34 @@ class Completer: ...@@ -223,18 +244,34 @@ class Completer:
tensor_desc.name(), compatible_dims_mapping) tensor_desc.name(), compatible_dims_mapping)
changed = True changed = True
# Find the most compatible implemenetations from the distributed operator # Find the most compatible implemenetations from the distributed operator
op_dist_impl = find_best_compatible_distributed_operator_impl( op_dist_impls = find_best_compatible_distributed_operator_impl(
dist_op, fwd=True) dist_op, fwd=True)
if op_dist_impl is not None: if op_dist_impls is not None:
dim_changed = op_dist_impl.update_dims_mapping(dist_op) not_compatible = True
if dim_changed: backup_op_dist_attr = copy.deepcopy(op_dist_attr)
changed = True backup_changed = changed
if op_dist_impl.is_auto_compatible(dist_op): for op_dist_impl in op_dist_impls:
if op_dist_impl.type == "elementwise": dim_changed = op_dist_impl.update_dims_mapping(dist_op)
op_dist_attr.impl_type = "default" if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
op_dist_attr.impl_type = op_dist_impl.type
# op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
not_compatible = False
break
else: else:
op_dist_attr.impl_type = op_dist_impl.type dist_op.dist_attr = backup_op_dist_attr
op_dist_attr.impl_idx = op_dist_impl.idx changed = backup_changed
if not_compatible:
dist_op.dist_attr = original_op_dist_attr
changed = False
else:
dist_op.dist_attr = original_op_dist_attr
changed = False
else: else:
for tensor_node in op_node.outputs: for tensor_node in op_node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None: if tensor_node.is_var() and tensor_node.var() is not None:
...@@ -258,18 +295,35 @@ class Completer: ...@@ -258,18 +295,35 @@ class Completer:
tensor_desc.name(), compatible_dims_mapping) tensor_desc.name(), compatible_dims_mapping)
changed = True changed = True
# Find the most compatible implemenetations from the distributed operator # Find the most compatible implemenetations from the distributed operator
op_dist_impl = find_best_compatible_distributed_operator_impl( op_dist_impls = find_best_compatible_distributed_operator_impl(
dist_op, fwd=False) dist_op, fwd=False)
if op_dist_impl is not None: if op_dist_impls is not None:
dim_changed = op_dist_impl.update_dims_mapping(dist_op) not_compatible = True
if dim_changed: backup_op_dist_attr = copy.deepcopy(op_dist_attr)
changed = True backup_changed = changed
if op_dist_impl.is_auto_compatible(dist_op): for op_dist_impl in op_dist_impls:
if op_dist_impl.type == "elementwise": dim_changed = op_dist_impl.update_dims_mapping(dist_op)
op_dist_attr.impl_type = "default" if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
not_compatible = False
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
op_dist_attr.impl_type = op_dist_impl.type
# op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
not_compatible = False
break
else: else:
op_dist_attr.impl_type = op_dist_impl.type dist_op.dist_attr = backup_op_dist_attr
op_dist_attr.impl_idx = op_dist_impl.idx changed = backup_changed
if not_compatible:
dist_op.dist_attr = original_op_dist_attr
changed = False
else:
dist_op.dist_attr = original_op_dist_attr
changed = False
return changed return changed
def _update_dims_mapping_between_graphs(self): def _update_dims_mapping_between_graphs(self):
...@@ -279,17 +333,22 @@ class Completer: ...@@ -279,17 +333,22 @@ class Completer:
parent_node) parent_node)
child_node_dist_attr = self._dist_context.get_dist_attr_for_graph( child_node_dist_attr = self._dist_context.get_dist_attr_for_graph(
child_node) child_node)
if parent_node_dist_attr.process_mesh != child_node_dist_attr.process_mesh:
continue
parent_node_dims_mapping = parent_node_dist_attr.dims_mapping parent_node_dims_mapping = parent_node_dist_attr.dims_mapping
child_node_dims_mapping = child_node_dist_attr.dims_mapping child_node_dims_mapping = child_node_dist_attr.dims_mapping
compatible_dims_mapping = compute_compatible_dims_mapping( compatible_dims_mapping = compute_compatible_dims_mapping(
[parent_node_dims_mapping, child_node_dims_mapping]) [parent_node_dims_mapping, child_node_dims_mapping])
if not _validate_dims_mapping(compatible_dims_mapping,
parent_node_dist_attr.process_mesh):
return False
if (compatible_dims_mapping is not None) \ if (compatible_dims_mapping is not None) \
and (compatible_dims_mapping != parent_node_dims_mapping): and (compatible_dims_mapping != parent_node_dims_mapping):
parent_node_dist_attr.dims_mapping = compatible_dims_mapping parent_node_dist_attr.dims_mapping = compatible_dims_mapping
changed = True changed = True
if (compatible_dims_mapping is not None) \ if (compatible_dims_mapping is not None) \
and (compatible_dims_mapping != child_node_dims_mapping): and (compatible_dims_mapping != child_node_dims_mapping):
parent_node_dist_attr.dims_mapping = compatible_dims_mapping child_node_dist_attr.dims_mapping = compatible_dims_mapping
changed = True changed = True
return changed return changed
...@@ -351,7 +410,7 @@ class Completer: ...@@ -351,7 +410,7 @@ class Completer:
if compatible_process_mesh is not None \ if compatible_process_mesh is not None \
and tensor_dist_attr.process_mesh != compatible_process_mesh: and tensor_dist_attr.process_mesh != compatible_process_mesh:
tensor_dist_attr.process_mesh = compatible_process_mesh tensor_dist_attr.process_mesh = compatible_process_mesh
# Set the process mesh of the op node's outputs # Set the process mesh of the op node's outputs
for tensor_node in op_node.outputs: for tensor_node in op_node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None: if tensor_node.is_var() and tensor_node.var() is not None:
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
...@@ -389,7 +448,8 @@ class Completer: ...@@ -389,7 +448,8 @@ class Completer:
if _node_id(cur) in visited: if _node_id(cur) in visited:
continue continue
# TODO: need more restrictions # TODO: need more restrictions
for node in cur.inputs: neighbors = cur.inputs + cur.outputs
for node in neighbors:
if node.is_var() and node.var() is not None: if node.is_var() and node.var() is not None:
if node.var().type() != core.VarDesc.VarType.READER \ if node.var().type() != core.VarDesc.VarType.READER \
and len(node.var().shape()) == 1: and len(node.var().shape()) == 1:
...@@ -421,10 +481,29 @@ class Completer: ...@@ -421,10 +481,29 @@ class Completer:
visited.add(_node_id(cur)) visited.add(_node_id(cur))
return related_nodes return related_nodes
def _make_dims_mapping_replicate(dist_attr):
if isinstance(dist_attr, TensorDistributedAttribute):
for i, _ in enumerate(dist_attr.dims_mapping):
dist_attr.dims_mapping[i] = -1
if isinstance(dist_attr, OperatorDistributedAttribute):
for arg_name in dist_attr.inputs_dist_attrs.keys():
new_dims_mapping = []
dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
for _ in dims_mapping:
new_dims_mapping.append(-1)
dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping)
for arg_name in dist_attr.outputs_dist_attrs.keys():
new_dims_mapping = []
dims_mapping = dist_attr.get_output_dims_mapping(arg_name)
for _ in dims_mapping:
new_dims_mapping.append(-1)
dist_attr.set_output_dims_mapping(arg_name,
new_dims_mapping)
# Amend the process meshes related to while_op # Amend the process meshes related to while_op
for while_op_node, while_op_node_idx in self._while_op_nodes.values(): for while_op_node, while_op_node_idx in self._while_op_nodes.values():
sub_graph_id = while_op_node.op()._block_attr_id("sub_block") sub_graph_id = while_op_node.op()._block_attr_id("sub_block")
sub_graph = self._dist_context._serial_graph.get_sub_graph( sub_graph = self._dist_context.serial_graph.get_sub_graph(
sub_graph_id) sub_graph_id)
sub_graph_nodes = list(sub_graph.all_nodes()) sub_graph_nodes = list(sub_graph.all_nodes())
while_dist_op = self._dist_context.get_dist_op_for_graph( while_dist_op = self._dist_context.get_dist_op_for_graph(
...@@ -440,6 +519,7 @@ class Completer: ...@@ -440,6 +519,7 @@ class Completer:
merged_process_mesh = merge_process_mesh_two( merged_process_mesh = merge_process_mesh_two(
merged_process_mesh, dist_attr.process_mesh) merged_process_mesh, dist_attr.process_mesh)
while_op_dist_attr.process_mesh = merged_process_mesh while_op_dist_attr.process_mesh = merged_process_mesh
_make_dims_mapping_replicate(while_op_dist_attr)
# Step 2: set the related nodes of while_op to the process mesh of while_op # Step 2: set the related nodes of while_op to the process mesh of while_op
# Step 2.1: Find related nodes of cond var the graph of while_op # Step 2.1: Find related nodes of cond var the graph of while_op
...@@ -480,6 +560,7 @@ class Completer: ...@@ -480,6 +560,7 @@ class Completer:
tensor_dist_attr = self._dist_context.get_dist_attr_for_graph( tensor_dist_attr = self._dist_context.get_dist_attr_for_graph(
node) node)
tensor_dist_attr.process_mesh = merged_process_mesh tensor_dist_attr.process_mesh = merged_process_mesh
_make_dims_mapping_replicate(tensor_dist_attr)
# Step 3: set the process meshes of the inputs in while_op to the process meshes of the outside input nodes # Step 3: set the process meshes of the inputs in while_op to the process meshes of the outside input nodes
while_op_inputs_dist_attrs = while_op_dist_attr.inputs_dist_attrs while_op_inputs_dist_attrs = while_op_dist_attr.inputs_dist_attrs
...@@ -519,6 +600,25 @@ class Completer: ...@@ -519,6 +600,25 @@ class Completer:
dist_attr = self._dist_context.get_dist_attr_for_graph( dist_attr = self._dist_context.get_dist_attr_for_graph(
array_node) array_node)
dist_attr.process_mesh = merged_process_mesh dist_attr.process_mesh = merged_process_mesh
_make_dims_mapping_replicate(dist_attr)
def _update_process_mesh_between_graphs(self):
for parent_node, child_node in self._node_pairs_between_graphs:
parent_node_dist_attr = self._dist_context.get_dist_attr_for_graph(
parent_node)
child_node_dist_attr = self._dist_context.get_dist_attr_for_graph(
child_node)
parent_node_dist_attr.process_mesh = child_node_dist_attr.process_mesh
compatible_process_mesh = compute_compatible_process_mesh([
parent_node_dist_attr.process_mesh,
child_node_dist_attr.process_mesh
])
if compatible_process_mesh is not None \
and parent_node_dist_attr.process_mesh != compatible_process_mesh:
parent_node_dist_attr.process_mesh = compatible_process_mesh
if compatible_process_mesh is not None \
and child_node_dist_attr.process_mesh != compatible_process_mesh:
child_node_dist_attr.process_mesh = compatible_process_mesh
def _update_process_mesh(self): def _update_process_mesh(self):
ordered_op_nodes = self._dist_context._serial_ordered_op_nodes ordered_op_nodes = self._dist_context._serial_ordered_op_nodes
...@@ -569,7 +669,7 @@ class Completer: ...@@ -569,7 +669,7 @@ class Completer:
return None return None
for idx, op_node in enumerate(ordered_op_nodes[ for idx, op_node in enumerate(ordered_op_nodes[
idx_of_first_op_node_has_process_mesh + 1:]): idx_of_first_op_node_has_process_mesh + 1:]):
original_idx = idx_of_first_op_node_has_process_mesh + +idx + 1 original_idx = idx_of_first_op_node_has_process_mesh + idx + 1
nearest_op_node = ordered_op_nodes[original_idx - 1] nearest_op_node = ordered_op_nodes[original_idx - 1]
nearest_op_dist_attr = self._dist_context.get_dist_attr_for_graph( nearest_op_dist_attr = self._dist_context.get_dist_attr_for_graph(
nearest_op_node) nearest_op_node)
...@@ -585,6 +685,9 @@ class Completer: ...@@ -585,6 +685,9 @@ class Completer:
# Step 3: adjust the process meshes for special ops # Step 3: adjust the process meshes for special ops
self._update_process_mesh_for_specials() self._update_process_mesh_for_specials()
# Step 4: adjust the process meshes between graphs
self._update_process_mesh_between_graphs()
def _prepare(self): def _prepare(self):
self._while_op_nodes = {} self._while_op_nodes = {}
self._array_nodes = {} self._array_nodes = {}
...@@ -620,7 +723,7 @@ class Completer: ...@@ -620,7 +723,7 @@ class Completer:
self._node_pairs_between_graphs.append( self._node_pairs_between_graphs.append(
(after_node, node)) (after_node, node))
def complete_forward_annotation(self, serial_main_program): def complete_forward_annotation(self, serial_main_program=None):
""" Complete annotation for the partial annotated serial_main_program. """ Complete annotation for the partial annotated serial_main_program.
Arguments: Arguments:
serial_main_program: partial annotated serial_main_program. serial_main_program: partial annotated serial_main_program.
...@@ -628,15 +731,12 @@ class Completer: ...@@ -628,15 +731,12 @@ class Completer:
serial_main_program: completed annotated serial_main_program. serial_main_program: completed annotated serial_main_program.
""" """
# Use the default distribted context for completeion if there is no one if serial_main_program is None:
self._dist_context.serial_program = serial_main_program serial_main_program = self._dist_context.serial_main_program
else:
# Initialize distributed attributes for all var and op node in serial_main_program self._dist_context.serial_main_program = serial_main_program
self._dist_context.init_dist_attr_for_program()
# print_program_with_dist_attr(serial_main_program, self._dist_context)
# Initialize distributed attributes for all var and op node in graph self._dist_context.initialize()
self._dist_context.init_dist_attr_for_graph()
self._prepare() self._prepare()
...@@ -646,10 +746,9 @@ class Completer: ...@@ -646,10 +746,9 @@ class Completer:
# Copy the corresponding distributed attribute from graph to serial_main_program # Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program() self._dist_context.copy_dist_attr_from_graph_to_program()
self._dist_context.clear_dist_info_for_graph()
# NOTE:[HighOrderGrad] update vars and ops distributed attribute in high order gradient # NOTE:[HighOrderGrad] update vars and ops distributed attribute in high order gradient
self.complete_high_order_grad_annotation(serial_main_program) self._complete_high_order_grad_annotation(serial_main_program)
# Do the validation check and amend some completion # Do the validation check and amend some completion
self._dist_context.amend_dist_attr_for_program() self._dist_context.amend_dist_attr_for_program()
...@@ -658,7 +757,7 @@ class Completer: ...@@ -658,7 +757,7 @@ class Completer:
return serial_main_program return serial_main_program
def complete_high_order_grad_annotation(self, serial_main_program): def _complete_high_order_grad_annotation(self, serial_main_program):
""" """
NOTE: NOTE:
[HighOrderGrad] Complete the annotation of vars and ops only for high order gradient. [HighOrderGrad] Complete the annotation of vars and ops only for high order gradient.
...@@ -818,6 +917,10 @@ class Completer: ...@@ -818,6 +917,10 @@ class Completer:
def complete_backward_annotation(self, serial_main_program): def complete_backward_annotation(self, serial_main_program):
"""Complete the annotation of vars and ops in the backward phase for parallel program.""" """Complete the annotation of vars and ops in the backward phase for parallel program."""
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context.serial_main_program = serial_main_program
def _is_grad_var_name(name): def _is_grad_var_name(name):
if "@GRAD" in name: if "@GRAD" in name:
...@@ -1036,8 +1139,12 @@ class Completer: ...@@ -1036,8 +1139,12 @@ class Completer:
self._dist_context.set_op_dist_attr_for_program( self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr) grad_op, grad_op_dist_attr)
def complete_update_annotation(self, serial_main_program): def complete_update_annotation(self, serial_main_program=None):
"""Complete the annotation of vars and ops in the update phase for parallel program.""" """Complete the annotation of vars and ops in the update phase for parallel program."""
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context.serial_main_program = serial_main_program
ops = list(serial_main_program.global_block().ops) ops = list(serial_main_program.global_block().ops)
vars = serial_main_program.global_block().vars vars = serial_main_program.global_block().vars
learning_rate_completed = False learning_rate_completed = False
......
...@@ -52,7 +52,7 @@ def append_op_output_suffix(name): ...@@ -52,7 +52,7 @@ def append_op_output_suffix(name):
class TensorDistributedAttribute: class TensorDistributedAttribute:
def __init__(self): def __init__(self):
# The process mesh of distributed operator attribute must is the same as # The process mesh of distributed operator attribute must is the same as
# the process meshes of all input and output distributed attributed # the process meshes of all input and output distributed attributed
self._process_mesh = None self._process_mesh = None
self._dims_mapping = None self._dims_mapping = None
...@@ -132,12 +132,29 @@ class TensorDistributedAttribute: ...@@ -132,12 +132,29 @@ class TensorDistributedAttribute:
key, dist_attr) key, dist_attr)
self._is_annotated = copy.deepcopy(dist_attr._is_annotated) self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
# def reset(self, skip_dist_attr_field_names):
# if skip_dist_attr_field_names is not None \
# and "process_mesh" not in skip_dist_attr_field_names:
# self._process_mesh = None
# if skip_dist_attr_field_names is not None \
# and "dims_mapping" not in skip_dist_attr_field_names:
# for i in enumerate(self._dims_mapping):
# self._dims_mapping[i] = -1
# self._is_annotated = {}
def is_annotated(self, dist_attr_field_name): def is_annotated(self, dist_attr_field_name):
return self._is_annotated.get(dist_attr_field_name, False) return self._is_annotated.get(dist_attr_field_name, False)
# def mark_annotated_all(self):
# for key in get_tensor_dist_attr_field_keys():
# self.mark_annotated(key)
def mark_annotated(self, dist_attr_field_name): def mark_annotated(self, dist_attr_field_name):
self._is_annotated[dist_attr_field_name] = True self._is_annotated[dist_attr_field_name] = True
# def unmark_annotated(self, dist_attr_field_name):
# self._is_annotated[dist_attr_field_name] = False
def mark_annotated_as(self, dist_attr): def mark_annotated_as(self, dist_attr):
if dist_attr is None: if dist_attr is None:
return return
...@@ -195,7 +212,7 @@ class OperatorDistributedAttribute: ...@@ -195,7 +212,7 @@ class OperatorDistributedAttribute:
if isinstance(process_mesh, list): if isinstance(process_mesh, list):
process_mesh = ProcessMesh(process_mesh) process_mesh = ProcessMesh(process_mesh)
self._process_mesh = copy.deepcopy(process_mesh) self._process_mesh = copy.deepcopy(process_mesh)
# In while op, the proess mesh is not shared by all inputs and outputs # In while op, the proess mesh is not shared by all inputs and outputs
if self._op_type == "while": if self._op_type == "while":
return None return None
for dist_attr in self._inputs_dist_attrs.values(): for dist_attr in self._inputs_dist_attrs.values():
...@@ -357,9 +374,25 @@ class OperatorDistributedAttribute: ...@@ -357,9 +374,25 @@ class OperatorDistributedAttribute:
"ProcessMeshes in DistributedOperator must be the same." "ProcessMeshes in DistributedOperator must be the same."
self.process_mesh = shared_process_mesh self.process_mesh = shared_process_mesh
# def reset(self, skip_dist_attr_field_names):
# for tensor_dist_attr in self.inputs_dist_attrs.values():
# tensor_dist_attr.reset(skip_dist_attr_field_names)
# for tensor_dist_attr in self.outputs_dist_attrs.values():
# tensor_dist_attr.reset(skip_dist_attr_field_names)
# if skip_dist_attr_field_names is not None \
# and "process_mesh" not in skip_dist_attr_field_names:
# self.process_mesh = None
# self.impl_type = "default"
# self.impl_idx = 0
# self._is_annotated = {}
def is_annotated(self, attr_name): def is_annotated(self, attr_name):
return self._is_annotated.get(attr_name, False) return self._is_annotated.get(attr_name, False)
# def mark_annotated_all(self):
# for key in get_op_dist_attr_field_keys():
# self.mark_annotated(key)
def mark_annotated(self, attr_name): def mark_annotated(self, attr_name):
if attr_name == "process_mesh": if attr_name == "process_mesh":
# Make sure proscess_mesh be annotated consistently # Make sure proscess_mesh be annotated consistently
......
...@@ -118,11 +118,10 @@ class Engine: ...@@ -118,11 +118,10 @@ class Engine:
losses = to_list(self._loss(*(outputs + labels))) losses = to_list(self._loss(*(outputs + labels)))
default_ctx = get_default_distributed_context() default_ctx = get_default_distributed_context()
if not default_ctx.is_annotation or self._default_strategy: if not default_ctx.has_annotation or self._default_strategy:
inputs = [self._set_data_parallel(var) for var in inputs] inputs = [self._set_data_parallel(var) for var in inputs]
labels = [self._set_data_parallel(var) for var in labels] labels = [self._set_data_parallel(var) for var in labels]
# print(serial_main_prog)
self._feed_vars[mode] = {"inputs": inputs, "labels": labels} self._feed_vars[mode] = {"inputs": inputs, "labels": labels}
self._fetch_vars[mode] = { self._fetch_vars[mode] = {
......
...@@ -18,16 +18,16 @@ from ..dist_attribute import OperatorDistributedAttribute ...@@ -18,16 +18,16 @@ from ..dist_attribute import OperatorDistributedAttribute
_g_distributed_operator_impl_containers = {} _g_distributed_operator_impl_containers = {}
_g_elementwise_ops = [ _g_elementwise_ops = [
"elementwise_add", "gelu", "dropout", "cast", "gather", "concat" "elementwise", "gelu", "dropout", "cast", "gather", "concat"
] ]
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'} BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'}
def is_elementwise_op(op_type): def is_elementwise_op(op_type):
if op_type in _g_elementwise_ops: for eltwise_op in _g_elementwise_ops:
return True if eltwise_op in op_type:
else: return True
return False return False
class DistributedOperatorImplContainer: class DistributedOperatorImplContainer:
...@@ -156,7 +156,9 @@ def register_distributed_operator_impl(op_type, dist_impl): ...@@ -156,7 +156,9 @@ def register_distributed_operator_impl(op_type, dist_impl):
assert False, "Must register distributed operator registry first." assert False, "Must register distributed operator registry first."
def find_best_compatible_distributed_operator_impl(dist_op, fwd=True): def find_best_compatible_distributed_operator_impl(dist_op,
fwd=True,
partial=True):
""" """
Here just return the first compatible implemention. Here just return the first compatible implemention.
This will be improved by cost model in the future. This will be improved by cost model in the future.
...@@ -168,39 +170,55 @@ def find_best_compatible_distributed_operator_impl(dist_op, fwd=True): ...@@ -168,39 +170,55 @@ def find_best_compatible_distributed_operator_impl(dist_op, fwd=True):
dist_op_default_impl_container = get_distributed_operator_impl_container( dist_op_default_impl_container = get_distributed_operator_impl_container(
"default") "default")
compatible_impls = [] compatible_impls = []
if fwd: if partial:
# First, find impls in the corresponding container if fwd:
if dist_op_impl_container: # First, find impls in the corresponding container
compatible_impls.extend( if dist_op_impl_container:
dist_op_impl_container.get_input_compatible_impls(dist_op)) compatible_impls.extend(
# Second, find impls in the elementwise container dist_op_impl_container.get_input_compatible_impls(dist_op))
if dist_op_eltwise_impl_container and is_elementwise_op(op_type): # Second, find impls in the elementwise container
compatible_impls.extend( if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
dist_op_eltwise_impl_container.get_input_compatible_impls( compatible_impls.extend(
dist_op)) dist_op_eltwise_impl_container.get_input_compatible_impls(
# Third, find impls in the default container dist_op))
if dist_op_default_impl_container: # Third, find impls in the default container
compatible_impls.extend( if dist_op_default_impl_container:
dist_op_default_impl_container.get_input_compatible_impls( compatible_impls.extend(
dist_op)) dist_op_default_impl_container.get_input_compatible_impls(
dist_op))
else:
# First, find impls in the corresponding container
if dist_op_impl_container:
compatible_impls.extend(
dist_op_impl_container.get_output_compatible_impls(dist_op))
# Second, find impls in the elementwise container
if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
compatible_impls.extend(
dist_op_eltwise_impl_container.get_output_compatible_impls(
dist_op))
# Third, find impls in the default container
if dist_op_default_impl_container:
compatible_impls.extend(
dist_op_default_impl_container.get_output_compatible_impls(
dist_op))
else: else:
# First, find impls in the corresponding container # First, find impls in the corresponding container
if dist_op_impl_container: if dist_op_impl_container:
compatible_impls.extend( compatible_impls.extend(
dist_op_impl_container.get_output_compatible_impls(dist_op)) dist_op_impl_container.get_compatible_impls(dist_op))
# Second, find impls in the elementwise container # Second, find impls in the elementwise container
if dist_op_eltwise_impl_container and is_elementwise_op(op_type): if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
compatible_impls.extend( compatible_impls.extend(
dist_op_eltwise_impl_container.get_output_compatible_impls( dist_op_eltwise_impl_container.get_compatible_impls(dist_op))
dist_op))
# Third, find impls in the default container # Third, find impls in the default container
if dist_op_default_impl_container: if dist_op_default_impl_container:
compatible_impls.extend( compatible_impls.extend(
dist_op_default_impl_container.get_output_compatible_impls( dist_op_default_impl_container.get_compatible_impls(dist_op))
dist_op))
if compatible_impls: if compatible_impls:
# For now, just return the first compatible impl # For now, just return the first compatible impl
best_compatible_impl = compatible_impls[0] # best_compatible_impl = compatible_impls[0]
best_compatible_impl = compatible_impls
else: else:
best_compatible_impl = None best_compatible_impl = None
return best_compatible_impl return best_compatible_impl
......
...@@ -53,6 +53,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -53,6 +53,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
batch_dim_mappings = []
input_names = op_desc.input_names() input_names = op_desc.input_names()
xshape_arg_names = [] xshape_arg_names = []
if "XShape" in input_names: if "XShape" in input_names:
...@@ -64,14 +65,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -64,14 +65,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping: for mapping in dims_mapping:
if mapping != -1: if mapping != -1:
return False return False
# continue continue
# if len(dims_mapping) < 1:
# continue
if arg_name not in xshape_arg_names: if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1: if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]: for mapping in dims_mapping[1:]:
if mapping != -1: if mapping != -1:
return False return False
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
else: else:
if dims_mapping[0] != -1: if dims_mapping[0] != -1:
return False return False
...@@ -79,12 +80,19 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -79,12 +80,19 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping[2:]: for mapping in dims_mapping[2:]:
if mapping != -1: if mapping != -1:
return False return False
if len(dims_mapping) >= 2:
batch_dim_mappings.append(dims_mapping[1])
if compute_compatible_dim_mapping(batch_dim_mappings) is None:
return False
return True return True
def is_output_compatible(self, dist_op): def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
output_names = op_desc.output_names() output_names = op_desc.output_names()
batch_dim_mappings = []
xshape_arg_names = [] xshape_arg_names = []
if "XShape" in output_names: if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape") xshape_arg_names = op_desc.output("XShape")
...@@ -95,14 +103,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -95,14 +103,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping: for mapping in dims_mapping:
if mapping != -1: if mapping != -1:
return False return False
# continue continue
# if len(dims_mapping) < 1:
# continue
if arg_name not in xshape_arg_names: if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1: if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]: for mapping in dims_mapping[1:]:
if mapping != -1: if mapping != -1:
return False return False
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
else: else:
if dims_mapping[0] != -1: if dims_mapping[0] != -1:
return False return False
...@@ -110,6 +118,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -110,6 +118,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping[2:]: for mapping in dims_mapping[2:]:
if mapping != -1: if mapping != -1:
return False return False
if len(dims_mapping) >= 2:
batch_dim_mappings.append(dims_mapping[1])
if compute_compatible_dim_mapping(batch_dim_mappings) is None:
return False
return True return True
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
...@@ -123,9 +137,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -123,9 +137,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
xshape_arg_names = op_desc.input("XShape") xshape_arg_names = op_desc.input("XShape")
for arg_name in op_desc.input_arg_names(): for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name) serial_tensor = dist_op.get_serial_input(arg_name)
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if serial_tensor.is_parameter: if serial_tensor.is_parameter:
for mapping in dims_mapping:
if mapping != -1:
return False
continue continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if arg_name not in xshape_arg_names: if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1: if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]: for mapping in dims_mapping[1:]:
...@@ -150,9 +167,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -150,9 +167,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
xshape_arg_names = op_desc.output("XShape") xshape_arg_names = op_desc.output("XShape")
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name) serial_tensor = dist_op.get_serial_output(arg_name)
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if serial_tensor.is_parameter: if serial_tensor.is_parameter:
for mapping in dims_mapping:
if mapping != -1:
return False
continue continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names: if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1: if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]: for mapping in dims_mapping[1:]:
...@@ -229,7 +249,9 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -229,7 +249,9 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
compatible_dim_mapping = compute_compatible_dim_mapping( compatible_dim_mapping = compute_compatible_dim_mapping(
batch_dim_mappings) batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping." if compatible_dim_mapping is None:
return False
for arg_name in op_desc.input_arg_names(): for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name) serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter: if serial_tensor.is_parameter:
......
...@@ -52,21 +52,46 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): ...@@ -52,21 +52,46 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
if is_elementwise_op(op_desc.type()): if not is_elementwise_op(op_desc.type()):
return True
else:
return False return False
op_dist_attr = dist_op.dist_attr
dims_mapping_list = []
input_arg_names = op_desc.input_arg_names()
max_dims_mapping_len = -1
for arg_name in input_arg_names:
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if max_dims_mapping_len < len(dims_mapping):
max_dims_mapping_len = len(dims_mapping)
dims_mapping_list.append(dims_mapping)
for idx in range(max_dims_mapping_len):
dim_mappings = []
for dims_mapping in dims_mapping_list:
if idx < len(dims_mapping):
dim_mappings.append(dims_mapping[-(idx + 1)])
if compute_compatible_dim_mapping(dim_mappings) is None:
return False
return True
def is_output_compatible(self, dist_op): def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_desc = dist_op.serial_op.desc if not is_elementwise_op(op_desc.type()):
if is_elementwise_op(op_desc.type()): return False
return True op_dist_attr = dist_op.dist_attr
else: dims_mapping_list = []
output_arg_names = op_desc.output_arg_names()
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
dims_mapping_list.append(dims_mapping)
if compute_compatible_dims_mapping(dims_mapping_list) is None:
return False return False
return True
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
if not is_elementwise_op(op_desc.type()):
return False
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
dims_mapping_list = [] dims_mapping_list = []
input_arg_names = op_desc.input_arg_names() input_arg_names = op_desc.input_arg_names()
...@@ -127,7 +152,8 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): ...@@ -127,7 +152,8 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
compatible_dims_mapping = compute_compatible_dims_mapping( compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list) dims_mapping_list)
assert compatible_dims_mapping is not None, "There is no compatible dim mapping." if compatible_dims_mapping is None:
return False
for arg_name in input_arg_names: for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
......
...@@ -95,7 +95,8 @@ def _update_dims_mapping_for_matmul(dist_op): ...@@ -95,7 +95,8 @@ def _update_dims_mapping_for_matmul(dist_op):
broadcast_x_dims_mapping, broadcast_y_dims_mapping, broadcast_x_dims_mapping, broadcast_y_dims_mapping,
broadcast_out_dims_mapping broadcast_out_dims_mapping
]) ])
assert compatible_dims_mapping is not None, "There is no compatible dim mapping." if compatible_dims_mapping is None:
return False
for i in range(x_dims_mapping_len - 2): for i in range(x_dims_mapping_len - 2):
new_idx = i + (out_dims_mapping_len - x_dims_mapping_len) new_idx = i + (out_dims_mapping_len - x_dims_mapping_len)
......
...@@ -117,7 +117,8 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -117,7 +117,8 @@ class DistributedPNormImpl(DistributedOperatorImpl):
compatible_dim_mapping = compute_compatible_dim_mapping( compatible_dim_mapping = compute_compatible_dim_mapping(
batch_dim_mappings) batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping." if compatible_dim_mapping is None:
return False
for arg_name in op_desc.input_arg_names(): for arg_name in op_desc.input_arg_names():
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -17,6 +17,7 @@ from .common import DistributedOperatorImpl ...@@ -17,6 +17,7 @@ from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from .dist_default import DistributedDefaultImpl0 from .dist_default import DistributedDefaultImpl0
...@@ -47,6 +48,29 @@ class DistributedSliceImpl(DistributedOperatorImpl): ...@@ -47,6 +48,29 @@ class DistributedSliceImpl(DistributedOperatorImpl):
return True return True
def is_output_compatible(self, dist_op): def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
in_name = op_desc.input('Input')[0]
out_name = op_desc.output('Out')[0]
axes = op_desc.attr('axes')
decrease_axis = op_desc.attr('decrease_axis')
in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
ref_indices = []
for i in range(len(in_dims_mapping)):
if i not in decrease_axis:
ref_indices.append(i)
if ref_indices == []:
assert len(out_dims_mapping) == 1
if is_dim_shard(out_dims_mapping[0]):
return False
else:
for i in range(len(out_dims_mapping)):
ref_index = ref_indices[i]
if ref_index in axes and is_dim_shard(out_dims_mapping[i]):
return False
return True return True
def is_compatible(self, dist_op): def is_compatible(self, dist_op):
...@@ -95,17 +119,30 @@ class DistributedSliceImpl(DistributedOperatorImpl): ...@@ -95,17 +119,30 @@ class DistributedSliceImpl(DistributedOperatorImpl):
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
ref_dims_mapping = [] ref_dims_mapping = []
ref_indices = []
for i in range(len(in_dims_mapping)): for i in range(len(in_dims_mapping)):
if i not in decrease_axis: if i not in decrease_axis:
ref_dims_mapping.append(in_dims_mapping[i]) ref_dims_mapping.append(in_dims_mapping[i])
ref_indices.append(i)
if ref_dims_mapping == []: if ref_dims_mapping == []:
ref_dims_mapping = [-1] ref_dims_mapping = [-1]
assert len(ref_dims_mapping) == len(out_dims_mapping)
assert len(ref_dims_mapping) == len(out_dims_mapping) assert ref_dims_mapping[0] == out_dims_mapping[0]
for i in range(len(out_dims_mapping)): changed = False
if out_dims_mapping[i] != ref_dims_mapping[i]: else:
out_dims_mapping[i] = ref_dims_mapping[i] assert len(ref_dims_mapping) == len(out_dims_mapping)
changed = True for i in range(len(out_dims_mapping)):
compatible_dim_mapping = compute_compatible_dim_mapping(
[out_dims_mapping[i], ref_dims_mapping[i]])
if compatible_dim_mapping is None:
continue
if ref_dims_mapping[i] != compatible_dim_mapping:
in_dims_mapping[ref_indices[i]] = compatible_dim_mapping
changed = True
if out_dims_mapping[i] != compatible_dim_mapping:
out_dims_mapping[i] = compatible_dim_mapping
changed = True
return changed return changed
......
...@@ -230,7 +230,7 @@ class AutoParallelizer: ...@@ -230,7 +230,7 @@ class AutoParallelizer:
g_process_group_map = copy.deepcopy(_g_process_group_map) g_process_group_map = copy.deepcopy(_g_process_group_map)
_g_process_group_map.clear() _g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, []) _g_process_group_map[0] = ProcessGroup(0, [])
for process_mesh in dist_context._process_meshes: for process_mesh in self._dist_context._process_meshes:
_g_process_group_map[0].add_ranks(process_mesh.processes) _g_process_group_map[0].add_ranks(process_mesh.processes)
return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map
......
...@@ -138,7 +138,6 @@ class MetricRecords(object): ...@@ -138,7 +138,6 @@ class MetricRecords(object):
def from_state(cls, state): def from_state(cls, state):
records = cls(state["direction"]) records = cls(state["direction"])
records.records = [MetricRecord.from_state(r) for r in state["records"]] records.records = [MetricRecord.from_state(r) for r in state["records"]]
print("here 1", records.records)
return records return records
......
...@@ -159,11 +159,11 @@ def print_program_with_dist_attr(program, dist_context=None): ...@@ -159,11 +159,11 @@ def print_program_with_dist_attr(program, dist_context=None):
from .dist_context import set_default_distributed_context from .dist_context import set_default_distributed_context
if dist_context is None: if dist_context is None:
dist_context = get_default_distributed_context() dist_context = get_default_distributed_context()
print(program) print(program, flush=True)
else: else:
original_default_context = get_default_distributed_context() original_default_context = get_default_distributed_context()
set_default_distributed_context(dist_context) set_default_distributed_context(dist_context)
print(program) print(program, flush=True)
set_default_distributed_context(original_default_context) set_default_distributed_context(original_default_context)
lock.release() lock.release()
......
...@@ -350,11 +350,12 @@ class RecomputePass(PassBase): ...@@ -350,11 +350,12 @@ class RecomputePass(PassBase):
for _, op_desc in reversed(list(enumerate(segment_descs))): for _, op_desc in reversed(list(enumerate(segment_descs))):
rc_desc = main_block.desc._insert_op(idx) rc_desc = main_block.desc._insert_op(idx)
rc_desc.copy_from(op_desc) rc_desc.copy_from(op_desc)
rc_desc.set_original_id(rc_desc.id())
rc_op = Operator(main_block, rc_desc) rc_op = Operator(main_block, rc_desc)
main_block.ops.insert(idx, rc_op) main_block.ops.insert(idx, rc_op)
# set recomputed ops' dist attr # set recomputed ops' dist attr
fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id( fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id(
rc_desc.original_id()) op_desc.original_id())
assert fwd_op_dist_attr is not None assert fwd_op_dist_attr is not None
self.set_op_dist_attr(rc_op, fwd_op_dist_attr, self.set_op_dist_attr(rc_op, fwd_op_dist_attr,
var_name_dict) var_name_dict)
......
...@@ -3,18 +3,23 @@ ...@@ -3,18 +3,23 @@
if(WITH_DISTRIBUTE AND WITH_GPU) if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_auto_parallel_relaunch MODULES test_auto_parallel_relaunch ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_relaunch MODULES test_auto_parallel_relaunch ENVS ${dist_ENVS})
set_tests_properties(test_auto_parallel_relaunch PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) set_tests_properties(test_auto_parallel_relaunch PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120)
py_test_modules(test_relaunch_with_planner MODULES test_relaunch_with_planner ENVS ${dist_ENVS}) py_test_modules(test_relaunch_with_planner MODULES test_relaunch_with_planner ENVS ${dist_ENVS})
set_tests_properties(test_relaunch_with_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) set_tests_properties(test_relaunch_with_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120)
py_test_modules(test_relaunch_with_gpt_planner MODULES test_relaunch_with_gpt_planner ENVS ${dist_ENVS}) py_test_modules(test_relaunch_with_gpt_planner MODULES test_relaunch_with_gpt_planner ENVS ${dist_ENVS})
set_tests_properties(test_relaunch_with_gpt_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 240) set_tests_properties(test_relaunch_with_gpt_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 240)
py_test_modules(test_engine_api MODULES test_engine_api ENVS ${dist_ENVS}) py_test_modules(test_engine_api MODULES test_engine_api ENVS ${dist_ENVS})
set_tests_properties(test_engine_api PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80) set_tests_properties(test_engine_api PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80)
py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS})
py_test_modules(test_converter MODULES test_converter ENVS ${dist_ENVS}) py_test_modules(test_converter MODULES test_converter ENVS ${dist_ENVS})
set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_high_order_grad MODULES test_high_order_grad ENVS ${dist_ENVS}) py_test_modules(test_high_order_grad MODULES test_high_order_grad ENVS ${dist_ENVS})
set_tests_properties(test_high_order_grad PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) set_tests_properties(test_high_order_grad PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS})
py_test_modules(test_while_op_partition MODULES test_while_op_partition ENVS ${dist_ENVS})
py_test_modules(test_tunable_variable MODULES test_tunable_variable ENVS ${dist_ENVS}) py_test_modules(test_tunable_variable MODULES test_tunable_variable ENVS ${dist_ENVS})
py_test_modules(test_tunable_space MODULES test_tunable_space ENVS ${dist_ENVS}) py_test_modules(test_tunable_space MODULES test_tunable_space ENVS ${dist_ENVS})
py_test_modules(test_recorder MODULES test_recorder ENVS ${dist_ENVS}) py_test_modules(test_recorder MODULES test_recorder ENVS ${dist_ENVS})
......
...@@ -66,7 +66,6 @@ class TestDistReshape(unittest.TestCase): ...@@ -66,7 +66,6 @@ class TestDistReshape(unittest.TestCase):
for rank in range(2): for rank in range(2):
dist_main_prog, dist_context = parallelizer(make_program_dp2, rank) dist_main_prog, dist_context = parallelizer(make_program_dp2, rank)
ops = dist_main_prog.global_block().ops ops = dist_main_prog.global_block().ops
print_program_with_dist_attr(dist_main_prog, dist_context)
for idx, op in enumerate(ops): for idx, op in enumerate(ops):
op_dist_attr = dist_context.get_op_dist_attr_for_program(op) op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_type == "reshape2" assert op_dist_attr.impl_type == "reshape2"
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import paddle import paddle
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
...@@ -85,14 +86,9 @@ class TestDistSlice(unittest.TestCase): ...@@ -85,14 +86,9 @@ class TestDistSlice(unittest.TestCase):
for op in ops: for op in ops:
axes = op.desc.attr('axes') axes = op.desc.attr('axes')
op_dist_attr = dist_context.get_op_dist_attr_for_program(op) op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if axes[0] == 0: assert op_dist_attr.impl_type == "slice"
assert op_dist_attr.impl_type == "default" for out in op.output_arg_names:
else: var_dims_mapping = op_dist_attr.get_output_dims_mapping(out)
assert op_dist_attr.impl_type == "slice"
for out in op.output_arg_names:
var_dims_mapping = op_dist_attr.get_output_dims_mapping(
out)
assert var_dims_mapping[0] == 0
def test_dist_slice_serial(self): def test_dist_slice_serial(self):
dist_main_prog, dist_context = parallelizer(make_program_serial, 0) dist_main_prog, dist_context = parallelizer(make_program_serial, 0)
......
...@@ -23,12 +23,13 @@ import paddle.nn.functional as F ...@@ -23,12 +23,13 @@ import paddle.nn.functional as F
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import make_data_unshard from paddle.distributed.auto_parallel.utils import make_data_unshard
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context
from paddle.distributed.auto_parallel.operators import find_best_compatible_distributed_operator_impl from paddle.distributed.auto_parallel.operators import find_best_compatible_distributed_operator_impl
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
...@@ -283,139 +284,143 @@ def get_program(): ...@@ -283,139 +284,143 @@ def get_program():
def completion(train_program, start_program, dist_context): def completion(train_program, start_program, dist_context):
blocks = train_program.blocks # blocks = train_program.blocks
# completion tensors # # completion tensors
for block in blocks: # for block in blocks:
for op in block.ops: # for op in block.ops:
if op.type == "layer_norm": # if op.type == "layer_norm":
for out_name in op.output_arg_names: # for out_name in op.output_arg_names:
out_var = block.vars[out_name] # out_var = block.vars[out_name]
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( # tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var) # out_var)
if tensor_dist_attr: # if tensor_dist_attr:
continue # continue
tensor_dist_attr = TensorDistributedAttribute() # tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = _g_process_mesh # tensor_dist_attr.process_mesh = _g_process_mesh
tensor_dist_attr.dims_mapping = [-1] # tensor_dist_attr.dims_mapping = [-1]
dist_context.set_tensor_dist_attr_for_program( # dist_context.set_tensor_dist_attr_for_program(
out_var, tensor_dist_attr) # out_var, tensor_dist_attr)
elif op.type == "elementwise_sub": # elif op.type == "elementwise_sub":
for out_name in op.output_arg_names: # for out_name in op.output_arg_names:
out_var = block.vars[out_name] # out_var = block.vars[out_name]
tensor_dist_attr = TensorDistributedAttribute() # tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = _g_process_mesh # tensor_dist_attr.process_mesh = _g_process_mesh
tensor_dist_attr.dims_mapping = [-1, -1, -1] # tensor_dist_attr.dims_mapping = [-1, -1, -1]
dist_context.set_tensor_dist_attr_for_program( # dist_context.set_tensor_dist_attr_for_program(
out_var, tensor_dist_attr) # out_var, tensor_dist_attr)
elif op.type == "matmul_v2": # elif op.type == "matmul_v2":
col = False # col = False
for in_name in op.input_arg_names: # for in_name in op.input_arg_names:
if ".w_" not in in_name: # if ".w_" not in in_name:
continue # continue
if in_name not in block.vars: # if in_name not in block.vars:
in_var = blocks[0].vars[in_name] # in_var = blocks[0].vars[in_name]
else: # else:
in_var = block.vars[in_name] # in_var = block.vars[in_name]
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( # tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
in_var) # in_var)
assert tensor_dist_attr is not None # assert tensor_dist_attr is not None
if tensor_dist_attr.dims_mapping == [-1, 0]: # if tensor_dist_attr.dims_mapping == [-1, 0]:
col = True # col = True
for out_name in op.output_arg_names: # for out_name in op.output_arg_names:
out_var = block.vars[out_name] # out_var = block.vars[out_name]
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( # tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var) # out_var)
if tensor_dist_attr: # if tensor_dist_attr:
continue # continue
tensor_dist_attr = TensorDistributedAttribute() # tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = _g_process_mesh # tensor_dist_attr.process_mesh = _g_process_mesh
if col: # if col:
tensor_dist_attr.dims_mapping = [-1, -1, 0] # tensor_dist_attr.dims_mapping = [-1, -1, 0]
else: # else:
tensor_dist_attr.dims_mapping = [-1, -1, -1] # tensor_dist_attr.dims_mapping = [-1, -1, -1]
dist_context.set_tensor_dist_attr_for_program( # dist_context.set_tensor_dist_attr_for_program(
out_var, tensor_dist_attr) # out_var, tensor_dist_attr)
elif op.type == "while": # elif op.type == "while":
out_name = op.desc.output("StepScopes")[0] # out_name = op.desc.output("StepScopes")[0]
out_var = block.vars[out_name] # out_var = block.vars[out_name]
tensor_dist_attr = TensorDistributedAttribute() # tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = _g_process_mesh # tensor_dist_attr.process_mesh = _g_process_mesh
tensor_dist_attr.dims_mapping = [-1] # tensor_dist_attr.dims_mapping = [-1]
dist_context.set_tensor_dist_attr_for_program(out_var, # dist_context.set_tensor_dist_attr_for_program(out_var,
tensor_dist_attr) # tensor_dist_attr)
# completion ops # # completion ops
for block in blocks: # for block in blocks:
for op in block.ops: # for op in block.ops:
op_dist_attr = OperatorDistributedAttribute() # op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = _g_process_mesh # op_dist_attr.process_mesh = _g_process_mesh
if op.type == "create_by_read" or op.type == "create_double_buffer_reader": # if op.type == "create_by_read" or op.type == "create_double_buffer_reader":
for in_name in op.input_arg_names: # for in_name in op.input_arg_names:
op_dist_attr.set_input_dims_mapping(in_name, []) # op_dist_attr.set_input_dims_mapping(in_name, [])
for out_name in op.output_arg_names: # for out_name in op.output_arg_names:
op_dist_attr.set_output_dims_mapping(out_name, []) # op_dist_attr.set_output_dims_mapping(out_name, [])
elif op.type == "read": # elif op.type == "read":
for in_name in op.input_arg_names: # for in_name in op.input_arg_names:
op_dist_attr.set_output_dims_mapping(in_name, []) # op_dist_attr.set_output_dims_mapping(in_name, [])
for out_name in op.output_arg_names: # for out_name in op.output_arg_names:
out_var = block.vars[out_name] # out_var = block.vars[out_name]
out_dist_attr = dist_context.get_tensor_dist_attr_for_program( # out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var) # out_var)
op_dist_attr.set_output_dist_attr(out_name, out_dist_attr) # op_dist_attr.set_output_dist_attr(out_name, out_dist_attr)
elif op.type == "while": # elif op.type == "while":
for in_name in op.input_arg_names: # for in_name in op.input_arg_names:
in_var = block.vars[in_name] # in_var = block.vars[in_name]
in_dist_attr = dist_context.get_tensor_dist_attr_for_program( # in_dist_attr = dist_context.get_tensor_dist_attr_for_program(
in_var) # in_var)
op_dist_attr.set_input_dist_attr(in_name, in_dist_attr) # op_dist_attr.set_input_dist_attr(in_name, in_dist_attr)
for out_name in op.output_arg_names: # for out_name in op.output_arg_names:
if out_name == op.desc.output("StepScopes")[0]: # if out_name == op.desc.output("StepScopes")[0]:
op_dist_attr.set_output_dims_mapping(out_name, []) # op_dist_attr.set_output_dims_mapping(out_name, [])
else: # else:
out_var = block.vars[out_name] # out_var = block.vars[out_name]
out_dist_attr = dist_context.get_tensor_dist_attr_for_program( # out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var) # out_var)
op_dist_attr.set_output_dist_attr(out_name, # op_dist_attr.set_output_dist_attr(out_name,
out_dist_attr) # out_dist_attr)
else: # else:
for in_name in op.input_arg_names: # for in_name in op.input_arg_names:
if in_name == "lod_tensor_blocking_queue_0": # if in_name == "lod_tensor_blocking_queue_0":
continue # continue
if in_name not in block.vars: # if in_name not in block.vars:
in_var = blocks[0].vars[in_name] # in_var = blocks[0].vars[in_name]
else: # else:
in_var = block.vars[in_name] # in_var = block.vars[in_name]
in_dist_attr = dist_context.get_tensor_dist_attr_for_program( # in_dist_attr = dist_context.get_tensor_dist_attr_for_program(
in_var) # in_var)
op_dist_attr.set_input_dist_attr(in_name, in_dist_attr) # op_dist_attr.set_input_dist_attr(in_name, in_dist_attr)
for out_name in op.output_arg_names: # for out_name in op.output_arg_names:
if out_name not in block.vars: # if out_name not in block.vars:
out_var = blocks[0].vars[out_name] # out_var = blocks[0].vars[out_name]
else: # else:
out_var = block.vars[out_name] # out_var = block.vars[out_name]
out_dist_attr = dist_context.get_tensor_dist_attr_for_program( # out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var) # out_var)
op_dist_attr.set_output_dist_attr(out_name, out_dist_attr) # op_dist_attr.set_output_dist_attr(out_name, out_dist_attr)
if op.type == "matmul_v2": # if op.type == "matmul_v2":
op_dist_attr.impl_type = "matmul_v2" # op_dist_attr.impl_type = "matmul_v2"
for in_name in op_dist_attr.inputs_dist_attrs.keys(): # for in_name in op_dist_attr.inputs_dist_attrs.keys():
in_dist_attr = op_dist_attr.inputs_dist_attrs[in_name] # in_dist_attr = op_dist_attr.inputs_dist_attrs[in_name]
if ".w_" in in_name and in_dist_attr.dims_mapping[-1] == 0: # if ".w_" in in_name and in_dist_attr.dims_mapping[-1] == 0:
op_dist_attr.impl_idx = 0 # op_dist_attr.impl_idx = 0
else: # else:
op_dist_attr.impl_idx = 1 # op_dist_attr.impl_idx = 1
elif op.type == "fill_constant_batch_size_like": # elif op.type == "fill_constant_batch_size_like":
op_dist_attr.impl_type = "fill_constant_batch_size_like" # op_dist_attr.impl_type = "fill_constant_batch_size_like"
op_dist_attr.impl_idx = 0 # op_dist_attr.impl_idx = 0
else: # else:
op_dist_attr.impl_type = "default" # op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = 0 # op_dist_attr.impl_idx = 0
dist_context.set_op_dist_attr_for_program(op, op_dist_attr) # dist_context.set_op_dist_attr_for_program(op, op_dist_attr)
make_data_unshard(train_program, start_program, dist_context) # make_data_unshard(train_program, start_program, dist_context)
completer = Completer(dist_context)
train_program = completer.complete_forward_annotation(train_program)
make_data_unshard(train_program, start_program, dist_context)
return train_program, start_program return train_program, start_program
......
...@@ -134,7 +134,6 @@ class TestMLPAutoParallelizer(unittest.TestCase): ...@@ -134,7 +134,6 @@ class TestMLPAutoParallelizer(unittest.TestCase):
for op in block.ops: for op in block.ops:
for attr_name in op.attr_names: for attr_name in op.attr_names:
self.assertTrue(suffix not in attr_name) self.assertTrue(suffix not in attr_name)
# print_program_with_dist_attr(distributed_main_program)
self.assertIsNotNone(distributed_startup_program) self.assertIsNotNone(distributed_startup_program)
self.assertIsNotNone(distributed_main_program) self.assertIsNotNone(distributed_main_program)
......
...@@ -332,7 +332,6 @@ class TestMLPReshard(unittest.TestCase): ...@@ -332,7 +332,6 @@ class TestMLPReshard(unittest.TestCase):
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads) dist_context, dist_params_grads)
resharder.reshard() resharder.reshard()
print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册