未验证 提交 010aba33 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Add miscellaneous improvements (#43108)

* [Auto Parallel] Add the parallel tuner

* [Auto Parallel] Improve the parallel tuner and fix some bugs

* upodate cost model

* update import Resharder by dist op

* update cost model

* fix comp cost bug

* update cost model

* [Auto Parallel] Amend the dist attr for #processses=1

* update cost model and tuner

* update cost model and tuner

* update cost model and tuner

* update cluster

* update reshard

* [Auto Parallel] Add the estimation from the cost model

* [Auto Parallel] Reimplement the backup and restore functions

* [Auto Parallel] Fix the bugs of the parallel tuner

* [Auto Parallel] Update the engine api and dist context

* [Auto Parallel] Work around the high order grad problem

* [Auto Parallel] Add some miscellaneous improvements

* [Auto Parallel] Add a unittest for DistributedContext
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
上级 5f2c251c
...@@ -20,7 +20,7 @@ from paddle.fluid import core ...@@ -20,7 +20,7 @@ from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
from .utils import print_program_with_dist_attr from .utils import print_program_with_dist_attr
from .operators import find_best_compatible_distributed_operator_impl from .operators import find_compatible_distributed_operator_impls
from .dist_context import get_default_distributed_context, _node_id from .dist_context import get_default_distributed_context, _node_id
from .dist_tensor import DistributedTensor from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator from .dist_op import DistributedOperator
...@@ -238,13 +238,17 @@ class Completer: ...@@ -238,13 +238,17 @@ class Completer:
tensor_desc.name()) tensor_desc.name())
compatible_dims_mapping = compute_compatible_dims_mapping( compatible_dims_mapping = compute_compatible_dims_mapping(
[op_dims_mapping, tensor_dims_mapping]) [op_dims_mapping, tensor_dims_mapping])
if not _validate_dims_mapping(
compatible_dims_mapping,
op_dist_attr.process_mesh):
continue
if (compatible_dims_mapping is not None) and \ if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != op_dims_mapping): (compatible_dims_mapping != op_dims_mapping):
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
tensor_desc.name(), compatible_dims_mapping) tensor_desc.name(), compatible_dims_mapping)
changed = True changed = True
# Find the most compatible implemenetations from the distributed operator # Find the most compatible implemenetations from the distributed operator
op_dist_impls = find_best_compatible_distributed_operator_impl( op_dist_impls = find_compatible_distributed_operator_impls(
dist_op, fwd=True) dist_op, fwd=True)
if op_dist_impls is not None: if op_dist_impls is not None:
not_compatible = True not_compatible = True
...@@ -254,7 +258,8 @@ class Completer: ...@@ -254,7 +258,8 @@ class Completer:
dim_changed = op_dist_impl.update_dims_mapping(dist_op) dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed: if dim_changed:
changed = True changed = True
if op_dist_impl.is_auto_compatible(dist_op): if op_dist_impl.is_auto_compatible(dist_op) \
and dist_op.validate_dist_attr():
if op_dist_impl.type == "elementwise": if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default" op_dist_attr.impl_type = "default"
else: else:
...@@ -289,13 +294,17 @@ class Completer: ...@@ -289,13 +294,17 @@ class Completer:
tensor_desc.name()) tensor_desc.name())
compatible_dims_mapping = compute_compatible_dims_mapping( compatible_dims_mapping = compute_compatible_dims_mapping(
[op_dims_mapping, tensor_dims_mapping]) [op_dims_mapping, tensor_dims_mapping])
if not _validate_dims_mapping(
compatible_dims_mapping,
op_dist_attr.process_mesh):
continue
if (compatible_dims_mapping is not None) and \ if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != op_dims_mapping): (compatible_dims_mapping != op_dims_mapping):
op_dist_attr.set_output_dims_mapping( op_dist_attr.set_output_dims_mapping(
tensor_desc.name(), compatible_dims_mapping) tensor_desc.name(), compatible_dims_mapping)
changed = True changed = True
# Find the most compatible implemenetations from the distributed operator # Find the most compatible implemenetations from the distributed operator
op_dist_impls = find_best_compatible_distributed_operator_impl( op_dist_impls = find_compatible_distributed_operator_impls(
dist_op, fwd=False) dist_op, fwd=False)
if op_dist_impls is not None: if op_dist_impls is not None:
not_compatible = True not_compatible = True
...@@ -305,8 +314,8 @@ class Completer: ...@@ -305,8 +314,8 @@ class Completer:
dim_changed = op_dist_impl.update_dims_mapping(dist_op) dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed: if dim_changed:
changed = True changed = True
if op_dist_impl.is_auto_compatible(dist_op): if op_dist_impl.is_auto_compatible(dist_op) \
not_compatible = False and dist_op.validate_dist_attr():
if op_dist_impl.type == "elementwise": if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default" op_dist_attr.impl_type = "default"
else: else:
...@@ -352,6 +361,23 @@ class Completer: ...@@ -352,6 +361,23 @@ class Completer:
changed = True changed = True
return changed return changed
def _update_dims_mapping_for_special(self):
# Set the dims_mapping of a tensor to the dims_mapping inside the op which produces it
op_nodes = self._dist_context._serial_ordered_op_nodes
for op_node in op_nodes:
op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node)
for tensor_node in op_node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER:
continue
tensor_desc = tensor_node.var()
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh:
op_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_desc.name())
tensor_dist_attr.dims_mapping = op_dims_mapping
def _update_dims_mapping(self): def _update_dims_mapping(self):
# Complete dims_mapping for each node # Complete dims_mapping for each node
reach_fix_point = False reach_fix_point = False
...@@ -378,6 +404,7 @@ class Completer: ...@@ -378,6 +404,7 @@ class Completer:
reach_fix_point = False reach_fix_point = False
else: else:
reach_fix_point = True reach_fix_point = True
self._update_dims_mapping_for_special()
def _update_process_mesh_by_nearest(self, op_node, nearest_op_node): def _update_process_mesh_by_nearest(self, op_node, nearest_op_node):
op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node) op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node)
...@@ -685,7 +712,7 @@ class Completer: ...@@ -685,7 +712,7 @@ class Completer:
# Step 3: adjust the process meshes for special ops # Step 3: adjust the process meshes for special ops
self._update_process_mesh_for_specials() self._update_process_mesh_for_specials()
# Step 4: adjust the process meshes between graphs # Step 4: adjust the process meshes between graphs
self._update_process_mesh_between_graphs() self._update_process_mesh_between_graphs()
def _prepare(self): def _prepare(self):
...@@ -727,14 +754,14 @@ class Completer: ...@@ -727,14 +754,14 @@ class Completer:
""" Complete annotation for the partial annotated serial_main_program. """ Complete annotation for the partial annotated serial_main_program.
Arguments: Arguments:
serial_main_program: partial annotated serial_main_program. serial_main_program: partial annotated serial_main_program.
Returns: Returns:e
serial_main_program: completed annotated serial_main_program. serial_main_program: completed annotated serial_main_program.
""" """
if serial_main_program is None: if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program serial_main_program = self._dist_context.serial_main_program
else: else:
self._dist_context.serial_main_program = serial_main_program self._dist_context._serial_main_program = serial_main_program
self._dist_context.initialize() self._dist_context.initialize()
...@@ -757,13 +784,18 @@ class Completer: ...@@ -757,13 +784,18 @@ class Completer:
return serial_main_program return serial_main_program
def _complete_high_order_grad_annotation(self, serial_main_program): def _complete_high_order_grad_annotation(self, serial_main_program=None):
""" """
NOTE: NOTE:
[HighOrderGrad] Complete the annotation of vars and ops only for high order gradient. [HighOrderGrad] Complete the annotation of vars and ops only for high order gradient.
This function is temporary to support high order gradient, and will be removed in the future. This function is temporary to support high order gradient, and will be removed in the future.
""" """
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context._serial_main_program = serial_main_program
def _is_grad_var_name(name): def _is_grad_var_name(name):
if "@GRAD" in name: if "@GRAD" in name:
return True return True
...@@ -917,12 +949,13 @@ class Completer: ...@@ -917,12 +949,13 @@ class Completer:
self._dist_context.set_op_dist_attr_for_program( self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr) grad_op, grad_op_dist_attr)
def complete_backward_annotation(self, serial_main_program): def complete_backward_annotation(self, serial_main_program=None):
"""Complete the annotation of vars and ops in the backward phase for parallel program.""" """Complete the annotation of vars and ops in the backward phase for parallel program."""
if serial_main_program is None: if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program serial_main_program = self._dist_context.serial_main_program
else: else:
self._dist_context.serial_main_program = serial_main_program self._dist_context._serial_main_program = serial_main_program
def _is_grad_var_name(name): def _is_grad_var_name(name):
if "@GRAD" in name: if "@GRAD" in name:
...@@ -1032,6 +1065,9 @@ class Completer: ...@@ -1032,6 +1065,9 @@ class Completer:
grad_op_dist_attr.process_mesh = ref_mesh grad_op_dist_attr.process_mesh = ref_mesh
self._dist_context.set_op_dist_attr_for_program( self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr) grad_op, grad_op_dist_attr)
grad_op_dist_attr.impl_type = fwd_op_dist_attr.impl_type
grad_op_dist_attr.impl_idx = fwd_op_dist_attr.impl_idx
continue continue
fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
...@@ -1078,6 +1114,8 @@ class Completer: ...@@ -1078,6 +1114,8 @@ class Completer:
grad_op_dist_attr.set_output_dims_mapping(output_name, grad_op_dist_attr.set_output_dims_mapping(output_name,
ref_dims_mapping) ref_dims_mapping)
grad_op_dist_attr.impl_type = fwd_op_dist_attr.impl_type
grad_op_dist_attr.impl_idx = fwd_op_dist_attr.impl_idx
self._dist_context.set_op_dist_attr_for_program( self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr) grad_op, grad_op_dist_attr)
...@@ -1111,6 +1149,8 @@ class Completer: ...@@ -1111,6 +1149,8 @@ class Completer:
var_name, ref_fwd_dims_mapping) var_name, ref_fwd_dims_mapping)
grad_op_dist_attr.set_output_dims_mapping( grad_op_dist_attr.set_output_dims_mapping(
output_name, ref_fwd_dims_mapping) output_name, ref_fwd_dims_mapping)
grad_op_dist_attr.impl_type = "default"
grad_op_dist_attr.impl_idx = 0
elif grad_op.type == 'fill_zeros_like': elif grad_op.type == 'fill_zeros_like':
ref_var_name = grad_op.input_arg_names[0] ref_var_name = grad_op.input_arg_names[0]
...@@ -1142,12 +1182,13 @@ class Completer: ...@@ -1142,12 +1182,13 @@ class Completer:
self._dist_context.set_op_dist_attr_for_program( self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr) grad_op, grad_op_dist_attr)
def complete_update_annotation(self, serial_main_program=None): def complete_update_annotation(self, serial_main_program):
"""Complete the annotation of vars and ops in the update phase for parallel program.""" """Complete the annotation of vars and ops in the update phase for parallel program."""
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program # Notice: serial_main_program is actually a dist_main_program of current rank,
else: # and must be passed into this function.
self._dist_context.serial_main_program = serial_main_program # TODO: We should fix this behavior.
ops = list(serial_main_program.global_block().ops) ops = list(serial_main_program.global_block().ops)
vars = serial_main_program.global_block().vars vars = serial_main_program.global_block().vars
learning_rate_completed = False learning_rate_completed = False
...@@ -1304,7 +1345,7 @@ class Completer: ...@@ -1304,7 +1345,7 @@ class Completer:
dist_op.dist_attr.process_mesh = world_ranks dist_op.dist_attr.process_mesh = world_ranks
# Find the most compatible implemenetations from the distributed operator # Find the most compatible implemenetations from the distributed operator
op_dist_impls = find_best_compatible_distributed_operator_impl( op_dist_impls = find_compatible_distributed_operator_impls(
dist_op, fwd=True) dist_op, fwd=True)
if op_dist_impls is not None: if op_dist_impls is not None:
backup_op_dist_attr = copy.deepcopy(dist_op.dist_attr) backup_op_dist_attr = copy.deepcopy(dist_op.dist_attr)
......
...@@ -132,15 +132,17 @@ class TensorDistributedAttribute: ...@@ -132,15 +132,17 @@ class TensorDistributedAttribute:
key, dist_attr) key, dist_attr)
self._is_annotated = copy.deepcopy(dist_attr._is_annotated) self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
# def reset(self, skip_dist_attr_field_names): def reset(self, skip_dist_attr_field_names=None):
# if skip_dist_attr_field_names is not None \ if skip_dist_attr_field_names is None or \
# and "process_mesh" not in skip_dist_attr_field_names: (skip_dist_attr_field_names is not None \
# self._process_mesh = None and "process_mesh" not in skip_dist_attr_field_names):
# if skip_dist_attr_field_names is not None \ self._process_mesh = None
# and "dims_mapping" not in skip_dist_attr_field_names: if skip_dist_attr_field_names is None or \
# for i in enumerate(self._dims_mapping): (skip_dist_attr_field_names is not None \
# self._dims_mapping[i] = -1 and "dims_mapping" not in skip_dist_attr_field_names):
# self._is_annotated = {} for i, _ in enumerate(self._dims_mapping):
self._dims_mapping[i] = -1
self._is_annotated = {}
def is_annotated(self, dist_attr_field_name): def is_annotated(self, dist_attr_field_name):
return self._is_annotated.get(dist_attr_field_name, False) return self._is_annotated.get(dist_attr_field_name, False)
...@@ -272,6 +274,9 @@ class OperatorDistributedAttribute: ...@@ -272,6 +274,9 @@ class OperatorDistributedAttribute:
dist_attr_object.init(dist_attr) dist_attr_object.init(dist_attr)
self._inputs_dist_attrs[name] = dist_attr_object self._inputs_dist_attrs[name] = dist_attr_object
# def del_input_dist_attr(self, name):
# del self._inputs_dist_attrs[name]
def get_output_dist_attr(self, name): def get_output_dist_attr(self, name):
return self._outputs_dist_attrs.get(name, None) return self._outputs_dist_attrs.get(name, None)
...@@ -280,6 +285,9 @@ class OperatorDistributedAttribute: ...@@ -280,6 +285,9 @@ class OperatorDistributedAttribute:
dist_attr_object.init(dist_attr) dist_attr_object.init(dist_attr)
self._outputs_dist_attrs[name] = dist_attr_object self._outputs_dist_attrs[name] = dist_attr_object
# def del_output_dist_attr(self, name):
# del self._inputs_dist_attrs[name]
def get_input_dims_mapping(self, name): def get_input_dims_mapping(self, name):
input_dist_attr = self.get_input_dist_attr(name) input_dist_attr = self.get_input_dist_attr(name)
if input_dist_attr: if input_dist_attr:
...@@ -374,17 +382,18 @@ class OperatorDistributedAttribute: ...@@ -374,17 +382,18 @@ class OperatorDistributedAttribute:
"ProcessMeshes in DistributedOperator must be the same." "ProcessMeshes in DistributedOperator must be the same."
self.process_mesh = shared_process_mesh self.process_mesh = shared_process_mesh
# def reset(self, skip_dist_attr_field_names): def reset(self, skip_dist_attr_field_names=None):
# for tensor_dist_attr in self.inputs_dist_attrs.values(): for tensor_dist_attr in self.inputs_dist_attrs.values():
# tensor_dist_attr.reset(skip_dist_attr_field_names) tensor_dist_attr.reset(skip_dist_attr_field_names)
# for tensor_dist_attr in self.outputs_dist_attrs.values(): for tensor_dist_attr in self.outputs_dist_attrs.values():
# tensor_dist_attr.reset(skip_dist_attr_field_names) tensor_dist_attr.reset(skip_dist_attr_field_names)
# if skip_dist_attr_field_names is not None \ if skip_dist_attr_field_names is None or \
# and "process_mesh" not in skip_dist_attr_field_names: (skip_dist_attr_field_names is not None \
# self.process_mesh = None and "process_mesh" not in skip_dist_attr_field_names):
# self.impl_type = "default" self._process_mesh = None
# self.impl_idx = 0 self.impl_type = "default"
# self._is_annotated = {} self.impl_idx = 0
self._is_annotated = {}
def is_annotated(self, attr_name): def is_annotated(self, attr_name):
return self._is_annotated.get(attr_name, False) return self._is_annotated.get(attr_name, False)
......
...@@ -57,33 +57,30 @@ class DistributedContext: ...@@ -57,33 +57,30 @@ class DistributedContext:
serial_startup_prog=None, serial_startup_prog=None,
serial_optimizer=None, serial_optimizer=None,
serial_loss=None, serial_loss=None,
feed_vars=None, feed_vars={},
fetch_vars=None, fetch_vars={},
cluster=None,
strategy=None): strategy=None):
# Data members related to original programs (unchanged) # Data members related to original programs (unchanged)
self._original_serial_main_program = serial_main_prog self._original_serial_main_program = serial_main_prog
self._original_serial_startup_program = serial_startup_prog self._original_serial_startup_program = serial_startup_prog
self._original_serial_optimizer = serial_optimizer
self._original_serial_loss = serial_loss self._original_serial_loss = serial_loss
self._original_serial_feed_vars = feed_vars
self._original_serial_fetch_vars = fetch_vars
self._original_serial_optimizer = serial_optimizer self._original_serial_optimizer = serial_optimizer
if self._original_serial_main_program is None:
self._original_serial_main_program = paddle.fluid.default_main_program(
)
if self._original_serial_startup_program is None:
self._original_serial_startup_program = paddle.fluid.default_startup_program(
)
# Data members related to programs (changed) # Data members related to programs (changed)
self._serial_main_program = None self._serial_main_program = None
self._serial_startup_program = None self._serial_startup_program = None
self._serial_loss = serial_loss self._serial_loss = None
self._serial_optimizer = serial_optimizer self._serial_optimizer = None
self._serial_feed_vars = feed_vars self._serial_feed_vars = {}
self._serial_fetch_vars = fetch_vars self._serial_fetch_vars = {}
# Data members related to the program # Data members related to the program
self._dist_tensors_for_program = {} self._dist_tensors_for_program = {}
self._dist_ops_for_program = {} self._dist_ops_for_program = {}
self._block_state = BlockState()
# Data members related to the graph # Data members related to the graph
self._serial_graph = None self._serial_graph = None
...@@ -96,24 +93,30 @@ class DistributedContext: ...@@ -96,24 +93,30 @@ class DistributedContext:
# Distributed programs # Distributed programs
self._dist_main_programs = {} self._dist_main_programs = {}
self._dist_startup_programs = {} self._dist_startup_programs = {}
self._dist_op_context = DistributedOperatorContext()
self._process_meshes = []
# Distributed Strategy self._cluster = cluster
self._strategy = strategy self._strategy = strategy
# Pass Context # Pass Context
self._pass_context = PassContext() self._pass_context = PassContext()
self._block_state = BlockState()
# Distributed Operator Context
self._dist_op_context = DistributedOperatorContext()
# Other data members # Other data members
self._process_meshes = []
self._serial_ordered_tensor_nodes = [] self._serial_ordered_tensor_nodes = []
self._serial_ordered_op_nodes = [] self._serial_ordered_op_nodes = []
self._serial_ordered_nodes = [] self._serial_ordered_nodes = []
# self._tensor_id_to_tensor_node_ids = {} # self._tensor_id_to_tensor_node_ids = {}
self._is_initialized = False self._is_initialized = False
self._need_copy_dist_attr_to_graph = False
self._backup_pass_context_stack = []
self._backup_block_state_stack = []
self._backup_dist_tensors_for_program_stack = []
self._backup_dist_ops_for_program_stack = []
self._backup_serial_main_program_stack = []
self._backup_serial_startup_program_stack = []
# flag whether scale gradient with dp size # flag whether scale gradient with dp size
self._gradient_scale = True self._gradient_scale = True
...@@ -122,13 +125,6 @@ class DistributedContext: ...@@ -122,13 +125,6 @@ class DistributedContext:
def serial_main_program(self): def serial_main_program(self):
return self._serial_main_program return self._serial_main_program
@serial_main_program.setter
def serial_main_program(self, program):
# if self._serial_main_program:
# print("WARNING: The program attached to this distributed context will be replaced by the new one.")
self._original_serial_main_program = program
self._serial_main_program = program
@property @property
def serial_startup_program(self): def serial_startup_program(self):
return self._serial_startup_program return self._serial_startup_program
...@@ -149,6 +145,18 @@ class DistributedContext: ...@@ -149,6 +145,18 @@ class DistributedContext:
def serial_fetch_vars(self): def serial_fetch_vars(self):
return self._serial_fetch_vars return self._serial_fetch_vars
@property
def dist_main_programs(self):
return self._dist_main_programs
@property
def dist_startup_programs(self):
return self._dist_startup_programs
@property
def cluster(self):
return self._cluster
@property @property
def strategy(self): def strategy(self):
return self._strategy return self._strategy
...@@ -177,14 +185,6 @@ class DistributedContext: ...@@ -177,14 +185,6 @@ class DistributedContext:
def block_state(self): def block_state(self):
return self._block_state return self._block_state
@property
def dist_main_programs(self):
return self._dist_main_programs
@property
def dist_startup_programs(self):
return self._dist_startup_programs
@property @property
def has_annotation(self): def has_annotation(self):
return len(self._dist_tensors_for_program) or len( return len(self._dist_tensors_for_program) or len(
...@@ -198,21 +198,168 @@ class DistributedContext: ...@@ -198,21 +198,168 @@ class DistributedContext:
def gradient_scale(self, gs): def gradient_scale(self, gs):
self._gradient_scale = gs self._gradient_scale = gs
def initialize(self): def _backup_serial_info(self, mode):
if not self._is_initialized: self._backup_serial_main_program_stack.append(
self._serial_main_program.clone())
self._backup_serial_startup_program_stack.append(
self._serial_startup_program.clone())
self._backup_pass_context_stack.append(
copy.deepcopy(self._pass_context))
self._backup_block_state_stack.append(copy.deepcopy(self._block_state))
def _backup_dist_info(self, mode):
self._backup_dist_tensors_for_program_stack.append(
copy.deepcopy(self._dist_tensors_for_program))
self._backup_dist_ops_for_program_stack.append(
copy.deepcopy(self._dist_ops_for_program))
def _backup(self, serial=True, serial_mode=None, dist=True, dist_mode=None):
# Use this function carefully
if serial:
self._backup_serial_info(serial_mode)
if dist:
self._backup_dist_info(dist_mode)
def _restore_serial_info(self, mode="to_backup"):
if mode == "to_backup":
self._serial_main_program = self._backup_serial_main_program_stack.pop(
)
self._serial_startup_program = self._backup_serial_startup_program_stack.pop(
)
elif mode == "to_original":
assert self._original_serial_main_program is not None
assert self._original_serial_startup_program is not None
self._serial_main_program = self._original_serial_main_program.clone( self._serial_main_program = self._original_serial_main_program.clone(
) )
self._serial_startup_program = self._original_serial_startup_program.clone( self._serial_startup_program = self._original_serial_startup_program.clone(
) )
# self._serial_main_program = self._original_serial_main_program
# self._serial_startup_program = self._original_serial_startup_program self._serial_optimizer = self._original_serial_optimizer
if self._original_serial_loss:
self._serial_loss = self._serial_main_program.global_block( if self._original_serial_loss:
).vars[self._original_serial_loss[0].name] if isinstance(self._original_serial_loss, list):
assert len(self._original_serial_loss) == 1
loss = self._original_serial_loss[0]
block_idx = loss.block.idx
var_name = loss.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
self._serial_loss = var
else: else:
self._serial_loss = self._original_serial_loss block_idx = self._original_serial_loss.block.idx
self._serial_optimizer = self._original_serial_optimizer var_name = self._original_serial_loss.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
self._serial_loss = var
for key, var_list in self._original_serial_feed_vars.items():
new_var_list = []
for var in var_list:
block_idx = var.block.idx
var_name = var.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_var_list.append(var)
self._serial_feed_vars[key] = new_var_list
for key, var_list in self._original_serial_fetch_vars.items():
new_var_list = []
for var in var_list:
block_idx = var.block.idx
var_name = var.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_var_list.append(var)
self._serial_fetch_vars[key] = new_var_list
self._pass_context = self._backup_pass_context_stack.pop()
self._block_state = self._backup_block_state_stack.pop()
def _restore_dist_info(self, mode="to_backup"):
if mode == "to_backup":
self._dist_tensors_for_program = self._backup_dist_tensors_for_program_stack.pop(
)
self._dist_ops_for_program = self._backup_dist_ops_for_program_stack.pop(
)
elif mode == "to_original":
assert self._original_dist_tensors_for_program
assert self._original_dist_ops_for_program
self._dist_tensors_for_program = copy.deepcopy(
self._original_dist_tensors_for_program)
self._dist_ops_for_program = copy.deepcopy(
self._original_dist_ops_for_program)
elif mode == "to_default":
new_tensors_ids = []
for tensor_id, dist_tensor in self._dist_tensors_for_program.items(
):
if tensor_id in self._tensors_ids:
dist_tensor.dist_attr.reset()
else:
new_tensors_ids.append(tensor_id)
for tensor_id in new_tensors_ids:
self._dist_tensors_for_program.pop(tensor_id)
new_ops_ids = []
for op_id, dist_op in self._dist_ops_for_program.items():
if op_id in self._ops_ids:
dist_op.dist_attr.reset()
else:
new_ops_ids.append(op_id)
for op_id in new_ops_ids:
self._dist_ops_for_program.pop(op_id)
else:
new_tensors_ids = []
for tensor_id, dist_tensor in self._dist_tensors_for_program.items(
):
new_tensors_ids.append(tensor_id)
for tensor_id in new_tensors_ids:
self._dist_tensors_for_program.pop(tensor_id)
new_ops_ids = []
for op_id, dist_op in self._dist_ops_for_program.items():
new_ops_ids.append(op_id)
for op_id in new_ops_ids:
self._dist_ops_for_program.pop(op_id)
self._dist_main_programs = {}
self._dist_startup_programs = {}
self._dist_op_context = DistributedOperatorContext()
self._need_copy_dist_attr_to_graph = True
self._process_meshes = []
def _restore(self,
serial=True,
serial_mode="to_backup",
dist=True,
dist_mode="to_backup"):
# Use this function carefully
if serial:
self._restore_serial_info(serial_mode)
if dist:
self._restore_dist_info(dist_mode)
def initialize(self):
if not self._is_initialized:
if not self._serial_main_program:
self._serial_main_program = self._original_serial_main_program
if not self._serial_startup_program:
self._serial_startup_program = self._original_serial_startup_program
if not self._serial_loss:
if isinstance(self._original_serial_loss, list):
assert len(self._original_serial_loss) == 1
self._serial_loss = self._original_serial_loss[0]
else:
self._serial_loss = self._original_serial_loss
if not self._serial_optimizer:
self._serial_optimizer = self._original_serial_optimizer
if not self._serial_feed_vars:
self._serial_feed_vars = self._original_serial_feed_vars
if not self._serial_fetch_vars:
self._serial_fetch_vars = self._original_serial_fetch_vars
self._init_dist_attr_for_program() self._init_dist_attr_for_program()
# Backup the original distributed information for later restore
self._original_dist_tensors_for_program = copy.deepcopy(
self._dist_tensors_for_program)
self._original_dist_ops_for_program = copy.deepcopy(
self._dist_ops_for_program)
self._tensors_ids = list(self._dist_tensors_for_program.keys()) self._tensors_ids = list(self._dist_tensors_for_program.keys())
self._ops_ids = list(self._dist_ops_for_program.keys()) self._ops_ids = list(self._dist_ops_for_program.keys())
set_flags({"FLAGS_convert_all_blocks": True}) set_flags({"FLAGS_convert_all_blocks": True})
...@@ -220,41 +367,9 @@ class DistributedContext: ...@@ -220,41 +367,9 @@ class DistributedContext:
core.Graph(self._serial_main_program.desc)) core.Graph(self._serial_main_program.desc))
self._init_dist_attr_for_graph() self._init_dist_attr_for_graph()
self._is_initialized = True self._is_initialized = True
self._need_copy_dist_attr_to_graph = False
# def reset(self, if self._need_copy_dist_attr_to_graph:
# skip_dist_tensors=None, self.copy_dist_attr_from_program_to_graph()
# skip_dist_ops=None,
# skip_tensor_dist_attr_fields=None,
# skip_op_dist_attr_fields=None):
# self._serial_main_program = self._original_serial_main_program.clone()
# self._serial_startup_program = self._original_serial_startup_program.clone()
# new_tensors_ids = []
# for tensor_id, dist_tensor in self._dist_tensors_for_program.items():
# if tensor_id in self._tensors_ids:
# dist_tensor.dist_attr.reset(skip_tensor_dist_attr_fields)
# else:
# new_tensors_ids.append(tensor_id)
# for tensor_id in new_tensors_ids:
# self._dist_tensors_for_program.pop(tensor_id)
# new_ops_ids = []
# for op_id, dist_op in self._dist_ops_for_program.items():
# if op_id in self._ops_ids:
# dist_op.dist_attr.reset(skip_op_dist_attr_fields)
# else:
# new_ops_ids.append(op_id)
# for op_id in new_ops_ids:
# self._dist_ops_for_program.pop(op_id)
# self.copy_dist_attr_from_program_to_graph()
# self._dist_main_programs = {}
# self._dist_startup_programs = {}
# self._pass_context = PassContext()
# self._dist_op_context = DistributedOperatorContext()
# self._process_meshes = []
def add_process_mesh(self, process_mesh): def add_process_mesh(self, process_mesh):
assert isinstance(process_mesh, ProcessMesh), \ assert isinstance(process_mesh, ProcessMesh), \
...@@ -423,6 +538,10 @@ class DistributedContext: ...@@ -423,6 +538,10 @@ class DistributedContext:
if current_dist_op is None: if current_dist_op is None:
dist_op = DistributedOperator(op) dist_op = DistributedOperator(op)
self.add_dist_op_for_program(dist_op) self.add_dist_op_for_program(dist_op)
self._original_dist_tensors_for_program = copy.deepcopy(
self._dist_tensors_for_program)
self._original_dist_ops_for_program = copy.deepcopy(
self._dist_ops_for_program)
def _order_nodes_by_program_order(self): def _order_nodes_by_program_order(self):
def _contains(nodes, target_node): def _contains(nodes, target_node):
...@@ -592,7 +711,7 @@ class DistributedContext: ...@@ -592,7 +711,7 @@ class DistributedContext:
op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node) op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node)
dist_op_for_program = self._dist_ops_for_program[op_id] dist_op_for_program = self._dist_ops_for_program[op_id]
dist_op_for_program.dist_attr = op_dist_attr_for_graph dist_op_for_program.dist_attr = op_dist_attr_for_graph
# TODO: the completion algorithm will skip orphan tensors, # TODO: the completion algorithm will skipped orphan tensors,
# here we just set there process_mesh to the first one. # here we just set there process_mesh to the first one.
for orphan_node in self._serial_orphan_tensor_nodes: for orphan_node in self._serial_orphan_tensor_nodes:
serial_tensor_id = orphan_node.var().id() serial_tensor_id = orphan_node.var().id()
...@@ -618,16 +737,21 @@ class DistributedContext: ...@@ -618,16 +737,21 @@ class DistributedContext:
tensor_shape = serial_tensor.shape tensor_shape = serial_tensor.shape
dims_mapping = dist_attr.dims_mapping dims_mapping = dist_attr.dims_mapping
process_mesh_shape = dist_attr.process_mesh.topology process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_processes = dist_attr.process_mesh.processes
# If the dimension of tensor is less than the sharding dimension of process mesh, # If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?) # we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)): for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1 dims_mapping[i] = -1
if dims_mapping[i] != -1 and len(process_mesh_processes) == 1:
dims_mapping[i] = -1
for dist_op in self._dist_ops_for_program.values(): for dist_op in self._dist_ops_for_program.values():
serial_op = dist_op.serial_op serial_op = dist_op.serial_op
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_processes = dist_attr.process_mesh.processes
for arg_name in serial_op.input_arg_names: for arg_name in serial_op.input_arg_names:
if dist_op.get_serial_input(arg_name) is None: if dist_op.get_serial_input(arg_name) is None:
tensor_shape = [] tensor_shape = []
...@@ -639,13 +763,15 @@ class DistributedContext: ...@@ -639,13 +763,15 @@ class DistributedContext:
else: else:
tensor_shape = dist_op.get_serial_input(arg_name).shape tensor_shape = dist_op.get_serial_input(arg_name).shape
dims_mapping = dist_attr.get_input_dims_mapping(arg_name) dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
process_mesh_shape = dist_attr.process_mesh.topology
# If the dimension of tensor is less than the sharding dimension of process mesh, # If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?) # we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)): for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1 dims_mapping[i] = -1
if dims_mapping[i] != -1 and len(
process_mesh_processes) == 1:
dims_mapping[i] = -1
for arg_name in serial_op.output_arg_names: for arg_name in serial_op.output_arg_names:
if dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.READER \ if dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.READER \
or dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \ or dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
...@@ -654,13 +780,18 @@ class DistributedContext: ...@@ -654,13 +780,18 @@ class DistributedContext:
else: else:
tensor_shape = dist_op.get_serial_output(arg_name).shape tensor_shape = dist_op.get_serial_output(arg_name).shape
dims_mapping = dist_attr.get_output_dims_mapping(arg_name) dims_mapping = dist_attr.get_output_dims_mapping(arg_name)
process_mesh_shape = dist_attr.process_mesh.topology
# If the dimension of tensor is less than the sharding dimension of process mesh, # If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?) # we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)): for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1 dims_mapping[i] = -1
if dims_mapping[i] != -1 and len(
process_mesh_processes) == 1:
dims_mapping[i] = -1
if len(process_mesh_processes) == 1:
dist_op.dist_attr.impl_type = "default"
dist_op.dist_attr.impl_idx = 0
def validate_dist_attr_for_program(self): def validate_dist_attr_for_program(self):
if not self._is_initialized: if not self._is_initialized:
...@@ -674,16 +805,20 @@ class DistributedContext: ...@@ -674,16 +805,20 @@ class DistributedContext:
dist_tensor.serial_tensor.name) dist_tensor.serial_tensor.name)
if (dist_tensor is not None) and ( if (dist_tensor is not None) and (
not dist_tensor.validate_dist_attr()): not dist_tensor.validate_dist_attr()):
assert False, "Tensor {} has a wrong distributed attributes {}.".format( assert False, "Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}.".format(
dist_tensor.serial_tensor.name, dist_tensor.dist_attr) dist_tensor.serial_tensor.name,
dist_tensor.desc.id(),
dist_tensor.desc.original_id(), dist_tensor.dist_attr)
for op in block.ops: for op in block.ops:
dist_op = self.get_dist_op_for_program(op) dist_op = self.get_dist_op_for_program(op)
assert dist_op is not None, \ assert dist_op is not None, \
"Operator {} does not have a distributed attribute.".format( "Operator {} does not have a distributed attribute.".format(
dist_op.serial_op.type) dist_op.serial_op.type)
if (dist_op is not None) and (not dist_op.validate_dist_attr()): if (dist_op is not None) and (not dist_op.validate_dist_attr()):
assert False, "Operator {} has a wrong distributed attributes {}.".format( assert False, "Operator {} (id: {}, original_id: {}) has a wrong distributed attributes {} .".format(
dist_op.serial_op.type, dist_op.dist_attr) dist_op.serial_op.type,
dist_op.serial_op.desc.id(),
dist_op.serial_op.desc.original_id(), dist_op.dist_attr)
return True return True
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
......
...@@ -41,7 +41,7 @@ class DistributedTensor: ...@@ -41,7 +41,7 @@ class DistributedTensor:
rank=None, rank=None,
shard_sizes=None): shard_sizes=None):
if not (isinstance(sizes, (list, tuple)) and if not (isinstance(sizes, (list, tuple)) and
all(map(lambda x: isinstance(x, int) and x > 0, sizes))): all(map(lambda x: isinstance(x, int) and x >= 0, sizes))):
raise ValueError( raise ValueError(
"The sizes must be list or tuple and item in sizes must be non-negative integer, but got {}". "The sizes must be list or tuple and item in sizes must be non-negative integer, but got {}".
format(sizes)) format(sizes))
...@@ -79,8 +79,11 @@ class DistributedTensor: ...@@ -79,8 +79,11 @@ class DistributedTensor:
local_sizes = [] local_sizes = []
# for even sharding, the local sizes of every rank are equal # for even sharding, the local sizes of every rank are equal
for idx, item in enumerate(global_sizes): for idx, item in enumerate(global_sizes):
if dims_mapping[idx] == -1: # This is a trick to avoid dims_mapping is []
val = dims_mapping[idx] if idx < len(dims_mapping) else -1
if val == -1:
local_sizes.append(item) local_sizes.append(item)
else: else:
local_sizes.append(item // topology[dims_mapping[idx]]) local_sizes.append(item // topology[dims_mapping[idx]])
......
...@@ -31,10 +31,11 @@ from paddle.fluid.backward import append_backward ...@@ -31,10 +31,11 @@ from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Operator from paddle.fluid.framework import Operator
from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import fleet
from paddle.distributed.utils import get_logger from paddle.distributed.utils import get_logger
from paddle.distributed.passes import new_pass, PassContext from paddle.distributed.passes import new_pass, PassContext
from .cluster import Cluster # from .cluster import Cluster, get_default_cluster
from .planner_v2 import Planner from .planner_v2 import Planner
from .parallelizer_v2 import Parallelizer from .parallelizer_v2 import Parallelizer
from .dist_op import DistributedOperator from .dist_op import DistributedOperator
...@@ -57,7 +58,11 @@ class Engine: ...@@ -57,7 +58,11 @@ class Engine:
self.inputs_spec = self._validate_spec(inputs_spec) self.inputs_spec = self._validate_spec(inputs_spec)
self.labels_spec = self._validate_spec(labels_spec) self.labels_spec = self._validate_spec(labels_spec)
self.cluster = cluster self.cluster = cluster
# if self.cluster is None:
# self.cluster = get_default_cluster()
self.strategy = strategy self.strategy = strategy
if self.strategy is None:
self.strategy = fleet.DistributedStrategy()
self._executor = None self._executor = None
self._cur_rank = paddle.distributed.get_rank() self._cur_rank = paddle.distributed.get_rank()
...@@ -69,11 +74,11 @@ class Engine: ...@@ -69,11 +74,11 @@ class Engine:
self._orig_main_prog = fluid.default_main_program() self._orig_main_prog = fluid.default_main_program()
self._orig_startup_prog = fluid.default_startup_program() self._orig_startup_prog = fluid.default_startup_program()
self._orig_dist_context = get_default_distributed_context() self._orig_dist_context = get_default_distributed_context()
self._dist_contexts = {}
self._serial_main_progs = {} self._serial_main_progs = {}
self._serial_startup_progs = {} self._serial_startup_progs = {}
self._dist_main_progs = defaultdict(dict) # dist main programs self._dist_main_progs = defaultdict(dict) # dist main programs
self._dist_startup_progs = defaultdict(dict) # dist startup programs self._dist_startup_progs = defaultdict(dict) # dist startup programs
self._dist_contexts = {}
self._feed_vars = {} self._feed_vars = {}
self._fetch_vars = {} self._fetch_vars = {}
...@@ -104,11 +109,17 @@ class Engine: ...@@ -104,11 +109,17 @@ class Engine:
parallelizer.parallel(self._cur_rank) parallelizer.parallel(self._cur_rank)
else: else:
parallelizer.parallel_all() parallelizer.parallel_all()
# Get the distributed main programs and startup programs # Get the current content from the distributed context
self._serial_main_progs[mode] = self._dist_contexts[
mode].serial_main_program
self._serial_startup_progs[mode] = self._dist_contexts[
mode].serial_startup_program
self._dist_main_progs[mode] = self._dist_contexts[ self._dist_main_progs[mode] = self._dist_contexts[
mode].dist_main_programs mode].dist_main_programs
self._dist_startup_progs[mode] = self._dist_contexts[ self._dist_startup_progs[mode] = self._dist_contexts[
mode].dist_startup_programs mode].dist_startup_programs
self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
# Init comm and startup program # Init comm and startup program
self._initialize(mode) self._initialize(mode)
...@@ -135,20 +146,23 @@ class Engine: ...@@ -135,20 +146,23 @@ class Engine:
inputs = [self._set_data_parallel(var) for var in inputs] inputs = [self._set_data_parallel(var) for var in inputs]
labels = [self._set_data_parallel(var) for var in labels] labels = [self._set_data_parallel(var) for var in labels]
self._feed_vars[mode] = {"inputs": inputs, "labels": labels} # self._feed_vars[mode] = {"inputs": inputs, "labels": labels}
feed_vars = {"inputs": inputs, "labels": labels}
self._fetch_vars[mode] = { # self._fetch_vars[mode] = {
# "outputs": flatten(outputs),
# "loss": losses,
# "metrics": metrics
# }
fetch_vars = {
"outputs": flatten(outputs), "outputs": flatten(outputs),
"loss": losses, "loss": losses,
"metrics": metrics "metrics": metrics
} }
self._serial_main_progs[mode] = serial_main_prog
self._serial_startup_progs[mode] = serial_startup_prog
self._dist_contexts[mode] = DistributedContext( self._dist_contexts[mode] = DistributedContext(
self._serial_main_progs[mode], self._serial_startup_progs[mode], serial_main_prog, serial_startup_prog, self._optimizer, losses,
self._optimizer, losses, self._feed_vars[mode], feed_vars, fetch_vars, self.cluster, self.strategy)
self._fetch_vars[mode], self.strategy)
self._dist_contexts[mode].gradient_scale = self._gradient_scale self._dist_contexts[mode].gradient_scale = self._gradient_scale
def _initialize(self, mode): def _initialize(self, mode):
......
...@@ -16,7 +16,7 @@ from .common import DistributedOperatorImplContainer ...@@ -16,7 +16,7 @@ from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from .common import find_best_compatible_distributed_operator_impl from .common import find_compatible_distributed_operator_impls
from . import dist_embedding from . import dist_embedding
from . import dist_matmul from . import dist_matmul
from . import dist_reshape from . import dist_reshape
......
...@@ -157,9 +157,7 @@ def register_distributed_operator_impl(op_type, dist_impl): ...@@ -157,9 +157,7 @@ def register_distributed_operator_impl(op_type, dist_impl):
assert False, "Must register distributed operator registry first." assert False, "Must register distributed operator registry first."
def find_best_compatible_distributed_operator_impl(dist_op, def find_compatible_distributed_operator_impls(dist_op, fwd=True, partial=True):
fwd=True,
partial=True):
""" """
Here just return the first compatible implemention. Here just return the first compatible implemention.
This will be improved by cost model in the future. This will be improved by cost model in the future.
......
...@@ -187,7 +187,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -187,7 +187,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for arg_name in op_desc.input_arg_names(): for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name) serial_tensor = dist_op.get_serial_input(arg_name)
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if serial_tensor.is_parameter: if serial_tensor is not None and serial_tensor.is_parameter:
for mapping in dims_mapping: for mapping in dims_mapping:
if mapping != -1: if mapping != -1:
return False return False
...@@ -217,7 +217,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -217,7 +217,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name) serial_tensor = dist_op.get_serial_output(arg_name)
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if serial_tensor.is_parameter: if serial_tensor is not None and serial_tensor.is_parameter:
for mapping in dims_mapping: for mapping in dims_mapping:
if mapping != -1: if mapping != -1:
return False return False
......
...@@ -22,7 +22,6 @@ from .common import register_distributed_operator_impl_container ...@@ -22,7 +22,6 @@ from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from .common import set_comm_op_dist_attr_for_program from .common import set_comm_op_dist_attr_for_program
from .dist_default import DistributedDefaultImpl0 from .dist_default import DistributedDefaultImpl0
from ..reshard import Resharder
from ..process_group import new_process_group from ..process_group import new_process_group
from ..utils import is_dim_shard, is_dim_replicate, _get_corresponding_rank from ..utils import is_dim_shard, is_dim_replicate, _get_corresponding_rank
from ..utils import compute_compatible_dim_mapping, set_dist_op_desc_original_id, _get_comm_group from ..utils import compute_compatible_dim_mapping, set_dist_op_desc_original_id, _get_comm_group
...@@ -324,6 +323,8 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -324,6 +323,8 @@ class DistributedPNormImpl(DistributedOperatorImpl):
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.processes
dims_mapping = [0] + [-1 for _ in range(len(new_X_grad.shape) - 1)] dims_mapping = [0] + [-1 for _ in range(len(new_X_grad.shape) - 1)]
from ..reshard import Resharder
partition_idx = Resharder.compute_partition_index( partition_idx = Resharder.compute_partition_index(
rank_id, new_X_grad.shape, dims_mapping, process_mesh_shape, rank_id, new_X_grad.shape, dims_mapping, process_mesh_shape,
process_mesh_group) process_mesh_group)
......
...@@ -35,7 +35,7 @@ class Parallelizer: ...@@ -35,7 +35,7 @@ class Parallelizer:
self._mode = mode self._mode = mode
self._completer = completer self._completer = completer
self._dist_context = dist_context self._dist_context = dist_context
self._dist_context.initialize() assert self._dist_context._is_initialized
self._pass_context = self._dist_context.pass_context self._pass_context = self._dist_context.pass_context
self._strategy = self._dist_context.strategy self._strategy = self._dist_context.strategy
...@@ -43,7 +43,9 @@ class Parallelizer: ...@@ -43,7 +43,9 @@ class Parallelizer:
world_process_group = get_world_process_group() world_process_group = get_world_process_group()
all_ranks = world_process_group.ranks all_ranks = world_process_group.ranks
for rank in all_ranks: for rank in all_ranks:
# self._dist_context._backup(serial=True, dist=True)
self.parallel(rank) self.parallel(rank)
# self._dist_context._restore(serial=True, dist=True)
def parallel(self, rank): def parallel(self, rank):
serial_main_program = self._dist_context.serial_main_program serial_main_program = self._dist_context.serial_main_program
...@@ -58,6 +60,7 @@ class Parallelizer: ...@@ -58,6 +60,7 @@ class Parallelizer:
self._apply_pre_optimization(serial_main_program, self._apply_pre_optimization(serial_main_program,
serial_startup_program, serial_loss, serial_startup_program, serial_loss,
serial_optimizer, params_grads) serial_optimizer, params_grads)
# Do logical partition # Do logical partition
partitioner = Partitioner(self._dist_context, rank) partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
...@@ -85,7 +88,6 @@ class Parallelizer: ...@@ -85,7 +88,6 @@ class Parallelizer:
resharder = Resharder(dist_main_prog, dist_startup_prog, rank, resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
self._dist_context, [], 1) self._dist_context, [], 1)
resharder.reshard() resharder.reshard()
# Clone program for test # Clone program for test
if self._mode != 'train': if self._mode != 'train':
dist_main_prog = dist_main_prog.clone(for_test=True) dist_main_prog = dist_main_prog.clone(for_test=True)
......
...@@ -16,6 +16,8 @@ from .completion import Completer ...@@ -16,6 +16,8 @@ from .completion import Completer
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .utils import print_program_with_dist_attr from .utils import print_program_with_dist_attr
# from .tuner.parallel_tuner import ParallelTuner
class Planner: class Planner:
def __init__(self, mode, dist_context): def __init__(self, mode, dist_context):
...@@ -24,19 +26,28 @@ class Planner: ...@@ -24,19 +26,28 @@ class Planner:
# NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need # NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need
# dependency of backward-forward ops in forward completion. # dependency of backward-forward ops in forward completion.
# TODO: The id mapping will be lost if we clone the original program.
default_ctx = get_default_distributed_context() default_ctx = get_default_distributed_context()
self._dist_context._dist_op_context = default_ctx.dist_op_context self._dist_context._dist_op_context = default_ctx.dist_op_context
self._dist_context.initialize() self._dist_context.initialize()
self._completer = Completer(self._dist_context) self._completer = Completer(self._dist_context)
self._strategy = dist_context.strategy
# if self._strategy.auto_search:
# self._parallel_tuner = ParallelTuner(
# self._dist_context, mode=self._mode)
@property @property
def completer(self): def completer(self):
return self._completer return self._completer
def plan(self): def plan(self):
self._completer.complete_forward_annotation() self._completer.complete_forward_annotation()
# if self._strategy.auto_search:
# self._parallel_tuner.tune()
# else:
# self._completer.complete_forward_annotation()
# parse forward sub block # parse forward sub block
self._dist_context.block_state.parse_forward_blocks( self._dist_context.block_state.parse_forward_blocks(
self._dist_context.serial_main_program) self._dist_context.serial_main_program)
# TODO: add the auto searcher
...@@ -324,10 +324,13 @@ def _get_corresponding_rank(dist_context, target_mesh, rank): ...@@ -324,10 +324,13 @@ def _get_corresponding_rank(dist_context, target_mesh, rank):
mesh.processes.index(rank)) mesh.processes.index(rank))
break break
assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format( # assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format(
rank) # rank)
return target_mesh.processes[_coordinate2linear_idx(mesh.topology, if coordinate is not None:
coordinate)] return target_mesh.processes[_coordinate2linear_idx(mesh.topology,
coordinate)]
else:
return target_mesh.processes[0]
def _get_unshard_dist_shape(var, dist_attr): def _get_unshard_dist_shape(var, dist_attr):
......
...@@ -31,4 +31,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -31,4 +31,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_cluster MODULES test_cluster ENVS ${dist_ENVS}) py_test_modules(test_cluster MODULES test_cluster ENVS ${dist_ENVS})
py_test_modules(test_comm_cost MODULES test_comm_cost ENVS ${dist_ENVS}) py_test_modules(test_comm_cost MODULES test_comm_cost ENVS ${dist_ENVS})
py_test_modules(test_comp_cost MODULES test_comp_cost ENVS ${dist_ENVS}) py_test_modules(test_comp_cost MODULES test_comp_cost ENVS ${dist_ENVS})
py_test_modules(test_dist_context MODULES test_dist_context ENVS ${dist_ENVS})
endif() endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import json
import paddle
import numpy as np
import paddle.nn as nn
import paddle.utils as utils
import paddle.static as static
import paddle.nn.functional as F
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
batch_size = 4
hidden_size = 1024
sequence_len = 512
_g_process_mesh = [[0, 1], [2, 3]]
def get_random_inputs_and_labels(input_shape, label_shape):
input = np.random.random(size=input_shape).astype('float32')
label = np.random.random(size=label_shape).astype('float32')
return input, label
def batch_generator_creator():
def __reader__():
for _ in range(batch_size):
batch_input, batch_label = get_random_inputs_and_labels(
[batch_size, sequence_len, hidden_size],
[batch_size, sequence_len, 1])
yield batch_input, batch_label
return __reader__
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
param_initializer = nn.initializer.Normal(
mean=0.0, std=initializer_range)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.linear0 = nn.Linear(
d_model,
dim_feedforward,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None)
self.linear1 = nn.Linear(
dim_feedforward,
d_model,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None)
def forward(self, input):
out = self.norm(input)
auto.shard_tensor(
self.linear0.weight,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [-1, 0]
})
out = self.linear0(out)
out = F.gelu(out, approximate=True)
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": _g_process_mesh[1],
"dims_mapping": [0, -1]
})
out = self.linear1(out)
return out
def get_program():
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
# fleet.init(is_collective=True, strategy=dist_strategy)
train_program = static.Program()
start_program = static.Program()
with static.program_guard(train_program, start_program):
# input
input = static.data(
name="input",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
label = static.data(
name="label", shape=[batch_size, sequence_len, 1], dtype='float32')
data_holder = [input, label]
# dataloader
dataloader = paddle.io.DataLoader.from_generator(
feed_list=data_holder, capacity=4 * batch_size, iterable=False)
dataloader.set_batch_generator(
batch_generator_creator(), places=paddle.static.cuda_places())
# data dist_attr
auto.shard_tensor(
input,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [0, -1, -1]
})
auto.shard_tensor(
label,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [0, -1, -1]
})
mlp_start = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
pred = mlp_start(input)
mlp_mid = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
pred = mlp_mid(pred)
mlp_end = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
pred = mlp_end(pred)
error_cost = paddle.nn.functional.square_error_cost(pred, label)
loss = paddle.mean(error_cost)
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
feed_vars = {"inputs": [input], "labels": [label]}
fetch_vars = {"loss": [loss]}
return train_program, start_program, dataloader, loss, optimizer, feed_vars, fetch_vars
class TestDistributedContext(unittest.TestCase):
def test_backup_restore(self):
train_program, start_program, dataloader, loss, optimizer, feed_vars, fetch_vars = get_program(
)
dist_context = DistributedContext(train_program, start_program,
optimizer, loss, feed_vars,
fetch_vars)
dist_context.initialize()
dist_context._backup(serial=True, dist=True)
dist_context._restore(
serial=True,
serial_mode="to_backup",
dist=True,
dist_mode="to_backup")
dist_context._backup(serial=True, dist=True)
dist_context._restore(
serial=True,
serial_mode="to_original",
dist=True,
dist_mode="to_original")
dist_context._backup(serial=True, dist=True)
dist_context._restore(serial=True, dist=True, dist_mode="to_default")
dist_context._backup(serial=True, dist=True)
dist_context._restore(serial=True, dist=True, dist_mode="to_nothing")
if __name__ == "__main__":
unittest.main()
...@@ -94,7 +94,8 @@ class TestDistSlice(unittest.TestCase): ...@@ -94,7 +94,8 @@ class TestDistSlice(unittest.TestCase):
ops = dist_main_prog.global_block().ops ops = dist_main_prog.global_block().ops
for op in ops: for op in ops:
op_dist_attr = dist_context.get_op_dist_attr_for_program(op) op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_type == "slice" # We amend this impl_type after completion
assert op_dist_attr.impl_type == "default"
for out in op.output_arg_names: for out in op.output_arg_names:
var_dims_mapping = op_dist_attr.get_output_dims_mapping(out) var_dims_mapping = op_dist_attr.get_output_dims_mapping(out)
ref_dims_mapping = [-1 for i in range(len(var_dims_mapping))] ref_dims_mapping = [-1 for i in range(len(var_dims_mapping))]
......
...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.partitioner import Partitioner ...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import make_data_unshard from paddle.distributed.auto_parallel.utils import make_data_unshard
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context
from paddle.distributed.auto_parallel.operators import find_best_compatible_distributed_operator_impl from paddle.distributed.auto_parallel.operators import find_compatible_distributed_operator_impls
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
......
...@@ -28,7 +28,7 @@ from paddle.distributed.auto_parallel.partitioner import Partitioner ...@@ -28,7 +28,7 @@ from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import make_data_unshard from paddle.distributed.auto_parallel.utils import make_data_unshard
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context
from paddle.distributed.auto_parallel.operators import find_best_compatible_distributed_operator_impl from paddle.distributed.auto_parallel.operators import find_compatible_distributed_operator_impls
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册