未验证 提交 e5cda6fa 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Use the new completion algorithm (#39086)

* Add the backward support for QR

* Remove unnecessary comments

* [Auto Parallel] Improve the dist op interface and compatible computation

* Remove unnecessary modification

* Recover some modifications

* Add lost files

* Fix a minor bug

* Fix the bug of the planner

* Fix the format problem

* [Auto Parallel] Update the completion algorithm

* Fix the bug of auto_searcher unittest
上级 f68ef9d2
...@@ -15,12 +15,6 @@ ...@@ -15,12 +15,6 @@
from .interface import shard_tensor # noqa: F401 from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401 from .interface import shard_op # noqa: F401
from .process_mesh import ProcessMesh from .process_mesh import ProcessMesh
# from .interface import set_shard_mask # noqa: F401
# from .interface import set_offload_device # noqa: F401
# from .interface import set_pipeline_stage # noqa: F401
# from .interface import ProcessMesh # noqa: F401
from .completion import complete_annotation # noqa: F401
from .completion import complete_backward_annotation # noqa: F401
from .reshard import reshard # noqa: F401 from .reshard import reshard # noqa: F401
from .cost_model import estimate_cost from .cost_model import estimate_cost
......
...@@ -12,14 +12,13 @@ ...@@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
from copy import deepcopy from copy import deepcopy
import time
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
from .utils import compute_compatible_process_mesh
from .utils import compute_compatible_dim_mapping
from .utils import compute_compatible_dims_mapping
from .utils import print_program_with_dist_attr from .utils import print_program_with_dist_attr
from .operators import find_best_compatible_distributed_operator_impl from .operators import find_best_compatible_distributed_operator_impl
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
...@@ -29,865 +28,602 @@ from .dist_attribute import TensorDistributedAttribute ...@@ -29,865 +28,602 @@ from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute from .dist_attribute import OperatorDistributedAttribute
from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.distributed.fleet.meta_optimizers.common import OpRole
ELEMENTWISE_LIKE_OP_LIST = ["elementwise_add", "gelu", "dropout", "cast"]
def compute_compatible_process_mesh(process_mesh_list):
"""Compute the compatible process mesh given a list of process meshes."""
if not process_mesh_list:
return None
def is_elementwise_like_op(op_type): def _compute_compatible_process_mesh_two(pm1, pm2):
if op_type in ELEMENTWISE_LIKE_OP_LIST: if pm1 is None:
return True return True, pm2
else: if pm2 is None:
return False return True, pm1
if pm1 == pm2:
return True, pm1
if pm1.processes == pm2.processes:
if len(pm1.topology) >= len(pm2.topology):
return True, pm1
else:
return True, pm2
process_set1 = set(pm1.processes)
process_set2 = set(pm2.processes)
if process_set1.issubset(process_set2):
return True, pm2
if process_set2.issubset(process_set1):
return True, pm1
return False, None
compatible_result = None
for process_mesh in process_mesh_list:
compatible, compatible_result = _compute_compatible_process_mesh_two(
compatible_result, process_mesh)
if not compatible:
return None
return copy.deepcopy(compatible_result)
def compute_compatible_dim_mapping(dim_mapping_list):
"""Compute the compatible dim mapping given a list of dim mapping."""
if not dim_mapping_list:
return None
def update_tensor_node_process_mesh(dist_context, tensor_node, fwd=True): def _compute_compatible_dim_mapping_two(dm1, dm2):
""" if dm1 == -1:
Update tensor's process mesh by using its predecessor's process mesh if in the forward direction, return True, dm2
and by using its successor's process mesh if in the backward direction. Note: only the equal if dm2 == -1:
process meshes are compatible for now. return True, dm1
if dm1 == dm2:
return True, dm1
return False, None
compatible_result = -1
for mapping in dim_mapping_list:
compatible, compatible_result = _compute_compatible_dim_mapping_two(
compatible_result, mapping)
if not compatible:
return None
return compatible_result
def compute_compatible_dims_mapping(dims_mapping_list):
"""Compute the compatible dims mapping given a list of dims mapping.
Each of dims mapping is also a list.
""" """
changed = False if not dims_mapping_list:
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(tensor_node) return None
if tensor_dist_attr.is_annotated("process_mesh"): length = len(dims_mapping_list[0])
return changed for dims_mapping in dims_mapping_list:
tensor_process_mesh = tensor_dist_attr.process_mesh if dims_mapping is None:
if fwd: return None
inputs_process_meshes = [] if len(dims_mapping) != length:
for pred_op_node in tensor_node.inputs: return None
if pred_op_node.op() is not None: compatible_result = []
op_dist_attr = dist_context.get_op_dist_attr_for_graph( for dim_mappings in zip(*dims_mapping_list):
pred_op_node) compatible_dim_mapping = compute_compatible_dim_mapping(
op_process_mesh = op_dist_attr.process_mesh list(dim_mappings))
inputs_process_meshes.append(op_process_mesh) if compatible_dim_mapping is None:
compatible_process_mesh = compute_compatible_process_mesh( return None
inputs_process_meshes) compatible_result.append(compatible_dim_mapping)
if compatible_process_mesh is not None and tensor_process_mesh is None: return compatible_result
tensor_dist_attr.process_mesh = compatible_process_mesh
changed = True
else: class Completer:
outputs_process_meshes = [] def __init__(self, dist_context):
for succ_op_node in tensor_node.outputs: assert dist_context is not None
if succ_op_node.op() is not None: self._dist_context = dist_context
op_dist_attr = dist_context.get_op_dist_attr_for_graph(
succ_op_node) def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):
op_process_mesh = op_dist_attr.process_mesh changed = False
outputs_process_meshes.append(op_process_mesh) if (not tensor_node.is_var()) or (tensor_node.var() is None):
compatible_process_mesh = compute_compatible_process_mesh( return False
outputs_process_meshes) tensor_desc = tensor_node.var()
if compatible_process_mesh is not None and tensor_process_mesh is None: # Skip reader tensor
tensor_dist_attr.process_mesh = compatible_process_mesh if tensor_desc.type() == core.VarDesc.VarType.READER:
changed = True return False
return changed tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
assert tensor_dist_attr is not None
def update_op_node_process_mesh(dist_context, op_node, fwd=True): if tensor_dist_attr.is_annotated("dims_mapping"):
""" return False
Update op's process mesh by using its predecessor's process mesh if in the forward direction, tensor_dims_mapping = tensor_dist_attr.dims_mapping
and by using its successor's process mesh if in the backward direction. Note: only the equal if fwd:
process meshes are compatible for now. dims_mapping_list = []
""" for pred_op_node in tensor_node.inputs:
changed = False if pred_op_node.op() is not None:
op_dist_attr = dist_context.get_op_dist_attr_for_graph(op_node) if pred_op_node.op().type() == "create_py_reader" \
if op_dist_attr.is_annotated("process_mesh"): or pred_op_node.op().type() == "create_double_buffer_reader" \
return changed or pred_op_node.op().type() == "read":
op_process_mesh = op_dist_attr.process_mesh continue
if fwd: op_dist_attr = self._dist_context.get_op_dist_attr_for_graph(
inputs_process_meshes = [] pred_op_node)
for tensor_node in op_node.inputs: if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh:
if tensor_node.var() is not None: op_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( tensor_desc.name())
tensor_node) dims_mapping_list.append(op_dims_mapping)
tensor_process_mesh = tensor_dist_attr.process_mesh dims_mapping_list.append(tensor_dims_mapping)
inputs_process_meshes.append(tensor_process_mesh) compatible_dims_mapping = compute_compatible_dims_mapping(
compatible_process_mesh = compute_compatible_process_mesh( dims_mapping_list)
inputs_process_meshes) if (compatible_dims_mapping is not None) and \
if compatible_process_mesh is not None and op_process_mesh is None: (compatible_dims_mapping != tensor_dims_mapping):
op_dist_attr.process_mesh = compatible_process_mesh tensor_dist_attr.dims_mapping = compatible_dims_mapping
changed = True
else:
outputs_process_meshes = []
for tensor_node in op_node.outputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_process_mesh = tensor_dist_attr.process_mesh
outputs_process_meshes.append(tensor_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
outputs_process_meshes)
if compatible_process_mesh is not None and op_process_mesh is None:
op_dist_attr.process_mesh = compatible_process_mesh
changed = True
return changed
def update_op_dims_mapping_by_default_dist_impl(dist_context, op_node):
"""Each operator has a default distributed operator, only allowed to be sharded in batch dimension."""
changed = False
if (not op_node.is_op()) or (op_node.op() is None):
return False
op_desc = op_node.op()
dist_op = dist_context.get_dist_op_for_graph(op_node)
op_dist_attr = dist_op.dist_attr
# The following statement will be replaced by a more elegent way
if op_desc.type() == "shape" or op_desc.type() == "slice":
return False
output_names = op_desc.output_names()
xshape_arg_names = []
if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape")
batch_dim_mappings = []
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) > 1:
for idx, mapping in enumerate(dims_mapping[1:]):
assert mapping == -1, \
"{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
batch_dim_mappings.append(dims_mapping[0])
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for idx, mapping in enumerate(dims_mapping[1:]):
assert mapping == -1, \
"{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
batch_dim_mappings.append(dims_mapping[0])
else:
assert dims_mapping[0] == -1, \
"{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {} part."\
.format(op_desc.type(), mapping)
if len(dims_mapping) > 2:
for idx, mapping in enumerate(dims_mapping[2:]):
assert mapping == -1, \
"{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
batch_dim_mappings.append(dims_mapping[1])
compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping."
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True changed = True
else: else:
if compatible_dim_mapping != dims_mapping[1]: dims_mapping_list = []
dims_mapping[1] = compatible_dim_mapping for succ_op_node in tensor_node.outputs:
if succ_op_node.op() is not None:
if succ_op_node.op().type() == "create_py_reader" \
or succ_op_node.op().type() == "create_double_buffer_reader" \
or succ_op_node.op().type() == "read":
continue
op_dist_attr = self._dist_context.get_op_dist_attr_for_graph(
succ_op_node)
if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh:
op_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_desc.name())
dims_mapping_list.append(op_dims_mapping)
dims_mapping_list.append(tensor_dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list)
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != tensor_dims_mapping):
tensor_dist_attr.dims_mapping = compatible_dims_mapping
changed = True changed = True
return changed
return changed def _update_op_node_dims_mapping(self, op_node, fwd=True):
changed = False
if (not op_node.is_op()) or (op_node.op() is None):
def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_context, op_node): return False
"""Element-wise operator can be sharded in any way (but should take care of broadcasting).""" # Skip reader op
changed = False op_desc = op_node.op()
if (not op_node.is_op()) or (op_node.op() is None): if op_desc.type() == "create_py_reader" \
return False or op_desc.type() == "create_double_buffer_reader" \
op_desc = op_node.op() or op_desc.type() == "read":
op_dist_attr = dist_context.get_op_dist_attr_for_graph(op_node) return False
dist_op = self._dist_context.get_dist_op_for_graph(op_node)
input_arg_names = op_desc.input_arg_names() op_dist_attr = dist_op.dist_attr
input_dims_mapping_dict = {} if fwd:
input_dims_mapping_lens = {} for tensor_node in op_node.inputs:
max_dims_mapping_len = -1 if tensor_node.var() is not None:
for arg_name in input_arg_names: if tensor_node.var().type() == core.VarDesc.VarType.READER:
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) continue
if max_dims_mapping_len < len(dims_mapping): tensor_desc = tensor_node.var()
max_dims_mapping_len = len(dims_mapping) if op_dist_attr.is_annotated_input_dims_mapping(
input_dims_mapping_dict[arg_name] = dims_mapping tensor_desc.name()):
input_dims_mapping_lens[arg_name] = len(dims_mapping) continue
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
dims_mapping_list = [] tensor_node)
for arg_name in input_arg_names: if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: tensor_dims_mapping = tensor_dist_attr.dims_mapping
new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)] op_dims_mapping = op_dist_attr.get_input_dims_mapping(
for i in range(input_dims_mapping_lens[arg_name]): tensor_desc.name())
new_idx = (max_dims_mapping_len - compatible_dims_mapping = compute_compatible_dims_mapping(
input_dims_mapping_lens[arg_name]) + i [op_dims_mapping, tensor_dims_mapping])
new_dims_mapping[new_idx] = input_dims_mapping_dict[arg_name][i] if (compatible_dims_mapping is not None) and \
dims_mapping_list.append(new_dims_mapping) (compatible_dims_mapping != op_dims_mapping):
else: op_dist_attr.set_input_dims_mapping(
dims_mapping_list.append(input_dims_mapping_dict[arg_name]) tensor_desc.name(), compatible_dims_mapping)
output_arg_names = op_desc.output_arg_names() changed = True
for arg_name in output_arg_names: # Find the most compatible implemenetations from the distributed operator
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) op_dist_impl = find_best_compatible_distributed_operator_impl(
assert len(dims_mapping) == max_dims_mapping_len dist_op, fwd=True)
dims_mapping_list.append(dims_mapping) assert op_dist_impl is not None, "Cannot find the dist op implementation."
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
compatible_dims_mapping = compute_compatible_dims_mapping(dims_mapping_list) if dim_changed:
assert compatible_dims_mapping is not None, "There is no compatible dim mapping."
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [
-1 for _ in range(input_dims_mapping_lens[arg_name])
]
for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
input_dims_mapping_lens[arg_name]) + i
new_dims_mapping[i] = compatible_dims_mapping[new_idx]
if new_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping)
changed = True changed = True
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
else: else:
if compatible_dims_mapping != input_dims_mapping_dict[arg_name]: for tensor_node in op_node.outputs:
op_dist_attr.set_input_dims_mapping(arg_name, if tensor_node.var() is not None:
compatible_dims_mapping) if tensor_node.var().type() == core.VarDesc.VarType.READER:
continue
tensor_desc = tensor_node.var()
if op_dist_attr.is_annotated_output_dims_mapping(
tensor_desc.name()):
continue
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh:
tensor_dims_mapping = tensor_dist_attr.dims_mapping
op_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_desc.name())
compatible_dims_mapping = compute_compatible_dims_mapping(
[op_dims_mapping, tensor_dims_mapping])
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != op_dims_mapping):
op_dist_attr.set_output_dims_mapping(
tensor_desc.name(), compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl = find_best_compatible_distributed_operator_impl(
dist_op, fwd=False)
assert op_dist_impl is not None, "Cannot find the dist op implementation."
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True changed = True
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
return changed
for arg_name in output_arg_names: def _update_process_mesh(self):
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) def _find_nearset_node(nodes, idx):
if compatible_dims_mapping != dims_mapping: for node in reversed(nodes[:idx]):
op_dist_attr.set_output_dims_mapping(arg_name, node_dist_attr = self._dist_context.get_dist_attr_for_graph(
compatible_dims_mapping) node)
changed = True if node_dist_attr.process_mesh is not None:
return node
return changed
total_reach_fix_point = False
while not total_reach_fix_point:
def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): total_changed = False
changed = False for is_fwd in [True, False]:
if (not tensor_node.is_var()) or (tensor_node.var() is None): all_nodes = self._dist_context.serial_ordered_nodes \
return False if is_fwd else reversed(self._dist_context.serial_ordered_nodes)
tensor_desc = tensor_node.var() reach_fix_point = False
# Skip reader tensor while not reach_fix_point:
if tensor_desc.type() == core.VarDesc.VarType.READER: changed = False
return False for idx, node in enumerate(all_nodes):
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(tensor_node) nearest_node = _find_nearset_node(
assert tensor_dist_attr is not None self._dist_context.serial_ordered_nodes, idx)
if tensor_dist_attr.is_annotated("dims_mapping"): if nearest_node is None:
return False continue
tensor_dims_mapping = tensor_dist_attr.dims_mapping nearest_node_dis_attr = self._dist_context.get_dist_attr_for_graph(
if fwd: nearest_node)
dims_mapping_list = [] nearest_process_mesh = nearest_node_dis_attr.process_mesh
for pred_op_node in tensor_node.inputs: cur_node_dist_attr = self._dist_context.get_dist_attr_for_graph(
if pred_op_node.op() is not None: node)
if pred_op_node.op().type() == "create_py_reader" \ cur_process_mesh = cur_node_dist_attr.process_mesh
or pred_op_node.op().type() == "create_double_buffer_reader" \ compatible_process_mesh = compute_compatible_process_mesh(
or pred_op_node.op().type() == "read": [cur_process_mesh, nearest_process_mesh])
continue if compatible_process_mesh is not None \
op_dist_attr = dist_context.get_op_dist_attr_for_graph( and cur_process_mesh != compatible_process_mesh:
pred_op_node) cur_node_dist_attr.process_mesh = compatible_process_mesh
op_dims_mapping = op_dist_attr.get_output_dims_mapping( changed = True
tensor_desc.name()) if changed:
dims_mapping_list.append(op_dims_mapping) reach_fix_point = False
dims_mapping_list.append(tensor_dims_mapping) total_changed = True
compatible_dims_mapping = compute_compatible_dims_mapping( else:
dims_mapping_list) reach_fix_point = True
if (compatible_dims_mapping is not None) and \ if total_changed:
(compatible_dims_mapping != tensor_dims_mapping): total_reach_fix_point = False
tensor_dist_attr.dims_mapping = compatible_dims_mapping
changed = True
else:
dims_mapping_list = []
for succ_op_node in tensor_node.outputs:
if succ_op_node.op() is not None:
if succ_op_node.op().type() == "create_py_reader" \
or succ_op_node.op().type() == "create_double_buffer_reader" \
or succ_op_node.op().type() == "read":
continue
op_dist_attr = dist_context.get_op_dist_attr_for_graph(
succ_op_node)
op_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_desc.name())
dims_mapping_list.append(op_dims_mapping)
dims_mapping_list.append(tensor_dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list)
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != tensor_dims_mapping):
tensor_dist_attr.dims_mapping = compatible_dims_mapping
changed = True
return changed
def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
changed = False
if (not op_node.is_op()) or (op_node.op() is None):
return False
# Skip reader op
op_desc = op_node.op()
if op_desc.type() == "create_py_reader" \
or op_desc.type() == "create_double_buffer_reader" \
or op_desc.type() == "read":
return False
dist_op = dist_context.get_dist_op_for_graph(op_node)
op_dist_attr = dist_op.dist_attr
if fwd:
for tensor_node in op_node.inputs:
if tensor_node.var() is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER:
continue
tensor_desc = tensor_node.var()
if op_dist_attr.is_annotated_input_dims_mapping(
tensor_desc.name()):
continue
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_dims_mapping = tensor_dist_attr.dims_mapping
op_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_desc.name())
compatible_dims_mapping = compute_compatible_dims_mapping(
[op_dims_mapping, tensor_dims_mapping])
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != op_dims_mapping):
op_dist_attr.set_input_dims_mapping(tensor_desc.name(),
compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl = find_best_compatible_distributed_operator_impl(
dist_op, fwd=True)
assert op_dist_impl is not None, "Cannot find the dist op implementation."
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
else:
for tensor_node in op_node.outputs:
if tensor_node.var() is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER:
continue
tensor_desc = tensor_node.var()
if op_dist_attr.is_annotated_output_dims_mapping(
tensor_desc.name()):
continue
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_dims_mapping = tensor_dist_attr.dims_mapping
op_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_desc.name())
compatible_dims_mapping = compute_compatible_dims_mapping(
[op_dims_mapping, tensor_dims_mapping])
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != op_dims_mapping):
op_dist_attr.set_output_dims_mapping(
tensor_desc.name(), compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl = find_best_compatible_distributed_operator_impl(
dist_op, fwd=False)
assert op_dist_impl is not None, "Cannot find the dist op implementation."
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else: else:
op_dist_attr.impl_type = op_dist_impl.type total_reach_fix_point = True
op_dist_attr.impl_idx = op_dist_impl.idx
return changed
def complete_annotation(program, dist_context=None):
""" Complete annotation for the partial annotated program.
Arguments:
program: partial annotated program.
dist_context: the distributed context is used to store distributed attributes for program.
If not provided, the default one will be used.
Returns:
program: completed annotated program.
"""
# Use the default distribted context for completeion if there is no one
if dist_context is None:
dist_context = get_default_distributed_context()
dist_context.serial_program = program
else:
dist_context.serial_program = program
# print_program_with_dist_attr(program, dist_context)
# Initialize distributed attributes for all var and op node in program
dist_context.init_dist_attr_for_program()
# Initialize distributed attributes for all var and op node in graph
dist_context.init_dist_attr_for_graph()
# Complete process mesh for each node
all_nodes = list(dist_context.serial_graph.all_nodes())
def sort_key_fun(node): def _update_dims_mapping(self):
first = -1 # Complete dims_mapping for each node
if node.is_op(): reach_fix_point = False
first = 0 while not reach_fix_point:
else:
first = 1
second = -1
if node.is_op() and node.op() is not None:
second = node.op().id()
if node.is_var() and node.var() is not None:
second = node.var().id()
return (first, second)
all_nodes.sort(key=sort_key_fun)
reach_fix_point = False
while not reach_fix_point:
total_changed = False
reach_fwd_fix_point = False
reach_bwd_fix_point = False
while not reach_fwd_fix_point:
changed = False changed = False
for node in all_nodes: for is_fwd in [True, False]:
if node.is_var() and node.var() is not None: all_nodes = self._dist_context.serial_ordered_nodes \
tensor_changed = update_tensor_node_process_mesh( if is_fwd else reversed(self._dist_context.serial_ordered_nodes)
dist_context, node, fwd=True) for node in all_nodes:
if tensor_changed: if node.is_var() and node.var() is not None:
changed = True tensor_changed = self._update_tensor_node_dims_mapping(
if node.is_op() and node.op() is not None: node, fwd=is_fwd)
op_changed = update_op_node_process_mesh( if tensor_changed:
dist_context, node, fwd=True) changed = True
if op_changed: if node.is_op() and node.op() is not None:
changed = True op_changed = self._update_op_node_dims_mapping(
node, fwd=is_fwd)
if op_changed:
changed = True
if changed: if changed:
reach_fwd_fix_point = False reach_fix_point = False
total_changed = True
else: else:
reach_fwd_fix_point = True reach_fix_point = True
while not reach_bwd_fix_point:
changed = False def complete_forward_annotation(self, serial_main_program):
for node in all_nodes: """ Complete annotation for the partial annotated serial_main_program.
if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_process_mesh( Arguments:
dist_context, node, fwd=False) serial_main_program: partial annotated serial_main_program.
if tensor_changed:
changed = True Returns:
if node.is_op() and node.op() is not None: serial_main_program: completed annotated serial_main_program.
op_changed = update_op_node_process_mesh( """
dist_context, node, fwd=False)
if op_changed: # Use the default distribted context for completeion if there is no one
changed = True self._dist_context.serial_program = serial_main_program
if changed:
reach_bwd_fix_point = False # Initialize distributed attributes for all var and op node in serial_main_program
total_changed = True self._dist_context.init_dist_attr_for_program()
else:
reach_bwd_fix_point = True # Initialize distributed attributes for all var and op node in graph
if total_changed: self._dist_context.init_dist_attr_for_graph()
reach_fix_point = False
else: self._update_process_mesh()
reach_fix_point = True
# Validation the completion of process meshes and should be moved to a proper location # Complete dims_mapping for each node
is_wrong = False self._update_dims_mapping()
for node in all_nodes:
if node.is_var() and node.var() is not None: # Copy the corresponding distributed attribute from graph to serial_main_program
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( self._dist_context.copy_dist_attr_from_graph_to_program()
node) self._dist_context.clear_dist_info_for_graph()
if tensor_dist_attr.process_mesh is None:
msg_str = "" # print_serial_main_program_with_dist_attr(serial_main_program, self._dist_context)
for op_node in node.inputs: # Do the validation check and amend some completion
if op_node.op() is not None: self._dist_context.amend_dist_attr_for_program()
op_dist_attr = dist_context.get_op_dist_attr_for_graph(
op_node) # print_serial_main_program_with_dist_attr(serial_main_program, self._dist_context)
msg_str += "{} [{}], ".format( self._dist_context.validate_dist_attr_for_program()
op_node.op().type(),
op_dist_attr.process_mesh) return serial_main_program
else:
msg_str += "{} [{}], ".format(op_node.name(), def complete_backward_annotation(self, serial_main_program):
None) """Complete the annotation of vars and ops in the backward phase for parallel program."""
for op_node in node.outputs:
if op_node.op() is not None: def _is_grad_var_name(name):
op_dist_attr = dist_context.get_op_dist_attr_for_graph( if "@GRAD" in name:
op_node) return True
msg_str += "{} [{}], ".format( return False
op_node.op().type(),
op_dist_attr.process_mesh) def _get_forward_varname_from_grad_varname(grad_var_name):
else: assert _is_grad_var_name(
msg_str += "{} [{}], ".format(op_node.name(), grad_var_name), "[{}] is not a grad varnme.".format(
None) grad_var_name)
msg_str = "Cannot decide ProcessMesh of {} among {}. Please use shard_tensor api explicitly to annotate it".format( return grad_var_name[:grad_var_name.find("@GRAD")]
node.var().name(), msg_str[:-2])
is_wrong = True def _get_op_by_id(ops, id):
print(msg_str) for op in ops:
if node.is_op() and node.op() is not None: if op.desc.id() == id:
op_dist_attr = dist_context.get_op_dist_attr_for_graph(node) return op
if op_dist_attr.process_mesh is None: return None
msg_str = ""
for tensor_node in node.inputs: first_backward_op_idx = -1
if tensor_node.var() is not None: for idx, op in enumerate(serial_main_program.global_block().ops):
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( if int(op.attr('op_role')) == int(
tensor_node) int(core.op_proto_and_checker_maker.OpRole.Backward) | int(
msg_str += "{} [{}], ".format( core.op_proto_and_checker_maker.OpRole.Loss)):
tensor_node.var().name(), assert op.type == "fill_constant"
tensor_dist_attr.process_mesh) first_backward_op_idx = idx
else: break
msg_str += "{} [{}], ".format(
tensor_node.name(), None) assert first_backward_op_idx >= 0, "No backward procedure found in this program."
for tensor_node in node.outputs:
if tensor_node.var() is not None: ops = list(serial_main_program.global_block().ops)
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( vars = serial_main_program.global_block().vars
tensor_node) dist_op_context = self._dist_context.dist_op_context
msg_str += "{} [{}], ".format(
tensor_node.var().name(), for idx in range(first_backward_op_idx, len(ops)):
tensor_dist_attr.process_mesh)
else: # complete the initial grad loss op
msg_str += "{} [{}], ".format( if idx == first_backward_op_idx:
tensor_node.name(), None) assert ops[idx].type == "fill_constant"
msg_str = "Cannot decide ProcessMesh of {} among {}. Please use shard_op api explicitly to annotate it".format( assert len(
node.op().type(), msg_str[:-2]) ops[idx].input_arg_names
is_wrong = True ) == 0, "first backward op should has only ONE output, but got [{}]".format(
print(msg_str) len(ops[idx].input_arg_names))
if node.is_op() and node.op() is None: assert len(
print("op op is None", node.name()) ops[idx].output_arg_names
if is_wrong: ) == 1, "first backward op should has only ONE output, but got [{}]".format(
assert False, "Cannot complete process_meshes of the program." len(ops[idx].output_arg_names))
# Complete dims_mapping for each node grad_var = vars[ops[idx].output_arg_names[0]]
reach_fix_point = False forward_var_name = _get_forward_varname_from_grad_varname(
while not reach_fix_point: grad_var.name)
changed = False forward_var = vars[forward_var_name]
for node in all_nodes:
if node.is_var() and node.var() is not None: # TODO complete other attribte for grad var
tensor_changed = update_tensor_node_dims_mapping( tensor_dist_attr = TensorDistributedAttribute()
dist_context, node, fwd=True) process_mesh = self._dist_context.get_tensor_dist_attr_for_program(
if tensor_changed: forward_var).process_mesh
changed = True dims_mapping = self._dist_context.get_tensor_dist_attr_for_program(
if node.is_op() and node.op() is not None: forward_var).dims_mapping
op_changed = update_op_node_dims_mapping( tensor_dist_attr.dims_mapping = dims_mapping
dist_context, node, fwd=True) tensor_dist_attr.process_mesh = process_mesh
if op_changed: self._dist_context.set_tensor_dist_attr_for_program(
changed = True grad_var, tensor_dist_attr)
for node in reversed(all_nodes):
if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_dims_mapping(
dist_context, node, fwd=False)
if tensor_changed:
changed = True
if node.is_op() and node.op() is not None:
op_changed = update_op_node_dims_mapping(
dist_context, node, fwd=False)
if op_changed:
changed = True
if changed:
reach_fix_point = False
else:
reach_fix_point = True
# Copy the corresponding distributed attribute from graph to program
dist_context.copy_dist_attr_from_graph_to_program()
dist_context.clear_dist_info_for_graph()
# Do the validation check and amend some completion
dist_context.amend_dist_attr_for_program()
# print_program_with_dist_attr(program, dist_context)
dist_context.validate_dist_attr_for_program()
return program op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = process_mesh
op_dist_attr.set_output_dims_mapping(grad_var.name,
def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): dims_mapping)
"""Complete the annotation of vars and ops in the backward phase for parallel program.""" self._dist_context.set_op_dist_attr_for_program(ops[idx],
op_dist_attr)
def _is_grad_var_name(name): continue
if "@GRAD" in name:
return True
return False
def _get_forward_varname_from_grad_varname(grad_var_name):
assert _is_grad_var_name(
grad_var_name), "[{}] is not a grad varnme.".format(grad_var_name)
return grad_var_name[:grad_var_name.find("@GRAD")]
def _get_op_by_id(ops, id):
for op in ops:
if op.desc.id() == id:
return op
return None
if dist_context is None: # complete the annotation of grad op (xxx_grad op or sum op)
dist_context = get_default_distributed_context() # xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id
grad_op = ops[idx]
first_backward_op_idx = -1 if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id:
for idx, op in enumerate(auto_parallel_main_prog.global_block().ops): # TODO support the case where one forward op corresponding to multiple xxx_grad op
if int(op.attr('op_role')) == int( forward_op = _get_op_by_id(
int(core.op_proto_and_checker_maker.OpRole.Backward) | int( ops[:first_backward_op_idx],
core.op_proto_and_checker_maker.OpRole.Loss)): dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()])
assert op.type == "fill_constant" assert forward_op is not None
first_backward_op_idx = idx
break # op dist attr
forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
assert first_backward_op_idx >= 0, "No backward procedure found in this program." forward_op)
forward_op_process_mesh = forward_op_dist_attr.process_mesh
ops = list(auto_parallel_main_prog.global_block().ops) grad_op_dist_attr = OperatorDistributedAttribute()
vars = auto_parallel_main_prog.global_block().vars grad_op_dist_attr.process_mesh = forward_op_process_mesh
dist_op_context = dist_context.dist_op_context
# var
for idx in range(first_backward_op_idx, len(ops)): for input_name in grad_op.input_arg_names:
input_var = vars[input_name]
# complete the initial grad loss op ref_dims_mapping = None
if idx == first_backward_op_idx: if "@GRAD" in input_name:
assert ops[idx].type == "fill_constant" forward_name = _get_forward_varname_from_grad_varname(
assert len(
ops[idx].input_arg_names
) == 0, "first backward op should has only ONE output, but got [{}]".format(
len(ops[idx].input_arg_names))
assert len(
ops[idx].output_arg_names
) == 1, "first backward op should has only ONE output, but got [{}]".format(
len(ops[idx].output_arg_names))
grad_var = vars[ops[idx].output_arg_names[0]]
forward_var_name = _get_forward_varname_from_grad_varname(
grad_var.name)
forward_var = vars[forward_var_name]
# TODO complete other attribte for grad var
tensor_dist_attr = TensorDistributedAttribute()
process_mesh = dist_context.get_tensor_dist_attr_for_program(
forward_var).process_mesh
dims_mapping = dist_context.get_tensor_dist_attr_for_program(
forward_var).dims_mapping
tensor_dist_attr.dims_mapping = dims_mapping
tensor_dist_attr.process_mesh = process_mesh
dist_context.set_tensor_dist_attr_for_program(grad_var,
tensor_dist_attr)
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = process_mesh
op_dist_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
dist_context.set_op_dist_attr_for_program(ops[idx], op_dist_attr)
continue
# complete the annotation of grad op (xxx_grad op or sum op)
# xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id
grad_op = ops[idx]
if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id:
# TODO support the case where one forward op corresponding to multiple xxx_grad op
forward_op = _get_op_by_id(
ops[:first_backward_op_idx],
dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()])
assert forward_op is not None
# op dist attr
forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op)
forward_op_process_mesh = forward_op_dist_attr.process_mesh
grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_dist_attr.process_mesh = forward_op_process_mesh
# var
for input_name in grad_op.input_arg_names:
input_var = vars[input_name]
ref_dims_mapping = None
if "@GRAD" in input_name:
forward_name = _get_forward_varname_from_grad_varname(
input_name)
ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping(
forward_name)
else:
if forward_op_dist_attr.get_input_dims_mapping(input_name):
ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping(
input_name) input_name)
else:
ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping( ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping(
input_name) forward_name)
else:
assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format( if forward_op_dist_attr.get_input_dims_mapping(
input_var.name) input_name):
grad_op_dist_attr.set_input_dims_mapping(input_name, ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping(
ref_dims_mapping) input_name)
else:
for output_name in grad_op.desc.output_names(): ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping(
assert len(grad_op.desc.output(output_name)) in [0, 1] input_name)
if _is_grad_var_name(output_name):
input_name = _get_forward_varname_from_grad_varname( assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format(
output_name) input_var.name)
else: grad_op_dist_attr.set_input_dims_mapping(input_name,
assert grad_op.type in [ ref_dims_mapping)
"cast", "c_identity", "c_allreduce_sum"
]
input_name = "X"
assert input_name in forward_op.desc.input_names(
), "var [{}] in op [{}]'s output but could not find [{}] in its forward op".format(
output_name, grad_op.type, input_name)
if len(grad_op.desc.output(output_name)) == 1:
# tensor dist attr
output_var = vars[grad_op.desc.output(output_name)[0]]
forward_name = _get_forward_varname_from_grad_varname(
output_var.name)
ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping(
forward_name)
output_var_dist_attr = TensorDistributedAttribute()
output_var_dist_attr.dims_mapping = ref_dims_mapping
output_var_dist_attr.process_mesh = forward_op_process_mesh
dist_context.set_tensor_dist_attr_for_program(
output_var, output_var_dist_attr)
grad_op_dist_attr.set_output_dims_mapping(output_var.name,
ref_dims_mapping)
dist_context.set_op_dist_attr_for_program(grad_op,
grad_op_dist_attr)
# only sum op for merge mutiple version grad has no a corresponding mapping in grad_op_id_to_op_id
else:
assert grad_op.type == "sum", "got unexpect op [{}]".format(
str(grad_op.type))
assert all(map(_is_grad_var_name, grad_op.input_arg_names))
assert len(grad_op.output_arg_names) == 1
ref_forward_var_name = _get_forward_varname_from_grad_varname(
grad_op.output_arg_names[0])
forward_var = vars[ref_forward_var_name]
ref_forward_var_dims_mapping = dist_context.get_tensor_dist_attr_for_program(
forward_var).dims_mapping
ref_forward_var_process_mesh = dist_context.get_tensor_dist_attr_for_program(
forward_var).process_mesh
# output
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.dims_mapping = ref_forward_var_dims_mapping
tensor_dist_attr.process_mesh = ref_forward_var_process_mesh
dist_context.set_tensor_dist_attr_for_program(
vars[grad_op.output_arg_names[0]], tensor_dist_attr)
# op
grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_dist_attr.process_mesh = ref_forward_var_process_mesh
for var_name in grad_op.input_arg_names:
assert _get_forward_varname_from_grad_varname(
var_name) == ref_forward_var_name
grad_op_dist_attr.set_input_dims_mapping(
var_name, ref_forward_var_dims_mapping)
grad_op_dist_attr.set_output_dims_mapping(
grad_op.output_arg_names[0], ref_forward_var_dims_mapping)
dist_context.set_op_dist_attr_for_program(grad_op,
grad_op_dist_attr)
def complete_update_annotation(auto_parallel_main_prog, dist_context):
"""Complete the annotation of vars and ops in the update phase for parallel program."""
if dist_context is None:
dist_context = get_default_distributed_context()
ops = list(auto_parallel_main_prog.global_block().ops)
vars = auto_parallel_main_prog.global_block().vars
learning_rate_completed = False
for idx in range(len(ops)):
# complete the annotation of the optimizer op.
# TODO to add attribute for moment var
op = ops[idx]
if int(op.attr('op_role')) == int(OpRole.Optimize):
if op.type == "clip_by_norm":
param_grad = vars[op.input("X")[0]]
param_grad_dist_attr = dist_context.get_tensor_dist_attr_for_program(
param_grad)
assert param_grad_dist_attr is not None
ref_process_mesh = param_grad_dist_attr.process_mesh
ref_dims_mapping = param_grad_dist_attr.dims_mapping
out = vars[op.output("Out")[0]]
out_dist_attr = TensorDistributedAttribute()
out_dist_attr.process_mesh = ref_process_mesh
out_dist_attr.dims_mapping = ref_dims_mapping
dist_context.set_tensor_dist_attr_for_program(out,
out_dist_attr)
op_dist_attr = OperatorDistributedAttribute() for output_name in grad_op.desc.output_names():
op_dist_attr.process_mesh = ref_process_mesh assert len(grad_op.desc.output(output_name)) in [0, 1]
op_dist_attr.set_input_dist_attr(param_grad.name, if _is_grad_var_name(output_name):
param_grad_dist_attr) input_name = _get_forward_varname_from_grad_varname(
op_dist_attr.set_output_dist_attr(out.name, out_dist_attr) output_name)
dist_context.set_op_dist_attr_for_program(op, op_dist_attr) else:
assert grad_op.type in [
if "Grad" in op.input_names and "Param" in ops[idx].input_names: "cast", "c_identity", "c_allreduce_sum"
assert len(op.input( ]
"Param")) == 1, "Only support one-to-one now." input_name = "X"
assert len(op.input( assert input_name in forward_op.desc.input_names(
"Grad")) == 1, "Only support one-to-one now." ), "var [{}] in op [{}]'s output but could not find [{}] in its forward op".format(
param = vars[op.input("Param")[0]] output_name, grad_op.type, input_name)
grad_var = vars[op.input("Grad")[0]] if len(grad_op.desc.output(output_name)) == 1:
# tensor dist attr
param_dist_attr = dist_context.get_tensor_dist_attr_for_program( output_var = vars[grad_op.desc.output(output_name)[0]]
param) forward_name = _get_forward_varname_from_grad_varname(
assert param_dist_attr is not None output_var.name)
ref_process_mesh = dist_context.get_tensor_dist_attr_for_program( ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping(
param).process_mesh forward_name)
assert ref_process_mesh is not None
ref_dims_mapping = dist_context.get_tensor_dist_attr_for_program(
param).dims_mapping
assert ref_dims_mapping is not None
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = ref_process_mesh
op_dist_attr.set_input_dims_mapping(grad_var.name,
ref_dims_mapping)
op_dist_attr.set_input_dims_mapping(param.name,
ref_dims_mapping)
op_dist_attr.set_output_dims_mapping(param.name,
ref_dims_mapping)
learning_var = vars[op.input("LearningRate")[0]]
op_dist_attr.set_input_dims_mapping(learning_var.name, [-1])
op_dist_attr.set_output_dims_mapping(learning_var.name, [-1])
if not learning_rate_completed:
learning_rate_completed = True
var_dist_attr = TensorDistributedAttribute()
var_dist_attr.process_mesh = ref_process_mesh
var_dist_attr.dims_mapping = [-1]
dist_context.set_tensor_dist_attr_for_program(learning_var,
var_dist_attr)
for input_name in op.desc.input_names():
if input_name in [
'Param', 'Grad', 'LearningRate', "SkipUpdate",
"Beta1Tensor", "Beta2Tensor", "EpsilonTensor",
"MasterParam"
]:
continue
assert len(op.desc.input(input_name)) == 1 output_var_dist_attr = TensorDistributedAttribute()
input_var = vars[op.desc.input(input_name)[0]] output_var_dist_attr.dims_mapping = ref_dims_mapping
input_var_attr = TensorDistributedAttribute() output_var_dist_attr.process_mesh = forward_op_process_mesh
self._dist_context.set_tensor_dist_attr_for_program(
output_var, output_var_dist_attr)
if "Beta1Pow" in input_name or "Beta2Pow" in input_name: grad_op_dist_attr.set_output_dims_mapping(
input_var_attr.dims_mapping = [-1] output_var.name, ref_dims_mapping)
op_dist_attr.set_input_dims_mapping(input_var.name,
[-1])
op_dist_attr.set_output_dims_mapping(input_var.name,
[-1])
else:
assert "Moment" in input_name
input_var_attr.dims_mapping = ref_dims_mapping
op_dist_attr.set_input_dims_mapping(input_var.name,
ref_dims_mapping)
op_dist_attr.set_output_dims_mapping(input_var.name,
ref_dims_mapping)
input_var_attr.process_mesh = ref_process_mesh self._dist_context.set_op_dist_attr_for_program(
dist_context.set_tensor_dist_attr_for_program( grad_op, grad_op_dist_attr)
input_var, input_var_attr)
dist_context.set_op_dist_attr_for_program(op, op_dist_attr) # only sum op for merge mutiple version grad has no a corresponding mapping in grad_op_id_to_op_id
continue else:
assert grad_op.type == "sum", "got unexpect op [{}]".format(
str(grad_op.type))
assert all(map(_is_grad_var_name, grad_op.input_arg_names))
assert len(grad_op.output_arg_names) == 1
ref_forward_var_name = _get_forward_varname_from_grad_varname(
grad_op.output_arg_names[0])
forward_var = vars[ref_forward_var_name]
ref_forward_var_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program(
forward_var).dims_mapping
ref_forward_var_process_mesh = self._dist_context.get_tensor_dist_attr_for_program(
forward_var).process_mesh
# output
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.dims_mapping = ref_forward_var_dims_mapping
tensor_dist_attr.process_mesh = ref_forward_var_process_mesh
self._dist_context.set_tensor_dist_attr_for_program(
vars[grad_op.output_arg_names[0]], tensor_dist_attr)
# op
grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_dist_attr.process_mesh = ref_forward_var_process_mesh
for var_name in grad_op.input_arg_names:
assert _get_forward_varname_from_grad_varname(
var_name) == ref_forward_var_name
grad_op_dist_attr.set_input_dims_mapping(
var_name, ref_forward_var_dims_mapping)
grad_op_dist_attr.set_output_dims_mapping(
grad_op.output_arg_names[0], ref_forward_var_dims_mapping)
self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr)
def complete_update_annotation(self, serial_main_program):
"""Complete the annotation of vars and ops in the update phase for parallel program."""
ops = list(serial_main_program.global_block().ops)
vars = serial_main_program.global_block().vars
learning_rate_completed = False
for idx in range(len(ops)):
# complete the annotation of the optimizer op.
# TODO to add attribute for moment var
op = ops[idx]
if int(op.attr('op_role')) == int(OpRole.Optimize):
if "Grad" in op.input_names and "Param" in ops[idx].input_names:
assert len(op.input(
"Param")) == 1, "Only support one-to-one now."
assert len(op.input(
"Grad")) == 1, "Only support one-to-one now."
param = vars[op.input("Param")[0]]
grad_var = vars[op.input("Grad")[0]]
param_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
param)
assert param_dist_attr is not None
ref_process_mesh = self._dist_context.get_tensor_dist_attr_for_program(
param).process_mesh
assert ref_process_mesh is not None
ref_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program(
param).dims_mapping
assert ref_dims_mapping is not None
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = ref_process_mesh
op_dist_attr.set_input_dims_mapping(grad_var.name,
ref_dims_mapping)
op_dist_attr.set_input_dims_mapping(param.name,
ref_dims_mapping)
op_dist_attr.set_output_dims_mapping(param.name,
ref_dims_mapping)
learning_var = vars[op.input("LearningRate")[0]]
op_dist_attr.set_input_dims_mapping(learning_var.name, [-1])
op_dist_attr.set_output_dims_mapping(learning_var.name,
[-1])
if not learning_rate_completed:
learning_rate_completed = True
var_dist_attr = TensorDistributedAttribute()
var_dist_attr.process_mesh = ref_process_mesh
var_dist_attr.dims_mapping = [-1]
self._dist_context.set_tensor_dist_attr_for_program(
learning_var, var_dist_attr)
for input_name in op.desc.input_names():
if input_name in [
'Param', 'Grad', 'LearningRate', "SkipUpdate",
"Beta1Tensor", "Beta2Tensor", "EpsilonTensor",
"MasterParam"
]:
continue
assert len(op.desc.input(input_name)) == 1
input_var = vars[op.desc.input(input_name)[0]]
input_var_attr = TensorDistributedAttribute()
if "Beta1Pow" in input_name or "Beta2Pow" in input_name:
input_var_attr.dims_mapping = [-1]
op_dist_attr.set_input_dims_mapping(input_var.name,
[-1])
op_dist_attr.set_output_dims_mapping(input_var.name,
[-1])
else:
assert "Moment" in input_name
input_var_attr.dims_mapping = ref_dims_mapping
op_dist_attr.set_input_dims_mapping(
input_var.name, ref_dims_mapping)
op_dist_attr.set_output_dims_mapping(
input_var.name, ref_dims_mapping)
input_var_attr.process_mesh = ref_process_mesh
self._dist_context.set_tensor_dist_attr_for_program(
input_var, input_var_attr)
self._dist_context.set_op_dist_attr_for_program(
op, op_dist_attr)
continue
...@@ -247,23 +247,23 @@ class DistributedContext: ...@@ -247,23 +247,23 @@ class DistributedContext:
# new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr) # new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr)
# self._dist_ops_for_graph[serial_op_node_id] = new_dist_op # self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
# def get_dist_attr_for_graph(self, serial_node): def get_dist_attr_for_graph(self, serial_node):
# if serial_node.is_var() and serial_node.var() is not None: if serial_node.is_var() and serial_node.var() is not None:
# serial_tensor_node_id = serial_node.id() serial_tensor_node_id = serial_node.id()
# dist_tensor = self._dist_tensors_for_graph.get( dist_tensor = self._dist_tensors_for_graph.get(
# serial_tensor_node_id, None) serial_tensor_node_id, None)
# if dist_tensor: if dist_tensor:
# return dist_tensor.dist_attr return dist_tensor.dist_attr
# else: else:
# return None return None
# if serial_node.is_op() and serial_node.op() is not None: if serial_node.is_op() and serial_node.op() is not None:
# serial_op_node_id = serial_node.id() serial_op_node_id = serial_node.id()
# dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None) dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
# if dist_op: if dist_op:
# return dist_op.dist_attr return dist_op.dist_attr
# else: else:
# return None return None
# return None return None
def init_dist_attr_for_program(self): def init_dist_attr_for_program(self):
assert self._serial_program, \ assert self._serial_program, \
......
...@@ -32,7 +32,7 @@ from paddle.distributed.passes import new_pass, PassContext ...@@ -32,7 +32,7 @@ from paddle.distributed.passes import new_pass, PassContext
from .dist_context import DistributedContext from .dist_context import DistributedContext
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .dist_context import set_default_distributed_context from .dist_context import set_default_distributed_context
from .completion import complete_annotation, complete_backward_annotation, complete_update_annotation from .completion import Completer
from .partitioner import Partitioner from .partitioner import Partitioner
from .process_group import get_all_process_groups from .process_group import get_all_process_groups
from .process_group import get_process_group from .process_group import get_process_group
...@@ -130,8 +130,8 @@ class AutoParallelizer: ...@@ -130,8 +130,8 @@ class AutoParallelizer:
no_grad_set, no_grad_set,
callbacks, callbacks,
distop_context=self._dist_context.dist_op_context) distop_context=self._dist_context.dist_op_context)
complete_backward_annotation( self._completer = Completer(self._dist_context)
main_program, dist_context=self._dist_context) self._completer.complete_backward_annotation(main_program)
return params_grads return params_grads
...@@ -142,8 +142,8 @@ class AutoParallelizer: ...@@ -142,8 +142,8 @@ class AutoParallelizer:
params_grads) params_grads)
# update completion # update completion
complete_update_annotation( self._completer = Completer(self._dist_context)
main_program, dist_context=self._dist_context) self._completer.complete_update_annotation(main_program)
return optimize_ops return optimize_ops
...@@ -179,8 +179,9 @@ class AutoParallelizer: ...@@ -179,8 +179,9 @@ class AutoParallelizer:
# Annotation completion # Annotation completion
self._dist_context = DistributedContext() self._dist_context = DistributedContext()
_logger.info("Start annotation dist attr.") _logger.info("Start annotation dist attr.")
completed_main_program = complete_annotation(serial_main_program, self._completer = Completer(self._dist_context)
self._dist_context) completed_main_program = self._completer.complete_forward_annotation(
serial_main_program)
else: else:
completed_main_program = serial_main_program completed_main_program = serial_main_program
self._dist_context = copy.deepcopy(dist_context) self._dist_context = copy.deepcopy(dist_context)
......
...@@ -27,6 +27,7 @@ import paddle.tensor as tensor ...@@ -27,6 +27,7 @@ import paddle.tensor as tensor
from paddle.fluid import layers from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
...@@ -154,10 +155,9 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -154,10 +155,9 @@ class TestMLPAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program, train_program, start_program = mlp_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_mlp_mp(self): def test_mlp_mp(self):
...@@ -171,10 +171,9 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -171,10 +171,9 @@ class TestMLPAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program, train_program, start_program = mlp_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_mlp_dp_mp(self): def test_mlp_dp_mp(self):
...@@ -189,10 +188,9 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -189,10 +188,9 @@ class TestMLPAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program, train_program, start_program = mlp_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
# def test_mlp_misc(self): # def test_mlp_misc(self):
...@@ -212,8 +210,8 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -212,8 +210,8 @@ class TestMLPAutoCompletion(unittest.TestCase):
# train_program, start_program = mlp_pretrain_forward(train_program, # train_program, start_program = mlp_pretrain_forward(train_program,
# start_program) # start_program)
# # pdb.set_trace() # # pdb.set_trace()
# complete_train_program = auto.complete_annotation(train_program, # completer = Completer(dist_context)
# dist_context) # complete_train_program = auto.completer.complete_forward_annotation(train_program)
# # print_program_with_dist_attr(complete_train_program, # # print_program_with_dist_attr(complete_train_program,
# # dist_context) # # dist_context)
# dist_context.finalize_distributed_attr_for_program( # dist_context.finalize_distributed_attr_for_program(
...@@ -423,8 +421,9 @@ class TestAttentionAutoCompletion(unittest.TestCase): ...@@ -423,8 +421,9 @@ class TestAttentionAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program, train_program, start_program = attn_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
# print_program_with_dist_attr(complete_train_program, # print_program_with_dist_attr(complete_train_program,
# dist_context) # dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
...@@ -440,10 +439,9 @@ class TestAttentionAutoCompletion(unittest.TestCase): ...@@ -440,10 +439,9 @@ class TestAttentionAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program, train_program, start_program = attn_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_attn_dp_mp(self): def test_attn_dp_mp(self):
...@@ -458,10 +456,9 @@ class TestAttentionAutoCompletion(unittest.TestCase): ...@@ -458,10 +456,9 @@ class TestAttentionAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program, train_program, start_program = attn_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
...@@ -747,10 +744,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): ...@@ -747,10 +744,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program, train_program, start_program = decoder_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_decoder_mp(self): def test_decoder_mp(self):
...@@ -764,10 +760,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): ...@@ -764,10 +760,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program, train_program, start_program = decoder_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_decoder_dp_mp(self): def test_decoder_dp_mp(self):
...@@ -782,10 +777,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): ...@@ -782,10 +777,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program, train_program, start_program = decoder_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
......
...@@ -31,6 +31,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer ...@@ -31,6 +31,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed.fleet import fleet from paddle.distributed.fleet import fleet
import paddle.static as static import paddle.static as static
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
...@@ -817,10 +818,9 @@ class TestGPTAutoCompletion(unittest.TestCase): ...@@ -817,10 +818,9 @@ class TestGPTAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program, train_program, start_program = gpt_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_gpt_mp(self): def test_gpt_mp(self):
...@@ -834,10 +834,9 @@ class TestGPTAutoCompletion(unittest.TestCase): ...@@ -834,10 +834,9 @@ class TestGPTAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program, train_program, start_program = gpt_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_gpt_dp_mp(self): def test_gpt_dp_mp(self):
...@@ -852,10 +851,9 @@ class TestGPTAutoCompletion(unittest.TestCase): ...@@ -852,10 +851,9 @@ class TestGPTAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program, train_program, start_program = gpt_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
......
...@@ -23,6 +23,7 @@ import paddle.static as static ...@@ -23,6 +23,7 @@ import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
...@@ -154,8 +155,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -154,8 +155,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context parallelizer._dist_context = dist_context
# serial forward & backward completion # serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
......
...@@ -18,6 +18,7 @@ import unittest ...@@ -18,6 +18,7 @@ import unittest
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
...@@ -42,8 +43,9 @@ def get_dist_prog(train_program, ...@@ -42,8 +43,9 @@ def get_dist_prog(train_program,
parallelizer._dist_context = dist_context parallelizer._dist_context = dist_context
# serial forward & backward completion # serial forward & backward completion
complete_train_program = auto.complete_annotation( completer = Completer(dist_context)
train_program, dist_context complete_train_program = completer.complete_forward_annotation(
train_program
) if complete_train_program is None else complete_train_program ) if complete_train_program is None else complete_train_program
# parallelizer._apply_serial_forward_pass(complete_train_program, # parallelizer._apply_serial_forward_pass(complete_train_program,
......
...@@ -36,6 +36,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer ...@@ -36,6 +36,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed import fleet from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
...@@ -433,6 +434,12 @@ class MLPLayer(nn.Layer): ...@@ -433,6 +434,12 @@ class MLPLayer(nn.Layer):
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
out = self.linear1(out) out = self.linear1(out)
auto.shard_tensor(
out,
dist_attr={
"process_mesh": _global_process_mesh[1],
"dims_mapping": [0, -1]
})
out = self.linear2(out) out = self.linear2(out)
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
out = self.linear3(out) out = self.linear3(out)
...@@ -476,8 +483,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -476,8 +483,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context parallelizer._dist_context = dist_context
# auto completion # auto completion
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
......
...@@ -28,6 +28,7 @@ import paddle.tensor as tensor ...@@ -28,6 +28,7 @@ import paddle.tensor as tensor
from paddle.fluid import layers from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
...@@ -49,8 +50,9 @@ def get_programs(annotated_func): ...@@ -49,8 +50,9 @@ def get_programs(annotated_func):
global _global_process_mesh global _global_process_mesh
dist_context.process_mesh = _global_process_mesh dist_context.process_mesh = _global_process_mesh
train_program, start_program = annotated_func(train_program, start_program) train_program, start_program = annotated_func(train_program, start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
rank_id = 3 rank_id = 3
dist_strategy = fleet.DistributedStrategy() dist_strategy = fleet.DistributedStrategy()
......
...@@ -31,6 +31,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer ...@@ -31,6 +31,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed import fleet from paddle.distributed import fleet
import paddle.static as static import paddle.static as static
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
...@@ -881,8 +882,9 @@ class TestGPTPartitioner(unittest.TestCase): ...@@ -881,8 +882,9 @@ class TestGPTPartitioner(unittest.TestCase):
dist_context.process_mesh = _global_process_mesh dist_context.process_mesh = _global_process_mesh
train_program, startup_program, loss = gpt_pretrain_forward( train_program, startup_program, loss = gpt_pretrain_forward(
train_program, startup_program) train_program, startup_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
# serial backward pass # serial backward pass
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
...@@ -913,8 +915,9 @@ class TestGPTPartitioner(unittest.TestCase): ...@@ -913,8 +915,9 @@ class TestGPTPartitioner(unittest.TestCase):
"w") as fw: "w") as fw:
fw.write(str(auto_parallel_startup_prog)) fw.write(str(auto_parallel_startup_prog))
# with open("./test_auto_parallel_partitioner_main_completed.txt", "w") as fw: # with open("./test_auto_parallel_partitioner_main_completed.txt", "w") as fw:
# from paddle.distributed.auto_parallel.completion import complete_backward_annotation # from paddle.distributed.auto_parallel.completion import Completer
# complete_backward_annotation(auto_parallel_main_prog) # completer = Completer()
# completer.complete_forward_annotation(auto_parallel_main_prog)
# fw.write(str(auto_parallel_main_prog)) # fw.write(str(auto_parallel_main_prog))
nrank = 4 nrank = 4
# col parallel # col parallel
......
...@@ -22,6 +22,7 @@ import paddle.static as static ...@@ -22,6 +22,7 @@ import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
...@@ -152,8 +153,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -152,8 +153,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context parallelizer._dist_context = dist_context
# serial forward & backward completion # serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
...@@ -299,7 +301,6 @@ class TestMLPReshard(unittest.TestCase): ...@@ -299,7 +301,6 @@ class TestMLPReshard(unittest.TestCase):
for key in list(_g_process_group_map.keys()): for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key] del _g_process_group_map[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
# print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......
...@@ -22,6 +22,7 @@ import paddle.static as static ...@@ -22,6 +22,7 @@ import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
...@@ -116,8 +117,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -116,8 +117,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context parallelizer._dist_context = dist_context
# serial forward & backward completion # serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
......
...@@ -22,6 +22,7 @@ import paddle.static as static ...@@ -22,6 +22,7 @@ import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
...@@ -132,8 +133,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -132,8 +133,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context parallelizer._dist_context = dist_context
# serial forward & backward completion # serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
...@@ -263,8 +265,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -263,8 +265,9 @@ class TestMLPReshard(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
dist_strategy = fleet.DistributedStrategy() dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_context, rank_id) partitioner = Partitioner(dist_context, rank_id)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition( partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition(
complete_train_program, startup_program, []) complete_train_program, startup_program, [])
reshard(partitioned_main_prog, partitioned_startup_prog, rank_id, reshard(partitioned_main_prog, partitioned_startup_prog, rank_id,
......
...@@ -154,7 +154,7 @@ class TestMLPSearcher(unittest.TestCase): ...@@ -154,7 +154,7 @@ class TestMLPSearcher(unittest.TestCase):
ops = train_program.global_block().ops ops = train_program.global_block().ops
vars = train_program.global_block().vars vars = train_program.global_block().vars
from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container
from paddle.distributed.auto_parallel.completion import is_elementwise_like_op from paddle.distributed.auto_parallel.operators.common import is_elementwise_op
from paddle.distributed.auto_parallel.dist_op import DistributedOperator from paddle.distributed.auto_parallel.dist_op import DistributedOperator
for op in ops: for op in ops:
...@@ -163,7 +163,7 @@ class TestMLPSearcher(unittest.TestCase): ...@@ -163,7 +163,7 @@ class TestMLPSearcher(unittest.TestCase):
if dist_op_impl_container is None: if dist_op_impl_container is None:
op_dist_attr = dist_context.get_op_dist_attr_for_program(op) op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
dist_op = DistributedOperator(op, op_dist_attr) dist_op = DistributedOperator(op, op_dist_attr)
if is_elementwise_like_op(op.type): if is_elementwise_op(op.type):
changed = update_op_dims_mapping_by_elementwise_like_dist_impl( changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
dist_op) dist_op)
self.assertFalse(changed) self.assertFalse(changed)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册