未验证 提交 bed9aaea 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Improve the codes of the completion and distributed context (#40671)

* [Auto Parallel] Replace the old planner by the new partition tuner

* [Auto Parallel] Improve the completion and distributed context

* [Auto Parallel] Fix some bugs of the compatible check of some dist ops

* [Auto Parallel] Fix some bugs
上级 afcf6bd0
...@@ -123,6 +123,19 @@ def merge_process_mesh_two(pm1, pm2): ...@@ -123,6 +123,19 @@ def merge_process_mesh_two(pm1, pm2):
return merged_process_mesh return merged_process_mesh
def _validate_dims_mapping(dims_mapping, process_mesh):
if dims_mapping is None:
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
process_mesh.topology):
return False
for i in range(len(process_mesh.topology)):
if dims_mapping.count(i) > 1:
return False
return True
class Completer: class Completer:
def __init__(self, dist_context): def __init__(self, dist_context):
assert dist_context is not None assert dist_context is not None
...@@ -161,6 +174,9 @@ class Completer: ...@@ -161,6 +174,9 @@ class Completer:
dims_mapping_list.append(tensor_dims_mapping) dims_mapping_list.append(tensor_dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping( compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list) dims_mapping_list)
if not _validate_dims_mapping(compatible_dims_mapping,
tensor_dist_attr.process_mesh):
return False
if (compatible_dims_mapping is not None) and \ if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != tensor_dims_mapping): (compatible_dims_mapping != tensor_dims_mapping):
tensor_dist_attr.dims_mapping = compatible_dims_mapping tensor_dist_attr.dims_mapping = compatible_dims_mapping
...@@ -182,6 +198,9 @@ class Completer: ...@@ -182,6 +198,9 @@ class Completer:
dims_mapping_list.append(tensor_dims_mapping) dims_mapping_list.append(tensor_dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping( compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list) dims_mapping_list)
if not _validate_dims_mapping(compatible_dims_mapping,
tensor_dist_attr.process_mesh):
return False
if (compatible_dims_mapping is not None) and \ if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != tensor_dims_mapping): (compatible_dims_mapping != tensor_dims_mapping):
tensor_dist_attr.dims_mapping = compatible_dims_mapping tensor_dist_attr.dims_mapping = compatible_dims_mapping
...@@ -196,10 +215,12 @@ class Completer: ...@@ -196,10 +215,12 @@ class Completer:
op_desc = op_node.op() op_desc = op_node.op()
if op_desc.type() == "create_py_reader" \ if op_desc.type() == "create_py_reader" \
or op_desc.type() == "create_double_buffer_reader" \ or op_desc.type() == "create_double_buffer_reader" \
or op_desc.type() == "while" \
or op_desc.type() == "read": or op_desc.type() == "read":
return False return False
dist_op = self._dist_context.get_dist_op_for_graph(op_node) dist_op = self._dist_context.get_dist_op_for_graph(op_node)
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
original_op_dist_attr = copy.deepcopy(op_dist_attr)
if fwd: if fwd:
for tensor_node in op_node.inputs: for tensor_node in op_node.inputs:
if tensor_node.is_var() and tensor_node.var() is not None: if tensor_node.is_var() and tensor_node.var() is not None:
...@@ -223,9 +244,13 @@ class Completer: ...@@ -223,9 +244,13 @@ class Completer:
tensor_desc.name(), compatible_dims_mapping) tensor_desc.name(), compatible_dims_mapping)
changed = True changed = True
# Find the most compatible implemenetations from the distributed operator # Find the most compatible implemenetations from the distributed operator
op_dist_impl = find_best_compatible_distributed_operator_impl( op_dist_impls = find_best_compatible_distributed_operator_impl(
dist_op, fwd=True) dist_op, fwd=True)
if op_dist_impl is not None: if op_dist_impls is not None:
not_compatible = True
backup_op_dist_attr = copy.deepcopy(op_dist_attr)
backup_changed = changed
for op_dist_impl in op_dist_impls:
dim_changed = op_dist_impl.update_dims_mapping(dist_op) dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed: if dim_changed:
changed = True changed = True
...@@ -234,7 +259,19 @@ class Completer: ...@@ -234,7 +259,19 @@ class Completer:
op_dist_attr.impl_type = "default" op_dist_attr.impl_type = "default"
else: else:
op_dist_attr.impl_type = op_dist_impl.type op_dist_attr.impl_type = op_dist_impl.type
# op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx op_dist_attr.impl_idx = op_dist_impl.idx
not_compatible = False
break
else:
dist_op.dist_attr = backup_op_dist_attr
changed = backup_changed
if not_compatible:
dist_op.dist_attr = original_op_dist_attr
changed = False
else:
dist_op.dist_attr = original_op_dist_attr
changed = False
else: else:
for tensor_node in op_node.outputs: for tensor_node in op_node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None: if tensor_node.is_var() and tensor_node.var() is not None:
...@@ -258,18 +295,35 @@ class Completer: ...@@ -258,18 +295,35 @@ class Completer:
tensor_desc.name(), compatible_dims_mapping) tensor_desc.name(), compatible_dims_mapping)
changed = True changed = True
# Find the most compatible implemenetations from the distributed operator # Find the most compatible implemenetations from the distributed operator
op_dist_impl = find_best_compatible_distributed_operator_impl( op_dist_impls = find_best_compatible_distributed_operator_impl(
dist_op, fwd=False) dist_op, fwd=False)
if op_dist_impl is not None: if op_dist_impls is not None:
not_compatible = True
backup_op_dist_attr = copy.deepcopy(op_dist_attr)
backup_changed = changed
for op_dist_impl in op_dist_impls:
dim_changed = op_dist_impl.update_dims_mapping(dist_op) dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed: if dim_changed:
changed = True changed = True
if op_dist_impl.is_auto_compatible(dist_op): if op_dist_impl.is_auto_compatible(dist_op):
not_compatible = False
if op_dist_impl.type == "elementwise": if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default" op_dist_attr.impl_type = "default"
else: else:
op_dist_attr.impl_type = op_dist_impl.type op_dist_attr.impl_type = op_dist_impl.type
# op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx op_dist_attr.impl_idx = op_dist_impl.idx
not_compatible = False
break
else:
dist_op.dist_attr = backup_op_dist_attr
changed = backup_changed
if not_compatible:
dist_op.dist_attr = original_op_dist_attr
changed = False
else:
dist_op.dist_attr = original_op_dist_attr
changed = False
return changed return changed
def _update_dims_mapping_between_graphs(self): def _update_dims_mapping_between_graphs(self):
...@@ -279,17 +333,22 @@ class Completer: ...@@ -279,17 +333,22 @@ class Completer:
parent_node) parent_node)
child_node_dist_attr = self._dist_context.get_dist_attr_for_graph( child_node_dist_attr = self._dist_context.get_dist_attr_for_graph(
child_node) child_node)
if parent_node_dist_attr.process_mesh != child_node_dist_attr.process_mesh:
continue
parent_node_dims_mapping = parent_node_dist_attr.dims_mapping parent_node_dims_mapping = parent_node_dist_attr.dims_mapping
child_node_dims_mapping = child_node_dist_attr.dims_mapping child_node_dims_mapping = child_node_dist_attr.dims_mapping
compatible_dims_mapping = compute_compatible_dims_mapping( compatible_dims_mapping = compute_compatible_dims_mapping(
[parent_node_dims_mapping, child_node_dims_mapping]) [parent_node_dims_mapping, child_node_dims_mapping])
if not _validate_dims_mapping(compatible_dims_mapping,
parent_node_dist_attr.process_mesh):
return False
if (compatible_dims_mapping is not None) \ if (compatible_dims_mapping is not None) \
and (compatible_dims_mapping != parent_node_dims_mapping): and (compatible_dims_mapping != parent_node_dims_mapping):
parent_node_dist_attr.dims_mapping = compatible_dims_mapping parent_node_dist_attr.dims_mapping = compatible_dims_mapping
changed = True changed = True
if (compatible_dims_mapping is not None) \ if (compatible_dims_mapping is not None) \
and (compatible_dims_mapping != child_node_dims_mapping): and (compatible_dims_mapping != child_node_dims_mapping):
parent_node_dist_attr.dims_mapping = compatible_dims_mapping child_node_dist_attr.dims_mapping = compatible_dims_mapping
changed = True changed = True
return changed return changed
...@@ -389,7 +448,8 @@ class Completer: ...@@ -389,7 +448,8 @@ class Completer:
if _node_id(cur) in visited: if _node_id(cur) in visited:
continue continue
# TODO: need more restrictions # TODO: need more restrictions
for node in cur.inputs: neighbors = cur.inputs + cur.outputs
for node in neighbors:
if node.is_var() and node.var() is not None: if node.is_var() and node.var() is not None:
if node.var().type() != core.VarDesc.VarType.READER \ if node.var().type() != core.VarDesc.VarType.READER \
and len(node.var().shape()) == 1: and len(node.var().shape()) == 1:
...@@ -421,10 +481,29 @@ class Completer: ...@@ -421,10 +481,29 @@ class Completer:
visited.add(_node_id(cur)) visited.add(_node_id(cur))
return related_nodes return related_nodes
def _make_dims_mapping_replicate(dist_attr):
if isinstance(dist_attr, TensorDistributedAttribute):
for i, _ in enumerate(dist_attr.dims_mapping):
dist_attr.dims_mapping[i] = -1
if isinstance(dist_attr, OperatorDistributedAttribute):
for arg_name in dist_attr.inputs_dist_attrs.keys():
new_dims_mapping = []
dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
for _ in dims_mapping:
new_dims_mapping.append(-1)
dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping)
for arg_name in dist_attr.outputs_dist_attrs.keys():
new_dims_mapping = []
dims_mapping = dist_attr.get_output_dims_mapping(arg_name)
for _ in dims_mapping:
new_dims_mapping.append(-1)
dist_attr.set_output_dims_mapping(arg_name,
new_dims_mapping)
# Amend the process meshes related to while_op # Amend the process meshes related to while_op
for while_op_node, while_op_node_idx in self._while_op_nodes.values(): for while_op_node, while_op_node_idx in self._while_op_nodes.values():
sub_graph_id = while_op_node.op()._block_attr_id("sub_block") sub_graph_id = while_op_node.op()._block_attr_id("sub_block")
sub_graph = self._dist_context._serial_graph.get_sub_graph( sub_graph = self._dist_context.serial_graph.get_sub_graph(
sub_graph_id) sub_graph_id)
sub_graph_nodes = list(sub_graph.all_nodes()) sub_graph_nodes = list(sub_graph.all_nodes())
while_dist_op = self._dist_context.get_dist_op_for_graph( while_dist_op = self._dist_context.get_dist_op_for_graph(
...@@ -440,6 +519,7 @@ class Completer: ...@@ -440,6 +519,7 @@ class Completer:
merged_process_mesh = merge_process_mesh_two( merged_process_mesh = merge_process_mesh_two(
merged_process_mesh, dist_attr.process_mesh) merged_process_mesh, dist_attr.process_mesh)
while_op_dist_attr.process_mesh = merged_process_mesh while_op_dist_attr.process_mesh = merged_process_mesh
_make_dims_mapping_replicate(while_op_dist_attr)
# Step 2: set the related nodes of while_op to the process mesh of while_op # Step 2: set the related nodes of while_op to the process mesh of while_op
# Step 2.1: Find related nodes of cond var the graph of while_op # Step 2.1: Find related nodes of cond var the graph of while_op
...@@ -480,6 +560,7 @@ class Completer: ...@@ -480,6 +560,7 @@ class Completer:
tensor_dist_attr = self._dist_context.get_dist_attr_for_graph( tensor_dist_attr = self._dist_context.get_dist_attr_for_graph(
node) node)
tensor_dist_attr.process_mesh = merged_process_mesh tensor_dist_attr.process_mesh = merged_process_mesh
_make_dims_mapping_replicate(tensor_dist_attr)
# Step 3: set the process meshes of the inputs in while_op to the process meshes of the outside input nodes # Step 3: set the process meshes of the inputs in while_op to the process meshes of the outside input nodes
while_op_inputs_dist_attrs = while_op_dist_attr.inputs_dist_attrs while_op_inputs_dist_attrs = while_op_dist_attr.inputs_dist_attrs
...@@ -519,6 +600,25 @@ class Completer: ...@@ -519,6 +600,25 @@ class Completer:
dist_attr = self._dist_context.get_dist_attr_for_graph( dist_attr = self._dist_context.get_dist_attr_for_graph(
array_node) array_node)
dist_attr.process_mesh = merged_process_mesh dist_attr.process_mesh = merged_process_mesh
_make_dims_mapping_replicate(dist_attr)
def _update_process_mesh_between_graphs(self):
for parent_node, child_node in self._node_pairs_between_graphs:
parent_node_dist_attr = self._dist_context.get_dist_attr_for_graph(
parent_node)
child_node_dist_attr = self._dist_context.get_dist_attr_for_graph(
child_node)
parent_node_dist_attr.process_mesh = child_node_dist_attr.process_mesh
compatible_process_mesh = compute_compatible_process_mesh([
parent_node_dist_attr.process_mesh,
child_node_dist_attr.process_mesh
])
if compatible_process_mesh is not None \
and parent_node_dist_attr.process_mesh != compatible_process_mesh:
parent_node_dist_attr.process_mesh = compatible_process_mesh
if compatible_process_mesh is not None \
and child_node_dist_attr.process_mesh != compatible_process_mesh:
child_node_dist_attr.process_mesh = compatible_process_mesh
def _update_process_mesh(self): def _update_process_mesh(self):
ordered_op_nodes = self._dist_context._serial_ordered_op_nodes ordered_op_nodes = self._dist_context._serial_ordered_op_nodes
...@@ -569,7 +669,7 @@ class Completer: ...@@ -569,7 +669,7 @@ class Completer:
return None return None
for idx, op_node in enumerate(ordered_op_nodes[ for idx, op_node in enumerate(ordered_op_nodes[
idx_of_first_op_node_has_process_mesh + 1:]): idx_of_first_op_node_has_process_mesh + 1:]):
original_idx = idx_of_first_op_node_has_process_mesh + +idx + 1 original_idx = idx_of_first_op_node_has_process_mesh + idx + 1
nearest_op_node = ordered_op_nodes[original_idx - 1] nearest_op_node = ordered_op_nodes[original_idx - 1]
nearest_op_dist_attr = self._dist_context.get_dist_attr_for_graph( nearest_op_dist_attr = self._dist_context.get_dist_attr_for_graph(
nearest_op_node) nearest_op_node)
...@@ -585,6 +685,9 @@ class Completer: ...@@ -585,6 +685,9 @@ class Completer:
# Step 3: adjust the process meshes for special ops # Step 3: adjust the process meshes for special ops
self._update_process_mesh_for_specials() self._update_process_mesh_for_specials()
# Step 4: adjust the process meshes between graphs
self._update_process_mesh_between_graphs()
def _prepare(self): def _prepare(self):
self._while_op_nodes = {} self._while_op_nodes = {}
self._array_nodes = {} self._array_nodes = {}
...@@ -620,7 +723,7 @@ class Completer: ...@@ -620,7 +723,7 @@ class Completer:
self._node_pairs_between_graphs.append( self._node_pairs_between_graphs.append(
(after_node, node)) (after_node, node))
def complete_forward_annotation(self, serial_main_program): def complete_forward_annotation(self, serial_main_program=None):
""" Complete annotation for the partial annotated serial_main_program. """ Complete annotation for the partial annotated serial_main_program.
Arguments: Arguments:
serial_main_program: partial annotated serial_main_program. serial_main_program: partial annotated serial_main_program.
...@@ -628,15 +731,12 @@ class Completer: ...@@ -628,15 +731,12 @@ class Completer:
serial_main_program: completed annotated serial_main_program. serial_main_program: completed annotated serial_main_program.
""" """
# Use the default distribted context for completeion if there is no one if serial_main_program is None:
self._dist_context.serial_program = serial_main_program serial_main_program = self._dist_context.serial_main_program
else:
# Initialize distributed attributes for all var and op node in serial_main_program self._dist_context.serial_main_program = serial_main_program
self._dist_context.init_dist_attr_for_program()
# print_program_with_dist_attr(serial_main_program, self._dist_context)
# Initialize distributed attributes for all var and op node in graph self._dist_context.initialize()
self._dist_context.init_dist_attr_for_graph()
self._prepare() self._prepare()
...@@ -646,10 +746,9 @@ class Completer: ...@@ -646,10 +746,9 @@ class Completer:
# Copy the corresponding distributed attribute from graph to serial_main_program # Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program() self._dist_context.copy_dist_attr_from_graph_to_program()
self._dist_context.clear_dist_info_for_graph()
# NOTE:[HighOrderGrad] update vars and ops distributed attribute in high order gradient # NOTE:[HighOrderGrad] update vars and ops distributed attribute in high order gradient
self.complete_high_order_grad_annotation(serial_main_program) self._complete_high_order_grad_annotation(serial_main_program)
# Do the validation check and amend some completion # Do the validation check and amend some completion
self._dist_context.amend_dist_attr_for_program() self._dist_context.amend_dist_attr_for_program()
...@@ -658,7 +757,7 @@ class Completer: ...@@ -658,7 +757,7 @@ class Completer:
return serial_main_program return serial_main_program
def complete_high_order_grad_annotation(self, serial_main_program): def _complete_high_order_grad_annotation(self, serial_main_program):
""" """
NOTE: NOTE:
[HighOrderGrad] Complete the annotation of vars and ops only for high order gradient. [HighOrderGrad] Complete the annotation of vars and ops only for high order gradient.
...@@ -818,6 +917,10 @@ class Completer: ...@@ -818,6 +917,10 @@ class Completer:
def complete_backward_annotation(self, serial_main_program): def complete_backward_annotation(self, serial_main_program):
"""Complete the annotation of vars and ops in the backward phase for parallel program.""" """Complete the annotation of vars and ops in the backward phase for parallel program."""
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context.serial_main_program = serial_main_program
def _is_grad_var_name(name): def _is_grad_var_name(name):
if "@GRAD" in name: if "@GRAD" in name:
...@@ -1036,8 +1139,12 @@ class Completer: ...@@ -1036,8 +1139,12 @@ class Completer:
self._dist_context.set_op_dist_attr_for_program( self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr) grad_op, grad_op_dist_attr)
def complete_update_annotation(self, serial_main_program): def complete_update_annotation(self, serial_main_program=None):
"""Complete the annotation of vars and ops in the update phase for parallel program.""" """Complete the annotation of vars and ops in the update phase for parallel program."""
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context.serial_main_program = serial_main_program
ops = list(serial_main_program.global_block().ops) ops = list(serial_main_program.global_block().ops)
vars = serial_main_program.global_block().vars vars = serial_main_program.global_block().vars
learning_rate_completed = False learning_rate_completed = False
......
...@@ -132,12 +132,29 @@ class TensorDistributedAttribute: ...@@ -132,12 +132,29 @@ class TensorDistributedAttribute:
key, dist_attr) key, dist_attr)
self._is_annotated = copy.deepcopy(dist_attr._is_annotated) self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
# def reset(self, skip_dist_attr_field_names):
# if skip_dist_attr_field_names is not None \
# and "process_mesh" not in skip_dist_attr_field_names:
# self._process_mesh = None
# if skip_dist_attr_field_names is not None \
# and "dims_mapping" not in skip_dist_attr_field_names:
# for i in enumerate(self._dims_mapping):
# self._dims_mapping[i] = -1
# self._is_annotated = {}
def is_annotated(self, dist_attr_field_name): def is_annotated(self, dist_attr_field_name):
return self._is_annotated.get(dist_attr_field_name, False) return self._is_annotated.get(dist_attr_field_name, False)
# def mark_annotated_all(self):
# for key in get_tensor_dist_attr_field_keys():
# self.mark_annotated(key)
def mark_annotated(self, dist_attr_field_name): def mark_annotated(self, dist_attr_field_name):
self._is_annotated[dist_attr_field_name] = True self._is_annotated[dist_attr_field_name] = True
# def unmark_annotated(self, dist_attr_field_name):
# self._is_annotated[dist_attr_field_name] = False
def mark_annotated_as(self, dist_attr): def mark_annotated_as(self, dist_attr):
if dist_attr is None: if dist_attr is None:
return return
...@@ -357,9 +374,25 @@ class OperatorDistributedAttribute: ...@@ -357,9 +374,25 @@ class OperatorDistributedAttribute:
"ProcessMeshes in DistributedOperator must be the same." "ProcessMeshes in DistributedOperator must be the same."
self.process_mesh = shared_process_mesh self.process_mesh = shared_process_mesh
# def reset(self, skip_dist_attr_field_names):
# for tensor_dist_attr in self.inputs_dist_attrs.values():
# tensor_dist_attr.reset(skip_dist_attr_field_names)
# for tensor_dist_attr in self.outputs_dist_attrs.values():
# tensor_dist_attr.reset(skip_dist_attr_field_names)
# if skip_dist_attr_field_names is not None \
# and "process_mesh" not in skip_dist_attr_field_names:
# self.process_mesh = None
# self.impl_type = "default"
# self.impl_idx = 0
# self._is_annotated = {}
def is_annotated(self, attr_name): def is_annotated(self, attr_name):
return self._is_annotated.get(attr_name, False) return self._is_annotated.get(attr_name, False)
# def mark_annotated_all(self):
# for key in get_op_dist_attr_field_keys():
# self.mark_annotated(key)
def mark_annotated(self, attr_name): def mark_annotated(self, attr_name):
if attr_name == "process_mesh": if attr_name == "process_mesh":
# Make sure proscess_mesh be annotated consistently # Make sure proscess_mesh be annotated consistently
......
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
import copy import copy
from collections import defaultdict from collections import defaultdict
import paddle.fluid
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid.framework import get_flags, set_flags from paddle.fluid.framework import get_flags, set_flags
from paddle.fluid import core from paddle.fluid import core
from paddle.distributed.passes import PassContext
from .dist_attribute import TensorDistributedAttribute from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute from .dist_attribute import OperatorDistributedAttribute
from .dist_tensor import DistributedTensor from .dist_tensor import DistributedTensor
...@@ -54,26 +56,41 @@ class DistributedContext: ...@@ -54,26 +56,41 @@ class DistributedContext:
serial_main_prog=None, serial_main_prog=None,
serial_startup_prog=None, serial_startup_prog=None,
dist_main_progs=None, dist_main_progs=None,
dist_startup_progs=None): dist_startup_progs=None,
# Program related data members serial_loss=None,
self._serial_program = serial_main_prog serial_optimizer=None,
self._is_initialized_for_program = False strategy=None):
# Data members related to original programs (unchanged)
self._original_serial_main_program = serial_main_prog
self._original_serial_startup_program = serial_startup_prog
self._original_serial_loss = serial_loss
self._original_serial_optimizer = serial_optimizer
if self._original_serial_main_program is None:
self._original_serial_main_program = paddle.fluid.default_main_program(
)
if self._original_serial_startup_program is None:
self._original_serial_startup_program = paddle.fluid.default_startup_program(
)
# Data members related to programs (changed)
self._serial_main_program = None
self._serial_startup_program = None
self._serial_loss = None
self._serial_optimizer = None
# Data members related to the program
self._dist_tensors_for_program = {} self._dist_tensors_for_program = {}
self._dist_ops_for_program = {} self._dist_ops_for_program = {}
self._block_state = BlockState() self._block_state = BlockState()
# Graph related data members
self._is_initialized_for_graph = False # Data members related to the graph
self._serial_graph = None self._serial_graph = None
self._dist_tensors_for_graph = {} self._dist_tensors_for_graph = {}
self._dist_ops_for_graph = {} self._dist_ops_for_graph = {}
self._node_id_to_tensor_id = {} self._node_id_to_tensor_id = {}
self._node_id_to_op_id = {} self._node_id_to_op_id = {}
# Other data members
self._dist_op_context = DistributedOperatorContext()
self._process_meshes = []
self._serial_ordered_nodes = []
self._tensor_id_to_tensor_node_ids = {}
# Data members related to the distributed programs
# Distributed programs # Distributed programs
self._dist_main_programs = dist_main_progs self._dist_main_programs = dist_main_progs
if not self._dist_main_programs: if not self._dist_main_programs:
...@@ -82,20 +99,71 @@ class DistributedContext: ...@@ -82,20 +99,71 @@ class DistributedContext:
if not self._dist_startup_programs: if not self._dist_startup_programs:
self._dist_startup_programs = {} self._dist_startup_programs = {}
# Distributed Strategy
self._strategy = strategy
# Pass Context
self._pass_context = PassContext()
# Distributed Operator Context
self._dist_op_context = DistributedOperatorContext()
# Other data members
self._process_meshes = []
self._serial_ordered_tensor_nodes = []
self._serial_ordered_op_nodes = []
self._serial_ordered_nodes = []
# self._tensor_id_to_tensor_node_ids = {}
self._is_initialized = False
@property @property
def serial_program(self): def serial_main_program(self):
return self._serial_program return self._serial_main_program
@serial_main_program.setter
def serial_main_program(self, program):
# if self._serial_main_program:
# print("WARNING: The program attached to this distributed context will be replaced by the new one.")
self._original_serial_main_program = program
self._serial_main_program = program
@property
def serial_startup_program(self):
return self._serial_startup_program
# @serial_startup_program.setter
# def serial_startup_program(self, serial_startup_program):
# self._serial_startup_program = serial_startup_program
@property
def serial_loss(self):
return self._serial_loss
# @serial_loss.setter
# def serial_loss(self, serial_loss):
# self._serial_loss = serial_loss
@property
def serial_optimizer(self):
return self._serial_optimizer
# @serial_optimizer.setter
# def serial_optimizer(self, serial_optimizer):
# self._serial_optimizer = serial_optimizer
@property
def strategy(self):
return self._strategy
# @strategy.setter
# def strategy(self, strategy):
# self._strategy = strategy
@property @property
def serial_graph(self): def serial_graph(self):
return self._serial_graph return self._serial_graph
@serial_program.setter
def serial_program(self, program):
# assert self._serial_program is None, \
# "This distributed context has already been realted to a serial program"
self._serial_program = program
@property @property
def serial_ordered_nodes(self): def serial_ordered_nodes(self):
return self._serial_ordered_nodes return self._serial_ordered_nodes
...@@ -104,6 +172,10 @@ class DistributedContext: ...@@ -104,6 +172,10 @@ class DistributedContext:
def process_meshes(self): def process_meshes(self):
return self._process_meshes return self._process_meshes
@property
def pass_context(self):
return self._pass_context
@property @property
def dist_op_context(self): def dist_op_context(self):
return self._dist_op_context return self._dist_op_context
...@@ -121,10 +193,64 @@ class DistributedContext: ...@@ -121,10 +193,64 @@ class DistributedContext:
return self._dist_startup_programs return self._dist_startup_programs
@property @property
def is_annotation(self): def has_annotation(self):
return len(self._dist_tensors_for_program) or len( return len(self._dist_tensors_for_program) or len(
self._dist_ops_for_program) self._dist_ops_for_program)
def initialize(self):
if not self._is_initialized:
self._serial_main_program = self._original_serial_main_program.clone(
)
self._serial_startup_program = self._original_serial_startup_program.clone(
)
self._serial_main_program = self._original_serial_main_program
self._serial_startup_program = self._original_serial_startup_program
self._serial_loss = self._original_serial_loss
self._serial_optimizer = self._original_serial_optimizer
self._init_dist_attr_for_program()
self._tensors_ids = list(self._dist_tensors_for_program.keys())
self._ops_ids = list(self._dist_ops_for_program.keys())
set_flags({"FLAGS_convert_all_blocks": True})
self._serial_graph = framework.IrGraph(
core.Graph(self._serial_main_program.desc))
self._init_dist_attr_for_graph()
self._is_initialized = True
# def reset(self,
# skip_dist_tensors=None,
# skip_dist_ops=None,
# skip_tensor_dist_attr_fields=None,
# skip_op_dist_attr_fields=None):
# self._serial_main_program = self._original_serial_main_program.clone()
# self._serial_startup_program = self._original_serial_startup_program.clone()
# new_tensors_ids = []
# for tensor_id, dist_tensor in self._dist_tensors_for_program.items():
# if tensor_id in self._tensors_ids:
# dist_tensor.dist_attr.reset(skip_tensor_dist_attr_fields)
# else:
# new_tensors_ids.append(tensor_id)
# for tensor_id in new_tensors_ids:
# self._dist_tensors_for_program.pop(tensor_id)
# new_ops_ids = []
# for op_id, dist_op in self._dist_ops_for_program.items():
# if op_id in self._ops_ids:
# dist_op.dist_attr.reset(skip_op_dist_attr_fields)
# else:
# new_ops_ids.append(op_id)
# for op_id in new_ops_ids:
# self._dist_ops_for_program.pop(op_id)
# self.copy_dist_attr_from_program_to_graph()
# self._dist_main_programs = {}
# self._dist_startup_programs = {}
# self._pass_context = PassContext()
# self._dist_op_context = DistributedOperatorContext()
# self._process_meshes = []
def add_process_mesh(self, process_mesh): def add_process_mesh(self, process_mesh):
assert isinstance(process_mesh, ProcessMesh), \ assert isinstance(process_mesh, ProcessMesh), \
'The type of dim_mapping must be ProcessMesh.' 'The type of dim_mapping must be ProcessMesh.'
...@@ -133,12 +259,12 @@ class DistributedContext: ...@@ -133,12 +259,12 @@ class DistributedContext:
def add_dist_tensor_for_program(self, dist_tensor): def add_dist_tensor_for_program(self, dist_tensor):
inner_serial_tensor = dist_tensor.serial_tensor inner_serial_tensor = dist_tensor.serial_tensor
inner_serial_tensor_id = inner_serial_tensor.desc.id() inner_serial_tensor_id = inner_serial_tensor.desc.original_id()
self._dist_tensors_for_program[inner_serial_tensor_id] = dist_tensor self._dist_tensors_for_program[inner_serial_tensor_id] = dist_tensor
def add_dist_op_for_program(self, dist_op): def add_dist_op_for_program(self, dist_op):
inner_serial_op = dist_op.serial_op inner_serial_op = dist_op.serial_op
inner_serial_op_id = inner_serial_op.desc.id() inner_serial_op_id = inner_serial_op.desc.original_id()
self._dist_ops_for_program[inner_serial_op_id] = dist_op self._dist_ops_for_program[inner_serial_op_id] = dist_op
def get_dist_tensor_for_program(self, serial_tensor): def get_dist_tensor_for_program(self, serial_tensor):
...@@ -215,18 +341,6 @@ class DistributedContext: ...@@ -215,18 +341,6 @@ class DistributedContext:
else: else:
return None return None
# def set_tensor_dist_attr_for_graph(self, serial_tensor_node, dist_attr):
# assert serial_tensor_node.is_var() and \
# serial_tensor_node.var() is not None
# serial_tensor_id = serial_tensor_node.node.original_desc_id()
# dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
# assert dist_tensor is not None, \
# "The distributed tensor of the program has not been added to this context."
# serial_tensor_node_id = serial_tensor_node.id()
# new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
# dist_attr)
# self._dist_tensors_for_graph[serial_tensor_node_id] = new_dist_tensor
def get_op_dist_attr_for_program(self, serial_op): def get_op_dist_attr_for_program(self, serial_op):
serial_op_id = serial_op.desc.id() serial_op_id = serial_op.desc.id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None) dist_op = self._dist_ops_for_program.get(serial_op_id, None)
...@@ -259,17 +373,6 @@ class DistributedContext: ...@@ -259,17 +373,6 @@ class DistributedContext:
else: else:
return None return None
# def set_op_dist_attr_for_graph(self, serial_op_node, dist_attr):
# assert serial_op_node.is_op() and \
# serial_op_node.op() is not None
# serial_op_id = serial_op_node.node.original_desc_id()
# dist_op = self._dist_ops_for_program.get(serial_op_id, None)
# assert dist_op is not None, \
# "The distributed operator of the program has not been added to this context."
# serial_op_node_id = serial_op_node.id()
# new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr)
# self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
def get_dist_attr_for_graph(self, serial_node): def get_dist_attr_for_graph(self, serial_node):
if serial_node.is_var() and serial_node.var() is not None: if serial_node.is_var() and serial_node.var() is not None:
serial_tensor_node_id = _node_id(serial_node) serial_tensor_node_id = _node_id(serial_node)
...@@ -288,15 +391,14 @@ class DistributedContext: ...@@ -288,15 +391,14 @@ class DistributedContext:
return None return None
return None return None
def init_dist_attr_for_program(self): def _init_dist_attr_for_program(self, no_default=False):
assert self._serial_program, \
"Please set the program of this context before initializing its distribute attributes."
if self._is_initialized_for_program:
return
# Copy the dist tensors and dist ops annotated by users from the default context # Copy the dist tensors and dist ops annotated by users from the default context
if not no_default:
default_ctx = get_default_distributed_context() default_ctx = get_default_distributed_context()
self._process_meshes = copy.deepcopy(default_ctx.process_meshes) self._process_meshes = copy.deepcopy(default_ctx.process_meshes)
for block in self._serial_program.blocks: else:
default_ctx = self
for block in self._serial_main_program.blocks:
for tensor in block.vars.values(): for tensor in block.vars.values():
# Copy the distributed tensors in the default context # Copy the distributed tensors in the default context
default_dist_tensor = default_ctx.get_dist_tensor_for_program( default_dist_tensor = default_ctx.get_dist_tensor_for_program(
...@@ -316,9 +418,8 @@ class DistributedContext: ...@@ -316,9 +418,8 @@ class DistributedContext:
if current_dist_op is None: if current_dist_op is None:
dist_op = DistributedOperator(op) dist_op = DistributedOperator(op)
self.add_dist_op_for_program(dist_op) self.add_dist_op_for_program(dist_op)
self._is_initialized_for_program = True
def order_nodes_by_program_order(self): def _order_nodes_by_program_order(self):
def _contains(nodes, target_node): def _contains(nodes, target_node):
for node in nodes: for node in nodes:
if _node_id(node) == _node_id(target_node): if _node_id(node) == _node_id(target_node):
...@@ -328,7 +429,6 @@ class DistributedContext: ...@@ -328,7 +429,6 @@ class DistributedContext:
serial_ordered_tensor_nodes = [] serial_ordered_tensor_nodes = []
serial_ordered_op_nodes = [] serial_ordered_op_nodes = []
all_nodes = [] all_nodes = []
# for idx, graph in enumerate(self._serial_graph.all_sub_graphs()):
for idx, graph in enumerate(self._serial_graph.all_sub_graphs()): for idx, graph in enumerate(self._serial_graph.all_sub_graphs()):
for node in graph.all_nodes(): for node in graph.all_nodes():
all_nodes.append(node) all_nodes.append(node)
...@@ -346,33 +446,35 @@ class DistributedContext: ...@@ -346,33 +446,35 @@ class DistributedContext:
new_serial_ordered_tensor_nodes = [] new_serial_ordered_tensor_nodes = []
new_serial_ordered_op_nodes = [] new_serial_ordered_op_nodes = []
new_serial_ordered_nodes = []
for op_node in serial_ordered_op_nodes: for op_node in serial_ordered_op_nodes:
tensor_nodes = [] tensor_nodes = []
for tensor_node in op_node.inputs: for tensor_node in op_node.inputs:
if tensor_node.is_var() \ if tensor_node.is_var() \
and tensor_node.var() is not None \ and tensor_node.var() is not None \
and not _contains(self._serial_ordered_nodes, tensor_node): and not _contains(new_serial_ordered_nodes, tensor_node):
tensor_nodes.append(tensor_node) tensor_nodes.append(tensor_node)
new_serial_ordered_tensor_nodes.append(tensor_node) new_serial_ordered_tensor_nodes.append(tensor_node)
tensor_nodes.sort(key=lambda node: node.node.original_desc_id()) tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
self._serial_ordered_nodes.extend(tensor_nodes) new_serial_ordered_nodes.extend(tensor_nodes)
self._serial_ordered_nodes.append(op_node) new_serial_ordered_nodes.append(op_node)
new_serial_ordered_op_nodes.append(op_node) new_serial_ordered_op_nodes.append(op_node)
tensor_nodes = [] tensor_nodes = []
for tensor_node in op_node.outputs: for tensor_node in op_node.outputs:
if tensor_node.is_var() \ if tensor_node.is_var() \
and tensor_node.var() is not None \ and tensor_node.var() is not None \
and not _contains(self._serial_ordered_nodes, tensor_node): and not _contains(new_serial_ordered_nodes, tensor_node):
tensor_nodes.append(tensor_node) tensor_nodes.append(tensor_node)
new_serial_ordered_tensor_nodes.append(tensor_node) new_serial_ordered_tensor_nodes.append(tensor_node)
tensor_nodes.sort(key=lambda node: node.node.original_desc_id()) tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
self._serial_ordered_nodes.extend(tensor_nodes) new_serial_ordered_nodes.extend(tensor_nodes)
new_serial_ordered_tensor_nodes.sort( new_serial_ordered_tensor_nodes.sort(
key=lambda node: node.node.original_desc_id()) key=lambda node: node.node.original_desc_id())
new_serial_ordered_op_nodes.sort( new_serial_ordered_op_nodes.sort(
key=lambda node: node.node.original_desc_id()) key=lambda node: node.node.original_desc_id())
self._serial_ordered_tensor_nodes = new_serial_ordered_tensor_nodes self._serial_ordered_tensor_nodes = new_serial_ordered_tensor_nodes
self._serial_ordered_op_nodes = new_serial_ordered_op_nodes self._serial_ordered_op_nodes = new_serial_ordered_op_nodes
self._serial_ordered_nodes = new_serial_ordered_nodes
assert len(self._serial_ordered_nodes) == len( assert len(self._serial_ordered_nodes) == len(
self._serial_ordered_tensor_nodes) + len( self._serial_ordered_tensor_nodes) + len(
self._serial_ordered_op_nodes) self._serial_ordered_op_nodes)
...@@ -385,16 +487,9 @@ class DistributedContext: ...@@ -385,16 +487,9 @@ class DistributedContext:
"WARNING: there are some orphan tensors or ops which are not used in the execution." "WARNING: there are some orphan tensors or ops which are not used in the execution."
) )
def init_dist_attr_for_graph(self): def _init_dist_attr_for_graph(self):
assert self._is_initialized_for_program, \ # Convert program to graph and initialize the distributed attributes
"The program must be initialized before initializing the distributed attributes for its graph." self._order_nodes_by_program_order()
if self._is_initialized_for_graph:
return
# Convert program to graph
set_flags({"FLAGS_convert_all_blocks": True})
self._serial_graph = framework.IrGraph(
core.Graph(self._serial_program.desc))
self.order_nodes_by_program_order()
for node in self.serial_ordered_nodes: for node in self.serial_ordered_nodes:
if node.is_var() and node.var() is not None: if node.is_var() and node.var() is not None:
dist_tensor = None dist_tensor = None
...@@ -428,7 +523,6 @@ class DistributedContext: ...@@ -428,7 +523,6 @@ class DistributedContext:
new_dist_op = DistributedOperator(dist_op.serial_op, new_dist_op = DistributedOperator(dist_op.serial_op,
dist_op.dist_attr) dist_op.dist_attr)
self._dist_ops_for_graph[serial_op_node_id] = new_dist_op self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
self._is_initialized_for_graph = True
def clear_dist_info_for_program(self): def clear_dist_info_for_program(self):
self._dist_tensors_for_program.clear() self._dist_tensors_for_program.clear()
...@@ -438,8 +532,40 @@ class DistributedContext: ...@@ -438,8 +532,40 @@ class DistributedContext:
self._dist_tensors_for_graph.clear() self._dist_tensors_for_graph.clear()
self._dist_ops_for_graph.clear() self._dist_ops_for_graph.clear()
def copy_dist_attr_from_program_to_graph(self):
for node in self.serial_ordered_nodes:
if node.is_var() and node.var() is not None:
dist_tensor = None
tensor_id = node.node.original_desc_id()
for cur_tensor_id, cur_dist_tensor in self._dist_tensors_for_program.items(
):
if tensor_id == cur_tensor_id \
or tensor_id == cur_dist_tensor.serial_tensor.desc.original_id():
dist_tensor = cur_dist_tensor
assert dist_tensor is not None, \
"Tensor must have a distributed tensor after the initialization for program."
serial_tensor_node_id = _node_id(node)
new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
dist_tensor.dist_attr)
self._dist_tensors_for_graph[
serial_tensor_node_id] = new_dist_tensor
if node.is_op() and node.op() is not None:
dist_op = None
op_id = node.node.original_desc_id()
for cur_op_id, cur_dist_op in self._dist_ops_for_program.items(
):
if op_id == cur_op_id \
or op_id == cur_dist_op.serial_op.desc.original_id():
dist_op = cur_dist_op
assert dist_op is not None, \
"Operator must have a distributed operator after the initialization for program."
serial_op_node_id = _node_id(node)
new_dist_op = DistributedOperator(dist_op.serial_op,
dist_op.dist_attr)
self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
def copy_dist_attr_from_graph_to_program(self): def copy_dist_attr_from_graph_to_program(self):
assert self._is_initialized_for_program and self._is_initialized_for_graph, \ assert self._is_initialized, \
"Both program and graph must be initialized." "Both program and graph must be initialized."
updated_tensors = {} updated_tensors = {}
# all_nodes = self._serial_graph.all_nodes() # all_nodes = self._serial_graph.all_nodes()
...@@ -532,18 +658,24 @@ class DistributedContext: ...@@ -532,18 +658,24 @@ class DistributedContext:
dims_mapping[i] = -1 dims_mapping[i] = -1
def validate_dist_attr_for_program(self): def validate_dist_attr_for_program(self):
if not self._is_initialized_for_program: if not self._is_initialized:
assert False, \ assert False, \
"Program must be initialized before validating its distributed attributes" "Program must be initialized before validating its distributed attributes"
for block in self.serial_program.blocks: for block in self.serial_main_program.blocks:
for tensor in block.vars.values(): for tensor in block.vars.values():
dist_tensor = self.get_dist_tensor_for_program(tensor) dist_tensor = self.get_dist_tensor_for_program(tensor)
assert dist_tensor is not None, \
"Tensor {} does not have a distributed attribute.".format(
dist_tensor.serial_tensor.name)
if (dist_tensor is not None) and ( if (dist_tensor is not None) and (
not dist_tensor.validate_dist_attr()): not dist_tensor.validate_dist_attr()):
assert False, "Tensor {} has a wrong distributed attributes {}.".format( assert False, "Tensor {} has a wrong distributed attributes {}.".format(
dist_tensor.serial_tensor.name, dist_tensor.dist_attr) dist_tensor.serial_tensor.name, dist_tensor.dist_attr)
for op in block.ops: for op in block.ops:
dist_op = self.get_dist_op_for_program(op) dist_op = self.get_dist_op_for_program(op)
assert dist_op is not None, \
"Operator {} does not have a distributed attribute.".format(
dist_op.serial_op.type)
if (dist_op is not None) and (not dist_op.validate_dist_attr()): if (dist_op is not None) and (not dist_op.validate_dist_attr()):
assert False, "Operator {} has a wrong distributed attributes {}.".format( assert False, "Operator {} has a wrong distributed attributes {}.".format(
dist_op.serial_op.type, dist_tensor.dist_attr) dist_op.serial_op.type, dist_tensor.dist_attr)
...@@ -554,10 +686,12 @@ class DistributedContext: ...@@ -554,10 +686,12 @@ class DistributedContext:
result = cls.__new__(cls) result = cls.__new__(cls)
memo[id(self)] = result memo[id(self)] = result
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if k == "_serial_program" or k == "_serial_graph" \ if k in [
or k == "_dist_main_programs" or k == "_dist_startup_programs" \ "_original_serial_main_program", "_original_serial_startup_program", \
or k == "_serial_ordered_nodes" or k == "_serial_ordered_tensor_nodes" \ "_serial_main_program", "_serial_startup_program", "_serial_graph", \
or k == "_serial_ordered_op_nodes": "_dist_main_programs", "_dist_startup_programs", \
"_serial_ordered_nodes", "_serial_ordered_tensor_nodes", \
"_serial_ordered_op_nodes"]:
setattr(result, k, v) setattr(result, k, v)
else: else:
setattr(result, k, copy.deepcopy(v, memo)) setattr(result, k, copy.deepcopy(v, memo))
......
...@@ -118,11 +118,10 @@ class Engine: ...@@ -118,11 +118,10 @@ class Engine:
losses = to_list(self._loss(*(outputs + labels))) losses = to_list(self._loss(*(outputs + labels)))
default_ctx = get_default_distributed_context() default_ctx = get_default_distributed_context()
if not default_ctx.is_annotation or self._default_strategy: if not default_ctx.has_annotation or self._default_strategy:
inputs = [self._set_data_parallel(var) for var in inputs] inputs = [self._set_data_parallel(var) for var in inputs]
labels = [self._set_data_parallel(var) for var in labels] labels = [self._set_data_parallel(var) for var in labels]
# print(serial_main_prog)
self._feed_vars[mode] = {"inputs": inputs, "labels": labels} self._feed_vars[mode] = {"inputs": inputs, "labels": labels}
self._fetch_vars[mode] = { self._fetch_vars[mode] = {
......
...@@ -18,15 +18,15 @@ from ..dist_attribute import OperatorDistributedAttribute ...@@ -18,15 +18,15 @@ from ..dist_attribute import OperatorDistributedAttribute
_g_distributed_operator_impl_containers = {} _g_distributed_operator_impl_containers = {}
_g_elementwise_ops = [ _g_elementwise_ops = [
"elementwise_add", "gelu", "dropout", "cast", "gather", "concat" "elementwise", "gelu", "dropout", "cast", "gather", "concat"
] ]
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'} BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'}
def is_elementwise_op(op_type): def is_elementwise_op(op_type):
if op_type in _g_elementwise_ops: for eltwise_op in _g_elementwise_ops:
if eltwise_op in op_type:
return True return True
else:
return False return False
...@@ -156,7 +156,9 @@ def register_distributed_operator_impl(op_type, dist_impl): ...@@ -156,7 +156,9 @@ def register_distributed_operator_impl(op_type, dist_impl):
assert False, "Must register distributed operator registry first." assert False, "Must register distributed operator registry first."
def find_best_compatible_distributed_operator_impl(dist_op, fwd=True): def find_best_compatible_distributed_operator_impl(dist_op,
fwd=True,
partial=True):
""" """
Here just return the first compatible implemention. Here just return the first compatible implemention.
This will be improved by cost model in the future. This will be improved by cost model in the future.
...@@ -168,6 +170,7 @@ def find_best_compatible_distributed_operator_impl(dist_op, fwd=True): ...@@ -168,6 +170,7 @@ def find_best_compatible_distributed_operator_impl(dist_op, fwd=True):
dist_op_default_impl_container = get_distributed_operator_impl_container( dist_op_default_impl_container = get_distributed_operator_impl_container(
"default") "default")
compatible_impls = [] compatible_impls = []
if partial:
if fwd: if fwd:
# First, find impls in the corresponding container # First, find impls in the corresponding container
if dist_op_impl_container: if dist_op_impl_container:
...@@ -198,9 +201,24 @@ def find_best_compatible_distributed_operator_impl(dist_op, fwd=True): ...@@ -198,9 +201,24 @@ def find_best_compatible_distributed_operator_impl(dist_op, fwd=True):
compatible_impls.extend( compatible_impls.extend(
dist_op_default_impl_container.get_output_compatible_impls( dist_op_default_impl_container.get_output_compatible_impls(
dist_op)) dist_op))
else:
# First, find impls in the corresponding container
if dist_op_impl_container:
compatible_impls.extend(
dist_op_impl_container.get_compatible_impls(dist_op))
# Second, find impls in the elementwise container
if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
compatible_impls.extend(
dist_op_eltwise_impl_container.get_compatible_impls(dist_op))
# Third, find impls in the default container
if dist_op_default_impl_container:
compatible_impls.extend(
dist_op_default_impl_container.get_compatible_impls(dist_op))
if compatible_impls: if compatible_impls:
# For now, just return the first compatible impl # For now, just return the first compatible impl
best_compatible_impl = compatible_impls[0] # best_compatible_impl = compatible_impls[0]
best_compatible_impl = compatible_impls
else: else:
best_compatible_impl = None best_compatible_impl = None
return best_compatible_impl return best_compatible_impl
......
...@@ -53,6 +53,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -53,6 +53,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
batch_dim_mappings = []
input_names = op_desc.input_names() input_names = op_desc.input_names()
xshape_arg_names = [] xshape_arg_names = []
if "XShape" in input_names: if "XShape" in input_names:
...@@ -64,14 +65,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -64,14 +65,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping: for mapping in dims_mapping:
if mapping != -1: if mapping != -1:
return False return False
# continue continue
# if len(dims_mapping) < 1:
# continue
if arg_name not in xshape_arg_names: if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1: if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]: for mapping in dims_mapping[1:]:
if mapping != -1: if mapping != -1:
return False return False
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
else: else:
if dims_mapping[0] != -1: if dims_mapping[0] != -1:
return False return False
...@@ -79,12 +80,19 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -79,12 +80,19 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping[2:]: for mapping in dims_mapping[2:]:
if mapping != -1: if mapping != -1:
return False return False
if len(dims_mapping) >= 2:
batch_dim_mappings.append(dims_mapping[1])
if compute_compatible_dim_mapping(batch_dim_mappings) is None:
return False
return True return True
def is_output_compatible(self, dist_op): def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
output_names = op_desc.output_names() output_names = op_desc.output_names()
batch_dim_mappings = []
xshape_arg_names = [] xshape_arg_names = []
if "XShape" in output_names: if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape") xshape_arg_names = op_desc.output("XShape")
...@@ -95,14 +103,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -95,14 +103,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping: for mapping in dims_mapping:
if mapping != -1: if mapping != -1:
return False return False
# continue continue
# if len(dims_mapping) < 1:
# continue
if arg_name not in xshape_arg_names: if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1: if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]: for mapping in dims_mapping[1:]:
if mapping != -1: if mapping != -1:
return False return False
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
else: else:
if dims_mapping[0] != -1: if dims_mapping[0] != -1:
return False return False
...@@ -110,6 +118,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -110,6 +118,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping[2:]: for mapping in dims_mapping[2:]:
if mapping != -1: if mapping != -1:
return False return False
if len(dims_mapping) >= 2:
batch_dim_mappings.append(dims_mapping[1])
if compute_compatible_dim_mapping(batch_dim_mappings) is None:
return False
return True return True
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
...@@ -123,9 +137,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -123,9 +137,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
xshape_arg_names = op_desc.input("XShape") xshape_arg_names = op_desc.input("XShape")
for arg_name in op_desc.input_arg_names(): for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name) serial_tensor = dist_op.get_serial_input(arg_name)
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if serial_tensor.is_parameter: if serial_tensor.is_parameter:
for mapping in dims_mapping:
if mapping != -1:
return False
continue continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if arg_name not in xshape_arg_names: if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1: if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]: for mapping in dims_mapping[1:]:
...@@ -150,9 +167,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -150,9 +167,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
xshape_arg_names = op_desc.output("XShape") xshape_arg_names = op_desc.output("XShape")
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name) serial_tensor = dist_op.get_serial_output(arg_name)
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if serial_tensor.is_parameter: if serial_tensor.is_parameter:
for mapping in dims_mapping:
if mapping != -1:
return False
continue continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names: if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1: if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]: for mapping in dims_mapping[1:]:
...@@ -229,7 +249,9 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -229,7 +249,9 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
compatible_dim_mapping = compute_compatible_dim_mapping( compatible_dim_mapping = compute_compatible_dim_mapping(
batch_dim_mappings) batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping." if compatible_dim_mapping is None:
return False
for arg_name in op_desc.input_arg_names(): for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name) serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter: if serial_tensor.is_parameter:
......
...@@ -52,21 +52,46 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): ...@@ -52,21 +52,46 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
if is_elementwise_op(op_desc.type()): if not is_elementwise_op(op_desc.type()):
return True
else:
return False return False
op_dist_attr = dist_op.dist_attr
dims_mapping_list = []
input_arg_names = op_desc.input_arg_names()
max_dims_mapping_len = -1
for arg_name in input_arg_names:
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if max_dims_mapping_len < len(dims_mapping):
max_dims_mapping_len = len(dims_mapping)
dims_mapping_list.append(dims_mapping)
for idx in range(max_dims_mapping_len):
dim_mappings = []
for dims_mapping in dims_mapping_list:
if idx < len(dims_mapping):
dim_mappings.append(dims_mapping[-(idx + 1)])
if compute_compatible_dim_mapping(dim_mappings) is None:
return False
return True
def is_output_compatible(self, dist_op): def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_desc = dist_op.serial_op.desc if not is_elementwise_op(op_desc.type()):
if is_elementwise_op(op_desc.type()): return False
return True op_dist_attr = dist_op.dist_attr
else: dims_mapping_list = []
output_arg_names = op_desc.output_arg_names()
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
dims_mapping_list.append(dims_mapping)
if compute_compatible_dims_mapping(dims_mapping_list) is None:
return False return False
return True
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
if not is_elementwise_op(op_desc.type()):
return False
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
dims_mapping_list = [] dims_mapping_list = []
input_arg_names = op_desc.input_arg_names() input_arg_names = op_desc.input_arg_names()
...@@ -127,7 +152,8 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): ...@@ -127,7 +152,8 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
compatible_dims_mapping = compute_compatible_dims_mapping( compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list) dims_mapping_list)
assert compatible_dims_mapping is not None, "There is no compatible dim mapping." if compatible_dims_mapping is None:
return False
for arg_name in input_arg_names: for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
......
...@@ -95,7 +95,8 @@ def _update_dims_mapping_for_matmul(dist_op): ...@@ -95,7 +95,8 @@ def _update_dims_mapping_for_matmul(dist_op):
broadcast_x_dims_mapping, broadcast_y_dims_mapping, broadcast_x_dims_mapping, broadcast_y_dims_mapping,
broadcast_out_dims_mapping broadcast_out_dims_mapping
]) ])
assert compatible_dims_mapping is not None, "There is no compatible dim mapping." if compatible_dims_mapping is None:
return False
for i in range(x_dims_mapping_len - 2): for i in range(x_dims_mapping_len - 2):
new_idx = i + (out_dims_mapping_len - x_dims_mapping_len) new_idx = i + (out_dims_mapping_len - x_dims_mapping_len)
......
...@@ -117,7 +117,8 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -117,7 +117,8 @@ class DistributedPNormImpl(DistributedOperatorImpl):
compatible_dim_mapping = compute_compatible_dim_mapping( compatible_dim_mapping = compute_compatible_dim_mapping(
batch_dim_mappings) batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping." if compatible_dim_mapping is None:
return False
for arg_name in op_desc.input_arg_names(): for arg_name in op_desc.input_arg_names():
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
......
...@@ -17,6 +17,7 @@ from .common import DistributedOperatorImpl ...@@ -17,6 +17,7 @@ from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from .dist_default import DistributedDefaultImpl0 from .dist_default import DistributedDefaultImpl0
...@@ -47,6 +48,29 @@ class DistributedSliceImpl(DistributedOperatorImpl): ...@@ -47,6 +48,29 @@ class DistributedSliceImpl(DistributedOperatorImpl):
return True return True
def is_output_compatible(self, dist_op): def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
in_name = op_desc.input('Input')[0]
out_name = op_desc.output('Out')[0]
axes = op_desc.attr('axes')
decrease_axis = op_desc.attr('decrease_axis')
in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
ref_indices = []
for i in range(len(in_dims_mapping)):
if i not in decrease_axis:
ref_indices.append(i)
if ref_indices == []:
assert len(out_dims_mapping) == 1
if is_dim_shard(out_dims_mapping[0]):
return False
else:
for i in range(len(out_dims_mapping)):
ref_index = ref_indices[i]
if ref_index in axes and is_dim_shard(out_dims_mapping[i]):
return False
return True return True
def is_compatible(self, dist_op): def is_compatible(self, dist_op):
...@@ -95,16 +119,29 @@ class DistributedSliceImpl(DistributedOperatorImpl): ...@@ -95,16 +119,29 @@ class DistributedSliceImpl(DistributedOperatorImpl):
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
ref_dims_mapping = [] ref_dims_mapping = []
ref_indices = []
for i in range(len(in_dims_mapping)): for i in range(len(in_dims_mapping)):
if i not in decrease_axis: if i not in decrease_axis:
ref_dims_mapping.append(in_dims_mapping[i]) ref_dims_mapping.append(in_dims_mapping[i])
ref_indices.append(i)
if ref_dims_mapping == []: if ref_dims_mapping == []:
ref_dims_mapping = [-1] ref_dims_mapping = [-1]
assert len(ref_dims_mapping) == len(out_dims_mapping)
assert ref_dims_mapping[0] == out_dims_mapping[0]
changed = False
else:
assert len(ref_dims_mapping) == len(out_dims_mapping) assert len(ref_dims_mapping) == len(out_dims_mapping)
for i in range(len(out_dims_mapping)): for i in range(len(out_dims_mapping)):
if out_dims_mapping[i] != ref_dims_mapping[i]: compatible_dim_mapping = compute_compatible_dim_mapping(
out_dims_mapping[i] = ref_dims_mapping[i] [out_dims_mapping[i], ref_dims_mapping[i]])
if compatible_dim_mapping is None:
continue
if ref_dims_mapping[i] != compatible_dim_mapping:
in_dims_mapping[ref_indices[i]] = compatible_dim_mapping
changed = True
if out_dims_mapping[i] != compatible_dim_mapping:
out_dims_mapping[i] = compatible_dim_mapping
changed = True changed = True
return changed return changed
......
...@@ -230,7 +230,7 @@ class AutoParallelizer: ...@@ -230,7 +230,7 @@ class AutoParallelizer:
g_process_group_map = copy.deepcopy(_g_process_group_map) g_process_group_map = copy.deepcopy(_g_process_group_map)
_g_process_group_map.clear() _g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, []) _g_process_group_map[0] = ProcessGroup(0, [])
for process_mesh in dist_context._process_meshes: for process_mesh in self._dist_context._process_meshes:
_g_process_group_map[0].add_ranks(process_mesh.processes) _g_process_group_map[0].add_ranks(process_mesh.processes)
return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map
......
...@@ -138,7 +138,6 @@ class MetricRecords(object): ...@@ -138,7 +138,6 @@ class MetricRecords(object):
def from_state(cls, state): def from_state(cls, state):
records = cls(state["direction"]) records = cls(state["direction"])
records.records = [MetricRecord.from_state(r) for r in state["records"]] records.records = [MetricRecord.from_state(r) for r in state["records"]]
print("here 1", records.records)
return records return records
......
...@@ -159,11 +159,11 @@ def print_program_with_dist_attr(program, dist_context=None): ...@@ -159,11 +159,11 @@ def print_program_with_dist_attr(program, dist_context=None):
from .dist_context import set_default_distributed_context from .dist_context import set_default_distributed_context
if dist_context is None: if dist_context is None:
dist_context = get_default_distributed_context() dist_context = get_default_distributed_context()
print(program) print(program, flush=True)
else: else:
original_default_context = get_default_distributed_context() original_default_context = get_default_distributed_context()
set_default_distributed_context(dist_context) set_default_distributed_context(dist_context)
print(program) print(program, flush=True)
set_default_distributed_context(original_default_context) set_default_distributed_context(original_default_context)
lock.release() lock.release()
......
...@@ -350,11 +350,12 @@ class RecomputePass(PassBase): ...@@ -350,11 +350,12 @@ class RecomputePass(PassBase):
for _, op_desc in reversed(list(enumerate(segment_descs))): for _, op_desc in reversed(list(enumerate(segment_descs))):
rc_desc = main_block.desc._insert_op(idx) rc_desc = main_block.desc._insert_op(idx)
rc_desc.copy_from(op_desc) rc_desc.copy_from(op_desc)
rc_desc.set_original_id(rc_desc.id())
rc_op = Operator(main_block, rc_desc) rc_op = Operator(main_block, rc_desc)
main_block.ops.insert(idx, rc_op) main_block.ops.insert(idx, rc_op)
# set recomputed ops' dist attr # set recomputed ops' dist attr
fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id( fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id(
rc_desc.original_id()) op_desc.original_id())
assert fwd_op_dist_attr is not None assert fwd_op_dist_attr is not None
self.set_op_dist_attr(rc_op, fwd_op_dist_attr, self.set_op_dist_attr(rc_op, fwd_op_dist_attr,
var_name_dict) var_name_dict)
......
...@@ -3,18 +3,23 @@ ...@@ -3,18 +3,23 @@
if(WITH_DISTRIBUTE AND WITH_GPU) if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_auto_parallel_relaunch MODULES test_auto_parallel_relaunch ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_relaunch MODULES test_auto_parallel_relaunch ENVS ${dist_ENVS})
set_tests_properties(test_auto_parallel_relaunch PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) set_tests_properties(test_auto_parallel_relaunch PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120)
py_test_modules(test_relaunch_with_planner MODULES test_relaunch_with_planner ENVS ${dist_ENVS}) py_test_modules(test_relaunch_with_planner MODULES test_relaunch_with_planner ENVS ${dist_ENVS})
set_tests_properties(test_relaunch_with_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) set_tests_properties(test_relaunch_with_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120)
py_test_modules(test_relaunch_with_gpt_planner MODULES test_relaunch_with_gpt_planner ENVS ${dist_ENVS}) py_test_modules(test_relaunch_with_gpt_planner MODULES test_relaunch_with_gpt_planner ENVS ${dist_ENVS})
set_tests_properties(test_relaunch_with_gpt_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 240) set_tests_properties(test_relaunch_with_gpt_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 240)
py_test_modules(test_engine_api MODULES test_engine_api ENVS ${dist_ENVS}) py_test_modules(test_engine_api MODULES test_engine_api ENVS ${dist_ENVS})
set_tests_properties(test_engine_api PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80) set_tests_properties(test_engine_api PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80)
py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS})
py_test_modules(test_converter MODULES test_converter ENVS ${dist_ENVS}) py_test_modules(test_converter MODULES test_converter ENVS ${dist_ENVS})
set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_high_order_grad MODULES test_high_order_grad ENVS ${dist_ENVS}) py_test_modules(test_high_order_grad MODULES test_high_order_grad ENVS ${dist_ENVS})
set_tests_properties(test_high_order_grad PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) set_tests_properties(test_high_order_grad PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS})
py_test_modules(test_while_op_partition MODULES test_while_op_partition ENVS ${dist_ENVS})
py_test_modules(test_tunable_variable MODULES test_tunable_variable ENVS ${dist_ENVS}) py_test_modules(test_tunable_variable MODULES test_tunable_variable ENVS ${dist_ENVS})
py_test_modules(test_tunable_space MODULES test_tunable_space ENVS ${dist_ENVS}) py_test_modules(test_tunable_space MODULES test_tunable_space ENVS ${dist_ENVS})
py_test_modules(test_recorder MODULES test_recorder ENVS ${dist_ENVS}) py_test_modules(test_recorder MODULES test_recorder ENVS ${dist_ENVS})
......
...@@ -66,7 +66,6 @@ class TestDistReshape(unittest.TestCase): ...@@ -66,7 +66,6 @@ class TestDistReshape(unittest.TestCase):
for rank in range(2): for rank in range(2):
dist_main_prog, dist_context = parallelizer(make_program_dp2, rank) dist_main_prog, dist_context = parallelizer(make_program_dp2, rank)
ops = dist_main_prog.global_block().ops ops = dist_main_prog.global_block().ops
print_program_with_dist_attr(dist_main_prog, dist_context)
for idx, op in enumerate(ops): for idx, op in enumerate(ops):
op_dist_attr = dist_context.get_op_dist_attr_for_program(op) op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_type == "reshape2" assert op_dist_attr.impl_type == "reshape2"
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import paddle import paddle
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
...@@ -85,14 +86,9 @@ class TestDistSlice(unittest.TestCase): ...@@ -85,14 +86,9 @@ class TestDistSlice(unittest.TestCase):
for op in ops: for op in ops:
axes = op.desc.attr('axes') axes = op.desc.attr('axes')
op_dist_attr = dist_context.get_op_dist_attr_for_program(op) op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if axes[0] == 0:
assert op_dist_attr.impl_type == "default"
else:
assert op_dist_attr.impl_type == "slice" assert op_dist_attr.impl_type == "slice"
for out in op.output_arg_names: for out in op.output_arg_names:
var_dims_mapping = op_dist_attr.get_output_dims_mapping( var_dims_mapping = op_dist_attr.get_output_dims_mapping(out)
out)
assert var_dims_mapping[0] == 0
def test_dist_slice_serial(self): def test_dist_slice_serial(self):
dist_main_prog, dist_context = parallelizer(make_program_serial, 0) dist_main_prog, dist_context = parallelizer(make_program_serial, 0)
......
...@@ -23,12 +23,13 @@ import paddle.nn.functional as F ...@@ -23,12 +23,13 @@ import paddle.nn.functional as F
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import make_data_unshard from paddle.distributed.auto_parallel.utils import make_data_unshard
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context
from paddle.distributed.auto_parallel.operators import find_best_compatible_distributed_operator_impl from paddle.distributed.auto_parallel.operators import find_best_compatible_distributed_operator_impl
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
...@@ -283,138 +284,142 @@ def get_program(): ...@@ -283,138 +284,142 @@ def get_program():
def completion(train_program, start_program, dist_context): def completion(train_program, start_program, dist_context):
blocks = train_program.blocks # blocks = train_program.blocks
# completion tensors # # completion tensors
for block in blocks: # for block in blocks:
for op in block.ops: # for op in block.ops:
if op.type == "layer_norm": # if op.type == "layer_norm":
for out_name in op.output_arg_names: # for out_name in op.output_arg_names:
out_var = block.vars[out_name] # out_var = block.vars[out_name]
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( # tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var) # out_var)
if tensor_dist_attr: # if tensor_dist_attr:
continue # continue
tensor_dist_attr = TensorDistributedAttribute() # tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = _g_process_mesh # tensor_dist_attr.process_mesh = _g_process_mesh
tensor_dist_attr.dims_mapping = [-1] # tensor_dist_attr.dims_mapping = [-1]
dist_context.set_tensor_dist_attr_for_program( # dist_context.set_tensor_dist_attr_for_program(
out_var, tensor_dist_attr) # out_var, tensor_dist_attr)
elif op.type == "elementwise_sub": # elif op.type == "elementwise_sub":
for out_name in op.output_arg_names: # for out_name in op.output_arg_names:
out_var = block.vars[out_name] # out_var = block.vars[out_name]
tensor_dist_attr = TensorDistributedAttribute() # tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = _g_process_mesh # tensor_dist_attr.process_mesh = _g_process_mesh
tensor_dist_attr.dims_mapping = [-1, -1, -1] # tensor_dist_attr.dims_mapping = [-1, -1, -1]
dist_context.set_tensor_dist_attr_for_program( # dist_context.set_tensor_dist_attr_for_program(
out_var, tensor_dist_attr) # out_var, tensor_dist_attr)
elif op.type == "matmul_v2": # elif op.type == "matmul_v2":
col = False # col = False
for in_name in op.input_arg_names: # for in_name in op.input_arg_names:
if ".w_" not in in_name: # if ".w_" not in in_name:
continue # continue
if in_name not in block.vars: # if in_name not in block.vars:
in_var = blocks[0].vars[in_name] # in_var = blocks[0].vars[in_name]
else: # else:
in_var = block.vars[in_name] # in_var = block.vars[in_name]
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( # tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
in_var) # in_var)
assert tensor_dist_attr is not None # assert tensor_dist_attr is not None
if tensor_dist_attr.dims_mapping == [-1, 0]: # if tensor_dist_attr.dims_mapping == [-1, 0]:
col = True # col = True
for out_name in op.output_arg_names: # for out_name in op.output_arg_names:
out_var = block.vars[out_name] # out_var = block.vars[out_name]
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( # tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var) # out_var)
if tensor_dist_attr: # if tensor_dist_attr:
continue # continue
tensor_dist_attr = TensorDistributedAttribute() # tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = _g_process_mesh # tensor_dist_attr.process_mesh = _g_process_mesh
if col: # if col:
tensor_dist_attr.dims_mapping = [-1, -1, 0] # tensor_dist_attr.dims_mapping = [-1, -1, 0]
else: # else:
tensor_dist_attr.dims_mapping = [-1, -1, -1] # tensor_dist_attr.dims_mapping = [-1, -1, -1]
dist_context.set_tensor_dist_attr_for_program( # dist_context.set_tensor_dist_attr_for_program(
out_var, tensor_dist_attr) # out_var, tensor_dist_attr)
elif op.type == "while": # elif op.type == "while":
out_name = op.desc.output("StepScopes")[0] # out_name = op.desc.output("StepScopes")[0]
out_var = block.vars[out_name] # out_var = block.vars[out_name]
tensor_dist_attr = TensorDistributedAttribute() # tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = _g_process_mesh # tensor_dist_attr.process_mesh = _g_process_mesh
tensor_dist_attr.dims_mapping = [-1] # tensor_dist_attr.dims_mapping = [-1]
dist_context.set_tensor_dist_attr_for_program(out_var, # dist_context.set_tensor_dist_attr_for_program(out_var,
tensor_dist_attr) # tensor_dist_attr)
# completion ops # # completion ops
for block in blocks: # for block in blocks:
for op in block.ops: # for op in block.ops:
op_dist_attr = OperatorDistributedAttribute() # op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = _g_process_mesh # op_dist_attr.process_mesh = _g_process_mesh
if op.type == "create_by_read" or op.type == "create_double_buffer_reader": # if op.type == "create_by_read" or op.type == "create_double_buffer_reader":
for in_name in op.input_arg_names: # for in_name in op.input_arg_names:
op_dist_attr.set_input_dims_mapping(in_name, []) # op_dist_attr.set_input_dims_mapping(in_name, [])
for out_name in op.output_arg_names: # for out_name in op.output_arg_names:
op_dist_attr.set_output_dims_mapping(out_name, []) # op_dist_attr.set_output_dims_mapping(out_name, [])
elif op.type == "read": # elif op.type == "read":
for in_name in op.input_arg_names: # for in_name in op.input_arg_names:
op_dist_attr.set_output_dims_mapping(in_name, []) # op_dist_attr.set_output_dims_mapping(in_name, [])
for out_name in op.output_arg_names: # for out_name in op.output_arg_names:
out_var = block.vars[out_name] # out_var = block.vars[out_name]
out_dist_attr = dist_context.get_tensor_dist_attr_for_program( # out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var) # out_var)
op_dist_attr.set_output_dist_attr(out_name, out_dist_attr) # op_dist_attr.set_output_dist_attr(out_name, out_dist_attr)
elif op.type == "while": # elif op.type == "while":
for in_name in op.input_arg_names: # for in_name in op.input_arg_names:
in_var = block.vars[in_name] # in_var = block.vars[in_name]
in_dist_attr = dist_context.get_tensor_dist_attr_for_program( # in_dist_attr = dist_context.get_tensor_dist_attr_for_program(
in_var) # in_var)
op_dist_attr.set_input_dist_attr(in_name, in_dist_attr) # op_dist_attr.set_input_dist_attr(in_name, in_dist_attr)
for out_name in op.output_arg_names: # for out_name in op.output_arg_names:
if out_name == op.desc.output("StepScopes")[0]: # if out_name == op.desc.output("StepScopes")[0]:
op_dist_attr.set_output_dims_mapping(out_name, []) # op_dist_attr.set_output_dims_mapping(out_name, [])
else: # else:
out_var = block.vars[out_name] # out_var = block.vars[out_name]
out_dist_attr = dist_context.get_tensor_dist_attr_for_program( # out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var) # out_var)
op_dist_attr.set_output_dist_attr(out_name, # op_dist_attr.set_output_dist_attr(out_name,
out_dist_attr) # out_dist_attr)
else: # else:
for in_name in op.input_arg_names: # for in_name in op.input_arg_names:
if in_name == "lod_tensor_blocking_queue_0": # if in_name == "lod_tensor_blocking_queue_0":
continue # continue
if in_name not in block.vars: # if in_name not in block.vars:
in_var = blocks[0].vars[in_name] # in_var = blocks[0].vars[in_name]
else: # else:
in_var = block.vars[in_name] # in_var = block.vars[in_name]
in_dist_attr = dist_context.get_tensor_dist_attr_for_program( # in_dist_attr = dist_context.get_tensor_dist_attr_for_program(
in_var) # in_var)
op_dist_attr.set_input_dist_attr(in_name, in_dist_attr) # op_dist_attr.set_input_dist_attr(in_name, in_dist_attr)
for out_name in op.output_arg_names: # for out_name in op.output_arg_names:
if out_name not in block.vars: # if out_name not in block.vars:
out_var = blocks[0].vars[out_name] # out_var = blocks[0].vars[out_name]
else: # else:
out_var = block.vars[out_name] # out_var = block.vars[out_name]
out_dist_attr = dist_context.get_tensor_dist_attr_for_program( # out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var) # out_var)
op_dist_attr.set_output_dist_attr(out_name, out_dist_attr) # op_dist_attr.set_output_dist_attr(out_name, out_dist_attr)
if op.type == "matmul_v2": # if op.type == "matmul_v2":
op_dist_attr.impl_type = "matmul_v2" # op_dist_attr.impl_type = "matmul_v2"
for in_name in op_dist_attr.inputs_dist_attrs.keys(): # for in_name in op_dist_attr.inputs_dist_attrs.keys():
in_dist_attr = op_dist_attr.inputs_dist_attrs[in_name] # in_dist_attr = op_dist_attr.inputs_dist_attrs[in_name]
if ".w_" in in_name and in_dist_attr.dims_mapping[-1] == 0: # if ".w_" in in_name and in_dist_attr.dims_mapping[-1] == 0:
op_dist_attr.impl_idx = 0 # op_dist_attr.impl_idx = 0
else: # else:
op_dist_attr.impl_idx = 1 # op_dist_attr.impl_idx = 1
elif op.type == "fill_constant_batch_size_like": # elif op.type == "fill_constant_batch_size_like":
op_dist_attr.impl_type = "fill_constant_batch_size_like" # op_dist_attr.impl_type = "fill_constant_batch_size_like"
op_dist_attr.impl_idx = 0 # op_dist_attr.impl_idx = 0
else: # else:
op_dist_attr.impl_type = "default" # op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = 0 # op_dist_attr.impl_idx = 0
dist_context.set_op_dist_attr_for_program(op, op_dist_attr) # dist_context.set_op_dist_attr_for_program(op, op_dist_attr)
# make_data_unshard(train_program, start_program, dist_context)
completer = Completer(dist_context)
train_program = completer.complete_forward_annotation(train_program)
make_data_unshard(train_program, start_program, dist_context) make_data_unshard(train_program, start_program, dist_context)
return train_program, start_program return train_program, start_program
......
...@@ -134,7 +134,6 @@ class TestMLPAutoParallelizer(unittest.TestCase): ...@@ -134,7 +134,6 @@ class TestMLPAutoParallelizer(unittest.TestCase):
for op in block.ops: for op in block.ops:
for attr_name in op.attr_names: for attr_name in op.attr_names:
self.assertTrue(suffix not in attr_name) self.assertTrue(suffix not in attr_name)
# print_program_with_dist_attr(distributed_main_program)
self.assertIsNotNone(distributed_startup_program) self.assertIsNotNone(distributed_startup_program)
self.assertIsNotNone(distributed_main_program) self.assertIsNotNone(distributed_main_program)
......
...@@ -332,7 +332,6 @@ class TestMLPReshard(unittest.TestCase): ...@@ -332,7 +332,6 @@ class TestMLPReshard(unittest.TestCase):
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads) dist_context, dist_params_grads)
resharder.reshard() resharder.reshard()
print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册