未验证 提交 bed9aaea 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Improve the codes of the completion and distributed context (#40671)

* [Auto Parallel] Replace the old planner by the new partition tuner

* [Auto Parallel] Improve the completion and distributed context

* [Auto Parallel] Fix some bugs of the compatible check of some dist ops

* [Auto Parallel] Fix some bugs
上级 afcf6bd0
......@@ -123,6 +123,19 @@ def merge_process_mesh_two(pm1, pm2):
return merged_process_mesh
def _validate_dims_mapping(dims_mapping, process_mesh):
if dims_mapping is None:
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
process_mesh.topology):
return False
for i in range(len(process_mesh.topology)):
if dims_mapping.count(i) > 1:
return False
return True
class Completer:
def __init__(self, dist_context):
assert dist_context is not None
......@@ -161,6 +174,9 @@ class Completer:
dims_mapping_list.append(tensor_dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list)
if not _validate_dims_mapping(compatible_dims_mapping,
tensor_dist_attr.process_mesh):
return False
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != tensor_dims_mapping):
tensor_dist_attr.dims_mapping = compatible_dims_mapping
......@@ -182,6 +198,9 @@ class Completer:
dims_mapping_list.append(tensor_dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list)
if not _validate_dims_mapping(compatible_dims_mapping,
tensor_dist_attr.process_mesh):
return False
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != tensor_dims_mapping):
tensor_dist_attr.dims_mapping = compatible_dims_mapping
......@@ -196,10 +215,12 @@ class Completer:
op_desc = op_node.op()
if op_desc.type() == "create_py_reader" \
or op_desc.type() == "create_double_buffer_reader" \
or op_desc.type() == "while" \
or op_desc.type() == "read":
return False
dist_op = self._dist_context.get_dist_op_for_graph(op_node)
op_dist_attr = dist_op.dist_attr
original_op_dist_attr = copy.deepcopy(op_dist_attr)
if fwd:
for tensor_node in op_node.inputs:
if tensor_node.is_var() and tensor_node.var() is not None:
......@@ -223,18 +244,34 @@ class Completer:
tensor_desc.name(), compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl = find_best_compatible_distributed_operator_impl(
op_dist_impls = find_best_compatible_distributed_operator_impl(
dist_op, fwd=True)
if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
if op_dist_impls is not None:
not_compatible = True
backup_op_dist_attr = copy.deepcopy(op_dist_attr)
backup_changed = changed
for op_dist_impl in op_dist_impls:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
op_dist_attr.impl_type = op_dist_impl.type
# op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
not_compatible = False
break
else:
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
dist_op.dist_attr = backup_op_dist_attr
changed = backup_changed
if not_compatible:
dist_op.dist_attr = original_op_dist_attr
changed = False
else:
dist_op.dist_attr = original_op_dist_attr
changed = False
else:
for tensor_node in op_node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None:
......@@ -258,18 +295,35 @@ class Completer:
tensor_desc.name(), compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl = find_best_compatible_distributed_operator_impl(
op_dist_impls = find_best_compatible_distributed_operator_impl(
dist_op, fwd=False)
if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
if op_dist_impls is not None:
not_compatible = True
backup_op_dist_attr = copy.deepcopy(op_dist_attr)
backup_changed = changed
for op_dist_impl in op_dist_impls:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
not_compatible = False
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
op_dist_attr.impl_type = op_dist_impl.type
# op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
not_compatible = False
break
else:
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
dist_op.dist_attr = backup_op_dist_attr
changed = backup_changed
if not_compatible:
dist_op.dist_attr = original_op_dist_attr
changed = False
else:
dist_op.dist_attr = original_op_dist_attr
changed = False
return changed
def _update_dims_mapping_between_graphs(self):
......@@ -279,17 +333,22 @@ class Completer:
parent_node)
child_node_dist_attr = self._dist_context.get_dist_attr_for_graph(
child_node)
if parent_node_dist_attr.process_mesh != child_node_dist_attr.process_mesh:
continue
parent_node_dims_mapping = parent_node_dist_attr.dims_mapping
child_node_dims_mapping = child_node_dist_attr.dims_mapping
compatible_dims_mapping = compute_compatible_dims_mapping(
[parent_node_dims_mapping, child_node_dims_mapping])
if not _validate_dims_mapping(compatible_dims_mapping,
parent_node_dist_attr.process_mesh):
return False
if (compatible_dims_mapping is not None) \
and (compatible_dims_mapping != parent_node_dims_mapping):
parent_node_dist_attr.dims_mapping = compatible_dims_mapping
changed = True
if (compatible_dims_mapping is not None) \
and (compatible_dims_mapping != child_node_dims_mapping):
parent_node_dist_attr.dims_mapping = compatible_dims_mapping
child_node_dist_attr.dims_mapping = compatible_dims_mapping
changed = True
return changed
......@@ -351,7 +410,7 @@ class Completer:
if compatible_process_mesh is not None \
and tensor_dist_attr.process_mesh != compatible_process_mesh:
tensor_dist_attr.process_mesh = compatible_process_mesh
# Set the process mesh of the op node's outputs
# Set the process mesh of the op node's outputs
for tensor_node in op_node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None:
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
......@@ -389,7 +448,8 @@ class Completer:
if _node_id(cur) in visited:
continue
# TODO: need more restrictions
for node in cur.inputs:
neighbors = cur.inputs + cur.outputs
for node in neighbors:
if node.is_var() and node.var() is not None:
if node.var().type() != core.VarDesc.VarType.READER \
and len(node.var().shape()) == 1:
......@@ -421,10 +481,29 @@ class Completer:
visited.add(_node_id(cur))
return related_nodes
def _make_dims_mapping_replicate(dist_attr):
if isinstance(dist_attr, TensorDistributedAttribute):
for i, _ in enumerate(dist_attr.dims_mapping):
dist_attr.dims_mapping[i] = -1
if isinstance(dist_attr, OperatorDistributedAttribute):
for arg_name in dist_attr.inputs_dist_attrs.keys():
new_dims_mapping = []
dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
for _ in dims_mapping:
new_dims_mapping.append(-1)
dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping)
for arg_name in dist_attr.outputs_dist_attrs.keys():
new_dims_mapping = []
dims_mapping = dist_attr.get_output_dims_mapping(arg_name)
for _ in dims_mapping:
new_dims_mapping.append(-1)
dist_attr.set_output_dims_mapping(arg_name,
new_dims_mapping)
# Amend the process meshes related to while_op
for while_op_node, while_op_node_idx in self._while_op_nodes.values():
sub_graph_id = while_op_node.op()._block_attr_id("sub_block")
sub_graph = self._dist_context._serial_graph.get_sub_graph(
sub_graph = self._dist_context.serial_graph.get_sub_graph(
sub_graph_id)
sub_graph_nodes = list(sub_graph.all_nodes())
while_dist_op = self._dist_context.get_dist_op_for_graph(
......@@ -440,6 +519,7 @@ class Completer:
merged_process_mesh = merge_process_mesh_two(
merged_process_mesh, dist_attr.process_mesh)
while_op_dist_attr.process_mesh = merged_process_mesh
_make_dims_mapping_replicate(while_op_dist_attr)
# Step 2: set the related nodes of while_op to the process mesh of while_op
# Step 2.1: Find related nodes of cond var the graph of while_op
......@@ -480,6 +560,7 @@ class Completer:
tensor_dist_attr = self._dist_context.get_dist_attr_for_graph(
node)
tensor_dist_attr.process_mesh = merged_process_mesh
_make_dims_mapping_replicate(tensor_dist_attr)
# Step 3: set the process meshes of the inputs in while_op to the process meshes of the outside input nodes
while_op_inputs_dist_attrs = while_op_dist_attr.inputs_dist_attrs
......@@ -519,6 +600,25 @@ class Completer:
dist_attr = self._dist_context.get_dist_attr_for_graph(
array_node)
dist_attr.process_mesh = merged_process_mesh
_make_dims_mapping_replicate(dist_attr)
def _update_process_mesh_between_graphs(self):
for parent_node, child_node in self._node_pairs_between_graphs:
parent_node_dist_attr = self._dist_context.get_dist_attr_for_graph(
parent_node)
child_node_dist_attr = self._dist_context.get_dist_attr_for_graph(
child_node)
parent_node_dist_attr.process_mesh = child_node_dist_attr.process_mesh
compatible_process_mesh = compute_compatible_process_mesh([
parent_node_dist_attr.process_mesh,
child_node_dist_attr.process_mesh
])
if compatible_process_mesh is not None \
and parent_node_dist_attr.process_mesh != compatible_process_mesh:
parent_node_dist_attr.process_mesh = compatible_process_mesh
if compatible_process_mesh is not None \
and child_node_dist_attr.process_mesh != compatible_process_mesh:
child_node_dist_attr.process_mesh = compatible_process_mesh
def _update_process_mesh(self):
ordered_op_nodes = self._dist_context._serial_ordered_op_nodes
......@@ -569,7 +669,7 @@ class Completer:
return None
for idx, op_node in enumerate(ordered_op_nodes[
idx_of_first_op_node_has_process_mesh + 1:]):
original_idx = idx_of_first_op_node_has_process_mesh + +idx + 1
original_idx = idx_of_first_op_node_has_process_mesh + idx + 1
nearest_op_node = ordered_op_nodes[original_idx - 1]
nearest_op_dist_attr = self._dist_context.get_dist_attr_for_graph(
nearest_op_node)
......@@ -585,6 +685,9 @@ class Completer:
# Step 3: adjust the process meshes for special ops
self._update_process_mesh_for_specials()
# Step 4: adjust the process meshes between graphs
self._update_process_mesh_between_graphs()
def _prepare(self):
self._while_op_nodes = {}
self._array_nodes = {}
......@@ -620,7 +723,7 @@ class Completer:
self._node_pairs_between_graphs.append(
(after_node, node))
def complete_forward_annotation(self, serial_main_program):
def complete_forward_annotation(self, serial_main_program=None):
""" Complete annotation for the partial annotated serial_main_program.
Arguments:
serial_main_program: partial annotated serial_main_program.
......@@ -628,15 +731,12 @@ class Completer:
serial_main_program: completed annotated serial_main_program.
"""
# Use the default distribted context for completeion if there is no one
self._dist_context.serial_program = serial_main_program
# Initialize distributed attributes for all var and op node in serial_main_program
self._dist_context.init_dist_attr_for_program()
# print_program_with_dist_attr(serial_main_program, self._dist_context)
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context.serial_main_program = serial_main_program
# Initialize distributed attributes for all var and op node in graph
self._dist_context.init_dist_attr_for_graph()
self._dist_context.initialize()
self._prepare()
......@@ -646,10 +746,9 @@ class Completer:
# Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program()
self._dist_context.clear_dist_info_for_graph()
# NOTE:[HighOrderGrad] update vars and ops distributed attribute in high order gradient
self.complete_high_order_grad_annotation(serial_main_program)
self._complete_high_order_grad_annotation(serial_main_program)
# Do the validation check and amend some completion
self._dist_context.amend_dist_attr_for_program()
......@@ -658,7 +757,7 @@ class Completer:
return serial_main_program
def complete_high_order_grad_annotation(self, serial_main_program):
def _complete_high_order_grad_annotation(self, serial_main_program):
"""
NOTE:
[HighOrderGrad] Complete the annotation of vars and ops only for high order gradient.
......@@ -818,6 +917,10 @@ class Completer:
def complete_backward_annotation(self, serial_main_program):
"""Complete the annotation of vars and ops in the backward phase for parallel program."""
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context.serial_main_program = serial_main_program
def _is_grad_var_name(name):
if "@GRAD" in name:
......@@ -1036,8 +1139,12 @@ class Completer:
self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr)
def complete_update_annotation(self, serial_main_program):
def complete_update_annotation(self, serial_main_program=None):
"""Complete the annotation of vars and ops in the update phase for parallel program."""
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context.serial_main_program = serial_main_program
ops = list(serial_main_program.global_block().ops)
vars = serial_main_program.global_block().vars
learning_rate_completed = False
......
......@@ -52,7 +52,7 @@ def append_op_output_suffix(name):
class TensorDistributedAttribute:
def __init__(self):
# The process mesh of distributed operator attribute must is the same as
# The process mesh of distributed operator attribute must is the same as
# the process meshes of all input and output distributed attributed
self._process_mesh = None
self._dims_mapping = None
......@@ -132,12 +132,29 @@ class TensorDistributedAttribute:
key, dist_attr)
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
# def reset(self, skip_dist_attr_field_names):
# if skip_dist_attr_field_names is not None \
# and "process_mesh" not in skip_dist_attr_field_names:
# self._process_mesh = None
# if skip_dist_attr_field_names is not None \
# and "dims_mapping" not in skip_dist_attr_field_names:
# for i in enumerate(self._dims_mapping):
# self._dims_mapping[i] = -1
# self._is_annotated = {}
def is_annotated(self, dist_attr_field_name):
return self._is_annotated.get(dist_attr_field_name, False)
# def mark_annotated_all(self):
# for key in get_tensor_dist_attr_field_keys():
# self.mark_annotated(key)
def mark_annotated(self, dist_attr_field_name):
self._is_annotated[dist_attr_field_name] = True
# def unmark_annotated(self, dist_attr_field_name):
# self._is_annotated[dist_attr_field_name] = False
def mark_annotated_as(self, dist_attr):
if dist_attr is None:
return
......@@ -195,7 +212,7 @@ class OperatorDistributedAttribute:
if isinstance(process_mesh, list):
process_mesh = ProcessMesh(process_mesh)
self._process_mesh = copy.deepcopy(process_mesh)
# In while op, the proess mesh is not shared by all inputs and outputs
# In while op, the proess mesh is not shared by all inputs and outputs
if self._op_type == "while":
return None
for dist_attr in self._inputs_dist_attrs.values():
......@@ -357,9 +374,25 @@ class OperatorDistributedAttribute:
"ProcessMeshes in DistributedOperator must be the same."
self.process_mesh = shared_process_mesh
# def reset(self, skip_dist_attr_field_names):
# for tensor_dist_attr in self.inputs_dist_attrs.values():
# tensor_dist_attr.reset(skip_dist_attr_field_names)
# for tensor_dist_attr in self.outputs_dist_attrs.values():
# tensor_dist_attr.reset(skip_dist_attr_field_names)
# if skip_dist_attr_field_names is not None \
# and "process_mesh" not in skip_dist_attr_field_names:
# self.process_mesh = None
# self.impl_type = "default"
# self.impl_idx = 0
# self._is_annotated = {}
def is_annotated(self, attr_name):
return self._is_annotated.get(attr_name, False)
# def mark_annotated_all(self):
# for key in get_op_dist_attr_field_keys():
# self.mark_annotated(key)
def mark_annotated(self, attr_name):
if attr_name == "process_mesh":
# Make sure proscess_mesh be annotated consistently
......
......@@ -14,9 +14,11 @@
import copy
from collections import defaultdict
import paddle.fluid
from paddle.fluid import framework
from paddle.fluid.framework import get_flags, set_flags
from paddle.fluid import core
from paddle.distributed.passes import PassContext
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
from .dist_tensor import DistributedTensor
......@@ -54,26 +56,41 @@ class DistributedContext:
serial_main_prog=None,
serial_startup_prog=None,
dist_main_progs=None,
dist_startup_progs=None):
# Program related data members
self._serial_program = serial_main_prog
self._is_initialized_for_program = False
dist_startup_progs=None,
serial_loss=None,
serial_optimizer=None,
strategy=None):
# Data members related to original programs (unchanged)
self._original_serial_main_program = serial_main_prog
self._original_serial_startup_program = serial_startup_prog
self._original_serial_loss = serial_loss
self._original_serial_optimizer = serial_optimizer
if self._original_serial_main_program is None:
self._original_serial_main_program = paddle.fluid.default_main_program(
)
if self._original_serial_startup_program is None:
self._original_serial_startup_program = paddle.fluid.default_startup_program(
)
# Data members related to programs (changed)
self._serial_main_program = None
self._serial_startup_program = None
self._serial_loss = None
self._serial_optimizer = None
# Data members related to the program
self._dist_tensors_for_program = {}
self._dist_ops_for_program = {}
self._block_state = BlockState()
# Graph related data members
self._is_initialized_for_graph = False
# Data members related to the graph
self._serial_graph = None
self._dist_tensors_for_graph = {}
self._dist_ops_for_graph = {}
self._node_id_to_tensor_id = {}
self._node_id_to_op_id = {}
# Other data members
self._dist_op_context = DistributedOperatorContext()
self._process_meshes = []
self._serial_ordered_nodes = []
self._tensor_id_to_tensor_node_ids = {}
# Data members related to the distributed programs
# Distributed programs
self._dist_main_programs = dist_main_progs
if not self._dist_main_programs:
......@@ -82,20 +99,71 @@ class DistributedContext:
if not self._dist_startup_programs:
self._dist_startup_programs = {}
# Distributed Strategy
self._strategy = strategy
# Pass Context
self._pass_context = PassContext()
# Distributed Operator Context
self._dist_op_context = DistributedOperatorContext()
# Other data members
self._process_meshes = []
self._serial_ordered_tensor_nodes = []
self._serial_ordered_op_nodes = []
self._serial_ordered_nodes = []
# self._tensor_id_to_tensor_node_ids = {}
self._is_initialized = False
@property
def serial_program(self):
return self._serial_program
def serial_main_program(self):
return self._serial_main_program
@serial_main_program.setter
def serial_main_program(self, program):
# if self._serial_main_program:
# print("WARNING: The program attached to this distributed context will be replaced by the new one.")
self._original_serial_main_program = program
self._serial_main_program = program
@property
def serial_startup_program(self):
return self._serial_startup_program
# @serial_startup_program.setter
# def serial_startup_program(self, serial_startup_program):
# self._serial_startup_program = serial_startup_program
@property
def serial_loss(self):
return self._serial_loss
# @serial_loss.setter
# def serial_loss(self, serial_loss):
# self._serial_loss = serial_loss
@property
def serial_optimizer(self):
return self._serial_optimizer
# @serial_optimizer.setter
# def serial_optimizer(self, serial_optimizer):
# self._serial_optimizer = serial_optimizer
@property
def strategy(self):
return self._strategy
# @strategy.setter
# def strategy(self, strategy):
# self._strategy = strategy
@property
def serial_graph(self):
return self._serial_graph
@serial_program.setter
def serial_program(self, program):
# assert self._serial_program is None, \
# "This distributed context has already been realted to a serial program"
self._serial_program = program
@property
def serial_ordered_nodes(self):
return self._serial_ordered_nodes
......@@ -104,6 +172,10 @@ class DistributedContext:
def process_meshes(self):
return self._process_meshes
@property
def pass_context(self):
return self._pass_context
@property
def dist_op_context(self):
return self._dist_op_context
......@@ -121,10 +193,64 @@ class DistributedContext:
return self._dist_startup_programs
@property
def is_annotation(self):
def has_annotation(self):
return len(self._dist_tensors_for_program) or len(
self._dist_ops_for_program)
def initialize(self):
if not self._is_initialized:
self._serial_main_program = self._original_serial_main_program.clone(
)
self._serial_startup_program = self._original_serial_startup_program.clone(
)
self._serial_main_program = self._original_serial_main_program
self._serial_startup_program = self._original_serial_startup_program
self._serial_loss = self._original_serial_loss
self._serial_optimizer = self._original_serial_optimizer
self._init_dist_attr_for_program()
self._tensors_ids = list(self._dist_tensors_for_program.keys())
self._ops_ids = list(self._dist_ops_for_program.keys())
set_flags({"FLAGS_convert_all_blocks": True})
self._serial_graph = framework.IrGraph(
core.Graph(self._serial_main_program.desc))
self._init_dist_attr_for_graph()
self._is_initialized = True
# def reset(self,
# skip_dist_tensors=None,
# skip_dist_ops=None,
# skip_tensor_dist_attr_fields=None,
# skip_op_dist_attr_fields=None):
# self._serial_main_program = self._original_serial_main_program.clone()
# self._serial_startup_program = self._original_serial_startup_program.clone()
# new_tensors_ids = []
# for tensor_id, dist_tensor in self._dist_tensors_for_program.items():
# if tensor_id in self._tensors_ids:
# dist_tensor.dist_attr.reset(skip_tensor_dist_attr_fields)
# else:
# new_tensors_ids.append(tensor_id)
# for tensor_id in new_tensors_ids:
# self._dist_tensors_for_program.pop(tensor_id)
# new_ops_ids = []
# for op_id, dist_op in self._dist_ops_for_program.items():
# if op_id in self._ops_ids:
# dist_op.dist_attr.reset(skip_op_dist_attr_fields)
# else:
# new_ops_ids.append(op_id)
# for op_id in new_ops_ids:
# self._dist_ops_for_program.pop(op_id)
# self.copy_dist_attr_from_program_to_graph()
# self._dist_main_programs = {}
# self._dist_startup_programs = {}
# self._pass_context = PassContext()
# self._dist_op_context = DistributedOperatorContext()
# self._process_meshes = []
def add_process_mesh(self, process_mesh):
assert isinstance(process_mesh, ProcessMesh), \
'The type of dim_mapping must be ProcessMesh.'
......@@ -133,12 +259,12 @@ class DistributedContext:
def add_dist_tensor_for_program(self, dist_tensor):
inner_serial_tensor = dist_tensor.serial_tensor
inner_serial_tensor_id = inner_serial_tensor.desc.id()
inner_serial_tensor_id = inner_serial_tensor.desc.original_id()
self._dist_tensors_for_program[inner_serial_tensor_id] = dist_tensor
def add_dist_op_for_program(self, dist_op):
inner_serial_op = dist_op.serial_op
inner_serial_op_id = inner_serial_op.desc.id()
inner_serial_op_id = inner_serial_op.desc.original_id()
self._dist_ops_for_program[inner_serial_op_id] = dist_op
def get_dist_tensor_for_program(self, serial_tensor):
......@@ -215,18 +341,6 @@ class DistributedContext:
else:
return None
# def set_tensor_dist_attr_for_graph(self, serial_tensor_node, dist_attr):
# assert serial_tensor_node.is_var() and \
# serial_tensor_node.var() is not None
# serial_tensor_id = serial_tensor_node.node.original_desc_id()
# dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
# assert dist_tensor is not None, \
# "The distributed tensor of the program has not been added to this context."
# serial_tensor_node_id = serial_tensor_node.id()
# new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
# dist_attr)
# self._dist_tensors_for_graph[serial_tensor_node_id] = new_dist_tensor
def get_op_dist_attr_for_program(self, serial_op):
serial_op_id = serial_op.desc.id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
......@@ -259,17 +373,6 @@ class DistributedContext:
else:
return None
# def set_op_dist_attr_for_graph(self, serial_op_node, dist_attr):
# assert serial_op_node.is_op() and \
# serial_op_node.op() is not None
# serial_op_id = serial_op_node.node.original_desc_id()
# dist_op = self._dist_ops_for_program.get(serial_op_id, None)
# assert dist_op is not None, \
# "The distributed operator of the program has not been added to this context."
# serial_op_node_id = serial_op_node.id()
# new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr)
# self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
def get_dist_attr_for_graph(self, serial_node):
if serial_node.is_var() and serial_node.var() is not None:
serial_tensor_node_id = _node_id(serial_node)
......@@ -288,15 +391,14 @@ class DistributedContext:
return None
return None
def init_dist_attr_for_program(self):
assert self._serial_program, \
"Please set the program of this context before initializing its distribute attributes."
if self._is_initialized_for_program:
return
def _init_dist_attr_for_program(self, no_default=False):
# Copy the dist tensors and dist ops annotated by users from the default context
default_ctx = get_default_distributed_context()
self._process_meshes = copy.deepcopy(default_ctx.process_meshes)
for block in self._serial_program.blocks:
if not no_default:
default_ctx = get_default_distributed_context()
self._process_meshes = copy.deepcopy(default_ctx.process_meshes)
else:
default_ctx = self
for block in self._serial_main_program.blocks:
for tensor in block.vars.values():
# Copy the distributed tensors in the default context
default_dist_tensor = default_ctx.get_dist_tensor_for_program(
......@@ -316,9 +418,8 @@ class DistributedContext:
if current_dist_op is None:
dist_op = DistributedOperator(op)
self.add_dist_op_for_program(dist_op)
self._is_initialized_for_program = True
def order_nodes_by_program_order(self):
def _order_nodes_by_program_order(self):
def _contains(nodes, target_node):
for node in nodes:
if _node_id(node) == _node_id(target_node):
......@@ -328,7 +429,6 @@ class DistributedContext:
serial_ordered_tensor_nodes = []
serial_ordered_op_nodes = []
all_nodes = []
# for idx, graph in enumerate(self._serial_graph.all_sub_graphs()):
for idx, graph in enumerate(self._serial_graph.all_sub_graphs()):
for node in graph.all_nodes():
all_nodes.append(node)
......@@ -346,33 +446,35 @@ class DistributedContext:
new_serial_ordered_tensor_nodes = []
new_serial_ordered_op_nodes = []
new_serial_ordered_nodes = []
for op_node in serial_ordered_op_nodes:
tensor_nodes = []
for tensor_node in op_node.inputs:
if tensor_node.is_var() \
and tensor_node.var() is not None \
and not _contains(self._serial_ordered_nodes, tensor_node):
and not _contains(new_serial_ordered_nodes, tensor_node):
tensor_nodes.append(tensor_node)
new_serial_ordered_tensor_nodes.append(tensor_node)
tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
self._serial_ordered_nodes.extend(tensor_nodes)
self._serial_ordered_nodes.append(op_node)
new_serial_ordered_nodes.extend(tensor_nodes)
new_serial_ordered_nodes.append(op_node)
new_serial_ordered_op_nodes.append(op_node)
tensor_nodes = []
for tensor_node in op_node.outputs:
if tensor_node.is_var() \
and tensor_node.var() is not None \
and not _contains(self._serial_ordered_nodes, tensor_node):
and not _contains(new_serial_ordered_nodes, tensor_node):
tensor_nodes.append(tensor_node)
new_serial_ordered_tensor_nodes.append(tensor_node)
tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
self._serial_ordered_nodes.extend(tensor_nodes)
new_serial_ordered_nodes.extend(tensor_nodes)
new_serial_ordered_tensor_nodes.sort(
key=lambda node: node.node.original_desc_id())
new_serial_ordered_op_nodes.sort(
key=lambda node: node.node.original_desc_id())
self._serial_ordered_tensor_nodes = new_serial_ordered_tensor_nodes
self._serial_ordered_op_nodes = new_serial_ordered_op_nodes
self._serial_ordered_nodes = new_serial_ordered_nodes
assert len(self._serial_ordered_nodes) == len(
self._serial_ordered_tensor_nodes) + len(
self._serial_ordered_op_nodes)
......@@ -385,16 +487,9 @@ class DistributedContext:
"WARNING: there are some orphan tensors or ops which are not used in the execution."
)
def init_dist_attr_for_graph(self):
assert self._is_initialized_for_program, \
"The program must be initialized before initializing the distributed attributes for its graph."
if self._is_initialized_for_graph:
return
# Convert program to graph
set_flags({"FLAGS_convert_all_blocks": True})
self._serial_graph = framework.IrGraph(
core.Graph(self._serial_program.desc))
self.order_nodes_by_program_order()
def _init_dist_attr_for_graph(self):
# Convert program to graph and initialize the distributed attributes
self._order_nodes_by_program_order()
for node in self.serial_ordered_nodes:
if node.is_var() and node.var() is not None:
dist_tensor = None
......@@ -428,7 +523,6 @@ class DistributedContext:
new_dist_op = DistributedOperator(dist_op.serial_op,
dist_op.dist_attr)
self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
self._is_initialized_for_graph = True
def clear_dist_info_for_program(self):
self._dist_tensors_for_program.clear()
......@@ -438,8 +532,40 @@ class DistributedContext:
self._dist_tensors_for_graph.clear()
self._dist_ops_for_graph.clear()
def copy_dist_attr_from_program_to_graph(self):
for node in self.serial_ordered_nodes:
if node.is_var() and node.var() is not None:
dist_tensor = None
tensor_id = node.node.original_desc_id()
for cur_tensor_id, cur_dist_tensor in self._dist_tensors_for_program.items(
):
if tensor_id == cur_tensor_id \
or tensor_id == cur_dist_tensor.serial_tensor.desc.original_id():
dist_tensor = cur_dist_tensor
assert dist_tensor is not None, \
"Tensor must have a distributed tensor after the initialization for program."
serial_tensor_node_id = _node_id(node)
new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
dist_tensor.dist_attr)
self._dist_tensors_for_graph[
serial_tensor_node_id] = new_dist_tensor
if node.is_op() and node.op() is not None:
dist_op = None
op_id = node.node.original_desc_id()
for cur_op_id, cur_dist_op in self._dist_ops_for_program.items(
):
if op_id == cur_op_id \
or op_id == cur_dist_op.serial_op.desc.original_id():
dist_op = cur_dist_op
assert dist_op is not None, \
"Operator must have a distributed operator after the initialization for program."
serial_op_node_id = _node_id(node)
new_dist_op = DistributedOperator(dist_op.serial_op,
dist_op.dist_attr)
self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
def copy_dist_attr_from_graph_to_program(self):
assert self._is_initialized_for_program and self._is_initialized_for_graph, \
assert self._is_initialized, \
"Both program and graph must be initialized."
updated_tensors = {}
# all_nodes = self._serial_graph.all_nodes()
......@@ -461,7 +587,7 @@ class DistributedContext:
op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node)
dist_op_for_program = self._dist_ops_for_program[op_id]
dist_op_for_program.dist_attr = op_dist_attr_for_graph
# TODO: the completion algorithm will skip orphan tensors,
# TODO: the completion algorithm will skip orphan tensors,
# here we just set there process_mesh to the first one.
for orphan_node in self._serial_orphan_tensor_nodes:
serial_tensor_id = orphan_node.var().id()
......@@ -532,18 +658,24 @@ class DistributedContext:
dims_mapping[i] = -1
def validate_dist_attr_for_program(self):
if not self._is_initialized_for_program:
if not self._is_initialized:
assert False, \
"Program must be initialized before validating its distributed attributes"
for block in self.serial_program.blocks:
for block in self.serial_main_program.blocks:
for tensor in block.vars.values():
dist_tensor = self.get_dist_tensor_for_program(tensor)
assert dist_tensor is not None, \
"Tensor {} does not have a distributed attribute.".format(
dist_tensor.serial_tensor.name)
if (dist_tensor is not None) and (
not dist_tensor.validate_dist_attr()):
assert False, "Tensor {} has a wrong distributed attributes {}.".format(
dist_tensor.serial_tensor.name, dist_tensor.dist_attr)
for op in block.ops:
dist_op = self.get_dist_op_for_program(op)
assert dist_op is not None, \
"Operator {} does not have a distributed attribute.".format(
dist_op.serial_op.type)
if (dist_op is not None) and (not dist_op.validate_dist_attr()):
assert False, "Operator {} has a wrong distributed attributes {}.".format(
dist_op.serial_op.type, dist_tensor.dist_attr)
......@@ -554,10 +686,12 @@ class DistributedContext:
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == "_serial_program" or k == "_serial_graph" \
or k == "_dist_main_programs" or k == "_dist_startup_programs" \
or k == "_serial_ordered_nodes" or k == "_serial_ordered_tensor_nodes" \
or k == "_serial_ordered_op_nodes":
if k in [
"_original_serial_main_program", "_original_serial_startup_program", \
"_serial_main_program", "_serial_startup_program", "_serial_graph", \
"_dist_main_programs", "_dist_startup_programs", \
"_serial_ordered_nodes", "_serial_ordered_tensor_nodes", \
"_serial_ordered_op_nodes"]:
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
......
......@@ -118,11 +118,10 @@ class Engine:
losses = to_list(self._loss(*(outputs + labels)))
default_ctx = get_default_distributed_context()
if not default_ctx.is_annotation or self._default_strategy:
if not default_ctx.has_annotation or self._default_strategy:
inputs = [self._set_data_parallel(var) for var in inputs]
labels = [self._set_data_parallel(var) for var in labels]
# print(serial_main_prog)
self._feed_vars[mode] = {"inputs": inputs, "labels": labels}
self._fetch_vars[mode] = {
......
......@@ -18,16 +18,16 @@ from ..dist_attribute import OperatorDistributedAttribute
_g_distributed_operator_impl_containers = {}
_g_elementwise_ops = [
"elementwise_add", "gelu", "dropout", "cast", "gather", "concat"
"elementwise", "gelu", "dropout", "cast", "gather", "concat"
]
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'}
def is_elementwise_op(op_type):
if op_type in _g_elementwise_ops:
return True
else:
return False
for eltwise_op in _g_elementwise_ops:
if eltwise_op in op_type:
return True
return False
class DistributedOperatorImplContainer:
......@@ -156,7 +156,9 @@ def register_distributed_operator_impl(op_type, dist_impl):
assert False, "Must register distributed operator registry first."
def find_best_compatible_distributed_operator_impl(dist_op, fwd=True):
def find_best_compatible_distributed_operator_impl(dist_op,
fwd=True,
partial=True):
"""
Here just return the first compatible implemention.
This will be improved by cost model in the future.
......@@ -168,39 +170,55 @@ def find_best_compatible_distributed_operator_impl(dist_op, fwd=True):
dist_op_default_impl_container = get_distributed_operator_impl_container(
"default")
compatible_impls = []
if fwd:
# First, find impls in the corresponding container
if dist_op_impl_container:
compatible_impls.extend(
dist_op_impl_container.get_input_compatible_impls(dist_op))
# Second, find impls in the elementwise container
if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
compatible_impls.extend(
dist_op_eltwise_impl_container.get_input_compatible_impls(
dist_op))
# Third, find impls in the default container
if dist_op_default_impl_container:
compatible_impls.extend(
dist_op_default_impl_container.get_input_compatible_impls(
dist_op))
if partial:
if fwd:
# First, find impls in the corresponding container
if dist_op_impl_container:
compatible_impls.extend(
dist_op_impl_container.get_input_compatible_impls(dist_op))
# Second, find impls in the elementwise container
if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
compatible_impls.extend(
dist_op_eltwise_impl_container.get_input_compatible_impls(
dist_op))
# Third, find impls in the default container
if dist_op_default_impl_container:
compatible_impls.extend(
dist_op_default_impl_container.get_input_compatible_impls(
dist_op))
else:
# First, find impls in the corresponding container
if dist_op_impl_container:
compatible_impls.extend(
dist_op_impl_container.get_output_compatible_impls(dist_op))
# Second, find impls in the elementwise container
if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
compatible_impls.extend(
dist_op_eltwise_impl_container.get_output_compatible_impls(
dist_op))
# Third, find impls in the default container
if dist_op_default_impl_container:
compatible_impls.extend(
dist_op_default_impl_container.get_output_compatible_impls(
dist_op))
else:
# First, find impls in the corresponding container
if dist_op_impl_container:
compatible_impls.extend(
dist_op_impl_container.get_output_compatible_impls(dist_op))
dist_op_impl_container.get_compatible_impls(dist_op))
# Second, find impls in the elementwise container
if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
compatible_impls.extend(
dist_op_eltwise_impl_container.get_output_compatible_impls(
dist_op))
dist_op_eltwise_impl_container.get_compatible_impls(dist_op))
# Third, find impls in the default container
if dist_op_default_impl_container:
compatible_impls.extend(
dist_op_default_impl_container.get_output_compatible_impls(
dist_op))
dist_op_default_impl_container.get_compatible_impls(dist_op))
if compatible_impls:
# For now, just return the first compatible impl
best_compatible_impl = compatible_impls[0]
# best_compatible_impl = compatible_impls[0]
best_compatible_impl = compatible_impls
else:
best_compatible_impl = None
return best_compatible_impl
......
......@@ -53,6 +53,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
batch_dim_mappings = []
input_names = op_desc.input_names()
xshape_arg_names = []
if "XShape" in input_names:
......@@ -64,14 +65,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping:
if mapping != -1:
return False
# continue
# if len(dims_mapping) < 1:
# continue
continue
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
else:
if dims_mapping[0] != -1:
return False
......@@ -79,12 +80,19 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping[2:]:
if mapping != -1:
return False
if len(dims_mapping) >= 2:
batch_dim_mappings.append(dims_mapping[1])
if compute_compatible_dim_mapping(batch_dim_mappings) is None:
return False
return True
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
output_names = op_desc.output_names()
batch_dim_mappings = []
xshape_arg_names = []
if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape")
......@@ -95,14 +103,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping:
if mapping != -1:
return False
# continue
# if len(dims_mapping) < 1:
# continue
continue
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
else:
if dims_mapping[0] != -1:
return False
......@@ -110,6 +118,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping[2:]:
if mapping != -1:
return False
if len(dims_mapping) >= 2:
batch_dim_mappings.append(dims_mapping[1])
if compute_compatible_dim_mapping(batch_dim_mappings) is None:
return False
return True
def is_auto_compatible(self, dist_op):
......@@ -123,9 +137,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
xshape_arg_names = op_desc.input("XShape")
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if serial_tensor.is_parameter:
for mapping in dims_mapping:
if mapping != -1:
return False
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
......@@ -150,9 +167,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
xshape_arg_names = op_desc.output("XShape")
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if serial_tensor.is_parameter:
for mapping in dims_mapping:
if mapping != -1:
return False
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
......@@ -229,7 +249,9 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
compatible_dim_mapping = compute_compatible_dim_mapping(
batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping."
if compatible_dim_mapping is None:
return False
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
......
......@@ -52,21 +52,46 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
if is_elementwise_op(op_desc.type()):
return True
else:
if not is_elementwise_op(op_desc.type()):
return False
op_dist_attr = dist_op.dist_attr
dims_mapping_list = []
input_arg_names = op_desc.input_arg_names()
max_dims_mapping_len = -1
for arg_name in input_arg_names:
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if max_dims_mapping_len < len(dims_mapping):
max_dims_mapping_len = len(dims_mapping)
dims_mapping_list.append(dims_mapping)
for idx in range(max_dims_mapping_len):
dim_mappings = []
for dims_mapping in dims_mapping_list:
if idx < len(dims_mapping):
dim_mappings.append(dims_mapping[-(idx + 1)])
if compute_compatible_dim_mapping(dim_mappings) is None:
return False
return True
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_desc = dist_op.serial_op.desc
if is_elementwise_op(op_desc.type()):
return True
else:
if not is_elementwise_op(op_desc.type()):
return False
op_dist_attr = dist_op.dist_attr
dims_mapping_list = []
output_arg_names = op_desc.output_arg_names()
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
dims_mapping_list.append(dims_mapping)
if compute_compatible_dims_mapping(dims_mapping_list) is None:
return False
return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
if not is_elementwise_op(op_desc.type()):
return False
op_dist_attr = dist_op.dist_attr
dims_mapping_list = []
input_arg_names = op_desc.input_arg_names()
......@@ -127,7 +152,8 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list)
assert compatible_dims_mapping is not None, "There is no compatible dim mapping."
if compatible_dims_mapping is None:
return False
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
......
......@@ -95,7 +95,8 @@ def _update_dims_mapping_for_matmul(dist_op):
broadcast_x_dims_mapping, broadcast_y_dims_mapping,
broadcast_out_dims_mapping
])
assert compatible_dims_mapping is not None, "There is no compatible dim mapping."
if compatible_dims_mapping is None:
return False
for i in range(x_dims_mapping_len - 2):
new_idx = i + (out_dims_mapping_len - x_dims_mapping_len)
......
......@@ -117,7 +117,8 @@ class DistributedPNormImpl(DistributedOperatorImpl):
compatible_dim_mapping = compute_compatible_dim_mapping(
batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping."
if compatible_dim_mapping is None:
return False
for arg_name in op_desc.input_arg_names():
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -17,6 +17,7 @@ from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from .dist_default import DistributedDefaultImpl0
......@@ -47,6 +48,29 @@ class DistributedSliceImpl(DistributedOperatorImpl):
return True
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
in_name = op_desc.input('Input')[0]
out_name = op_desc.output('Out')[0]
axes = op_desc.attr('axes')
decrease_axis = op_desc.attr('decrease_axis')
in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
ref_indices = []
for i in range(len(in_dims_mapping)):
if i not in decrease_axis:
ref_indices.append(i)
if ref_indices == []:
assert len(out_dims_mapping) == 1
if is_dim_shard(out_dims_mapping[0]):
return False
else:
for i in range(len(out_dims_mapping)):
ref_index = ref_indices[i]
if ref_index in axes and is_dim_shard(out_dims_mapping[i]):
return False
return True
def is_compatible(self, dist_op):
......@@ -95,17 +119,30 @@ class DistributedSliceImpl(DistributedOperatorImpl):
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
ref_dims_mapping = []
ref_indices = []
for i in range(len(in_dims_mapping)):
if i not in decrease_axis:
ref_dims_mapping.append(in_dims_mapping[i])
ref_indices.append(i)
if ref_dims_mapping == []:
ref_dims_mapping = [-1]
assert len(ref_dims_mapping) == len(out_dims_mapping)
for i in range(len(out_dims_mapping)):
if out_dims_mapping[i] != ref_dims_mapping[i]:
out_dims_mapping[i] = ref_dims_mapping[i]
changed = True
assert len(ref_dims_mapping) == len(out_dims_mapping)
assert ref_dims_mapping[0] == out_dims_mapping[0]
changed = False
else:
assert len(ref_dims_mapping) == len(out_dims_mapping)
for i in range(len(out_dims_mapping)):
compatible_dim_mapping = compute_compatible_dim_mapping(
[out_dims_mapping[i], ref_dims_mapping[i]])
if compatible_dim_mapping is None:
continue
if ref_dims_mapping[i] != compatible_dim_mapping:
in_dims_mapping[ref_indices[i]] = compatible_dim_mapping
changed = True
if out_dims_mapping[i] != compatible_dim_mapping:
out_dims_mapping[i] = compatible_dim_mapping
changed = True
return changed
......
......@@ -230,7 +230,7 @@ class AutoParallelizer:
g_process_group_map = copy.deepcopy(_g_process_group_map)
_g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, [])
for process_mesh in dist_context._process_meshes:
for process_mesh in self._dist_context._process_meshes:
_g_process_group_map[0].add_ranks(process_mesh.processes)
return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map
......
......@@ -138,7 +138,6 @@ class MetricRecords(object):
def from_state(cls, state):
records = cls(state["direction"])
records.records = [MetricRecord.from_state(r) for r in state["records"]]
print("here 1", records.records)
return records
......
......@@ -159,11 +159,11 @@ def print_program_with_dist_attr(program, dist_context=None):
from .dist_context import set_default_distributed_context
if dist_context is None:
dist_context = get_default_distributed_context()
print(program)
print(program, flush=True)
else:
original_default_context = get_default_distributed_context()
set_default_distributed_context(dist_context)
print(program)
print(program, flush=True)
set_default_distributed_context(original_default_context)
lock.release()
......
......@@ -350,11 +350,12 @@ class RecomputePass(PassBase):
for _, op_desc in reversed(list(enumerate(segment_descs))):
rc_desc = main_block.desc._insert_op(idx)
rc_desc.copy_from(op_desc)
rc_desc.set_original_id(rc_desc.id())
rc_op = Operator(main_block, rc_desc)
main_block.ops.insert(idx, rc_op)
# set recomputed ops' dist attr
fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id(
rc_desc.original_id())
op_desc.original_id())
assert fwd_op_dist_attr is not None
self.set_op_dist_attr(rc_op, fwd_op_dist_attr,
var_name_dict)
......
......@@ -3,18 +3,23 @@
if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_auto_parallel_relaunch MODULES test_auto_parallel_relaunch ENVS ${dist_ENVS})
set_tests_properties(test_auto_parallel_relaunch PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120)
py_test_modules(test_relaunch_with_planner MODULES test_relaunch_with_planner ENVS ${dist_ENVS})
set_tests_properties(test_relaunch_with_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120)
py_test_modules(test_relaunch_with_gpt_planner MODULES test_relaunch_with_gpt_planner ENVS ${dist_ENVS})
set_tests_properties(test_relaunch_with_gpt_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 240)
py_test_modules(test_engine_api MODULES test_engine_api ENVS ${dist_ENVS})
set_tests_properties(test_engine_api PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80)
py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS})
py_test_modules(test_converter MODULES test_converter ENVS ${dist_ENVS})
set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_high_order_grad MODULES test_high_order_grad ENVS ${dist_ENVS})
set_tests_properties(test_high_order_grad PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS})
py_test_modules(test_while_op_partition MODULES test_while_op_partition ENVS ${dist_ENVS})
py_test_modules(test_tunable_variable MODULES test_tunable_variable ENVS ${dist_ENVS})
py_test_modules(test_tunable_space MODULES test_tunable_space ENVS ${dist_ENVS})
py_test_modules(test_recorder MODULES test_recorder ENVS ${dist_ENVS})
......
......@@ -66,7 +66,6 @@ class TestDistReshape(unittest.TestCase):
for rank in range(2):
dist_main_prog, dist_context = parallelizer(make_program_dp2, rank)
ops = dist_main_prog.global_block().ops
print_program_with_dist_attr(dist_main_prog, dist_context)
for idx, op in enumerate(ops):
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_type == "reshape2"
......
......@@ -15,6 +15,7 @@
import unittest
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
......@@ -85,14 +86,9 @@ class TestDistSlice(unittest.TestCase):
for op in ops:
axes = op.desc.attr('axes')
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if axes[0] == 0:
assert op_dist_attr.impl_type == "default"
else:
assert op_dist_attr.impl_type == "slice"
for out in op.output_arg_names:
var_dims_mapping = op_dist_attr.get_output_dims_mapping(
out)
assert var_dims_mapping[0] == 0
assert op_dist_attr.impl_type == "slice"
for out in op.output_arg_names:
var_dims_mapping = op_dist_attr.get_output_dims_mapping(out)
def test_dist_slice_serial(self):
dist_main_prog, dist_context = parallelizer(make_program_serial, 0)
......
......@@ -23,12 +23,13 @@ import paddle.nn.functional as F
import paddle.distributed.auto_parallel as auto
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import make_data_unshard
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context
from paddle.distributed.auto_parallel.operators import find_best_compatible_distributed_operator_impl
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
......@@ -283,139 +284,143 @@ def get_program():
def completion(train_program, start_program, dist_context):
blocks = train_program.blocks
# completion tensors
for block in blocks:
for op in block.ops:
if op.type == "layer_norm":
for out_name in op.output_arg_names:
out_var = block.vars[out_name]
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var)
if tensor_dist_attr:
continue
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = _g_process_mesh
tensor_dist_attr.dims_mapping = [-1]
dist_context.set_tensor_dist_attr_for_program(
out_var, tensor_dist_attr)
elif op.type == "elementwise_sub":
for out_name in op.output_arg_names:
out_var = block.vars[out_name]
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = _g_process_mesh
tensor_dist_attr.dims_mapping = [-1, -1, -1]
dist_context.set_tensor_dist_attr_for_program(
out_var, tensor_dist_attr)
elif op.type == "matmul_v2":
col = False
for in_name in op.input_arg_names:
if ".w_" not in in_name:
continue
if in_name not in block.vars:
in_var = blocks[0].vars[in_name]
else:
in_var = block.vars[in_name]
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
in_var)
assert tensor_dist_attr is not None
if tensor_dist_attr.dims_mapping == [-1, 0]:
col = True
for out_name in op.output_arg_names:
out_var = block.vars[out_name]
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var)
if tensor_dist_attr:
continue
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = _g_process_mesh
if col:
tensor_dist_attr.dims_mapping = [-1, -1, 0]
else:
tensor_dist_attr.dims_mapping = [-1, -1, -1]
dist_context.set_tensor_dist_attr_for_program(
out_var, tensor_dist_attr)
elif op.type == "while":
out_name = op.desc.output("StepScopes")[0]
out_var = block.vars[out_name]
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = _g_process_mesh
tensor_dist_attr.dims_mapping = [-1]
dist_context.set_tensor_dist_attr_for_program(out_var,
tensor_dist_attr)
# completion ops
for block in blocks:
for op in block.ops:
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = _g_process_mesh
if op.type == "create_by_read" or op.type == "create_double_buffer_reader":
for in_name in op.input_arg_names:
op_dist_attr.set_input_dims_mapping(in_name, [])
for out_name in op.output_arg_names:
op_dist_attr.set_output_dims_mapping(out_name, [])
elif op.type == "read":
for in_name in op.input_arg_names:
op_dist_attr.set_output_dims_mapping(in_name, [])
for out_name in op.output_arg_names:
out_var = block.vars[out_name]
out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var)
op_dist_attr.set_output_dist_attr(out_name, out_dist_attr)
elif op.type == "while":
for in_name in op.input_arg_names:
in_var = block.vars[in_name]
in_dist_attr = dist_context.get_tensor_dist_attr_for_program(
in_var)
op_dist_attr.set_input_dist_attr(in_name, in_dist_attr)
for out_name in op.output_arg_names:
if out_name == op.desc.output("StepScopes")[0]:
op_dist_attr.set_output_dims_mapping(out_name, [])
else:
out_var = block.vars[out_name]
out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var)
op_dist_attr.set_output_dist_attr(out_name,
out_dist_attr)
else:
for in_name in op.input_arg_names:
if in_name == "lod_tensor_blocking_queue_0":
continue
if in_name not in block.vars:
in_var = blocks[0].vars[in_name]
else:
in_var = block.vars[in_name]
in_dist_attr = dist_context.get_tensor_dist_attr_for_program(
in_var)
op_dist_attr.set_input_dist_attr(in_name, in_dist_attr)
for out_name in op.output_arg_names:
if out_name not in block.vars:
out_var = blocks[0].vars[out_name]
else:
out_var = block.vars[out_name]
out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var)
op_dist_attr.set_output_dist_attr(out_name, out_dist_attr)
if op.type == "matmul_v2":
op_dist_attr.impl_type = "matmul_v2"
for in_name in op_dist_attr.inputs_dist_attrs.keys():
in_dist_attr = op_dist_attr.inputs_dist_attrs[in_name]
if ".w_" in in_name and in_dist_attr.dims_mapping[-1] == 0:
op_dist_attr.impl_idx = 0
else:
op_dist_attr.impl_idx = 1
elif op.type == "fill_constant_batch_size_like":
op_dist_attr.impl_type = "fill_constant_batch_size_like"
op_dist_attr.impl_idx = 0
else:
op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = 0
dist_context.set_op_dist_attr_for_program(op, op_dist_attr)
make_data_unshard(train_program, start_program, dist_context)
# blocks = train_program.blocks
# # completion tensors
# for block in blocks:
# for op in block.ops:
# if op.type == "layer_norm":
# for out_name in op.output_arg_names:
# out_var = block.vars[out_name]
# tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
# out_var)
# if tensor_dist_attr:
# continue
# tensor_dist_attr = TensorDistributedAttribute()
# tensor_dist_attr.process_mesh = _g_process_mesh
# tensor_dist_attr.dims_mapping = [-1]
# dist_context.set_tensor_dist_attr_for_program(
# out_var, tensor_dist_attr)
# elif op.type == "elementwise_sub":
# for out_name in op.output_arg_names:
# out_var = block.vars[out_name]
# tensor_dist_attr = TensorDistributedAttribute()
# tensor_dist_attr.process_mesh = _g_process_mesh
# tensor_dist_attr.dims_mapping = [-1, -1, -1]
# dist_context.set_tensor_dist_attr_for_program(
# out_var, tensor_dist_attr)
# elif op.type == "matmul_v2":
# col = False
# for in_name in op.input_arg_names:
# if ".w_" not in in_name:
# continue
# if in_name not in block.vars:
# in_var = blocks[0].vars[in_name]
# else:
# in_var = block.vars[in_name]
# tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
# in_var)
# assert tensor_dist_attr is not None
# if tensor_dist_attr.dims_mapping == [-1, 0]:
# col = True
# for out_name in op.output_arg_names:
# out_var = block.vars[out_name]
# tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
# out_var)
# if tensor_dist_attr:
# continue
# tensor_dist_attr = TensorDistributedAttribute()
# tensor_dist_attr.process_mesh = _g_process_mesh
# if col:
# tensor_dist_attr.dims_mapping = [-1, -1, 0]
# else:
# tensor_dist_attr.dims_mapping = [-1, -1, -1]
# dist_context.set_tensor_dist_attr_for_program(
# out_var, tensor_dist_attr)
# elif op.type == "while":
# out_name = op.desc.output("StepScopes")[0]
# out_var = block.vars[out_name]
# tensor_dist_attr = TensorDistributedAttribute()
# tensor_dist_attr.process_mesh = _g_process_mesh
# tensor_dist_attr.dims_mapping = [-1]
# dist_context.set_tensor_dist_attr_for_program(out_var,
# tensor_dist_attr)
# # completion ops
# for block in blocks:
# for op in block.ops:
# op_dist_attr = OperatorDistributedAttribute()
# op_dist_attr.process_mesh = _g_process_mesh
# if op.type == "create_by_read" or op.type == "create_double_buffer_reader":
# for in_name in op.input_arg_names:
# op_dist_attr.set_input_dims_mapping(in_name, [])
# for out_name in op.output_arg_names:
# op_dist_attr.set_output_dims_mapping(out_name, [])
# elif op.type == "read":
# for in_name in op.input_arg_names:
# op_dist_attr.set_output_dims_mapping(in_name, [])
# for out_name in op.output_arg_names:
# out_var = block.vars[out_name]
# out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
# out_var)
# op_dist_attr.set_output_dist_attr(out_name, out_dist_attr)
# elif op.type == "while":
# for in_name in op.input_arg_names:
# in_var = block.vars[in_name]
# in_dist_attr = dist_context.get_tensor_dist_attr_for_program(
# in_var)
# op_dist_attr.set_input_dist_attr(in_name, in_dist_attr)
# for out_name in op.output_arg_names:
# if out_name == op.desc.output("StepScopes")[0]:
# op_dist_attr.set_output_dims_mapping(out_name, [])
# else:
# out_var = block.vars[out_name]
# out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
# out_var)
# op_dist_attr.set_output_dist_attr(out_name,
# out_dist_attr)
# else:
# for in_name in op.input_arg_names:
# if in_name == "lod_tensor_blocking_queue_0":
# continue
# if in_name not in block.vars:
# in_var = blocks[0].vars[in_name]
# else:
# in_var = block.vars[in_name]
# in_dist_attr = dist_context.get_tensor_dist_attr_for_program(
# in_var)
# op_dist_attr.set_input_dist_attr(in_name, in_dist_attr)
# for out_name in op.output_arg_names:
# if out_name not in block.vars:
# out_var = blocks[0].vars[out_name]
# else:
# out_var = block.vars[out_name]
# out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
# out_var)
# op_dist_attr.set_output_dist_attr(out_name, out_dist_attr)
# if op.type == "matmul_v2":
# op_dist_attr.impl_type = "matmul_v2"
# for in_name in op_dist_attr.inputs_dist_attrs.keys():
# in_dist_attr = op_dist_attr.inputs_dist_attrs[in_name]
# if ".w_" in in_name and in_dist_attr.dims_mapping[-1] == 0:
# op_dist_attr.impl_idx = 0
# else:
# op_dist_attr.impl_idx = 1
# elif op.type == "fill_constant_batch_size_like":
# op_dist_attr.impl_type = "fill_constant_batch_size_like"
# op_dist_attr.impl_idx = 0
# else:
# op_dist_attr.impl_type = "default"
# op_dist_attr.impl_idx = 0
# dist_context.set_op_dist_attr_for_program(op, op_dist_attr)
# make_data_unshard(train_program, start_program, dist_context)
completer = Completer(dist_context)
train_program = completer.complete_forward_annotation(train_program)
make_data_unshard(train_program, start_program, dist_context)
return train_program, start_program
......
......@@ -134,7 +134,6 @@ class TestMLPAutoParallelizer(unittest.TestCase):
for op in block.ops:
for attr_name in op.attr_names:
self.assertTrue(suffix not in attr_name)
# print_program_with_dist_attr(distributed_main_program)
self.assertIsNotNone(distributed_startup_program)
self.assertIsNotNone(distributed_main_program)
......
......@@ -332,7 +332,6 @@ class TestMLPReshard(unittest.TestCase):
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册