未验证 提交 bed9aaea 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Improve the codes of the completion and distributed context (#40671)

* [Auto Parallel] Replace the old planner by the new partition tuner

* [Auto Parallel] Improve the completion and distributed context

* [Auto Parallel] Fix some bugs of the compatible check of some dist ops

* [Auto Parallel] Fix some bugs
上级 afcf6bd0
......@@ -123,6 +123,19 @@ def merge_process_mesh_two(pm1, pm2):
return merged_process_mesh
def _validate_dims_mapping(dims_mapping, process_mesh):
if dims_mapping is None:
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
process_mesh.topology):
return False
for i in range(len(process_mesh.topology)):
if dims_mapping.count(i) > 1:
return False
return True
class Completer:
def __init__(self, dist_context):
assert dist_context is not None
......@@ -161,6 +174,9 @@ class Completer:
dims_mapping_list.append(tensor_dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list)
if not _validate_dims_mapping(compatible_dims_mapping,
tensor_dist_attr.process_mesh):
return False
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != tensor_dims_mapping):
tensor_dist_attr.dims_mapping = compatible_dims_mapping
......@@ -182,6 +198,9 @@ class Completer:
dims_mapping_list.append(tensor_dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list)
if not _validate_dims_mapping(compatible_dims_mapping,
tensor_dist_attr.process_mesh):
return False
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != tensor_dims_mapping):
tensor_dist_attr.dims_mapping = compatible_dims_mapping
......@@ -196,10 +215,12 @@ class Completer:
op_desc = op_node.op()
if op_desc.type() == "create_py_reader" \
or op_desc.type() == "create_double_buffer_reader" \
or op_desc.type() == "while" \
or op_desc.type() == "read":
return False
dist_op = self._dist_context.get_dist_op_for_graph(op_node)
op_dist_attr = dist_op.dist_attr
original_op_dist_attr = copy.deepcopy(op_dist_attr)
if fwd:
for tensor_node in op_node.inputs:
if tensor_node.is_var() and tensor_node.var() is not None:
......@@ -223,18 +244,34 @@ class Completer:
tensor_desc.name(), compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl = find_best_compatible_distributed_operator_impl(
op_dist_impls = find_best_compatible_distributed_operator_impl(
dist_op, fwd=True)
if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
if op_dist_impls is not None:
not_compatible = True
backup_op_dist_attr = copy.deepcopy(op_dist_attr)
backup_changed = changed
for op_dist_impl in op_dist_impls:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
op_dist_attr.impl_type = op_dist_impl.type
# op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
not_compatible = False
break
else:
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
dist_op.dist_attr = backup_op_dist_attr
changed = backup_changed
if not_compatible:
dist_op.dist_attr = original_op_dist_attr
changed = False
else:
dist_op.dist_attr = original_op_dist_attr
changed = False
else:
for tensor_node in op_node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None:
......@@ -258,18 +295,35 @@ class Completer:
tensor_desc.name(), compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl = find_best_compatible_distributed_operator_impl(
op_dist_impls = find_best_compatible_distributed_operator_impl(
dist_op, fwd=False)
if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
if op_dist_impls is not None:
not_compatible = True
backup_op_dist_attr = copy.deepcopy(op_dist_attr)
backup_changed = changed
for op_dist_impl in op_dist_impls:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
not_compatible = False
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
op_dist_attr.impl_type = op_dist_impl.type
# op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
not_compatible = False
break
else:
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
dist_op.dist_attr = backup_op_dist_attr
changed = backup_changed
if not_compatible:
dist_op.dist_attr = original_op_dist_attr
changed = False
else:
dist_op.dist_attr = original_op_dist_attr
changed = False
return changed
def _update_dims_mapping_between_graphs(self):
......@@ -279,17 +333,22 @@ class Completer:
parent_node)
child_node_dist_attr = self._dist_context.get_dist_attr_for_graph(
child_node)
if parent_node_dist_attr.process_mesh != child_node_dist_attr.process_mesh:
continue
parent_node_dims_mapping = parent_node_dist_attr.dims_mapping
child_node_dims_mapping = child_node_dist_attr.dims_mapping
compatible_dims_mapping = compute_compatible_dims_mapping(
[parent_node_dims_mapping, child_node_dims_mapping])
if not _validate_dims_mapping(compatible_dims_mapping,
parent_node_dist_attr.process_mesh):
return False
if (compatible_dims_mapping is not None) \
and (compatible_dims_mapping != parent_node_dims_mapping):
parent_node_dist_attr.dims_mapping = compatible_dims_mapping
changed = True
if (compatible_dims_mapping is not None) \
and (compatible_dims_mapping != child_node_dims_mapping):
parent_node_dist_attr.dims_mapping = compatible_dims_mapping
child_node_dist_attr.dims_mapping = compatible_dims_mapping
changed = True
return changed
......@@ -351,7 +410,7 @@ class Completer:
if compatible_process_mesh is not None \
and tensor_dist_attr.process_mesh != compatible_process_mesh:
tensor_dist_attr.process_mesh = compatible_process_mesh
# Set the process mesh of the op node's outputs
# Set the process mesh of the op node's outputs
for tensor_node in op_node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None:
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
......@@ -389,7 +448,8 @@ class Completer:
if _node_id(cur) in visited:
continue
# TODO: need more restrictions
for node in cur.inputs:
neighbors = cur.inputs + cur.outputs
for node in neighbors:
if node.is_var() and node.var() is not None:
if node.var().type() != core.VarDesc.VarType.READER \
and len(node.var().shape()) == 1:
......@@ -421,10 +481,29 @@ class Completer:
visited.add(_node_id(cur))
return related_nodes
def _make_dims_mapping_replicate(dist_attr):
if isinstance(dist_attr, TensorDistributedAttribute):
for i, _ in enumerate(dist_attr.dims_mapping):
dist_attr.dims_mapping[i] = -1
if isinstance(dist_attr, OperatorDistributedAttribute):
for arg_name in dist_attr.inputs_dist_attrs.keys():
new_dims_mapping = []
dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
for _ in dims_mapping:
new_dims_mapping.append(-1)
dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping)
for arg_name in dist_attr.outputs_dist_attrs.keys():
new_dims_mapping = []
dims_mapping = dist_attr.get_output_dims_mapping(arg_name)
for _ in dims_mapping:
new_dims_mapping.append(-1)
dist_attr.set_output_dims_mapping(arg_name,
new_dims_mapping)
# Amend the process meshes related to while_op
for while_op_node, while_op_node_idx in self._while_op_nodes.values():
sub_graph_id = while_op_node.op()._block_attr_id("sub_block")
sub_graph = self._dist_context._serial_graph.get_sub_graph(
sub_graph = self._dist_context.serial_graph.get_sub_graph(
sub_graph_id)
sub_graph_nodes = list(sub_graph.all_nodes())
while_dist_op = self._dist_context.get_dist_op_for_graph(
......@@ -440,6 +519,7 @@ class Completer:
merged_process_mesh = merge_process_mesh_two(
merged_process_mesh, dist_attr.process_mesh)
while_op_dist_attr.process_mesh = merged_process_mesh
_make_dims_mapping_replicate(while_op_dist_attr)
# Step 2: set the related nodes of while_op to the process mesh of while_op
# Step 2.1: Find related nodes of cond var the graph of while_op
......@@ -480,6 +560,7 @@ class Completer:
tensor_dist_attr = self._dist_context.get_dist_attr_for_graph(
node)
tensor_dist_attr.process_mesh = merged_process_mesh
_make_dims_mapping_replicate(tensor_dist_attr)
# Step 3: set the process meshes of the inputs in while_op to the process meshes of the outside input nodes
while_op_inputs_dist_attrs = while_op_dist_attr.inputs_dist_attrs
......@@ -519,6 +600,25 @@ class Completer:
dist_attr = self._dist_context.get_dist_attr_for_graph(
array_node)
dist_attr.process_mesh = merged_process_mesh
_make_dims_mapping_replicate(dist_attr)
def _update_process_mesh_between_graphs(self):
for parent_node, child_node in self._node_pairs_between_graphs:
parent_node_dist_attr = self._dist_context.get_dist_attr_for_graph(
parent_node)
child_node_dist_attr = self._dist_context.get_dist_attr_for_graph(
child_node)
parent_node_dist_attr.process_mesh = child_node_dist_attr.process_mesh
compatible_process_mesh = compute_compatible_process_mesh([
parent_node_dist_attr.process_mesh,
child_node_dist_attr.process_mesh
])
if compatible_process_mesh is not None \
and parent_node_dist_attr.process_mesh != compatible_process_mesh:
parent_node_dist_attr.process_mesh = compatible_process_mesh
if compatible_process_mesh is not None \
and child_node_dist_attr.process_mesh != compatible_process_mesh:
child_node_dist_attr.process_mesh = compatible_process_mesh
def _update_process_mesh(self):
ordered_op_nodes = self._dist_context._serial_ordered_op_nodes
......@@ -569,7 +669,7 @@ class Completer:
return None
for idx, op_node in enumerate(ordered_op_nodes[
idx_of_first_op_node_has_process_mesh + 1:]):
original_idx = idx_of_first_op_node_has_process_mesh + +idx + 1
original_idx = idx_of_first_op_node_has_process_mesh + idx + 1
nearest_op_node = ordered_op_nodes[original_idx - 1]
nearest_op_dist_attr = self._dist_context.get_dist_attr_for_graph(
nearest_op_node)
......@@ -585,6 +685,9 @@ class Completer:
# Step 3: adjust the process meshes for special ops
self._update_process_mesh_for_specials()
# Step 4: adjust the process meshes between graphs
self._update_process_mesh_between_graphs()
def _prepare(self):
self._while_op_nodes = {}
self._array_nodes = {}
......@@ -620,7 +723,7 @@ class Completer:
self._node_pairs_between_graphs.append(
(after_node, node))
def complete_forward_annotation(self, serial_main_program):
def complete_forward_annotation(self, serial_main_program=None):
""" Complete annotation for the partial annotated serial_main_program.
Arguments:
serial_main_program: partial annotated serial_main_program.
......@@ -628,15 +731,12 @@ class Completer:
serial_main_program: completed annotated serial_main_program.
"""
# Use the default distribted context for completeion if there is no one
self._dist_context.serial_program = serial_main_program
# Initialize distributed attributes for all var and op node in serial_main_program
self._dist_context.init_dist_attr_for_program()
# print_program_with_dist_attr(serial_main_program, self._dist_context)
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context.serial_main_program = serial_main_program
# Initialize distributed attributes for all var and op node in graph
self._dist_context.init_dist_attr_for_graph()
self._dist_context.initialize()
self._prepare()
......@@ -646,10 +746,9 @@ class Completer:
# Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program()
self._dist_context.clear_dist_info_for_graph()
# NOTE:[HighOrderGrad] update vars and ops distributed attribute in high order gradient
self.complete_high_order_grad_annotation(serial_main_program)
self._complete_high_order_grad_annotation(serial_main_program)
# Do the validation check and amend some completion
self._dist_context.amend_dist_attr_for_program()
......@@ -658,7 +757,7 @@ class Completer:
return serial_main_program
def complete_high_order_grad_annotation(self, serial_main_program):
def _complete_high_order_grad_annotation(self, serial_main_program):
"""
NOTE:
[HighOrderGrad] Complete the annotation of vars and ops only for high order gradient.
......@@ -818,6 +917,10 @@ class Completer:
def complete_backward_annotation(self, serial_main_program):
"""Complete the annotation of vars and ops in the backward phase for parallel program."""
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context.serial_main_program = serial_main_program
def _is_grad_var_name(name):
if "@GRAD" in name:
......@@ -1036,8 +1139,12 @@ class Completer:
self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr)
def complete_update_annotation(self, serial_main_program):
def complete_update_annotation(self, serial_main_program=None):
"""Complete the annotation of vars and ops in the update phase for parallel program."""
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context.serial_main_program = serial_main_program
ops = list(serial_main_program.global_block().ops)
vars = serial_main_program.global_block().vars
learning_rate_completed = False
......
......@@ -52,7 +52,7 @@ def append_op_output_suffix(name):
class TensorDistributedAttribute:
def __init__(self):
# The process mesh of distributed operator attribute must is the same as
# The process mesh of distributed operator attribute must is the same as
# the process meshes of all input and output distributed attributed
self._process_mesh = None
self._dims_mapping = None
......@@ -132,12 +132,29 @@ class TensorDistributedAttribute:
key, dist_attr)
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
# def reset(self, skip_dist_attr_field_names):
# if skip_dist_attr_field_names is not None \
# and "process_mesh" not in skip_dist_attr_field_names:
# self._process_mesh = None
# if skip_dist_attr_field_names is not None \
# and "dims_mapping" not in skip_dist_attr_field_names:
# for i in enumerate(self._dims_mapping):
# self._dims_mapping[i] = -1
# self._is_annotated = {}
def is_annotated(self, dist_attr_field_name):
return self._is_annotated.get(dist_attr_field_name, False)
# def mark_annotated_all(self):
# for key in get_tensor_dist_attr_field_keys():
# self.mark_annotated(key)
def mark_annotated(self, dist_attr_field_name):
self._is_annotated[dist_attr_field_name] = True
# def unmark_annotated(self, dist_attr_field_name):
# self._is_annotated[dist_attr_field_name] = False
def mark_annotated_as(self, dist_attr):
if dist_attr is None:
return
......@@ -195,7 +212,7 @@ class OperatorDistributedAttribute:
if isinstance(process_mesh, list):
process_mesh = ProcessMesh(process_mesh)
self._process_mesh = copy.deepcopy(process_mesh)
# In while op, the proess mesh is not shared by all inputs and outputs
# In while op, the proess mesh is not shared by all inputs and outputs
if self._op_type == "while":
return None
for dist_attr in self._inputs_dist_attrs.values():
......@@ -357,9 +374,25 @@ class OperatorDistributedAttribute:
"ProcessMeshes in DistributedOperator must be the same."
self.process_mesh = shared_process_mesh
# def reset(self, skip_dist_attr_field_names):
# for tensor_dist_attr in self.inputs_dist_attrs.values():
# tensor_dist_attr.reset(skip_dist_attr_field_names)
# for tensor_dist_attr in self.outputs_dist_attrs.values():
# tensor_dist_attr.reset(skip_dist_attr_field_names)
# if skip_dist_attr_field_names is not None \
# and "process_mesh" not in skip_dist_attr_field_names:
# self.process_mesh = None
# self.impl_type = "default"
# self.impl_idx = 0
# self._is_annotated = {}
def is_annotated(self, attr_name):
return self._is_annotated.get(attr_name, False)
# def mark_annotated_all(self):
# for key in get_op_dist_attr_field_keys():
# self.mark_annotated(key)
def mark_annotated(self, attr_name):
if attr_name == "process_mesh":
# Make sure proscess_mesh be annotated consistently
......
......@@ -118,11 +118,10 @@ class Engine:
losses = to_list(self._loss(*(outputs + labels)))
default_ctx = get_default_distributed_context()
if not default_ctx.is_annotation or self._default_strategy:
if not default_ctx.has_annotation or self._default_strategy:
inputs = [self._set_data_parallel(var) for var in inputs]
labels = [self._set_data_parallel(var) for var in labels]
# print(serial_main_prog)
self._feed_vars[mode] = {"inputs": inputs, "labels": labels}
self._fetch_vars[mode] = {
......
......@@ -18,16 +18,16 @@ from ..dist_attribute import OperatorDistributedAttribute
_g_distributed_operator_impl_containers = {}
_g_elementwise_ops = [
"elementwise_add", "gelu", "dropout", "cast", "gather", "concat"
"elementwise", "gelu", "dropout", "cast", "gather", "concat"
]
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'}
def is_elementwise_op(op_type):
if op_type in _g_elementwise_ops:
return True
else:
return False
for eltwise_op in _g_elementwise_ops:
if eltwise_op in op_type:
return True
return False
class DistributedOperatorImplContainer:
......@@ -156,7 +156,9 @@ def register_distributed_operator_impl(op_type, dist_impl):
assert False, "Must register distributed operator registry first."
def find_best_compatible_distributed_operator_impl(dist_op, fwd=True):
def find_best_compatible_distributed_operator_impl(dist_op,
fwd=True,
partial=True):
"""
Here just return the first compatible implemention.
This will be improved by cost model in the future.
......@@ -168,39 +170,55 @@ def find_best_compatible_distributed_operator_impl(dist_op, fwd=True):
dist_op_default_impl_container = get_distributed_operator_impl_container(
"default")
compatible_impls = []
if fwd:
# First, find impls in the corresponding container
if dist_op_impl_container:
compatible_impls.extend(
dist_op_impl_container.get_input_compatible_impls(dist_op))
# Second, find impls in the elementwise container
if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
compatible_impls.extend(
dist_op_eltwise_impl_container.get_input_compatible_impls(
dist_op))
# Third, find impls in the default container
if dist_op_default_impl_container:
compatible_impls.extend(
dist_op_default_impl_container.get_input_compatible_impls(
dist_op))
if partial:
if fwd:
# First, find impls in the corresponding container
if dist_op_impl_container:
compatible_impls.extend(
dist_op_impl_container.get_input_compatible_impls(dist_op))
# Second, find impls in the elementwise container
if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
compatible_impls.extend(
dist_op_eltwise_impl_container.get_input_compatible_impls(
dist_op))
# Third, find impls in the default container
if dist_op_default_impl_container:
compatible_impls.extend(
dist_op_default_impl_container.get_input_compatible_impls(
dist_op))
else:
# First, find impls in the corresponding container
if dist_op_impl_container:
compatible_impls.extend(
dist_op_impl_container.get_output_compatible_impls(dist_op))
# Second, find impls in the elementwise container
if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
compatible_impls.extend(
dist_op_eltwise_impl_container.get_output_compatible_impls(
dist_op))
# Third, find impls in the default container
if dist_op_default_impl_container:
compatible_impls.extend(
dist_op_default_impl_container.get_output_compatible_impls(
dist_op))
else:
# First, find impls in the corresponding container
if dist_op_impl_container:
compatible_impls.extend(
dist_op_impl_container.get_output_compatible_impls(dist_op))
dist_op_impl_container.get_compatible_impls(dist_op))
# Second, find impls in the elementwise container
if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
compatible_impls.extend(
dist_op_eltwise_impl_container.get_output_compatible_impls(
dist_op))
dist_op_eltwise_impl_container.get_compatible_impls(dist_op))
# Third, find impls in the default container
if dist_op_default_impl_container:
compatible_impls.extend(
dist_op_default_impl_container.get_output_compatible_impls(
dist_op))
dist_op_default_impl_container.get_compatible_impls(dist_op))
if compatible_impls:
# For now, just return the first compatible impl
best_compatible_impl = compatible_impls[0]
# best_compatible_impl = compatible_impls[0]
best_compatible_impl = compatible_impls
else:
best_compatible_impl = None
return best_compatible_impl
......
......@@ -53,6 +53,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
batch_dim_mappings = []
input_names = op_desc.input_names()
xshape_arg_names = []
if "XShape" in input_names:
......@@ -64,14 +65,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping:
if mapping != -1:
return False
# continue
# if len(dims_mapping) < 1:
# continue
continue
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
else:
if dims_mapping[0] != -1:
return False
......@@ -79,12 +80,19 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping[2:]:
if mapping != -1:
return False
if len(dims_mapping) >= 2:
batch_dim_mappings.append(dims_mapping[1])
if compute_compatible_dim_mapping(batch_dim_mappings) is None:
return False
return True
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
output_names = op_desc.output_names()
batch_dim_mappings = []
xshape_arg_names = []
if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape")
......@@ -95,14 +103,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping:
if mapping != -1:
return False
# continue
# if len(dims_mapping) < 1:
# continue
continue
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
else:
if dims_mapping[0] != -1:
return False
......@@ -110,6 +118,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for mapping in dims_mapping[2:]:
if mapping != -1:
return False
if len(dims_mapping) >= 2:
batch_dim_mappings.append(dims_mapping[1])
if compute_compatible_dim_mapping(batch_dim_mappings) is None:
return False
return True
def is_auto_compatible(self, dist_op):
......@@ -123,9 +137,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
xshape_arg_names = op_desc.input("XShape")
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if serial_tensor.is_parameter:
for mapping in dims_mapping:
if mapping != -1:
return False
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
......@@ -150,9 +167,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
xshape_arg_names = op_desc.output("XShape")
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if serial_tensor.is_parameter:
for mapping in dims_mapping:
if mapping != -1:
return False
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
......@@ -229,7 +249,9 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
compatible_dim_mapping = compute_compatible_dim_mapping(
batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping."
if compatible_dim_mapping is None:
return False
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
......
......@@ -52,21 +52,46 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
if is_elementwise_op(op_desc.type()):
return True
else:
if not is_elementwise_op(op_desc.type()):
return False
op_dist_attr = dist_op.dist_attr
dims_mapping_list = []
input_arg_names = op_desc.input_arg_names()
max_dims_mapping_len = -1
for arg_name in input_arg_names:
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if max_dims_mapping_len < len(dims_mapping):
max_dims_mapping_len = len(dims_mapping)
dims_mapping_list.append(dims_mapping)
for idx in range(max_dims_mapping_len):
dim_mappings = []
for dims_mapping in dims_mapping_list:
if idx < len(dims_mapping):
dim_mappings.append(dims_mapping[-(idx + 1)])
if compute_compatible_dim_mapping(dim_mappings) is None:
return False
return True
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_desc = dist_op.serial_op.desc
if is_elementwise_op(op_desc.type()):
return True
else:
if not is_elementwise_op(op_desc.type()):
return False
op_dist_attr = dist_op.dist_attr
dims_mapping_list = []
output_arg_names = op_desc.output_arg_names()
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
dims_mapping_list.append(dims_mapping)
if compute_compatible_dims_mapping(dims_mapping_list) is None:
return False
return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
if not is_elementwise_op(op_desc.type()):
return False
op_dist_attr = dist_op.dist_attr
dims_mapping_list = []
input_arg_names = op_desc.input_arg_names()
......@@ -127,7 +152,8 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list)
assert compatible_dims_mapping is not None, "There is no compatible dim mapping."
if compatible_dims_mapping is None:
return False
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
......
......@@ -95,7 +95,8 @@ def _update_dims_mapping_for_matmul(dist_op):
broadcast_x_dims_mapping, broadcast_y_dims_mapping,
broadcast_out_dims_mapping
])
assert compatible_dims_mapping is not None, "There is no compatible dim mapping."
if compatible_dims_mapping is None:
return False
for i in range(x_dims_mapping_len - 2):
new_idx = i + (out_dims_mapping_len - x_dims_mapping_len)
......
......@@ -117,7 +117,8 @@ class DistributedPNormImpl(DistributedOperatorImpl):
compatible_dim_mapping = compute_compatible_dim_mapping(
batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping."
if compatible_dim_mapping is None:
return False
for arg_name in op_desc.input_arg_names():
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -17,6 +17,7 @@ from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from .dist_default import DistributedDefaultImpl0
......@@ -47,6 +48,29 @@ class DistributedSliceImpl(DistributedOperatorImpl):
return True
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
in_name = op_desc.input('Input')[0]
out_name = op_desc.output('Out')[0]
axes = op_desc.attr('axes')
decrease_axis = op_desc.attr('decrease_axis')
in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
ref_indices = []
for i in range(len(in_dims_mapping)):
if i not in decrease_axis:
ref_indices.append(i)
if ref_indices == []:
assert len(out_dims_mapping) == 1
if is_dim_shard(out_dims_mapping[0]):
return False
else:
for i in range(len(out_dims_mapping)):
ref_index = ref_indices[i]
if ref_index in axes and is_dim_shard(out_dims_mapping[i]):
return False
return True
def is_compatible(self, dist_op):
......@@ -95,17 +119,30 @@ class DistributedSliceImpl(DistributedOperatorImpl):
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
ref_dims_mapping = []
ref_indices = []
for i in range(len(in_dims_mapping)):
if i not in decrease_axis:
ref_dims_mapping.append(in_dims_mapping[i])
ref_indices.append(i)
if ref_dims_mapping == []:
ref_dims_mapping = [-1]
assert len(ref_dims_mapping) == len(out_dims_mapping)
for i in range(len(out_dims_mapping)):
if out_dims_mapping[i] != ref_dims_mapping[i]:
out_dims_mapping[i] = ref_dims_mapping[i]
changed = True
assert len(ref_dims_mapping) == len(out_dims_mapping)
assert ref_dims_mapping[0] == out_dims_mapping[0]
changed = False
else:
assert len(ref_dims_mapping) == len(out_dims_mapping)
for i in range(len(out_dims_mapping)):
compatible_dim_mapping = compute_compatible_dim_mapping(
[out_dims_mapping[i], ref_dims_mapping[i]])
if compatible_dim_mapping is None:
continue
if ref_dims_mapping[i] != compatible_dim_mapping:
in_dims_mapping[ref_indices[i]] = compatible_dim_mapping
changed = True
if out_dims_mapping[i] != compatible_dim_mapping:
out_dims_mapping[i] = compatible_dim_mapping
changed = True
return changed
......
......@@ -230,7 +230,7 @@ class AutoParallelizer:
g_process_group_map = copy.deepcopy(_g_process_group_map)
_g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, [])
for process_mesh in dist_context._process_meshes:
for process_mesh in self._dist_context._process_meshes:
_g_process_group_map[0].add_ranks(process_mesh.processes)
return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map
......
......@@ -138,7 +138,6 @@ class MetricRecords(object):
def from_state(cls, state):
records = cls(state["direction"])
records.records = [MetricRecord.from_state(r) for r in state["records"]]
print("here 1", records.records)
return records
......
......@@ -159,11 +159,11 @@ def print_program_with_dist_attr(program, dist_context=None):
from .dist_context import set_default_distributed_context
if dist_context is None:
dist_context = get_default_distributed_context()
print(program)
print(program, flush=True)
else:
original_default_context = get_default_distributed_context()
set_default_distributed_context(dist_context)
print(program)
print(program, flush=True)
set_default_distributed_context(original_default_context)
lock.release()
......
......@@ -350,11 +350,12 @@ class RecomputePass(PassBase):
for _, op_desc in reversed(list(enumerate(segment_descs))):
rc_desc = main_block.desc._insert_op(idx)
rc_desc.copy_from(op_desc)
rc_desc.set_original_id(rc_desc.id())
rc_op = Operator(main_block, rc_desc)
main_block.ops.insert(idx, rc_op)
# set recomputed ops' dist attr
fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id(
rc_desc.original_id())
op_desc.original_id())
assert fwd_op_dist_attr is not None
self.set_op_dist_attr(rc_op, fwd_op_dist_attr,
var_name_dict)
......
......@@ -3,18 +3,23 @@
if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_auto_parallel_relaunch MODULES test_auto_parallel_relaunch ENVS ${dist_ENVS})
set_tests_properties(test_auto_parallel_relaunch PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120)
py_test_modules(test_relaunch_with_planner MODULES test_relaunch_with_planner ENVS ${dist_ENVS})
set_tests_properties(test_relaunch_with_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120)
py_test_modules(test_relaunch_with_gpt_planner MODULES test_relaunch_with_gpt_planner ENVS ${dist_ENVS})
set_tests_properties(test_relaunch_with_gpt_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 240)
py_test_modules(test_engine_api MODULES test_engine_api ENVS ${dist_ENVS})
set_tests_properties(test_engine_api PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80)
py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS})
py_test_modules(test_converter MODULES test_converter ENVS ${dist_ENVS})
set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_high_order_grad MODULES test_high_order_grad ENVS ${dist_ENVS})
set_tests_properties(test_high_order_grad PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS})
py_test_modules(test_while_op_partition MODULES test_while_op_partition ENVS ${dist_ENVS})
py_test_modules(test_tunable_variable MODULES test_tunable_variable ENVS ${dist_ENVS})
py_test_modules(test_tunable_space MODULES test_tunable_space ENVS ${dist_ENVS})
py_test_modules(test_recorder MODULES test_recorder ENVS ${dist_ENVS})
......
......@@ -66,7 +66,6 @@ class TestDistReshape(unittest.TestCase):
for rank in range(2):
dist_main_prog, dist_context = parallelizer(make_program_dp2, rank)
ops = dist_main_prog.global_block().ops
print_program_with_dist_attr(dist_main_prog, dist_context)
for idx, op in enumerate(ops):
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_type == "reshape2"
......
......@@ -15,6 +15,7 @@
import unittest
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
......@@ -85,14 +86,9 @@ class TestDistSlice(unittest.TestCase):
for op in ops:
axes = op.desc.attr('axes')
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if axes[0] == 0:
assert op_dist_attr.impl_type == "default"
else:
assert op_dist_attr.impl_type == "slice"
for out in op.output_arg_names:
var_dims_mapping = op_dist_attr.get_output_dims_mapping(
out)
assert var_dims_mapping[0] == 0
assert op_dist_attr.impl_type == "slice"
for out in op.output_arg_names:
var_dims_mapping = op_dist_attr.get_output_dims_mapping(out)
def test_dist_slice_serial(self):
dist_main_prog, dist_context = parallelizer(make_program_serial, 0)
......
......@@ -23,12 +23,13 @@ import paddle.nn.functional as F
import paddle.distributed.auto_parallel as auto
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import make_data_unshard
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context
from paddle.distributed.auto_parallel.operators import find_best_compatible_distributed_operator_impl
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
......@@ -283,139 +284,143 @@ def get_program():
def completion(train_program, start_program, dist_context):
blocks = train_program.blocks
# completion tensors
for block in blocks:
for op in block.ops:
if op.type == "layer_norm":
for out_name in op.output_arg_names:
out_var = block.vars[out_name]
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var)
if tensor_dist_attr:
continue
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = _g_process_mesh
tensor_dist_attr.dims_mapping = [-1]
dist_context.set_tensor_dist_attr_for_program(
out_var, tensor_dist_attr)
elif op.type == "elementwise_sub":
for out_name in op.output_arg_names:
out_var = block.vars[out_name]
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = _g_process_mesh
tensor_dist_attr.dims_mapping = [-1, -1, -1]
dist_context.set_tensor_dist_attr_for_program(
out_var, tensor_dist_attr)
elif op.type == "matmul_v2":
col = False
for in_name in op.input_arg_names:
if ".w_" not in in_name:
continue
if in_name not in block.vars:
in_var = blocks[0].vars[in_name]
else:
in_var = block.vars[in_name]
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
in_var)
assert tensor_dist_attr is not None
if tensor_dist_attr.dims_mapping == [-1, 0]:
col = True
for out_name in op.output_arg_names:
out_var = block.vars[out_name]
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var)
if tensor_dist_attr:
continue
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = _g_process_mesh
if col:
tensor_dist_attr.dims_mapping = [-1, -1, 0]
else:
tensor_dist_attr.dims_mapping = [-1, -1, -1]
dist_context.set_tensor_dist_attr_for_program(
out_var, tensor_dist_attr)
elif op.type == "while":
out_name = op.desc.output("StepScopes")[0]
out_var = block.vars[out_name]
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = _g_process_mesh
tensor_dist_attr.dims_mapping = [-1]
dist_context.set_tensor_dist_attr_for_program(out_var,
tensor_dist_attr)
# completion ops
for block in blocks:
for op in block.ops:
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = _g_process_mesh
if op.type == "create_by_read" or op.type == "create_double_buffer_reader":
for in_name in op.input_arg_names:
op_dist_attr.set_input_dims_mapping(in_name, [])
for out_name in op.output_arg_names:
op_dist_attr.set_output_dims_mapping(out_name, [])
elif op.type == "read":
for in_name in op.input_arg_names:
op_dist_attr.set_output_dims_mapping(in_name, [])
for out_name in op.output_arg_names:
out_var = block.vars[out_name]
out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var)
op_dist_attr.set_output_dist_attr(out_name, out_dist_attr)
elif op.type == "while":
for in_name in op.input_arg_names:
in_var = block.vars[in_name]
in_dist_attr = dist_context.get_tensor_dist_attr_for_program(
in_var)
op_dist_attr.set_input_dist_attr(in_name, in_dist_attr)
for out_name in op.output_arg_names:
if out_name == op.desc.output("StepScopes")[0]:
op_dist_attr.set_output_dims_mapping(out_name, [])
else:
out_var = block.vars[out_name]
out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var)
op_dist_attr.set_output_dist_attr(out_name,
out_dist_attr)
else:
for in_name in op.input_arg_names:
if in_name == "lod_tensor_blocking_queue_0":
continue
if in_name not in block.vars:
in_var = blocks[0].vars[in_name]
else:
in_var = block.vars[in_name]
in_dist_attr = dist_context.get_tensor_dist_attr_for_program(
in_var)
op_dist_attr.set_input_dist_attr(in_name, in_dist_attr)
for out_name in op.output_arg_names:
if out_name not in block.vars:
out_var = blocks[0].vars[out_name]
else:
out_var = block.vars[out_name]
out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var)
op_dist_attr.set_output_dist_attr(out_name, out_dist_attr)
if op.type == "matmul_v2":
op_dist_attr.impl_type = "matmul_v2"
for in_name in op_dist_attr.inputs_dist_attrs.keys():
in_dist_attr = op_dist_attr.inputs_dist_attrs[in_name]
if ".w_" in in_name and in_dist_attr.dims_mapping[-1] == 0:
op_dist_attr.impl_idx = 0
else:
op_dist_attr.impl_idx = 1
elif op.type == "fill_constant_batch_size_like":
op_dist_attr.impl_type = "fill_constant_batch_size_like"
op_dist_attr.impl_idx = 0
else:
op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = 0
dist_context.set_op_dist_attr_for_program(op, op_dist_attr)
make_data_unshard(train_program, start_program, dist_context)
# blocks = train_program.blocks
# # completion tensors
# for block in blocks:
# for op in block.ops:
# if op.type == "layer_norm":
# for out_name in op.output_arg_names:
# out_var = block.vars[out_name]
# tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
# out_var)
# if tensor_dist_attr:
# continue
# tensor_dist_attr = TensorDistributedAttribute()
# tensor_dist_attr.process_mesh = _g_process_mesh
# tensor_dist_attr.dims_mapping = [-1]
# dist_context.set_tensor_dist_attr_for_program(
# out_var, tensor_dist_attr)
# elif op.type == "elementwise_sub":
# for out_name in op.output_arg_names:
# out_var = block.vars[out_name]
# tensor_dist_attr = TensorDistributedAttribute()
# tensor_dist_attr.process_mesh = _g_process_mesh
# tensor_dist_attr.dims_mapping = [-1, -1, -1]
# dist_context.set_tensor_dist_attr_for_program(
# out_var, tensor_dist_attr)
# elif op.type == "matmul_v2":
# col = False
# for in_name in op.input_arg_names:
# if ".w_" not in in_name:
# continue
# if in_name not in block.vars:
# in_var = blocks[0].vars[in_name]
# else:
# in_var = block.vars[in_name]
# tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
# in_var)
# assert tensor_dist_attr is not None
# if tensor_dist_attr.dims_mapping == [-1, 0]:
# col = True
# for out_name in op.output_arg_names:
# out_var = block.vars[out_name]
# tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
# out_var)
# if tensor_dist_attr:
# continue
# tensor_dist_attr = TensorDistributedAttribute()
# tensor_dist_attr.process_mesh = _g_process_mesh
# if col:
# tensor_dist_attr.dims_mapping = [-1, -1, 0]
# else:
# tensor_dist_attr.dims_mapping = [-1, -1, -1]
# dist_context.set_tensor_dist_attr_for_program(
# out_var, tensor_dist_attr)
# elif op.type == "while":
# out_name = op.desc.output("StepScopes")[0]
# out_var = block.vars[out_name]
# tensor_dist_attr = TensorDistributedAttribute()
# tensor_dist_attr.process_mesh = _g_process_mesh
# tensor_dist_attr.dims_mapping = [-1]
# dist_context.set_tensor_dist_attr_for_program(out_var,
# tensor_dist_attr)
# # completion ops
# for block in blocks:
# for op in block.ops:
# op_dist_attr = OperatorDistributedAttribute()
# op_dist_attr.process_mesh = _g_process_mesh
# if op.type == "create_by_read" or op.type == "create_double_buffer_reader":
# for in_name in op.input_arg_names:
# op_dist_attr.set_input_dims_mapping(in_name, [])
# for out_name in op.output_arg_names:
# op_dist_attr.set_output_dims_mapping(out_name, [])
# elif op.type == "read":
# for in_name in op.input_arg_names:
# op_dist_attr.set_output_dims_mapping(in_name, [])
# for out_name in op.output_arg_names:
# out_var = block.vars[out_name]
# out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
# out_var)
# op_dist_attr.set_output_dist_attr(out_name, out_dist_attr)
# elif op.type == "while":
# for in_name in op.input_arg_names:
# in_var = block.vars[in_name]
# in_dist_attr = dist_context.get_tensor_dist_attr_for_program(
# in_var)
# op_dist_attr.set_input_dist_attr(in_name, in_dist_attr)
# for out_name in op.output_arg_names:
# if out_name == op.desc.output("StepScopes")[0]:
# op_dist_attr.set_output_dims_mapping(out_name, [])
# else:
# out_var = block.vars[out_name]
# out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
# out_var)
# op_dist_attr.set_output_dist_attr(out_name,
# out_dist_attr)
# else:
# for in_name in op.input_arg_names:
# if in_name == "lod_tensor_blocking_queue_0":
# continue
# if in_name not in block.vars:
# in_var = blocks[0].vars[in_name]
# else:
# in_var = block.vars[in_name]
# in_dist_attr = dist_context.get_tensor_dist_attr_for_program(
# in_var)
# op_dist_attr.set_input_dist_attr(in_name, in_dist_attr)
# for out_name in op.output_arg_names:
# if out_name not in block.vars:
# out_var = blocks[0].vars[out_name]
# else:
# out_var = block.vars[out_name]
# out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
# out_var)
# op_dist_attr.set_output_dist_attr(out_name, out_dist_attr)
# if op.type == "matmul_v2":
# op_dist_attr.impl_type = "matmul_v2"
# for in_name in op_dist_attr.inputs_dist_attrs.keys():
# in_dist_attr = op_dist_attr.inputs_dist_attrs[in_name]
# if ".w_" in in_name and in_dist_attr.dims_mapping[-1] == 0:
# op_dist_attr.impl_idx = 0
# else:
# op_dist_attr.impl_idx = 1
# elif op.type == "fill_constant_batch_size_like":
# op_dist_attr.impl_type = "fill_constant_batch_size_like"
# op_dist_attr.impl_idx = 0
# else:
# op_dist_attr.impl_type = "default"
# op_dist_attr.impl_idx = 0
# dist_context.set_op_dist_attr_for_program(op, op_dist_attr)
# make_data_unshard(train_program, start_program, dist_context)
completer = Completer(dist_context)
train_program = completer.complete_forward_annotation(train_program)
make_data_unshard(train_program, start_program, dist_context)
return train_program, start_program
......
......@@ -134,7 +134,6 @@ class TestMLPAutoParallelizer(unittest.TestCase):
for op in block.ops:
for attr_name in op.attr_names:
self.assertTrue(suffix not in attr_name)
# print_program_with_dist_attr(distributed_main_program)
self.assertIsNotNone(distributed_startup_program)
self.assertIsNotNone(distributed_main_program)
......
......@@ -332,7 +332,6 @@ class TestMLPReshard(unittest.TestCase):
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册