未验证 提交 e5cda6fa 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Use the new completion algorithm (#39086)

* Add the backward support for QR

* Remove unnecessary comments

* [Auto Parallel] Improve the dist op interface and compatible computation

* Remove unnecessary modification

* Recover some modifications

* Add lost files

* Fix a minor bug

* Fix the bug of the planner

* Fix the format problem

* [Auto Parallel] Update the completion algorithm

* Fix the bug of auto_searcher unittest
上级 f68ef9d2
......@@ -15,12 +15,6 @@
from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401
from .process_mesh import ProcessMesh
# from .interface import set_shard_mask # noqa: F401
# from .interface import set_offload_device # noqa: F401
# from .interface import set_pipeline_stage # noqa: F401
# from .interface import ProcessMesh # noqa: F401
from .completion import complete_annotation # noqa: F401
from .completion import complete_backward_annotation # noqa: F401
from .reshard import reshard # noqa: F401
from .cost_model import estimate_cost
......
......@@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from copy import deepcopy
import time
from paddle.fluid import core
from paddle.fluid import framework
from .utils import compute_compatible_process_mesh
from .utils import compute_compatible_dim_mapping
from .utils import compute_compatible_dims_mapping
from .utils import print_program_with_dist_attr
from .operators import find_best_compatible_distributed_operator_impl
from .dist_context import get_default_distributed_context
......@@ -29,241 +28,92 @@ from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
from paddle.distributed.fleet.meta_optimizers.common import OpRole
ELEMENTWISE_LIKE_OP_LIST = ["elementwise_add", "gelu", "dropout", "cast"]
def is_elementwise_like_op(op_type):
if op_type in ELEMENTWISE_LIKE_OP_LIST:
return True
else:
return False
def update_tensor_node_process_mesh(dist_context, tensor_node, fwd=True):
"""
Update tensor's process mesh by using its predecessor's process mesh if in the forward direction,
and by using its successor's process mesh if in the backward direction. Note: only the equal
process meshes are compatible for now.
"""
changed = False
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(tensor_node)
if tensor_dist_attr.is_annotated("process_mesh"):
return changed
tensor_process_mesh = tensor_dist_attr.process_mesh
if fwd:
inputs_process_meshes = []
for pred_op_node in tensor_node.inputs:
if pred_op_node.op() is not None:
op_dist_attr = dist_context.get_op_dist_attr_for_graph(
pred_op_node)
op_process_mesh = op_dist_attr.process_mesh
inputs_process_meshes.append(op_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
inputs_process_meshes)
if compatible_process_mesh is not None and tensor_process_mesh is None:
tensor_dist_attr.process_mesh = compatible_process_mesh
changed = True
else:
outputs_process_meshes = []
for succ_op_node in tensor_node.outputs:
if succ_op_node.op() is not None:
op_dist_attr = dist_context.get_op_dist_attr_for_graph(
succ_op_node)
op_process_mesh = op_dist_attr.process_mesh
outputs_process_meshes.append(op_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
outputs_process_meshes)
if compatible_process_mesh is not None and tensor_process_mesh is None:
tensor_dist_attr.process_mesh = compatible_process_mesh
changed = True
return changed
def update_op_node_process_mesh(dist_context, op_node, fwd=True):
"""
Update op's process mesh by using its predecessor's process mesh if in the forward direction,
and by using its successor's process mesh if in the backward direction. Note: only the equal
process meshes are compatible for now.
"""
changed = False
op_dist_attr = dist_context.get_op_dist_attr_for_graph(op_node)
if op_dist_attr.is_annotated("process_mesh"):
return changed
op_process_mesh = op_dist_attr.process_mesh
if fwd:
inputs_process_meshes = []
for tensor_node in op_node.inputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_process_mesh = tensor_dist_attr.process_mesh
inputs_process_meshes.append(tensor_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
inputs_process_meshes)
if compatible_process_mesh is not None and op_process_mesh is None:
op_dist_attr.process_mesh = compatible_process_mesh
changed = True
else:
outputs_process_meshes = []
for tensor_node in op_node.outputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_process_mesh = tensor_dist_attr.process_mesh
outputs_process_meshes.append(tensor_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
outputs_process_meshes)
if compatible_process_mesh is not None and op_process_mesh is None:
op_dist_attr.process_mesh = compatible_process_mesh
changed = True
return changed
def compute_compatible_process_mesh(process_mesh_list):
"""Compute the compatible process mesh given a list of process meshes."""
if not process_mesh_list:
return None
def update_op_dims_mapping_by_default_dist_impl(dist_context, op_node):
"""Each operator has a default distributed operator, only allowed to be sharded in batch dimension."""
changed = False
if (not op_node.is_op()) or (op_node.op() is None):
return False
op_desc = op_node.op()
dist_op = dist_context.get_dist_op_for_graph(op_node)
op_dist_attr = dist_op.dist_attr
# The following statement will be replaced by a more elegent way
if op_desc.type() == "shape" or op_desc.type() == "slice":
return False
output_names = op_desc.output_names()
xshape_arg_names = []
if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape")
batch_dim_mappings = []
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) > 1:
for idx, mapping in enumerate(dims_mapping[1:]):
assert mapping == -1, \
"{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
batch_dim_mappings.append(dims_mapping[0])
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for idx, mapping in enumerate(dims_mapping[1:]):
assert mapping == -1, \
"{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
batch_dim_mappings.append(dims_mapping[0])
else:
assert dims_mapping[0] == -1, \
"{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {} part."\
.format(op_desc.type(), mapping)
if len(dims_mapping) > 2:
for idx, mapping in enumerate(dims_mapping[2:]):
assert mapping == -1, \
"{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
batch_dim_mappings.append(dims_mapping[1])
compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping."
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
def _compute_compatible_process_mesh_two(pm1, pm2):
if pm1 is None:
return True, pm2
if pm2 is None:
return True, pm1
if pm1 == pm2:
return True, pm1
if pm1.processes == pm2.processes:
if len(pm1.topology) >= len(pm2.topology):
return True, pm1
else:
if compatible_dim_mapping != dims_mapping[1]:
dims_mapping[1] = compatible_dim_mapping
changed = True
return True, pm2
process_set1 = set(pm1.processes)
process_set2 = set(pm2.processes)
if process_set1.issubset(process_set2):
return True, pm2
if process_set2.issubset(process_set1):
return True, pm1
return False, None
compatible_result = None
for process_mesh in process_mesh_list:
compatible, compatible_result = _compute_compatible_process_mesh_two(
compatible_result, process_mesh)
if not compatible:
return None
return copy.deepcopy(compatible_result)
return changed
def compute_compatible_dim_mapping(dim_mapping_list):
"""Compute the compatible dim mapping given a list of dim mapping."""
if not dim_mapping_list:
return None
def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_context, op_node):
"""Element-wise operator can be sharded in any way (but should take care of broadcasting)."""
changed = False
if (not op_node.is_op()) or (op_node.op() is None):
return False
op_desc = op_node.op()
op_dist_attr = dist_context.get_op_dist_attr_for_graph(op_node)
input_arg_names = op_desc.input_arg_names()
input_dims_mapping_dict = {}
input_dims_mapping_lens = {}
max_dims_mapping_len = -1
for arg_name in input_arg_names:
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if max_dims_mapping_len < len(dims_mapping):
max_dims_mapping_len = len(dims_mapping)
input_dims_mapping_dict[arg_name] = dims_mapping
input_dims_mapping_lens[arg_name] = len(dims_mapping)
def _compute_compatible_dim_mapping_two(dm1, dm2):
if dm1 == -1:
return True, dm2
if dm2 == -1:
return True, dm1
if dm1 == dm2:
return True, dm1
return False, None
compatible_result = -1
for mapping in dim_mapping_list:
compatible, compatible_result = _compute_compatible_dim_mapping_two(
compatible_result, mapping)
if not compatible:
return None
return compatible_result
dims_mapping_list = []
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)]
for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
input_dims_mapping_lens[arg_name]) + i
new_dims_mapping[new_idx] = input_dims_mapping_dict[arg_name][i]
dims_mapping_list.append(new_dims_mapping)
else:
dims_mapping_list.append(input_dims_mapping_dict[arg_name])
output_arg_names = op_desc.output_arg_names()
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
assert len(dims_mapping) == max_dims_mapping_len
dims_mapping_list.append(dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping(dims_mapping_list)
assert compatible_dims_mapping is not None, "There is no compatible dim mapping."
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [
-1 for _ in range(input_dims_mapping_lens[arg_name])
]
for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
input_dims_mapping_lens[arg_name]) + i
new_dims_mapping[i] = compatible_dims_mapping[new_idx]
if new_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping)
changed = True
else:
if compatible_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name,
compatible_dims_mapping)
changed = True
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if compatible_dims_mapping != dims_mapping:
op_dist_attr.set_output_dims_mapping(arg_name,
compatible_dims_mapping)
changed = True
def compute_compatible_dims_mapping(dims_mapping_list):
"""Compute the compatible dims mapping given a list of dims mapping.
Each of dims mapping is also a list.
"""
if not dims_mapping_list:
return None
length = len(dims_mapping_list[0])
for dims_mapping in dims_mapping_list:
if dims_mapping is None:
return None
if len(dims_mapping) != length:
return None
compatible_result = []
for dim_mappings in zip(*dims_mapping_list):
compatible_dim_mapping = compute_compatible_dim_mapping(
list(dim_mappings))
if compatible_dim_mapping is None:
return None
compatible_result.append(compatible_dim_mapping)
return compatible_result
return changed
class Completer:
def __init__(self, dist_context):
assert dist_context is not None
self._dist_context = dist_context
def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):
changed = False
if (not tensor_node.is_var()) or (tensor_node.var() is None):
return False
......@@ -271,7 +121,8 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
# Skip reader tensor
if tensor_desc.type() == core.VarDesc.VarType.READER:
return False
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(tensor_node)
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
assert tensor_dist_attr is not None
if tensor_dist_attr.is_annotated("dims_mapping"):
return False
......@@ -284,8 +135,9 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
or pred_op_node.op().type() == "create_double_buffer_reader" \
or pred_op_node.op().type() == "read":
continue
op_dist_attr = dist_context.get_op_dist_attr_for_graph(
op_dist_attr = self._dist_context.get_op_dist_attr_for_graph(
pred_op_node)
if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh:
op_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_desc.name())
dims_mapping_list.append(op_dims_mapping)
......@@ -304,8 +156,9 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
or succ_op_node.op().type() == "create_double_buffer_reader" \
or succ_op_node.op().type() == "read":
continue
op_dist_attr = dist_context.get_op_dist_attr_for_graph(
op_dist_attr = self._dist_context.get_op_dist_attr_for_graph(
succ_op_node)
if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh:
op_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_desc.name())
dims_mapping_list.append(op_dims_mapping)
......@@ -318,8 +171,7 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
changed = True
return changed
def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
def _update_op_node_dims_mapping(self, op_node, fwd=True):
changed = False
if (not op_node.is_op()) or (op_node.op() is None):
return False
......@@ -329,7 +181,7 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
or op_desc.type() == "create_double_buffer_reader" \
or op_desc.type() == "read":
return False
dist_op = dist_context.get_dist_op_for_graph(op_node)
dist_op = self._dist_context.get_dist_op_for_graph(op_node)
op_dist_attr = dist_op.dist_attr
if fwd:
for tensor_node in op_node.inputs:
......@@ -340,8 +192,9 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
if op_dist_attr.is_annotated_input_dims_mapping(
tensor_desc.name()):
continue
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh:
tensor_dims_mapping = tensor_dist_attr.dims_mapping
op_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_desc.name())
......@@ -349,8 +202,8 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
[op_dims_mapping, tensor_dims_mapping])
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != op_dims_mapping):
op_dist_attr.set_input_dims_mapping(tensor_desc.name(),
compatible_dims_mapping)
op_dist_attr.set_input_dims_mapping(
tensor_desc.name(), compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl = find_best_compatible_distributed_operator_impl(
......@@ -374,8 +227,9 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
if op_dist_attr.is_annotated_output_dims_mapping(
tensor_desc.name()):
continue
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh:
tensor_dims_mapping = tensor_dist_attr.dims_mapping
op_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_desc.name())
......@@ -401,185 +255,67 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
op_dist_attr.impl_idx = op_dist_impl.idx
return changed
def _update_process_mesh(self):
def _find_nearset_node(nodes, idx):
for node in reversed(nodes[:idx]):
node_dist_attr = self._dist_context.get_dist_attr_for_graph(
node)
if node_dist_attr.process_mesh is not None:
return node
def complete_annotation(program, dist_context=None):
""" Complete annotation for the partial annotated program.
Arguments:
program: partial annotated program.
dist_context: the distributed context is used to store distributed attributes for program.
If not provided, the default one will be used.
Returns:
program: completed annotated program.
"""
# Use the default distribted context for completeion if there is no one
if dist_context is None:
dist_context = get_default_distributed_context()
dist_context.serial_program = program
else:
dist_context.serial_program = program
# print_program_with_dist_attr(program, dist_context)
# Initialize distributed attributes for all var and op node in program
dist_context.init_dist_attr_for_program()
# Initialize distributed attributes for all var and op node in graph
dist_context.init_dist_attr_for_graph()
# Complete process mesh for each node
all_nodes = list(dist_context.serial_graph.all_nodes())
def sort_key_fun(node):
first = -1
if node.is_op():
first = 0
else:
first = 1
second = -1
if node.is_op() and node.op() is not None:
second = node.op().id()
if node.is_var() and node.var() is not None:
second = node.var().id()
return (first, second)
all_nodes.sort(key=sort_key_fun)
total_reach_fix_point = False
while not total_reach_fix_point:
total_changed = False
for is_fwd in [True, False]:
all_nodes = self._dist_context.serial_ordered_nodes \
if is_fwd else reversed(self._dist_context.serial_ordered_nodes)
reach_fix_point = False
while not reach_fix_point:
total_changed = False
reach_fwd_fix_point = False
reach_bwd_fix_point = False
while not reach_fwd_fix_point:
changed = False
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_process_mesh(
dist_context, node, fwd=True)
if tensor_changed:
changed = True
if node.is_op() and node.op() is not None:
op_changed = update_op_node_process_mesh(
dist_context, node, fwd=True)
if op_changed:
changed = True
if changed:
reach_fwd_fix_point = False
total_changed = True
else:
reach_fwd_fix_point = True
while not reach_bwd_fix_point:
changed = False
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_process_mesh(
dist_context, node, fwd=False)
if tensor_changed:
changed = True
if node.is_op() and node.op() is not None:
op_changed = update_op_node_process_mesh(
dist_context, node, fwd=False)
if op_changed:
for idx, node in enumerate(all_nodes):
nearest_node = _find_nearset_node(
self._dist_context.serial_ordered_nodes, idx)
if nearest_node is None:
continue
nearest_node_dis_attr = self._dist_context.get_dist_attr_for_graph(
nearest_node)
nearest_process_mesh = nearest_node_dis_attr.process_mesh
cur_node_dist_attr = self._dist_context.get_dist_attr_for_graph(
node)
cur_process_mesh = cur_node_dist_attr.process_mesh
compatible_process_mesh = compute_compatible_process_mesh(
[cur_process_mesh, nearest_process_mesh])
if compatible_process_mesh is not None \
and cur_process_mesh != compatible_process_mesh:
cur_node_dist_attr.process_mesh = compatible_process_mesh
changed = True
if changed:
reach_bwd_fix_point = False
total_changed = True
else:
reach_bwd_fix_point = True
if total_changed:
reach_fix_point = False
total_changed = True
else:
reach_fix_point = True
# Validation the completion of process meshes and should be moved to a proper location
is_wrong = False
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
node)
if tensor_dist_attr.process_mesh is None:
msg_str = ""
for op_node in node.inputs:
if op_node.op() is not None:
op_dist_attr = dist_context.get_op_dist_attr_for_graph(
op_node)
msg_str += "{} [{}], ".format(
op_node.op().type(),
op_dist_attr.process_mesh)
else:
msg_str += "{} [{}], ".format(op_node.name(),
None)
for op_node in node.outputs:
if op_node.op() is not None:
op_dist_attr = dist_context.get_op_dist_attr_for_graph(
op_node)
msg_str += "{} [{}], ".format(
op_node.op().type(),
op_dist_attr.process_mesh)
else:
msg_str += "{} [{}], ".format(op_node.name(),
None)
msg_str = "Cannot decide ProcessMesh of {} among {}. Please use shard_tensor api explicitly to annotate it".format(
node.var().name(), msg_str[:-2])
is_wrong = True
print(msg_str)
if node.is_op() and node.op() is not None:
op_dist_attr = dist_context.get_op_dist_attr_for_graph(node)
if op_dist_attr.process_mesh is None:
msg_str = ""
for tensor_node in node.inputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
msg_str += "{} [{}], ".format(
tensor_node.var().name(),
tensor_dist_attr.process_mesh)
else:
msg_str += "{} [{}], ".format(
tensor_node.name(), None)
for tensor_node in node.outputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
msg_str += "{} [{}], ".format(
tensor_node.var().name(),
tensor_dist_attr.process_mesh)
if total_changed:
total_reach_fix_point = False
else:
msg_str += "{} [{}], ".format(
tensor_node.name(), None)
msg_str = "Cannot decide ProcessMesh of {} among {}. Please use shard_op api explicitly to annotate it".format(
node.op().type(), msg_str[:-2])
is_wrong = True
print(msg_str)
if node.is_op() and node.op() is None:
print("op op is None", node.name())
if is_wrong:
assert False, "Cannot complete process_meshes of the program."
total_reach_fix_point = True
def _update_dims_mapping(self):
# Complete dims_mapping for each node
reach_fix_point = False
while not reach_fix_point:
changed = False
for is_fwd in [True, False]:
all_nodes = self._dist_context.serial_ordered_nodes \
if is_fwd else reversed(self._dist_context.serial_ordered_nodes)
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_dims_mapping(
dist_context, node, fwd=True)
if tensor_changed:
changed = True
if node.is_op() and node.op() is not None:
op_changed = update_op_node_dims_mapping(
dist_context, node, fwd=True)
if op_changed:
changed = True
for node in reversed(all_nodes):
if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_dims_mapping(
dist_context, node, fwd=False)
tensor_changed = self._update_tensor_node_dims_mapping(
node, fwd=is_fwd)
if tensor_changed:
changed = True
if node.is_op() and node.op() is not None:
op_changed = update_op_node_dims_mapping(
dist_context, node, fwd=False)
op_changed = self._update_op_node_dims_mapping(
node, fwd=is_fwd)
if op_changed:
changed = True
if changed:
......@@ -587,20 +323,44 @@ def complete_annotation(program, dist_context=None):
else:
reach_fix_point = True
# Copy the corresponding distributed attribute from graph to program
dist_context.copy_dist_attr_from_graph_to_program()
dist_context.clear_dist_info_for_graph()
def complete_forward_annotation(self, serial_main_program):
""" Complete annotation for the partial annotated serial_main_program.
# Do the validation check and amend some completion
dist_context.amend_dist_attr_for_program()
Arguments:
serial_main_program: partial annotated serial_main_program.
# print_program_with_dist_attr(program, dist_context)
dist_context.validate_dist_attr_for_program()
Returns:
serial_main_program: completed annotated serial_main_program.
"""
return program
# Use the default distribted context for completeion if there is no one
self._dist_context.serial_program = serial_main_program
# Initialize distributed attributes for all var and op node in serial_main_program
self._dist_context.init_dist_attr_for_program()
# Initialize distributed attributes for all var and op node in graph
self._dist_context.init_dist_attr_for_graph()
def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
self._update_process_mesh()
# Complete dims_mapping for each node
self._update_dims_mapping()
# Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program()
self._dist_context.clear_dist_info_for_graph()
# print_serial_main_program_with_dist_attr(serial_main_program, self._dist_context)
# Do the validation check and amend some completion
self._dist_context.amend_dist_attr_for_program()
# print_serial_main_program_with_dist_attr(serial_main_program, self._dist_context)
self._dist_context.validate_dist_attr_for_program()
return serial_main_program
def complete_backward_annotation(self, serial_main_program):
"""Complete the annotation of vars and ops in the backward phase for parallel program."""
def _is_grad_var_name(name):
......@@ -610,7 +370,8 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
def _get_forward_varname_from_grad_varname(grad_var_name):
assert _is_grad_var_name(
grad_var_name), "[{}] is not a grad varnme.".format(grad_var_name)
grad_var_name), "[{}] is not a grad varnme.".format(
grad_var_name)
return grad_var_name[:grad_var_name.find("@GRAD")]
def _get_op_by_id(ops, id):
......@@ -619,11 +380,8 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
return op
return None
if dist_context is None:
dist_context = get_default_distributed_context()
first_backward_op_idx = -1
for idx, op in enumerate(auto_parallel_main_prog.global_block().ops):
for idx, op in enumerate(serial_main_program.global_block().ops):
if int(op.attr('op_role')) == int(
int(core.op_proto_and_checker_maker.OpRole.Backward) | int(
core.op_proto_and_checker_maker.OpRole.Loss)):
......@@ -633,9 +391,9 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
assert first_backward_op_idx >= 0, "No backward procedure found in this program."
ops = list(auto_parallel_main_prog.global_block().ops)
vars = auto_parallel_main_prog.global_block().vars
dist_op_context = dist_context.dist_op_context
ops = list(serial_main_program.global_block().ops)
vars = serial_main_program.global_block().vars
dist_op_context = self._dist_context.dist_op_context
for idx in range(first_backward_op_idx, len(ops)):
......@@ -658,19 +416,21 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
# TODO complete other attribte for grad var
tensor_dist_attr = TensorDistributedAttribute()
process_mesh = dist_context.get_tensor_dist_attr_for_program(
process_mesh = self._dist_context.get_tensor_dist_attr_for_program(
forward_var).process_mesh
dims_mapping = dist_context.get_tensor_dist_attr_for_program(
dims_mapping = self._dist_context.get_tensor_dist_attr_for_program(
forward_var).dims_mapping
tensor_dist_attr.dims_mapping = dims_mapping
tensor_dist_attr.process_mesh = process_mesh
dist_context.set_tensor_dist_attr_for_program(grad_var,
tensor_dist_attr)
self._dist_context.set_tensor_dist_attr_for_program(
grad_var, tensor_dist_attr)
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = process_mesh
op_dist_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
dist_context.set_op_dist_attr_for_program(ops[idx], op_dist_attr)
op_dist_attr.set_output_dims_mapping(grad_var.name,
dims_mapping)
self._dist_context.set_op_dist_attr_for_program(ops[idx],
op_dist_attr)
continue
# complete the annotation of grad op (xxx_grad op or sum op)
......@@ -684,7 +444,7 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
assert forward_op is not None
# op dist attr
forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
forward_op)
forward_op_process_mesh = forward_op_dist_attr.process_mesh
grad_op_dist_attr = OperatorDistributedAttribute()
......@@ -700,7 +460,8 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping(
forward_name)
else:
if forward_op_dist_attr.get_input_dims_mapping(input_name):
if forward_op_dist_attr.get_input_dims_mapping(
input_name):
ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping(
input_name)
else:
......@@ -736,14 +497,14 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
output_var_dist_attr = TensorDistributedAttribute()
output_var_dist_attr.dims_mapping = ref_dims_mapping
output_var_dist_attr.process_mesh = forward_op_process_mesh
dist_context.set_tensor_dist_attr_for_program(
self._dist_context.set_tensor_dist_attr_for_program(
output_var, output_var_dist_attr)
grad_op_dist_attr.set_output_dims_mapping(output_var.name,
ref_dims_mapping)
grad_op_dist_attr.set_output_dims_mapping(
output_var.name, ref_dims_mapping)
dist_context.set_op_dist_attr_for_program(grad_op,
grad_op_dist_attr)
self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr)
# only sum op for merge mutiple version grad has no a corresponding mapping in grad_op_id_to_op_id
else:
......@@ -755,16 +516,16 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
ref_forward_var_name = _get_forward_varname_from_grad_varname(
grad_op.output_arg_names[0])
forward_var = vars[ref_forward_var_name]
ref_forward_var_dims_mapping = dist_context.get_tensor_dist_attr_for_program(
ref_forward_var_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program(
forward_var).dims_mapping
ref_forward_var_process_mesh = dist_context.get_tensor_dist_attr_for_program(
ref_forward_var_process_mesh = self._dist_context.get_tensor_dist_attr_for_program(
forward_var).process_mesh
# output
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.dims_mapping = ref_forward_var_dims_mapping
tensor_dist_attr.process_mesh = ref_forward_var_process_mesh
dist_context.set_tensor_dist_attr_for_program(
self._dist_context.set_tensor_dist_attr_for_program(
vars[grad_op.output_arg_names[0]], tensor_dist_attr)
# op
......@@ -778,18 +539,13 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
grad_op_dist_attr.set_output_dims_mapping(
grad_op.output_arg_names[0], ref_forward_var_dims_mapping)
dist_context.set_op_dist_attr_for_program(grad_op,
grad_op_dist_attr)
self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr)
def complete_update_annotation(auto_parallel_main_prog, dist_context):
def complete_update_annotation(self, serial_main_program):
"""Complete the annotation of vars and ops in the update phase for parallel program."""
if dist_context is None:
dist_context = get_default_distributed_context()
ops = list(auto_parallel_main_prog.global_block().ops)
vars = auto_parallel_main_prog.global_block().vars
ops = list(serial_main_program.global_block().ops)
vars = serial_main_program.global_block().vars
learning_rate_completed = False
for idx in range(len(ops)):
......@@ -798,28 +554,6 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context):
# TODO to add attribute for moment var
op = ops[idx]
if int(op.attr('op_role')) == int(OpRole.Optimize):
if op.type == "clip_by_norm":
param_grad = vars[op.input("X")[0]]
param_grad_dist_attr = dist_context.get_tensor_dist_attr_for_program(
param_grad)
assert param_grad_dist_attr is not None
ref_process_mesh = param_grad_dist_attr.process_mesh
ref_dims_mapping = param_grad_dist_attr.dims_mapping
out = vars[op.output("Out")[0]]
out_dist_attr = TensorDistributedAttribute()
out_dist_attr.process_mesh = ref_process_mesh
out_dist_attr.dims_mapping = ref_dims_mapping
dist_context.set_tensor_dist_attr_for_program(out,
out_dist_attr)
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = ref_process_mesh
op_dist_attr.set_input_dist_attr(param_grad.name,
param_grad_dist_attr)
op_dist_attr.set_output_dist_attr(out.name, out_dist_attr)
dist_context.set_op_dist_attr_for_program(op, op_dist_attr)
if "Grad" in op.input_names and "Param" in ops[idx].input_names:
assert len(op.input(
......@@ -829,13 +563,13 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context):
param = vars[op.input("Param")[0]]
grad_var = vars[op.input("Grad")[0]]
param_dist_attr = dist_context.get_tensor_dist_attr_for_program(
param_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
param)
assert param_dist_attr is not None
ref_process_mesh = dist_context.get_tensor_dist_attr_for_program(
ref_process_mesh = self._dist_context.get_tensor_dist_attr_for_program(
param).process_mesh
assert ref_process_mesh is not None
ref_dims_mapping = dist_context.get_tensor_dist_attr_for_program(
ref_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program(
param).dims_mapping
assert ref_dims_mapping is not None
op_dist_attr = OperatorDistributedAttribute()
......@@ -848,15 +582,16 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context):
ref_dims_mapping)
learning_var = vars[op.input("LearningRate")[0]]
op_dist_attr.set_input_dims_mapping(learning_var.name, [-1])
op_dist_attr.set_output_dims_mapping(learning_var.name, [-1])
op_dist_attr.set_output_dims_mapping(learning_var.name,
[-1])
if not learning_rate_completed:
learning_rate_completed = True
var_dist_attr = TensorDistributedAttribute()
var_dist_attr.process_mesh = ref_process_mesh
var_dist_attr.dims_mapping = [-1]
dist_context.set_tensor_dist_attr_for_program(learning_var,
var_dist_attr)
self._dist_context.set_tensor_dist_attr_for_program(
learning_var, var_dist_attr)
for input_name in op.desc.input_names():
......@@ -880,14 +615,15 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context):
else:
assert "Moment" in input_name
input_var_attr.dims_mapping = ref_dims_mapping
op_dist_attr.set_input_dims_mapping(input_var.name,
ref_dims_mapping)
op_dist_attr.set_output_dims_mapping(input_var.name,
ref_dims_mapping)
op_dist_attr.set_input_dims_mapping(
input_var.name, ref_dims_mapping)
op_dist_attr.set_output_dims_mapping(
input_var.name, ref_dims_mapping)
input_var_attr.process_mesh = ref_process_mesh
dist_context.set_tensor_dist_attr_for_program(
self._dist_context.set_tensor_dist_attr_for_program(
input_var, input_var_attr)
dist_context.set_op_dist_attr_for_program(op, op_dist_attr)
self._dist_context.set_op_dist_attr_for_program(
op, op_dist_attr)
continue
......@@ -247,23 +247,23 @@ class DistributedContext:
# new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr)
# self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
# def get_dist_attr_for_graph(self, serial_node):
# if serial_node.is_var() and serial_node.var() is not None:
# serial_tensor_node_id = serial_node.id()
# dist_tensor = self._dist_tensors_for_graph.get(
# serial_tensor_node_id, None)
# if dist_tensor:
# return dist_tensor.dist_attr
# else:
# return None
# if serial_node.is_op() and serial_node.op() is not None:
# serial_op_node_id = serial_node.id()
# dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
# if dist_op:
# return dist_op.dist_attr
# else:
# return None
# return None
def get_dist_attr_for_graph(self, serial_node):
if serial_node.is_var() and serial_node.var() is not None:
serial_tensor_node_id = serial_node.id()
dist_tensor = self._dist_tensors_for_graph.get(
serial_tensor_node_id, None)
if dist_tensor:
return dist_tensor.dist_attr
else:
return None
if serial_node.is_op() and serial_node.op() is not None:
serial_op_node_id = serial_node.id()
dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
if dist_op:
return dist_op.dist_attr
else:
return None
return None
def init_dist_attr_for_program(self):
assert self._serial_program, \
......
......@@ -32,7 +32,7 @@ from paddle.distributed.passes import new_pass, PassContext
from .dist_context import DistributedContext
from .dist_context import get_default_distributed_context
from .dist_context import set_default_distributed_context
from .completion import complete_annotation, complete_backward_annotation, complete_update_annotation
from .completion import Completer
from .partitioner import Partitioner
from .process_group import get_all_process_groups
from .process_group import get_process_group
......@@ -130,8 +130,8 @@ class AutoParallelizer:
no_grad_set,
callbacks,
distop_context=self._dist_context.dist_op_context)
complete_backward_annotation(
main_program, dist_context=self._dist_context)
self._completer = Completer(self._dist_context)
self._completer.complete_backward_annotation(main_program)
return params_grads
......@@ -142,8 +142,8 @@ class AutoParallelizer:
params_grads)
# update completion
complete_update_annotation(
main_program, dist_context=self._dist_context)
self._completer = Completer(self._dist_context)
self._completer.complete_update_annotation(main_program)
return optimize_ops
......@@ -179,8 +179,9 @@ class AutoParallelizer:
# Annotation completion
self._dist_context = DistributedContext()
_logger.info("Start annotation dist attr.")
completed_main_program = complete_annotation(serial_main_program,
self._dist_context)
self._completer = Completer(self._dist_context)
completed_main_program = self._completer.complete_forward_annotation(
serial_main_program)
else:
completed_main_program = serial_main_program
self._dist_context = copy.deepcopy(dist_context)
......
......@@ -27,6 +27,7 @@ import paddle.tensor as tensor
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
......@@ -154,10 +155,9 @@ class TestMLPAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_mlp_mp(self):
......@@ -171,10 +171,9 @@ class TestMLPAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_mlp_dp_mp(self):
......@@ -189,10 +188,9 @@ class TestMLPAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
# def test_mlp_misc(self):
......@@ -212,8 +210,8 @@ class TestMLPAutoCompletion(unittest.TestCase):
# train_program, start_program = mlp_pretrain_forward(train_program,
# start_program)
# # pdb.set_trace()
# complete_train_program = auto.complete_annotation(train_program,
# dist_context)
# completer = Completer(dist_context)
# complete_train_program = auto.completer.complete_forward_annotation(train_program)
# # print_program_with_dist_attr(complete_train_program,
# # dist_context)
# dist_context.finalize_distributed_attr_for_program(
......@@ -423,8 +421,9 @@ class TestAttentionAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program())
......@@ -440,10 +439,9 @@ class TestAttentionAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_attn_dp_mp(self):
......@@ -458,10 +456,9 @@ class TestAttentionAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
......@@ -747,10 +744,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_decoder_mp(self):
......@@ -764,10 +760,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_decoder_dp_mp(self):
......@@ -782,10 +777,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
......
......@@ -31,6 +31,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed.fleet import fleet
import paddle.static as static
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.dist_context import DistributedContext
......@@ -817,10 +818,9 @@ class TestGPTAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_gpt_mp(self):
......@@ -834,10 +834,9 @@ class TestGPTAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_gpt_dp_mp(self):
......@@ -852,10 +851,9 @@ class TestGPTAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
......
......@@ -23,6 +23,7 @@ import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
......@@ -154,8 +155,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context
# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
......
......@@ -18,6 +18,7 @@ import unittest
import paddle
from paddle.fluid import core
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
......@@ -42,8 +43,9 @@ def get_dist_prog(train_program,
parallelizer._dist_context = dist_context
# serial forward & backward completion
complete_train_program = auto.complete_annotation(
train_program, dist_context
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program
) if complete_train_program is None else complete_train_program
# parallelizer._apply_serial_forward_pass(complete_train_program,
......
......@@ -36,6 +36,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner
......@@ -433,6 +434,12 @@ class MLPLayer(nn.Layer):
out = F.gelu(out, approximate=True)
out = self.linear1(out)
auto.shard_tensor(
out,
dist_attr={
"process_mesh": _global_process_mesh[1],
"dims_mapping": [0, -1]
})
out = self.linear2(out)
out = F.gelu(out, approximate=True)
out = self.linear3(out)
......@@ -476,8 +483,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context
# auto completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
......
......@@ -28,6 +28,7 @@ import paddle.tensor as tensor
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
......@@ -49,8 +50,9 @@ def get_programs(annotated_func):
global _global_process_mesh
dist_context.process_mesh = _global_process_mesh
train_program, start_program = annotated_func(train_program, start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
rank_id = 3
dist_strategy = fleet.DistributedStrategy()
......
......@@ -31,6 +31,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed import fleet
import paddle.static as static
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.dist_context import DistributedContext
......@@ -881,8 +882,9 @@ class TestGPTPartitioner(unittest.TestCase):
dist_context.process_mesh = _global_process_mesh
train_program, startup_program, loss = gpt_pretrain_forward(
train_program, startup_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
# serial backward pass
params_grads = parallelizer._generate_backward(
......@@ -913,8 +915,9 @@ class TestGPTPartitioner(unittest.TestCase):
"w") as fw:
fw.write(str(auto_parallel_startup_prog))
# with open("./test_auto_parallel_partitioner_main_completed.txt", "w") as fw:
# from paddle.distributed.auto_parallel.completion import complete_backward_annotation
# complete_backward_annotation(auto_parallel_main_prog)
# from paddle.distributed.auto_parallel.completion import Completer
# completer = Completer()
# completer.complete_forward_annotation(auto_parallel_main_prog)
# fw.write(str(auto_parallel_main_prog))
nrank = 4
# col parallel
......
......@@ -22,6 +22,7 @@ import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
......@@ -152,8 +153,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context
# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
......@@ -299,7 +301,6 @@ class TestMLPReshard(unittest.TestCase):
for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
# print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......
......@@ -22,6 +22,7 @@ import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
......@@ -116,8 +117,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context
# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
......
......@@ -22,6 +22,7 @@ import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
......@@ -132,8 +133,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context
# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
......@@ -263,8 +265,9 @@ class TestMLPReshard(unittest.TestCase):
dist_context = DistributedContext()
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_context, rank_id)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition(
complete_train_program, startup_program, [])
reshard(partitioned_main_prog, partitioned_startup_prog, rank_id,
......
......@@ -154,7 +154,7 @@ class TestMLPSearcher(unittest.TestCase):
ops = train_program.global_block().ops
vars = train_program.global_block().vars
from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container
from paddle.distributed.auto_parallel.completion import is_elementwise_like_op
from paddle.distributed.auto_parallel.operators.common import is_elementwise_op
from paddle.distributed.auto_parallel.dist_op import DistributedOperator
for op in ops:
......@@ -163,7 +163,7 @@ class TestMLPSearcher(unittest.TestCase):
if dist_op_impl_container is None:
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
dist_op = DistributedOperator(op, op_dist_attr)
if is_elementwise_like_op(op.type):
if is_elementwise_op(op.type):
changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
dist_op)
self.assertFalse(changed)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册