未验证 提交 e5cda6fa 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Use the new completion algorithm (#39086)

* Add the backward support for QR

* Remove unnecessary comments

* [Auto Parallel] Improve the dist op interface and compatible computation

* Remove unnecessary modification

* Recover some modifications

* Add lost files

* Fix a minor bug

* Fix the bug of the planner

* Fix the format problem

* [Auto Parallel] Update the completion algorithm

* Fix the bug of auto_searcher unittest
上级 f68ef9d2
...@@ -15,12 +15,6 @@ ...@@ -15,12 +15,6 @@
from .interface import shard_tensor # noqa: F401 from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401 from .interface import shard_op # noqa: F401
from .process_mesh import ProcessMesh from .process_mesh import ProcessMesh
# from .interface import set_shard_mask # noqa: F401
# from .interface import set_offload_device # noqa: F401
# from .interface import set_pipeline_stage # noqa: F401
# from .interface import ProcessMesh # noqa: F401
from .completion import complete_annotation # noqa: F401
from .completion import complete_backward_annotation # noqa: F401
from .reshard import reshard # noqa: F401 from .reshard import reshard # noqa: F401
from .cost_model import estimate_cost from .cost_model import estimate_cost
......
...@@ -12,14 +12,13 @@ ...@@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
from copy import deepcopy from copy import deepcopy
import time
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
from .utils import compute_compatible_process_mesh
from .utils import compute_compatible_dim_mapping
from .utils import compute_compatible_dims_mapping
from .utils import print_program_with_dist_attr from .utils import print_program_with_dist_attr
from .operators import find_best_compatible_distributed_operator_impl from .operators import find_best_compatible_distributed_operator_impl
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
...@@ -29,241 +28,92 @@ from .dist_attribute import TensorDistributedAttribute ...@@ -29,241 +28,92 @@ from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute from .dist_attribute import OperatorDistributedAttribute
from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.distributed.fleet.meta_optimizers.common import OpRole
ELEMENTWISE_LIKE_OP_LIST = ["elementwise_add", "gelu", "dropout", "cast"]
def is_elementwise_like_op(op_type):
if op_type in ELEMENTWISE_LIKE_OP_LIST:
return True
else:
return False
def update_tensor_node_process_mesh(dist_context, tensor_node, fwd=True):
"""
Update tensor's process mesh by using its predecessor's process mesh if in the forward direction,
and by using its successor's process mesh if in the backward direction. Note: only the equal
process meshes are compatible for now.
"""
changed = False
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(tensor_node)
if tensor_dist_attr.is_annotated("process_mesh"):
return changed
tensor_process_mesh = tensor_dist_attr.process_mesh
if fwd:
inputs_process_meshes = []
for pred_op_node in tensor_node.inputs:
if pred_op_node.op() is not None:
op_dist_attr = dist_context.get_op_dist_attr_for_graph(
pred_op_node)
op_process_mesh = op_dist_attr.process_mesh
inputs_process_meshes.append(op_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
inputs_process_meshes)
if compatible_process_mesh is not None and tensor_process_mesh is None:
tensor_dist_attr.process_mesh = compatible_process_mesh
changed = True
else:
outputs_process_meshes = []
for succ_op_node in tensor_node.outputs:
if succ_op_node.op() is not None:
op_dist_attr = dist_context.get_op_dist_attr_for_graph(
succ_op_node)
op_process_mesh = op_dist_attr.process_mesh
outputs_process_meshes.append(op_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
outputs_process_meshes)
if compatible_process_mesh is not None and tensor_process_mesh is None:
tensor_dist_attr.process_mesh = compatible_process_mesh
changed = True
return changed
def update_op_node_process_mesh(dist_context, op_node, fwd=True):
"""
Update op's process mesh by using its predecessor's process mesh if in the forward direction,
and by using its successor's process mesh if in the backward direction. Note: only the equal
process meshes are compatible for now.
"""
changed = False
op_dist_attr = dist_context.get_op_dist_attr_for_graph(op_node)
if op_dist_attr.is_annotated("process_mesh"):
return changed
op_process_mesh = op_dist_attr.process_mesh
if fwd:
inputs_process_meshes = []
for tensor_node in op_node.inputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_process_mesh = tensor_dist_attr.process_mesh
inputs_process_meshes.append(tensor_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
inputs_process_meshes)
if compatible_process_mesh is not None and op_process_mesh is None:
op_dist_attr.process_mesh = compatible_process_mesh
changed = True
else:
outputs_process_meshes = []
for tensor_node in op_node.outputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_process_mesh = tensor_dist_attr.process_mesh
outputs_process_meshes.append(tensor_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
outputs_process_meshes)
if compatible_process_mesh is not None and op_process_mesh is None:
op_dist_attr.process_mesh = compatible_process_mesh
changed = True
return changed
def compute_compatible_process_mesh(process_mesh_list):
"""Compute the compatible process mesh given a list of process meshes."""
if not process_mesh_list:
return None
def update_op_dims_mapping_by_default_dist_impl(dist_context, op_node): def _compute_compatible_process_mesh_two(pm1, pm2):
"""Each operator has a default distributed operator, only allowed to be sharded in batch dimension.""" if pm1 is None:
changed = False return True, pm2
if (not op_node.is_op()) or (op_node.op() is None): if pm2 is None:
return False return True, pm1
op_desc = op_node.op() if pm1 == pm2:
dist_op = dist_context.get_dist_op_for_graph(op_node) return True, pm1
op_dist_attr = dist_op.dist_attr if pm1.processes == pm2.processes:
# The following statement will be replaced by a more elegent way if len(pm1.topology) >= len(pm2.topology):
if op_desc.type() == "shape" or op_desc.type() == "slice": return True, pm1
return False
output_names = op_desc.output_names()
xshape_arg_names = []
if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape")
batch_dim_mappings = []
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) > 1:
for idx, mapping in enumerate(dims_mapping[1:]):
assert mapping == -1, \
"{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
batch_dim_mappings.append(dims_mapping[0])
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for idx, mapping in enumerate(dims_mapping[1:]):
assert mapping == -1, \
"{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
batch_dim_mappings.append(dims_mapping[0])
else:
assert dims_mapping[0] == -1, \
"{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {} part."\
.format(op_desc.type(), mapping)
if len(dims_mapping) > 2:
for idx, mapping in enumerate(dims_mapping[2:]):
assert mapping == -1, \
"{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
batch_dim_mappings.append(dims_mapping[1])
compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping."
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
else: else:
if compatible_dim_mapping != dims_mapping[1]: return True, pm2
dims_mapping[1] = compatible_dim_mapping process_set1 = set(pm1.processes)
changed = True process_set2 = set(pm2.processes)
if process_set1.issubset(process_set2):
return True, pm2
if process_set2.issubset(process_set1):
return True, pm1
return False, None
compatible_result = None
for process_mesh in process_mesh_list:
compatible, compatible_result = _compute_compatible_process_mesh_two(
compatible_result, process_mesh)
if not compatible:
return None
return copy.deepcopy(compatible_result)
return changed
def compute_compatible_dim_mapping(dim_mapping_list):
"""Compute the compatible dim mapping given a list of dim mapping."""
if not dim_mapping_list:
return None
def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_context, op_node): def _compute_compatible_dim_mapping_two(dm1, dm2):
"""Element-wise operator can be sharded in any way (but should take care of broadcasting).""" if dm1 == -1:
changed = False return True, dm2
if (not op_node.is_op()) or (op_node.op() is None): if dm2 == -1:
return False return True, dm1
op_desc = op_node.op() if dm1 == dm2:
op_dist_attr = dist_context.get_op_dist_attr_for_graph(op_node) return True, dm1
return False, None
input_arg_names = op_desc.input_arg_names()
input_dims_mapping_dict = {} compatible_result = -1
input_dims_mapping_lens = {} for mapping in dim_mapping_list:
max_dims_mapping_len = -1 compatible, compatible_result = _compute_compatible_dim_mapping_two(
for arg_name in input_arg_names: compatible_result, mapping)
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) if not compatible:
if max_dims_mapping_len < len(dims_mapping): return None
max_dims_mapping_len = len(dims_mapping) return compatible_result
input_dims_mapping_dict[arg_name] = dims_mapping
input_dims_mapping_lens[arg_name] = len(dims_mapping)
dims_mapping_list = []
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)]
for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
input_dims_mapping_lens[arg_name]) + i
new_dims_mapping[new_idx] = input_dims_mapping_dict[arg_name][i]
dims_mapping_list.append(new_dims_mapping)
else:
dims_mapping_list.append(input_dims_mapping_dict[arg_name])
output_arg_names = op_desc.output_arg_names()
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
assert len(dims_mapping) == max_dims_mapping_len
dims_mapping_list.append(dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping(dims_mapping_list)
assert compatible_dims_mapping is not None, "There is no compatible dim mapping."
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [
-1 for _ in range(input_dims_mapping_lens[arg_name])
]
for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
input_dims_mapping_lens[arg_name]) + i
new_dims_mapping[i] = compatible_dims_mapping[new_idx]
if new_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping)
changed = True
else:
if compatible_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name,
compatible_dims_mapping)
changed = True
for arg_name in output_arg_names: def compute_compatible_dims_mapping(dims_mapping_list):
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) """Compute the compatible dims mapping given a list of dims mapping.
if compatible_dims_mapping != dims_mapping: Each of dims mapping is also a list.
op_dist_attr.set_output_dims_mapping(arg_name, """
compatible_dims_mapping) if not dims_mapping_list:
changed = True return None
length = len(dims_mapping_list[0])
for dims_mapping in dims_mapping_list:
if dims_mapping is None:
return None
if len(dims_mapping) != length:
return None
compatible_result = []
for dim_mappings in zip(*dims_mapping_list):
compatible_dim_mapping = compute_compatible_dim_mapping(
list(dim_mappings))
if compatible_dim_mapping is None:
return None
compatible_result.append(compatible_dim_mapping)
return compatible_result
return changed
class Completer:
def __init__(self, dist_context):
assert dist_context is not None
self._dist_context = dist_context
def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):
changed = False changed = False
if (not tensor_node.is_var()) or (tensor_node.var() is None): if (not tensor_node.is_var()) or (tensor_node.var() is None):
return False return False
...@@ -271,7 +121,8 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): ...@@ -271,7 +121,8 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
# Skip reader tensor # Skip reader tensor
if tensor_desc.type() == core.VarDesc.VarType.READER: if tensor_desc.type() == core.VarDesc.VarType.READER:
return False return False
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(tensor_node) tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
assert tensor_dist_attr is not None assert tensor_dist_attr is not None
if tensor_dist_attr.is_annotated("dims_mapping"): if tensor_dist_attr.is_annotated("dims_mapping"):
return False return False
...@@ -284,8 +135,9 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): ...@@ -284,8 +135,9 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
or pred_op_node.op().type() == "create_double_buffer_reader" \ or pred_op_node.op().type() == "create_double_buffer_reader" \
or pred_op_node.op().type() == "read": or pred_op_node.op().type() == "read":
continue continue
op_dist_attr = dist_context.get_op_dist_attr_for_graph( op_dist_attr = self._dist_context.get_op_dist_attr_for_graph(
pred_op_node) pred_op_node)
if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh:
op_dims_mapping = op_dist_attr.get_output_dims_mapping( op_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_desc.name()) tensor_desc.name())
dims_mapping_list.append(op_dims_mapping) dims_mapping_list.append(op_dims_mapping)
...@@ -304,8 +156,9 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): ...@@ -304,8 +156,9 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
or succ_op_node.op().type() == "create_double_buffer_reader" \ or succ_op_node.op().type() == "create_double_buffer_reader" \
or succ_op_node.op().type() == "read": or succ_op_node.op().type() == "read":
continue continue
op_dist_attr = dist_context.get_op_dist_attr_for_graph( op_dist_attr = self._dist_context.get_op_dist_attr_for_graph(
succ_op_node) succ_op_node)
if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh:
op_dims_mapping = op_dist_attr.get_input_dims_mapping( op_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_desc.name()) tensor_desc.name())
dims_mapping_list.append(op_dims_mapping) dims_mapping_list.append(op_dims_mapping)
...@@ -318,8 +171,7 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): ...@@ -318,8 +171,7 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
changed = True changed = True
return changed return changed
def _update_op_node_dims_mapping(self, op_node, fwd=True):
def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
changed = False changed = False
if (not op_node.is_op()) or (op_node.op() is None): if (not op_node.is_op()) or (op_node.op() is None):
return False return False
...@@ -329,7 +181,7 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): ...@@ -329,7 +181,7 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
or op_desc.type() == "create_double_buffer_reader" \ or op_desc.type() == "create_double_buffer_reader" \
or op_desc.type() == "read": or op_desc.type() == "read":
return False return False
dist_op = dist_context.get_dist_op_for_graph(op_node) dist_op = self._dist_context.get_dist_op_for_graph(op_node)
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
if fwd: if fwd:
for tensor_node in op_node.inputs: for tensor_node in op_node.inputs:
...@@ -340,8 +192,9 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): ...@@ -340,8 +192,9 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
if op_dist_attr.is_annotated_input_dims_mapping( if op_dist_attr.is_annotated_input_dims_mapping(
tensor_desc.name()): tensor_desc.name()):
continue continue
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node) tensor_node)
if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh:
tensor_dims_mapping = tensor_dist_attr.dims_mapping tensor_dims_mapping = tensor_dist_attr.dims_mapping
op_dims_mapping = op_dist_attr.get_input_dims_mapping( op_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_desc.name()) tensor_desc.name())
...@@ -349,8 +202,8 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): ...@@ -349,8 +202,8 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
[op_dims_mapping, tensor_dims_mapping]) [op_dims_mapping, tensor_dims_mapping])
if (compatible_dims_mapping is not None) and \ if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != op_dims_mapping): (compatible_dims_mapping != op_dims_mapping):
op_dist_attr.set_input_dims_mapping(tensor_desc.name(), op_dist_attr.set_input_dims_mapping(
compatible_dims_mapping) tensor_desc.name(), compatible_dims_mapping)
changed = True changed = True
# Find the most compatible implemenetations from the distributed operator # Find the most compatible implemenetations from the distributed operator
op_dist_impl = find_best_compatible_distributed_operator_impl( op_dist_impl = find_best_compatible_distributed_operator_impl(
...@@ -374,8 +227,9 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): ...@@ -374,8 +227,9 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
if op_dist_attr.is_annotated_output_dims_mapping( if op_dist_attr.is_annotated_output_dims_mapping(
tensor_desc.name()): tensor_desc.name()):
continue continue
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node) tensor_node)
if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh:
tensor_dims_mapping = tensor_dist_attr.dims_mapping tensor_dims_mapping = tensor_dist_attr.dims_mapping
op_dims_mapping = op_dist_attr.get_output_dims_mapping( op_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_desc.name()) tensor_desc.name())
...@@ -401,185 +255,67 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): ...@@ -401,185 +255,67 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
op_dist_attr.impl_idx = op_dist_impl.idx op_dist_attr.impl_idx = op_dist_impl.idx
return changed return changed
def _update_process_mesh(self):
def _find_nearset_node(nodes, idx):
for node in reversed(nodes[:idx]):
node_dist_attr = self._dist_context.get_dist_attr_for_graph(
node)
if node_dist_attr.process_mesh is not None:
return node
def complete_annotation(program, dist_context=None): total_reach_fix_point = False
""" Complete annotation for the partial annotated program. while not total_reach_fix_point:
total_changed = False
Arguments: for is_fwd in [True, False]:
program: partial annotated program. all_nodes = self._dist_context.serial_ordered_nodes \
dist_context: the distributed context is used to store distributed attributes for program. if is_fwd else reversed(self._dist_context.serial_ordered_nodes)
If not provided, the default one will be used.
Returns:
program: completed annotated program.
"""
# Use the default distribted context for completeion if there is no one
if dist_context is None:
dist_context = get_default_distributed_context()
dist_context.serial_program = program
else:
dist_context.serial_program = program
# print_program_with_dist_attr(program, dist_context)
# Initialize distributed attributes for all var and op node in program
dist_context.init_dist_attr_for_program()
# Initialize distributed attributes for all var and op node in graph
dist_context.init_dist_attr_for_graph()
# Complete process mesh for each node
all_nodes = list(dist_context.serial_graph.all_nodes())
def sort_key_fun(node):
first = -1
if node.is_op():
first = 0
else:
first = 1
second = -1
if node.is_op() and node.op() is not None:
second = node.op().id()
if node.is_var() and node.var() is not None:
second = node.var().id()
return (first, second)
all_nodes.sort(key=sort_key_fun)
reach_fix_point = False reach_fix_point = False
while not reach_fix_point: while not reach_fix_point:
total_changed = False
reach_fwd_fix_point = False
reach_bwd_fix_point = False
while not reach_fwd_fix_point:
changed = False
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_process_mesh(
dist_context, node, fwd=True)
if tensor_changed:
changed = True
if node.is_op() and node.op() is not None:
op_changed = update_op_node_process_mesh(
dist_context, node, fwd=True)
if op_changed:
changed = True
if changed:
reach_fwd_fix_point = False
total_changed = True
else:
reach_fwd_fix_point = True
while not reach_bwd_fix_point:
changed = False changed = False
for node in all_nodes: for idx, node in enumerate(all_nodes):
if node.is_var() and node.var() is not None: nearest_node = _find_nearset_node(
tensor_changed = update_tensor_node_process_mesh( self._dist_context.serial_ordered_nodes, idx)
dist_context, node, fwd=False) if nearest_node is None:
if tensor_changed: continue
changed = True nearest_node_dis_attr = self._dist_context.get_dist_attr_for_graph(
if node.is_op() and node.op() is not None: nearest_node)
op_changed = update_op_node_process_mesh( nearest_process_mesh = nearest_node_dis_attr.process_mesh
dist_context, node, fwd=False) cur_node_dist_attr = self._dist_context.get_dist_attr_for_graph(
if op_changed: node)
cur_process_mesh = cur_node_dist_attr.process_mesh
compatible_process_mesh = compute_compatible_process_mesh(
[cur_process_mesh, nearest_process_mesh])
if compatible_process_mesh is not None \
and cur_process_mesh != compatible_process_mesh:
cur_node_dist_attr.process_mesh = compatible_process_mesh
changed = True changed = True
if changed: if changed:
reach_bwd_fix_point = False
total_changed = True
else:
reach_bwd_fix_point = True
if total_changed:
reach_fix_point = False reach_fix_point = False
total_changed = True
else: else:
reach_fix_point = True reach_fix_point = True
# Validation the completion of process meshes and should be moved to a proper location if total_changed:
is_wrong = False total_reach_fix_point = False
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
node)
if tensor_dist_attr.process_mesh is None:
msg_str = ""
for op_node in node.inputs:
if op_node.op() is not None:
op_dist_attr = dist_context.get_op_dist_attr_for_graph(
op_node)
msg_str += "{} [{}], ".format(
op_node.op().type(),
op_dist_attr.process_mesh)
else:
msg_str += "{} [{}], ".format(op_node.name(),
None)
for op_node in node.outputs:
if op_node.op() is not None:
op_dist_attr = dist_context.get_op_dist_attr_for_graph(
op_node)
msg_str += "{} [{}], ".format(
op_node.op().type(),
op_dist_attr.process_mesh)
else:
msg_str += "{} [{}], ".format(op_node.name(),
None)
msg_str = "Cannot decide ProcessMesh of {} among {}. Please use shard_tensor api explicitly to annotate it".format(
node.var().name(), msg_str[:-2])
is_wrong = True
print(msg_str)
if node.is_op() and node.op() is not None:
op_dist_attr = dist_context.get_op_dist_attr_for_graph(node)
if op_dist_attr.process_mesh is None:
msg_str = ""
for tensor_node in node.inputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
msg_str += "{} [{}], ".format(
tensor_node.var().name(),
tensor_dist_attr.process_mesh)
else:
msg_str += "{} [{}], ".format(
tensor_node.name(), None)
for tensor_node in node.outputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
msg_str += "{} [{}], ".format(
tensor_node.var().name(),
tensor_dist_attr.process_mesh)
else: else:
msg_str += "{} [{}], ".format( total_reach_fix_point = True
tensor_node.name(), None)
msg_str = "Cannot decide ProcessMesh of {} among {}. Please use shard_op api explicitly to annotate it".format(
node.op().type(), msg_str[:-2])
is_wrong = True
print(msg_str)
if node.is_op() and node.op() is None:
print("op op is None", node.name())
if is_wrong:
assert False, "Cannot complete process_meshes of the program."
def _update_dims_mapping(self):
# Complete dims_mapping for each node # Complete dims_mapping for each node
reach_fix_point = False reach_fix_point = False
while not reach_fix_point: while not reach_fix_point:
changed = False changed = False
for is_fwd in [True, False]:
all_nodes = self._dist_context.serial_ordered_nodes \
if is_fwd else reversed(self._dist_context.serial_ordered_nodes)
for node in all_nodes: for node in all_nodes:
if node.is_var() and node.var() is not None: if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_dims_mapping( tensor_changed = self._update_tensor_node_dims_mapping(
dist_context, node, fwd=True) node, fwd=is_fwd)
if tensor_changed:
changed = True
if node.is_op() and node.op() is not None:
op_changed = update_op_node_dims_mapping(
dist_context, node, fwd=True)
if op_changed:
changed = True
for node in reversed(all_nodes):
if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_dims_mapping(
dist_context, node, fwd=False)
if tensor_changed: if tensor_changed:
changed = True changed = True
if node.is_op() and node.op() is not None: if node.is_op() and node.op() is not None:
op_changed = update_op_node_dims_mapping( op_changed = self._update_op_node_dims_mapping(
dist_context, node, fwd=False) node, fwd=is_fwd)
if op_changed: if op_changed:
changed = True changed = True
if changed: if changed:
...@@ -587,20 +323,44 @@ def complete_annotation(program, dist_context=None): ...@@ -587,20 +323,44 @@ def complete_annotation(program, dist_context=None):
else: else:
reach_fix_point = True reach_fix_point = True
# Copy the corresponding distributed attribute from graph to program def complete_forward_annotation(self, serial_main_program):
dist_context.copy_dist_attr_from_graph_to_program() """ Complete annotation for the partial annotated serial_main_program.
dist_context.clear_dist_info_for_graph()
# Do the validation check and amend some completion Arguments:
dist_context.amend_dist_attr_for_program() serial_main_program: partial annotated serial_main_program.
# print_program_with_dist_attr(program, dist_context) Returns:
dist_context.validate_dist_attr_for_program() serial_main_program: completed annotated serial_main_program.
"""
return program # Use the default distribted context for completeion if there is no one
self._dist_context.serial_program = serial_main_program
# Initialize distributed attributes for all var and op node in serial_main_program
self._dist_context.init_dist_attr_for_program()
# Initialize distributed attributes for all var and op node in graph
self._dist_context.init_dist_attr_for_graph()
def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): self._update_process_mesh()
# Complete dims_mapping for each node
self._update_dims_mapping()
# Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program()
self._dist_context.clear_dist_info_for_graph()
# print_serial_main_program_with_dist_attr(serial_main_program, self._dist_context)
# Do the validation check and amend some completion
self._dist_context.amend_dist_attr_for_program()
# print_serial_main_program_with_dist_attr(serial_main_program, self._dist_context)
self._dist_context.validate_dist_attr_for_program()
return serial_main_program
def complete_backward_annotation(self, serial_main_program):
"""Complete the annotation of vars and ops in the backward phase for parallel program.""" """Complete the annotation of vars and ops in the backward phase for parallel program."""
def _is_grad_var_name(name): def _is_grad_var_name(name):
...@@ -610,7 +370,8 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): ...@@ -610,7 +370,8 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
def _get_forward_varname_from_grad_varname(grad_var_name): def _get_forward_varname_from_grad_varname(grad_var_name):
assert _is_grad_var_name( assert _is_grad_var_name(
grad_var_name), "[{}] is not a grad varnme.".format(grad_var_name) grad_var_name), "[{}] is not a grad varnme.".format(
grad_var_name)
return grad_var_name[:grad_var_name.find("@GRAD")] return grad_var_name[:grad_var_name.find("@GRAD")]
def _get_op_by_id(ops, id): def _get_op_by_id(ops, id):
...@@ -619,11 +380,8 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): ...@@ -619,11 +380,8 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
return op return op
return None return None
if dist_context is None:
dist_context = get_default_distributed_context()
first_backward_op_idx = -1 first_backward_op_idx = -1
for idx, op in enumerate(auto_parallel_main_prog.global_block().ops): for idx, op in enumerate(serial_main_program.global_block().ops):
if int(op.attr('op_role')) == int( if int(op.attr('op_role')) == int(
int(core.op_proto_and_checker_maker.OpRole.Backward) | int( int(core.op_proto_and_checker_maker.OpRole.Backward) | int(
core.op_proto_and_checker_maker.OpRole.Loss)): core.op_proto_and_checker_maker.OpRole.Loss)):
...@@ -633,9 +391,9 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): ...@@ -633,9 +391,9 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
assert first_backward_op_idx >= 0, "No backward procedure found in this program." assert first_backward_op_idx >= 0, "No backward procedure found in this program."
ops = list(auto_parallel_main_prog.global_block().ops) ops = list(serial_main_program.global_block().ops)
vars = auto_parallel_main_prog.global_block().vars vars = serial_main_program.global_block().vars
dist_op_context = dist_context.dist_op_context dist_op_context = self._dist_context.dist_op_context
for idx in range(first_backward_op_idx, len(ops)): for idx in range(first_backward_op_idx, len(ops)):
...@@ -658,19 +416,21 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): ...@@ -658,19 +416,21 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
# TODO complete other attribte for grad var # TODO complete other attribte for grad var
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistributedAttribute()
process_mesh = dist_context.get_tensor_dist_attr_for_program( process_mesh = self._dist_context.get_tensor_dist_attr_for_program(
forward_var).process_mesh forward_var).process_mesh
dims_mapping = dist_context.get_tensor_dist_attr_for_program( dims_mapping = self._dist_context.get_tensor_dist_attr_for_program(
forward_var).dims_mapping forward_var).dims_mapping
tensor_dist_attr.dims_mapping = dims_mapping tensor_dist_attr.dims_mapping = dims_mapping
tensor_dist_attr.process_mesh = process_mesh tensor_dist_attr.process_mesh = process_mesh
dist_context.set_tensor_dist_attr_for_program(grad_var, self._dist_context.set_tensor_dist_attr_for_program(
tensor_dist_attr) grad_var, tensor_dist_attr)
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = process_mesh op_dist_attr.process_mesh = process_mesh
op_dist_attr.set_output_dims_mapping(grad_var.name, dims_mapping) op_dist_attr.set_output_dims_mapping(grad_var.name,
dist_context.set_op_dist_attr_for_program(ops[idx], op_dist_attr) dims_mapping)
self._dist_context.set_op_dist_attr_for_program(ops[idx],
op_dist_attr)
continue continue
# complete the annotation of grad op (xxx_grad op or sum op) # complete the annotation of grad op (xxx_grad op or sum op)
...@@ -684,7 +444,7 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): ...@@ -684,7 +444,7 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
assert forward_op is not None assert forward_op is not None
# op dist attr # op dist attr
forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
forward_op) forward_op)
forward_op_process_mesh = forward_op_dist_attr.process_mesh forward_op_process_mesh = forward_op_dist_attr.process_mesh
grad_op_dist_attr = OperatorDistributedAttribute() grad_op_dist_attr = OperatorDistributedAttribute()
...@@ -700,7 +460,8 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): ...@@ -700,7 +460,8 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping( ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping(
forward_name) forward_name)
else: else:
if forward_op_dist_attr.get_input_dims_mapping(input_name): if forward_op_dist_attr.get_input_dims_mapping(
input_name):
ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping( ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping(
input_name) input_name)
else: else:
...@@ -736,14 +497,14 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): ...@@ -736,14 +497,14 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
output_var_dist_attr = TensorDistributedAttribute() output_var_dist_attr = TensorDistributedAttribute()
output_var_dist_attr.dims_mapping = ref_dims_mapping output_var_dist_attr.dims_mapping = ref_dims_mapping
output_var_dist_attr.process_mesh = forward_op_process_mesh output_var_dist_attr.process_mesh = forward_op_process_mesh
dist_context.set_tensor_dist_attr_for_program( self._dist_context.set_tensor_dist_attr_for_program(
output_var, output_var_dist_attr) output_var, output_var_dist_attr)
grad_op_dist_attr.set_output_dims_mapping(output_var.name, grad_op_dist_attr.set_output_dims_mapping(
ref_dims_mapping) output_var.name, ref_dims_mapping)
dist_context.set_op_dist_attr_for_program(grad_op, self._dist_context.set_op_dist_attr_for_program(
grad_op_dist_attr) grad_op, grad_op_dist_attr)
# only sum op for merge mutiple version grad has no a corresponding mapping in grad_op_id_to_op_id # only sum op for merge mutiple version grad has no a corresponding mapping in grad_op_id_to_op_id
else: else:
...@@ -755,16 +516,16 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): ...@@ -755,16 +516,16 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
ref_forward_var_name = _get_forward_varname_from_grad_varname( ref_forward_var_name = _get_forward_varname_from_grad_varname(
grad_op.output_arg_names[0]) grad_op.output_arg_names[0])
forward_var = vars[ref_forward_var_name] forward_var = vars[ref_forward_var_name]
ref_forward_var_dims_mapping = dist_context.get_tensor_dist_attr_for_program( ref_forward_var_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program(
forward_var).dims_mapping forward_var).dims_mapping
ref_forward_var_process_mesh = dist_context.get_tensor_dist_attr_for_program( ref_forward_var_process_mesh = self._dist_context.get_tensor_dist_attr_for_program(
forward_var).process_mesh forward_var).process_mesh
# output # output
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.dims_mapping = ref_forward_var_dims_mapping tensor_dist_attr.dims_mapping = ref_forward_var_dims_mapping
tensor_dist_attr.process_mesh = ref_forward_var_process_mesh tensor_dist_attr.process_mesh = ref_forward_var_process_mesh
dist_context.set_tensor_dist_attr_for_program( self._dist_context.set_tensor_dist_attr_for_program(
vars[grad_op.output_arg_names[0]], tensor_dist_attr) vars[grad_op.output_arg_names[0]], tensor_dist_attr)
# op # op
...@@ -778,18 +539,13 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): ...@@ -778,18 +539,13 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
grad_op_dist_attr.set_output_dims_mapping( grad_op_dist_attr.set_output_dims_mapping(
grad_op.output_arg_names[0], ref_forward_var_dims_mapping) grad_op.output_arg_names[0], ref_forward_var_dims_mapping)
dist_context.set_op_dist_attr_for_program(grad_op, self._dist_context.set_op_dist_attr_for_program(
grad_op_dist_attr) grad_op, grad_op_dist_attr)
def complete_update_annotation(self, serial_main_program):
def complete_update_annotation(auto_parallel_main_prog, dist_context):
"""Complete the annotation of vars and ops in the update phase for parallel program.""" """Complete the annotation of vars and ops in the update phase for parallel program."""
ops = list(serial_main_program.global_block().ops)
if dist_context is None: vars = serial_main_program.global_block().vars
dist_context = get_default_distributed_context()
ops = list(auto_parallel_main_prog.global_block().ops)
vars = auto_parallel_main_prog.global_block().vars
learning_rate_completed = False learning_rate_completed = False
for idx in range(len(ops)): for idx in range(len(ops)):
...@@ -798,28 +554,6 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context): ...@@ -798,28 +554,6 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context):
# TODO to add attribute for moment var # TODO to add attribute for moment var
op = ops[idx] op = ops[idx]
if int(op.attr('op_role')) == int(OpRole.Optimize): if int(op.attr('op_role')) == int(OpRole.Optimize):
if op.type == "clip_by_norm":
param_grad = vars[op.input("X")[0]]
param_grad_dist_attr = dist_context.get_tensor_dist_attr_for_program(
param_grad)
assert param_grad_dist_attr is not None
ref_process_mesh = param_grad_dist_attr.process_mesh
ref_dims_mapping = param_grad_dist_attr.dims_mapping
out = vars[op.output("Out")[0]]
out_dist_attr = TensorDistributedAttribute()
out_dist_attr.process_mesh = ref_process_mesh
out_dist_attr.dims_mapping = ref_dims_mapping
dist_context.set_tensor_dist_attr_for_program(out,
out_dist_attr)
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = ref_process_mesh
op_dist_attr.set_input_dist_attr(param_grad.name,
param_grad_dist_attr)
op_dist_attr.set_output_dist_attr(out.name, out_dist_attr)
dist_context.set_op_dist_attr_for_program(op, op_dist_attr)
if "Grad" in op.input_names and "Param" in ops[idx].input_names: if "Grad" in op.input_names and "Param" in ops[idx].input_names:
assert len(op.input( assert len(op.input(
...@@ -829,13 +563,13 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context): ...@@ -829,13 +563,13 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context):
param = vars[op.input("Param")[0]] param = vars[op.input("Param")[0]]
grad_var = vars[op.input("Grad")[0]] grad_var = vars[op.input("Grad")[0]]
param_dist_attr = dist_context.get_tensor_dist_attr_for_program( param_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
param) param)
assert param_dist_attr is not None assert param_dist_attr is not None
ref_process_mesh = dist_context.get_tensor_dist_attr_for_program( ref_process_mesh = self._dist_context.get_tensor_dist_attr_for_program(
param).process_mesh param).process_mesh
assert ref_process_mesh is not None assert ref_process_mesh is not None
ref_dims_mapping = dist_context.get_tensor_dist_attr_for_program( ref_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program(
param).dims_mapping param).dims_mapping
assert ref_dims_mapping is not None assert ref_dims_mapping is not None
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistributedAttribute()
...@@ -848,15 +582,16 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context): ...@@ -848,15 +582,16 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context):
ref_dims_mapping) ref_dims_mapping)
learning_var = vars[op.input("LearningRate")[0]] learning_var = vars[op.input("LearningRate")[0]]
op_dist_attr.set_input_dims_mapping(learning_var.name, [-1]) op_dist_attr.set_input_dims_mapping(learning_var.name, [-1])
op_dist_attr.set_output_dims_mapping(learning_var.name, [-1]) op_dist_attr.set_output_dims_mapping(learning_var.name,
[-1])
if not learning_rate_completed: if not learning_rate_completed:
learning_rate_completed = True learning_rate_completed = True
var_dist_attr = TensorDistributedAttribute() var_dist_attr = TensorDistributedAttribute()
var_dist_attr.process_mesh = ref_process_mesh var_dist_attr.process_mesh = ref_process_mesh
var_dist_attr.dims_mapping = [-1] var_dist_attr.dims_mapping = [-1]
dist_context.set_tensor_dist_attr_for_program(learning_var, self._dist_context.set_tensor_dist_attr_for_program(
var_dist_attr) learning_var, var_dist_attr)
for input_name in op.desc.input_names(): for input_name in op.desc.input_names():
...@@ -880,14 +615,15 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context): ...@@ -880,14 +615,15 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context):
else: else:
assert "Moment" in input_name assert "Moment" in input_name
input_var_attr.dims_mapping = ref_dims_mapping input_var_attr.dims_mapping = ref_dims_mapping
op_dist_attr.set_input_dims_mapping(input_var.name, op_dist_attr.set_input_dims_mapping(
ref_dims_mapping) input_var.name, ref_dims_mapping)
op_dist_attr.set_output_dims_mapping(input_var.name, op_dist_attr.set_output_dims_mapping(
ref_dims_mapping) input_var.name, ref_dims_mapping)
input_var_attr.process_mesh = ref_process_mesh input_var_attr.process_mesh = ref_process_mesh
dist_context.set_tensor_dist_attr_for_program( self._dist_context.set_tensor_dist_attr_for_program(
input_var, input_var_attr) input_var, input_var_attr)
dist_context.set_op_dist_attr_for_program(op, op_dist_attr) self._dist_context.set_op_dist_attr_for_program(
op, op_dist_attr)
continue continue
...@@ -247,23 +247,23 @@ class DistributedContext: ...@@ -247,23 +247,23 @@ class DistributedContext:
# new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr) # new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr)
# self._dist_ops_for_graph[serial_op_node_id] = new_dist_op # self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
# def get_dist_attr_for_graph(self, serial_node): def get_dist_attr_for_graph(self, serial_node):
# if serial_node.is_var() and serial_node.var() is not None: if serial_node.is_var() and serial_node.var() is not None:
# serial_tensor_node_id = serial_node.id() serial_tensor_node_id = serial_node.id()
# dist_tensor = self._dist_tensors_for_graph.get( dist_tensor = self._dist_tensors_for_graph.get(
# serial_tensor_node_id, None) serial_tensor_node_id, None)
# if dist_tensor: if dist_tensor:
# return dist_tensor.dist_attr return dist_tensor.dist_attr
# else: else:
# return None return None
# if serial_node.is_op() and serial_node.op() is not None: if serial_node.is_op() and serial_node.op() is not None:
# serial_op_node_id = serial_node.id() serial_op_node_id = serial_node.id()
# dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None) dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
# if dist_op: if dist_op:
# return dist_op.dist_attr return dist_op.dist_attr
# else: else:
# return None return None
# return None return None
def init_dist_attr_for_program(self): def init_dist_attr_for_program(self):
assert self._serial_program, \ assert self._serial_program, \
......
...@@ -32,7 +32,7 @@ from paddle.distributed.passes import new_pass, PassContext ...@@ -32,7 +32,7 @@ from paddle.distributed.passes import new_pass, PassContext
from .dist_context import DistributedContext from .dist_context import DistributedContext
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .dist_context import set_default_distributed_context from .dist_context import set_default_distributed_context
from .completion import complete_annotation, complete_backward_annotation, complete_update_annotation from .completion import Completer
from .partitioner import Partitioner from .partitioner import Partitioner
from .process_group import get_all_process_groups from .process_group import get_all_process_groups
from .process_group import get_process_group from .process_group import get_process_group
...@@ -130,8 +130,8 @@ class AutoParallelizer: ...@@ -130,8 +130,8 @@ class AutoParallelizer:
no_grad_set, no_grad_set,
callbacks, callbacks,
distop_context=self._dist_context.dist_op_context) distop_context=self._dist_context.dist_op_context)
complete_backward_annotation( self._completer = Completer(self._dist_context)
main_program, dist_context=self._dist_context) self._completer.complete_backward_annotation(main_program)
return params_grads return params_grads
...@@ -142,8 +142,8 @@ class AutoParallelizer: ...@@ -142,8 +142,8 @@ class AutoParallelizer:
params_grads) params_grads)
# update completion # update completion
complete_update_annotation( self._completer = Completer(self._dist_context)
main_program, dist_context=self._dist_context) self._completer.complete_update_annotation(main_program)
return optimize_ops return optimize_ops
...@@ -179,8 +179,9 @@ class AutoParallelizer: ...@@ -179,8 +179,9 @@ class AutoParallelizer:
# Annotation completion # Annotation completion
self._dist_context = DistributedContext() self._dist_context = DistributedContext()
_logger.info("Start annotation dist attr.") _logger.info("Start annotation dist attr.")
completed_main_program = complete_annotation(serial_main_program, self._completer = Completer(self._dist_context)
self._dist_context) completed_main_program = self._completer.complete_forward_annotation(
serial_main_program)
else: else:
completed_main_program = serial_main_program completed_main_program = serial_main_program
self._dist_context = copy.deepcopy(dist_context) self._dist_context = copy.deepcopy(dist_context)
......
...@@ -27,6 +27,7 @@ import paddle.tensor as tensor ...@@ -27,6 +27,7 @@ import paddle.tensor as tensor
from paddle.fluid import layers from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
...@@ -154,10 +155,9 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -154,10 +155,9 @@ class TestMLPAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program, train_program, start_program = mlp_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_mlp_mp(self): def test_mlp_mp(self):
...@@ -171,10 +171,9 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -171,10 +171,9 @@ class TestMLPAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program, train_program, start_program = mlp_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_mlp_dp_mp(self): def test_mlp_dp_mp(self):
...@@ -189,10 +188,9 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -189,10 +188,9 @@ class TestMLPAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program, train_program, start_program = mlp_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
# def test_mlp_misc(self): # def test_mlp_misc(self):
...@@ -212,8 +210,8 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -212,8 +210,8 @@ class TestMLPAutoCompletion(unittest.TestCase):
# train_program, start_program = mlp_pretrain_forward(train_program, # train_program, start_program = mlp_pretrain_forward(train_program,
# start_program) # start_program)
# # pdb.set_trace() # # pdb.set_trace()
# complete_train_program = auto.complete_annotation(train_program, # completer = Completer(dist_context)
# dist_context) # complete_train_program = auto.completer.complete_forward_annotation(train_program)
# # print_program_with_dist_attr(complete_train_program, # # print_program_with_dist_attr(complete_train_program,
# # dist_context) # # dist_context)
# dist_context.finalize_distributed_attr_for_program( # dist_context.finalize_distributed_attr_for_program(
...@@ -423,8 +421,9 @@ class TestAttentionAutoCompletion(unittest.TestCase): ...@@ -423,8 +421,9 @@ class TestAttentionAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program, train_program, start_program = attn_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
# print_program_with_dist_attr(complete_train_program, # print_program_with_dist_attr(complete_train_program,
# dist_context) # dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
...@@ -440,10 +439,9 @@ class TestAttentionAutoCompletion(unittest.TestCase): ...@@ -440,10 +439,9 @@ class TestAttentionAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program, train_program, start_program = attn_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_attn_dp_mp(self): def test_attn_dp_mp(self):
...@@ -458,10 +456,9 @@ class TestAttentionAutoCompletion(unittest.TestCase): ...@@ -458,10 +456,9 @@ class TestAttentionAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program, train_program, start_program = attn_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
...@@ -747,10 +744,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): ...@@ -747,10 +744,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program, train_program, start_program = decoder_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_decoder_mp(self): def test_decoder_mp(self):
...@@ -764,10 +760,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): ...@@ -764,10 +760,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program, train_program, start_program = decoder_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_decoder_dp_mp(self): def test_decoder_dp_mp(self):
...@@ -782,10 +777,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): ...@@ -782,10 +777,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program, train_program, start_program = decoder_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
......
...@@ -31,6 +31,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer ...@@ -31,6 +31,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed.fleet import fleet from paddle.distributed.fleet import fleet
import paddle.static as static import paddle.static as static
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
...@@ -817,10 +818,9 @@ class TestGPTAutoCompletion(unittest.TestCase): ...@@ -817,10 +818,9 @@ class TestGPTAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program, train_program, start_program = gpt_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_gpt_mp(self): def test_gpt_mp(self):
...@@ -834,10 +834,9 @@ class TestGPTAutoCompletion(unittest.TestCase): ...@@ -834,10 +834,9 @@ class TestGPTAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program, train_program, start_program = gpt_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_gpt_dp_mp(self): def test_gpt_dp_mp(self):
...@@ -852,10 +851,9 @@ class TestGPTAutoCompletion(unittest.TestCase): ...@@ -852,10 +851,9 @@ class TestGPTAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program, train_program, start_program = gpt_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
......
...@@ -23,6 +23,7 @@ import paddle.static as static ...@@ -23,6 +23,7 @@ import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
...@@ -154,8 +155,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -154,8 +155,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context parallelizer._dist_context = dist_context
# serial forward & backward completion # serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
......
...@@ -18,6 +18,7 @@ import unittest ...@@ -18,6 +18,7 @@ import unittest
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
...@@ -42,8 +43,9 @@ def get_dist_prog(train_program, ...@@ -42,8 +43,9 @@ def get_dist_prog(train_program,
parallelizer._dist_context = dist_context parallelizer._dist_context = dist_context
# serial forward & backward completion # serial forward & backward completion
complete_train_program = auto.complete_annotation( completer = Completer(dist_context)
train_program, dist_context complete_train_program = completer.complete_forward_annotation(
train_program
) if complete_train_program is None else complete_train_program ) if complete_train_program is None else complete_train_program
# parallelizer._apply_serial_forward_pass(complete_train_program, # parallelizer._apply_serial_forward_pass(complete_train_program,
......
...@@ -36,6 +36,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer ...@@ -36,6 +36,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed import fleet from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
...@@ -433,6 +434,12 @@ class MLPLayer(nn.Layer): ...@@ -433,6 +434,12 @@ class MLPLayer(nn.Layer):
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
out = self.linear1(out) out = self.linear1(out)
auto.shard_tensor(
out,
dist_attr={
"process_mesh": _global_process_mesh[1],
"dims_mapping": [0, -1]
})
out = self.linear2(out) out = self.linear2(out)
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
out = self.linear3(out) out = self.linear3(out)
...@@ -476,8 +483,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -476,8 +483,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context parallelizer._dist_context = dist_context
# auto completion # auto completion
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
......
...@@ -28,6 +28,7 @@ import paddle.tensor as tensor ...@@ -28,6 +28,7 @@ import paddle.tensor as tensor
from paddle.fluid import layers from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
...@@ -49,8 +50,9 @@ def get_programs(annotated_func): ...@@ -49,8 +50,9 @@ def get_programs(annotated_func):
global _global_process_mesh global _global_process_mesh
dist_context.process_mesh = _global_process_mesh dist_context.process_mesh = _global_process_mesh
train_program, start_program = annotated_func(train_program, start_program) train_program, start_program = annotated_func(train_program, start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
rank_id = 3 rank_id = 3
dist_strategy = fleet.DistributedStrategy() dist_strategy = fleet.DistributedStrategy()
......
...@@ -31,6 +31,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer ...@@ -31,6 +31,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed import fleet from paddle.distributed import fleet
import paddle.static as static import paddle.static as static
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
...@@ -881,8 +882,9 @@ class TestGPTPartitioner(unittest.TestCase): ...@@ -881,8 +882,9 @@ class TestGPTPartitioner(unittest.TestCase):
dist_context.process_mesh = _global_process_mesh dist_context.process_mesh = _global_process_mesh
train_program, startup_program, loss = gpt_pretrain_forward( train_program, startup_program, loss = gpt_pretrain_forward(
train_program, startup_program) train_program, startup_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
# serial backward pass # serial backward pass
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
...@@ -913,8 +915,9 @@ class TestGPTPartitioner(unittest.TestCase): ...@@ -913,8 +915,9 @@ class TestGPTPartitioner(unittest.TestCase):
"w") as fw: "w") as fw:
fw.write(str(auto_parallel_startup_prog)) fw.write(str(auto_parallel_startup_prog))
# with open("./test_auto_parallel_partitioner_main_completed.txt", "w") as fw: # with open("./test_auto_parallel_partitioner_main_completed.txt", "w") as fw:
# from paddle.distributed.auto_parallel.completion import complete_backward_annotation # from paddle.distributed.auto_parallel.completion import Completer
# complete_backward_annotation(auto_parallel_main_prog) # completer = Completer()
# completer.complete_forward_annotation(auto_parallel_main_prog)
# fw.write(str(auto_parallel_main_prog)) # fw.write(str(auto_parallel_main_prog))
nrank = 4 nrank = 4
# col parallel # col parallel
......
...@@ -22,6 +22,7 @@ import paddle.static as static ...@@ -22,6 +22,7 @@ import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
...@@ -152,8 +153,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -152,8 +153,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context parallelizer._dist_context = dist_context
# serial forward & backward completion # serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
...@@ -299,7 +301,6 @@ class TestMLPReshard(unittest.TestCase): ...@@ -299,7 +301,6 @@ class TestMLPReshard(unittest.TestCase):
for key in list(_g_process_group_map.keys()): for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key] del _g_process_group_map[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
# print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......
...@@ -22,6 +22,7 @@ import paddle.static as static ...@@ -22,6 +22,7 @@ import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
...@@ -116,8 +117,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -116,8 +117,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context parallelizer._dist_context = dist_context
# serial forward & backward completion # serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
......
...@@ -22,6 +22,7 @@ import paddle.static as static ...@@ -22,6 +22,7 @@ import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
...@@ -132,8 +133,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -132,8 +133,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context parallelizer._dist_context = dist_context
# serial forward & backward completion # serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
...@@ -263,8 +265,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -263,8 +265,9 @@ class TestMLPReshard(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
dist_strategy = fleet.DistributedStrategy() dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_context, rank_id) partitioner = Partitioner(dist_context, rank_id)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition( partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition(
complete_train_program, startup_program, []) complete_train_program, startup_program, [])
reshard(partitioned_main_prog, partitioned_startup_prog, rank_id, reshard(partitioned_main_prog, partitioned_startup_prog, rank_id,
......
...@@ -154,7 +154,7 @@ class TestMLPSearcher(unittest.TestCase): ...@@ -154,7 +154,7 @@ class TestMLPSearcher(unittest.TestCase):
ops = train_program.global_block().ops ops = train_program.global_block().ops
vars = train_program.global_block().vars vars = train_program.global_block().vars
from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container
from paddle.distributed.auto_parallel.completion import is_elementwise_like_op from paddle.distributed.auto_parallel.operators.common import is_elementwise_op
from paddle.distributed.auto_parallel.dist_op import DistributedOperator from paddle.distributed.auto_parallel.dist_op import DistributedOperator
for op in ops: for op in ops:
...@@ -163,7 +163,7 @@ class TestMLPSearcher(unittest.TestCase): ...@@ -163,7 +163,7 @@ class TestMLPSearcher(unittest.TestCase):
if dist_op_impl_container is None: if dist_op_impl_container is None:
op_dist_attr = dist_context.get_op_dist_attr_for_program(op) op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
dist_op = DistributedOperator(op, op_dist_attr) dist_op = DistributedOperator(op, op_dist_attr)
if is_elementwise_like_op(op.type): if is_elementwise_op(op.type):
changed = update_op_dims_mapping_by_elementwise_like_dist_impl( changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
dist_op) dist_op)
self.assertFalse(changed) self.assertFalse(changed)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册