未验证 提交 010aba33 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Add miscellaneous improvements (#43108)

* [Auto Parallel] Add the parallel tuner

* [Auto Parallel] Improve the parallel tuner and fix some bugs

* upodate cost model

* update import Resharder by dist op

* update cost model

* fix comp cost bug

* update cost model

* [Auto Parallel] Amend the dist attr for #processses=1

* update cost model and tuner

* update cost model and tuner

* update cost model and tuner

* update cluster

* update reshard

* [Auto Parallel] Add the estimation from the cost model

* [Auto Parallel] Reimplement the backup and restore functions

* [Auto Parallel] Fix the bugs of the parallel tuner

* [Auto Parallel] Update the engine api and dist context

* [Auto Parallel] Work around the high order grad problem

* [Auto Parallel] Add some miscellaneous improvements

* [Auto Parallel] Add a unittest for DistributedContext
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
上级 5f2c251c
...@@ -20,7 +20,7 @@ from paddle.fluid import core ...@@ -20,7 +20,7 @@ from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
from .utils import print_program_with_dist_attr from .utils import print_program_with_dist_attr
from .operators import find_best_compatible_distributed_operator_impl from .operators import find_compatible_distributed_operator_impls
from .dist_context import get_default_distributed_context, _node_id from .dist_context import get_default_distributed_context, _node_id
from .dist_tensor import DistributedTensor from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator from .dist_op import DistributedOperator
...@@ -238,13 +238,17 @@ class Completer: ...@@ -238,13 +238,17 @@ class Completer:
tensor_desc.name()) tensor_desc.name())
compatible_dims_mapping = compute_compatible_dims_mapping( compatible_dims_mapping = compute_compatible_dims_mapping(
[op_dims_mapping, tensor_dims_mapping]) [op_dims_mapping, tensor_dims_mapping])
if not _validate_dims_mapping(
compatible_dims_mapping,
op_dist_attr.process_mesh):
continue
if (compatible_dims_mapping is not None) and \ if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != op_dims_mapping): (compatible_dims_mapping != op_dims_mapping):
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
tensor_desc.name(), compatible_dims_mapping) tensor_desc.name(), compatible_dims_mapping)
changed = True changed = True
# Find the most compatible implemenetations from the distributed operator # Find the most compatible implemenetations from the distributed operator
op_dist_impls = find_best_compatible_distributed_operator_impl( op_dist_impls = find_compatible_distributed_operator_impls(
dist_op, fwd=True) dist_op, fwd=True)
if op_dist_impls is not None: if op_dist_impls is not None:
not_compatible = True not_compatible = True
...@@ -254,7 +258,8 @@ class Completer: ...@@ -254,7 +258,8 @@ class Completer:
dim_changed = op_dist_impl.update_dims_mapping(dist_op) dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed: if dim_changed:
changed = True changed = True
if op_dist_impl.is_auto_compatible(dist_op): if op_dist_impl.is_auto_compatible(dist_op) \
and dist_op.validate_dist_attr():
if op_dist_impl.type == "elementwise": if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default" op_dist_attr.impl_type = "default"
else: else:
...@@ -289,13 +294,17 @@ class Completer: ...@@ -289,13 +294,17 @@ class Completer:
tensor_desc.name()) tensor_desc.name())
compatible_dims_mapping = compute_compatible_dims_mapping( compatible_dims_mapping = compute_compatible_dims_mapping(
[op_dims_mapping, tensor_dims_mapping]) [op_dims_mapping, tensor_dims_mapping])
if not _validate_dims_mapping(
compatible_dims_mapping,
op_dist_attr.process_mesh):
continue
if (compatible_dims_mapping is not None) and \ if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != op_dims_mapping): (compatible_dims_mapping != op_dims_mapping):
op_dist_attr.set_output_dims_mapping( op_dist_attr.set_output_dims_mapping(
tensor_desc.name(), compatible_dims_mapping) tensor_desc.name(), compatible_dims_mapping)
changed = True changed = True
# Find the most compatible implemenetations from the distributed operator # Find the most compatible implemenetations from the distributed operator
op_dist_impls = find_best_compatible_distributed_operator_impl( op_dist_impls = find_compatible_distributed_operator_impls(
dist_op, fwd=False) dist_op, fwd=False)
if op_dist_impls is not None: if op_dist_impls is not None:
not_compatible = True not_compatible = True
...@@ -305,8 +314,8 @@ class Completer: ...@@ -305,8 +314,8 @@ class Completer:
dim_changed = op_dist_impl.update_dims_mapping(dist_op) dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed: if dim_changed:
changed = True changed = True
if op_dist_impl.is_auto_compatible(dist_op): if op_dist_impl.is_auto_compatible(dist_op) \
not_compatible = False and dist_op.validate_dist_attr():
if op_dist_impl.type == "elementwise": if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default" op_dist_attr.impl_type = "default"
else: else:
...@@ -352,6 +361,23 @@ class Completer: ...@@ -352,6 +361,23 @@ class Completer:
changed = True changed = True
return changed return changed
def _update_dims_mapping_for_special(self):
# Set the dims_mapping of a tensor to the dims_mapping inside the op which produces it
op_nodes = self._dist_context._serial_ordered_op_nodes
for op_node in op_nodes:
op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node)
for tensor_node in op_node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER:
continue
tensor_desc = tensor_node.var()
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh:
op_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_desc.name())
tensor_dist_attr.dims_mapping = op_dims_mapping
def _update_dims_mapping(self): def _update_dims_mapping(self):
# Complete dims_mapping for each node # Complete dims_mapping for each node
reach_fix_point = False reach_fix_point = False
...@@ -378,6 +404,7 @@ class Completer: ...@@ -378,6 +404,7 @@ class Completer:
reach_fix_point = False reach_fix_point = False
else: else:
reach_fix_point = True reach_fix_point = True
self._update_dims_mapping_for_special()
def _update_process_mesh_by_nearest(self, op_node, nearest_op_node): def _update_process_mesh_by_nearest(self, op_node, nearest_op_node):
op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node) op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node)
...@@ -685,7 +712,7 @@ class Completer: ...@@ -685,7 +712,7 @@ class Completer:
# Step 3: adjust the process meshes for special ops # Step 3: adjust the process meshes for special ops
self._update_process_mesh_for_specials() self._update_process_mesh_for_specials()
# Step 4: adjust the process meshes between graphs # Step 4: adjust the process meshes between graphs
self._update_process_mesh_between_graphs() self._update_process_mesh_between_graphs()
def _prepare(self): def _prepare(self):
...@@ -727,14 +754,14 @@ class Completer: ...@@ -727,14 +754,14 @@ class Completer:
""" Complete annotation for the partial annotated serial_main_program. """ Complete annotation for the partial annotated serial_main_program.
Arguments: Arguments:
serial_main_program: partial annotated serial_main_program. serial_main_program: partial annotated serial_main_program.
Returns: Returns:e
serial_main_program: completed annotated serial_main_program. serial_main_program: completed annotated serial_main_program.
""" """
if serial_main_program is None: if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program serial_main_program = self._dist_context.serial_main_program
else: else:
self._dist_context.serial_main_program = serial_main_program self._dist_context._serial_main_program = serial_main_program
self._dist_context.initialize() self._dist_context.initialize()
...@@ -757,13 +784,18 @@ class Completer: ...@@ -757,13 +784,18 @@ class Completer:
return serial_main_program return serial_main_program
def _complete_high_order_grad_annotation(self, serial_main_program): def _complete_high_order_grad_annotation(self, serial_main_program=None):
""" """
NOTE: NOTE:
[HighOrderGrad] Complete the annotation of vars and ops only for high order gradient. [HighOrderGrad] Complete the annotation of vars and ops only for high order gradient.
This function is temporary to support high order gradient, and will be removed in the future. This function is temporary to support high order gradient, and will be removed in the future.
""" """
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context._serial_main_program = serial_main_program
def _is_grad_var_name(name): def _is_grad_var_name(name):
if "@GRAD" in name: if "@GRAD" in name:
return True return True
...@@ -917,12 +949,13 @@ class Completer: ...@@ -917,12 +949,13 @@ class Completer:
self._dist_context.set_op_dist_attr_for_program( self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr) grad_op, grad_op_dist_attr)
def complete_backward_annotation(self, serial_main_program): def complete_backward_annotation(self, serial_main_program=None):
"""Complete the annotation of vars and ops in the backward phase for parallel program.""" """Complete the annotation of vars and ops in the backward phase for parallel program."""
if serial_main_program is None: if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program serial_main_program = self._dist_context.serial_main_program
else: else:
self._dist_context.serial_main_program = serial_main_program self._dist_context._serial_main_program = serial_main_program
def _is_grad_var_name(name): def _is_grad_var_name(name):
if "@GRAD" in name: if "@GRAD" in name:
...@@ -1032,6 +1065,9 @@ class Completer: ...@@ -1032,6 +1065,9 @@ class Completer:
grad_op_dist_attr.process_mesh = ref_mesh grad_op_dist_attr.process_mesh = ref_mesh
self._dist_context.set_op_dist_attr_for_program( self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr) grad_op, grad_op_dist_attr)
grad_op_dist_attr.impl_type = fwd_op_dist_attr.impl_type
grad_op_dist_attr.impl_idx = fwd_op_dist_attr.impl_idx
continue continue
fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
...@@ -1078,6 +1114,8 @@ class Completer: ...@@ -1078,6 +1114,8 @@ class Completer:
grad_op_dist_attr.set_output_dims_mapping(output_name, grad_op_dist_attr.set_output_dims_mapping(output_name,
ref_dims_mapping) ref_dims_mapping)
grad_op_dist_attr.impl_type = fwd_op_dist_attr.impl_type
grad_op_dist_attr.impl_idx = fwd_op_dist_attr.impl_idx
self._dist_context.set_op_dist_attr_for_program( self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr) grad_op, grad_op_dist_attr)
...@@ -1111,6 +1149,8 @@ class Completer: ...@@ -1111,6 +1149,8 @@ class Completer:
var_name, ref_fwd_dims_mapping) var_name, ref_fwd_dims_mapping)
grad_op_dist_attr.set_output_dims_mapping( grad_op_dist_attr.set_output_dims_mapping(
output_name, ref_fwd_dims_mapping) output_name, ref_fwd_dims_mapping)
grad_op_dist_attr.impl_type = "default"
grad_op_dist_attr.impl_idx = 0
elif grad_op.type == 'fill_zeros_like': elif grad_op.type == 'fill_zeros_like':
ref_var_name = grad_op.input_arg_names[0] ref_var_name = grad_op.input_arg_names[0]
...@@ -1142,12 +1182,13 @@ class Completer: ...@@ -1142,12 +1182,13 @@ class Completer:
self._dist_context.set_op_dist_attr_for_program( self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr) grad_op, grad_op_dist_attr)
def complete_update_annotation(self, serial_main_program=None): def complete_update_annotation(self, serial_main_program):
"""Complete the annotation of vars and ops in the update phase for parallel program.""" """Complete the annotation of vars and ops in the update phase for parallel program."""
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program # Notice: serial_main_program is actually a dist_main_program of current rank,
else: # and must be passed into this function.
self._dist_context.serial_main_program = serial_main_program # TODO: We should fix this behavior.
ops = list(serial_main_program.global_block().ops) ops = list(serial_main_program.global_block().ops)
vars = serial_main_program.global_block().vars vars = serial_main_program.global_block().vars
learning_rate_completed = False learning_rate_completed = False
...@@ -1304,7 +1345,7 @@ class Completer: ...@@ -1304,7 +1345,7 @@ class Completer:
dist_op.dist_attr.process_mesh = world_ranks dist_op.dist_attr.process_mesh = world_ranks
# Find the most compatible implemenetations from the distributed operator # Find the most compatible implemenetations from the distributed operator
op_dist_impls = find_best_compatible_distributed_operator_impl( op_dist_impls = find_compatible_distributed_operator_impls(
dist_op, fwd=True) dist_op, fwd=True)
if op_dist_impls is not None: if op_dist_impls is not None:
backup_op_dist_attr = copy.deepcopy(dist_op.dist_attr) backup_op_dist_attr = copy.deepcopy(dist_op.dist_attr)
......
...@@ -132,15 +132,17 @@ class TensorDistributedAttribute: ...@@ -132,15 +132,17 @@ class TensorDistributedAttribute:
key, dist_attr) key, dist_attr)
self._is_annotated = copy.deepcopy(dist_attr._is_annotated) self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
# def reset(self, skip_dist_attr_field_names): def reset(self, skip_dist_attr_field_names=None):
# if skip_dist_attr_field_names is not None \ if skip_dist_attr_field_names is None or \
# and "process_mesh" not in skip_dist_attr_field_names: (skip_dist_attr_field_names is not None \
# self._process_mesh = None and "process_mesh" not in skip_dist_attr_field_names):
# if skip_dist_attr_field_names is not None \ self._process_mesh = None
# and "dims_mapping" not in skip_dist_attr_field_names: if skip_dist_attr_field_names is None or \
# for i in enumerate(self._dims_mapping): (skip_dist_attr_field_names is not None \
# self._dims_mapping[i] = -1 and "dims_mapping" not in skip_dist_attr_field_names):
# self._is_annotated = {} for i, _ in enumerate(self._dims_mapping):
self._dims_mapping[i] = -1
self._is_annotated = {}
def is_annotated(self, dist_attr_field_name): def is_annotated(self, dist_attr_field_name):
return self._is_annotated.get(dist_attr_field_name, False) return self._is_annotated.get(dist_attr_field_name, False)
...@@ -272,6 +274,9 @@ class OperatorDistributedAttribute: ...@@ -272,6 +274,9 @@ class OperatorDistributedAttribute:
dist_attr_object.init(dist_attr) dist_attr_object.init(dist_attr)
self._inputs_dist_attrs[name] = dist_attr_object self._inputs_dist_attrs[name] = dist_attr_object
# def del_input_dist_attr(self, name):
# del self._inputs_dist_attrs[name]
def get_output_dist_attr(self, name): def get_output_dist_attr(self, name):
return self._outputs_dist_attrs.get(name, None) return self._outputs_dist_attrs.get(name, None)
...@@ -280,6 +285,9 @@ class OperatorDistributedAttribute: ...@@ -280,6 +285,9 @@ class OperatorDistributedAttribute:
dist_attr_object.init(dist_attr) dist_attr_object.init(dist_attr)
self._outputs_dist_attrs[name] = dist_attr_object self._outputs_dist_attrs[name] = dist_attr_object
# def del_output_dist_attr(self, name):
# del self._inputs_dist_attrs[name]
def get_input_dims_mapping(self, name): def get_input_dims_mapping(self, name):
input_dist_attr = self.get_input_dist_attr(name) input_dist_attr = self.get_input_dist_attr(name)
if input_dist_attr: if input_dist_attr:
...@@ -374,17 +382,18 @@ class OperatorDistributedAttribute: ...@@ -374,17 +382,18 @@ class OperatorDistributedAttribute:
"ProcessMeshes in DistributedOperator must be the same." "ProcessMeshes in DistributedOperator must be the same."
self.process_mesh = shared_process_mesh self.process_mesh = shared_process_mesh
# def reset(self, skip_dist_attr_field_names): def reset(self, skip_dist_attr_field_names=None):
# for tensor_dist_attr in self.inputs_dist_attrs.values(): for tensor_dist_attr in self.inputs_dist_attrs.values():
# tensor_dist_attr.reset(skip_dist_attr_field_names) tensor_dist_attr.reset(skip_dist_attr_field_names)
# for tensor_dist_attr in self.outputs_dist_attrs.values(): for tensor_dist_attr in self.outputs_dist_attrs.values():
# tensor_dist_attr.reset(skip_dist_attr_field_names) tensor_dist_attr.reset(skip_dist_attr_field_names)
# if skip_dist_attr_field_names is not None \ if skip_dist_attr_field_names is None or \
# and "process_mesh" not in skip_dist_attr_field_names: (skip_dist_attr_field_names is not None \
# self.process_mesh = None and "process_mesh" not in skip_dist_attr_field_names):
# self.impl_type = "default" self._process_mesh = None
# self.impl_idx = 0 self.impl_type = "default"
# self._is_annotated = {} self.impl_idx = 0
self._is_annotated = {}
def is_annotated(self, attr_name): def is_annotated(self, attr_name):
return self._is_annotated.get(attr_name, False) return self._is_annotated.get(attr_name, False)
......
...@@ -41,7 +41,7 @@ class DistributedTensor: ...@@ -41,7 +41,7 @@ class DistributedTensor:
rank=None, rank=None,
shard_sizes=None): shard_sizes=None):
if not (isinstance(sizes, (list, tuple)) and if not (isinstance(sizes, (list, tuple)) and
all(map(lambda x: isinstance(x, int) and x > 0, sizes))): all(map(lambda x: isinstance(x, int) and x >= 0, sizes))):
raise ValueError( raise ValueError(
"The sizes must be list or tuple and item in sizes must be non-negative integer, but got {}". "The sizes must be list or tuple and item in sizes must be non-negative integer, but got {}".
format(sizes)) format(sizes))
...@@ -79,8 +79,11 @@ class DistributedTensor: ...@@ -79,8 +79,11 @@ class DistributedTensor:
local_sizes = [] local_sizes = []
# for even sharding, the local sizes of every rank are equal # for even sharding, the local sizes of every rank are equal
for idx, item in enumerate(global_sizes): for idx, item in enumerate(global_sizes):
if dims_mapping[idx] == -1: # This is a trick to avoid dims_mapping is []
val = dims_mapping[idx] if idx < len(dims_mapping) else -1
if val == -1:
local_sizes.append(item) local_sizes.append(item)
else: else:
local_sizes.append(item // topology[dims_mapping[idx]]) local_sizes.append(item // topology[dims_mapping[idx]])
......
...@@ -31,10 +31,11 @@ from paddle.fluid.backward import append_backward ...@@ -31,10 +31,11 @@ from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Operator from paddle.fluid.framework import Operator
from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import fleet
from paddle.distributed.utils import get_logger from paddle.distributed.utils import get_logger
from paddle.distributed.passes import new_pass, PassContext from paddle.distributed.passes import new_pass, PassContext
from .cluster import Cluster # from .cluster import Cluster, get_default_cluster
from .planner_v2 import Planner from .planner_v2 import Planner
from .parallelizer_v2 import Parallelizer from .parallelizer_v2 import Parallelizer
from .dist_op import DistributedOperator from .dist_op import DistributedOperator
...@@ -57,7 +58,11 @@ class Engine: ...@@ -57,7 +58,11 @@ class Engine:
self.inputs_spec = self._validate_spec(inputs_spec) self.inputs_spec = self._validate_spec(inputs_spec)
self.labels_spec = self._validate_spec(labels_spec) self.labels_spec = self._validate_spec(labels_spec)
self.cluster = cluster self.cluster = cluster
# if self.cluster is None:
# self.cluster = get_default_cluster()
self.strategy = strategy self.strategy = strategy
if self.strategy is None:
self.strategy = fleet.DistributedStrategy()
self._executor = None self._executor = None
self._cur_rank = paddle.distributed.get_rank() self._cur_rank = paddle.distributed.get_rank()
...@@ -69,11 +74,11 @@ class Engine: ...@@ -69,11 +74,11 @@ class Engine:
self._orig_main_prog = fluid.default_main_program() self._orig_main_prog = fluid.default_main_program()
self._orig_startup_prog = fluid.default_startup_program() self._orig_startup_prog = fluid.default_startup_program()
self._orig_dist_context = get_default_distributed_context() self._orig_dist_context = get_default_distributed_context()
self._dist_contexts = {}
self._serial_main_progs = {} self._serial_main_progs = {}
self._serial_startup_progs = {} self._serial_startup_progs = {}
self._dist_main_progs = defaultdict(dict) # dist main programs self._dist_main_progs = defaultdict(dict) # dist main programs
self._dist_startup_progs = defaultdict(dict) # dist startup programs self._dist_startup_progs = defaultdict(dict) # dist startup programs
self._dist_contexts = {}
self._feed_vars = {} self._feed_vars = {}
self._fetch_vars = {} self._fetch_vars = {}
...@@ -104,11 +109,17 @@ class Engine: ...@@ -104,11 +109,17 @@ class Engine:
parallelizer.parallel(self._cur_rank) parallelizer.parallel(self._cur_rank)
else: else:
parallelizer.parallel_all() parallelizer.parallel_all()
# Get the distributed main programs and startup programs # Get the current content from the distributed context
self._serial_main_progs[mode] = self._dist_contexts[
mode].serial_main_program
self._serial_startup_progs[mode] = self._dist_contexts[
mode].serial_startup_program
self._dist_main_progs[mode] = self._dist_contexts[ self._dist_main_progs[mode] = self._dist_contexts[
mode].dist_main_programs mode].dist_main_programs
self._dist_startup_progs[mode] = self._dist_contexts[ self._dist_startup_progs[mode] = self._dist_contexts[
mode].dist_startup_programs mode].dist_startup_programs
self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
# Init comm and startup program # Init comm and startup program
self._initialize(mode) self._initialize(mode)
...@@ -135,20 +146,23 @@ class Engine: ...@@ -135,20 +146,23 @@ class Engine:
inputs = [self._set_data_parallel(var) for var in inputs] inputs = [self._set_data_parallel(var) for var in inputs]
labels = [self._set_data_parallel(var) for var in labels] labels = [self._set_data_parallel(var) for var in labels]
self._feed_vars[mode] = {"inputs": inputs, "labels": labels} # self._feed_vars[mode] = {"inputs": inputs, "labels": labels}
feed_vars = {"inputs": inputs, "labels": labels}
self._fetch_vars[mode] = { # self._fetch_vars[mode] = {
# "outputs": flatten(outputs),
# "loss": losses,
# "metrics": metrics
# }
fetch_vars = {
"outputs": flatten(outputs), "outputs": flatten(outputs),
"loss": losses, "loss": losses,
"metrics": metrics "metrics": metrics
} }
self._serial_main_progs[mode] = serial_main_prog
self._serial_startup_progs[mode] = serial_startup_prog
self._dist_contexts[mode] = DistributedContext( self._dist_contexts[mode] = DistributedContext(
self._serial_main_progs[mode], self._serial_startup_progs[mode], serial_main_prog, serial_startup_prog, self._optimizer, losses,
self._optimizer, losses, self._feed_vars[mode], feed_vars, fetch_vars, self.cluster, self.strategy)
self._fetch_vars[mode], self.strategy)
self._dist_contexts[mode].gradient_scale = self._gradient_scale self._dist_contexts[mode].gradient_scale = self._gradient_scale
def _initialize(self, mode): def _initialize(self, mode):
......
...@@ -16,7 +16,7 @@ from .common import DistributedOperatorImplContainer ...@@ -16,7 +16,7 @@ from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from .common import find_best_compatible_distributed_operator_impl from .common import find_compatible_distributed_operator_impls
from . import dist_embedding from . import dist_embedding
from . import dist_matmul from . import dist_matmul
from . import dist_reshape from . import dist_reshape
......
...@@ -157,9 +157,7 @@ def register_distributed_operator_impl(op_type, dist_impl): ...@@ -157,9 +157,7 @@ def register_distributed_operator_impl(op_type, dist_impl):
assert False, "Must register distributed operator registry first." assert False, "Must register distributed operator registry first."
def find_best_compatible_distributed_operator_impl(dist_op, def find_compatible_distributed_operator_impls(dist_op, fwd=True, partial=True):
fwd=True,
partial=True):
""" """
Here just return the first compatible implemention. Here just return the first compatible implemention.
This will be improved by cost model in the future. This will be improved by cost model in the future.
......
...@@ -187,7 +187,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -187,7 +187,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for arg_name in op_desc.input_arg_names(): for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name) serial_tensor = dist_op.get_serial_input(arg_name)
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if serial_tensor.is_parameter: if serial_tensor is not None and serial_tensor.is_parameter:
for mapping in dims_mapping: for mapping in dims_mapping:
if mapping != -1: if mapping != -1:
return False return False
...@@ -217,7 +217,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -217,7 +217,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name) serial_tensor = dist_op.get_serial_output(arg_name)
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if serial_tensor.is_parameter: if serial_tensor is not None and serial_tensor.is_parameter:
for mapping in dims_mapping: for mapping in dims_mapping:
if mapping != -1: if mapping != -1:
return False return False
......
...@@ -22,7 +22,6 @@ from .common import register_distributed_operator_impl_container ...@@ -22,7 +22,6 @@ from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from .common import set_comm_op_dist_attr_for_program from .common import set_comm_op_dist_attr_for_program
from .dist_default import DistributedDefaultImpl0 from .dist_default import DistributedDefaultImpl0
from ..reshard import Resharder
from ..process_group import new_process_group from ..process_group import new_process_group
from ..utils import is_dim_shard, is_dim_replicate, _get_corresponding_rank from ..utils import is_dim_shard, is_dim_replicate, _get_corresponding_rank
from ..utils import compute_compatible_dim_mapping, set_dist_op_desc_original_id, _get_comm_group from ..utils import compute_compatible_dim_mapping, set_dist_op_desc_original_id, _get_comm_group
...@@ -324,6 +323,8 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -324,6 +323,8 @@ class DistributedPNormImpl(DistributedOperatorImpl):
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.processes
dims_mapping = [0] + [-1 for _ in range(len(new_X_grad.shape) - 1)] dims_mapping = [0] + [-1 for _ in range(len(new_X_grad.shape) - 1)]
from ..reshard import Resharder
partition_idx = Resharder.compute_partition_index( partition_idx = Resharder.compute_partition_index(
rank_id, new_X_grad.shape, dims_mapping, process_mesh_shape, rank_id, new_X_grad.shape, dims_mapping, process_mesh_shape,
process_mesh_group) process_mesh_group)
......
...@@ -35,7 +35,7 @@ class Parallelizer: ...@@ -35,7 +35,7 @@ class Parallelizer:
self._mode = mode self._mode = mode
self._completer = completer self._completer = completer
self._dist_context = dist_context self._dist_context = dist_context
self._dist_context.initialize() assert self._dist_context._is_initialized
self._pass_context = self._dist_context.pass_context self._pass_context = self._dist_context.pass_context
self._strategy = self._dist_context.strategy self._strategy = self._dist_context.strategy
...@@ -43,7 +43,9 @@ class Parallelizer: ...@@ -43,7 +43,9 @@ class Parallelizer:
world_process_group = get_world_process_group() world_process_group = get_world_process_group()
all_ranks = world_process_group.ranks all_ranks = world_process_group.ranks
for rank in all_ranks: for rank in all_ranks:
# self._dist_context._backup(serial=True, dist=True)
self.parallel(rank) self.parallel(rank)
# self._dist_context._restore(serial=True, dist=True)
def parallel(self, rank): def parallel(self, rank):
serial_main_program = self._dist_context.serial_main_program serial_main_program = self._dist_context.serial_main_program
...@@ -58,6 +60,7 @@ class Parallelizer: ...@@ -58,6 +60,7 @@ class Parallelizer:
self._apply_pre_optimization(serial_main_program, self._apply_pre_optimization(serial_main_program,
serial_startup_program, serial_loss, serial_startup_program, serial_loss,
serial_optimizer, params_grads) serial_optimizer, params_grads)
# Do logical partition # Do logical partition
partitioner = Partitioner(self._dist_context, rank) partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
...@@ -85,7 +88,6 @@ class Parallelizer: ...@@ -85,7 +88,6 @@ class Parallelizer:
resharder = Resharder(dist_main_prog, dist_startup_prog, rank, resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
self._dist_context, [], 1) self._dist_context, [], 1)
resharder.reshard() resharder.reshard()
# Clone program for test # Clone program for test
if self._mode != 'train': if self._mode != 'train':
dist_main_prog = dist_main_prog.clone(for_test=True) dist_main_prog = dist_main_prog.clone(for_test=True)
......
...@@ -16,6 +16,8 @@ from .completion import Completer ...@@ -16,6 +16,8 @@ from .completion import Completer
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .utils import print_program_with_dist_attr from .utils import print_program_with_dist_attr
# from .tuner.parallel_tuner import ParallelTuner
class Planner: class Planner:
def __init__(self, mode, dist_context): def __init__(self, mode, dist_context):
...@@ -24,19 +26,28 @@ class Planner: ...@@ -24,19 +26,28 @@ class Planner:
# NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need # NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need
# dependency of backward-forward ops in forward completion. # dependency of backward-forward ops in forward completion.
# TODO: The id mapping will be lost if we clone the original program.
default_ctx = get_default_distributed_context() default_ctx = get_default_distributed_context()
self._dist_context._dist_op_context = default_ctx.dist_op_context self._dist_context._dist_op_context = default_ctx.dist_op_context
self._dist_context.initialize() self._dist_context.initialize()
self._completer = Completer(self._dist_context) self._completer = Completer(self._dist_context)
self._strategy = dist_context.strategy
# if self._strategy.auto_search:
# self._parallel_tuner = ParallelTuner(
# self._dist_context, mode=self._mode)
@property @property
def completer(self): def completer(self):
return self._completer return self._completer
def plan(self): def plan(self):
self._completer.complete_forward_annotation() self._completer.complete_forward_annotation()
# if self._strategy.auto_search:
# self._parallel_tuner.tune()
# else:
# self._completer.complete_forward_annotation()
# parse forward sub block # parse forward sub block
self._dist_context.block_state.parse_forward_blocks( self._dist_context.block_state.parse_forward_blocks(
self._dist_context.serial_main_program) self._dist_context.serial_main_program)
# TODO: add the auto searcher
...@@ -324,10 +324,13 @@ def _get_corresponding_rank(dist_context, target_mesh, rank): ...@@ -324,10 +324,13 @@ def _get_corresponding_rank(dist_context, target_mesh, rank):
mesh.processes.index(rank)) mesh.processes.index(rank))
break break
assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format( # assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format(
rank) # rank)
return target_mesh.processes[_coordinate2linear_idx(mesh.topology, if coordinate is not None:
coordinate)] return target_mesh.processes[_coordinate2linear_idx(mesh.topology,
coordinate)]
else:
return target_mesh.processes[0]
def _get_unshard_dist_shape(var, dist_attr): def _get_unshard_dist_shape(var, dist_attr):
......
...@@ -31,4 +31,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -31,4 +31,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_cluster MODULES test_cluster ENVS ${dist_ENVS}) py_test_modules(test_cluster MODULES test_cluster ENVS ${dist_ENVS})
py_test_modules(test_comm_cost MODULES test_comm_cost ENVS ${dist_ENVS}) py_test_modules(test_comm_cost MODULES test_comm_cost ENVS ${dist_ENVS})
py_test_modules(test_comp_cost MODULES test_comp_cost ENVS ${dist_ENVS}) py_test_modules(test_comp_cost MODULES test_comp_cost ENVS ${dist_ENVS})
py_test_modules(test_dist_context MODULES test_dist_context ENVS ${dist_ENVS})
endif() endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import json
import paddle
import numpy as np
import paddle.nn as nn
import paddle.utils as utils
import paddle.static as static
import paddle.nn.functional as F
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
batch_size = 4
hidden_size = 1024
sequence_len = 512
_g_process_mesh = [[0, 1], [2, 3]]
def get_random_inputs_and_labels(input_shape, label_shape):
input = np.random.random(size=input_shape).astype('float32')
label = np.random.random(size=label_shape).astype('float32')
return input, label
def batch_generator_creator():
def __reader__():
for _ in range(batch_size):
batch_input, batch_label = get_random_inputs_and_labels(
[batch_size, sequence_len, hidden_size],
[batch_size, sequence_len, 1])
yield batch_input, batch_label
return __reader__
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
param_initializer = nn.initializer.Normal(
mean=0.0, std=initializer_range)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.linear0 = nn.Linear(
d_model,
dim_feedforward,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None)
self.linear1 = nn.Linear(
dim_feedforward,
d_model,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None)
def forward(self, input):
out = self.norm(input)
auto.shard_tensor(
self.linear0.weight,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [-1, 0]
})
out = self.linear0(out)
out = F.gelu(out, approximate=True)
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": _g_process_mesh[1],
"dims_mapping": [0, -1]
})
out = self.linear1(out)
return out
def get_program():
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
# fleet.init(is_collective=True, strategy=dist_strategy)
train_program = static.Program()
start_program = static.Program()
with static.program_guard(train_program, start_program):
# input
input = static.data(
name="input",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
label = static.data(
name="label", shape=[batch_size, sequence_len, 1], dtype='float32')
data_holder = [input, label]
# dataloader
dataloader = paddle.io.DataLoader.from_generator(
feed_list=data_holder, capacity=4 * batch_size, iterable=False)
dataloader.set_batch_generator(
batch_generator_creator(), places=paddle.static.cuda_places())
# data dist_attr
auto.shard_tensor(
input,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [0, -1, -1]
})
auto.shard_tensor(
label,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [0, -1, -1]
})
mlp_start = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
pred = mlp_start(input)
mlp_mid = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
pred = mlp_mid(pred)
mlp_end = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
pred = mlp_end(pred)
error_cost = paddle.nn.functional.square_error_cost(pred, label)
loss = paddle.mean(error_cost)
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
feed_vars = {"inputs": [input], "labels": [label]}
fetch_vars = {"loss": [loss]}
return train_program, start_program, dataloader, loss, optimizer, feed_vars, fetch_vars
class TestDistributedContext(unittest.TestCase):
def test_backup_restore(self):
train_program, start_program, dataloader, loss, optimizer, feed_vars, fetch_vars = get_program(
)
dist_context = DistributedContext(train_program, start_program,
optimizer, loss, feed_vars,
fetch_vars)
dist_context.initialize()
dist_context._backup(serial=True, dist=True)
dist_context._restore(
serial=True,
serial_mode="to_backup",
dist=True,
dist_mode="to_backup")
dist_context._backup(serial=True, dist=True)
dist_context._restore(
serial=True,
serial_mode="to_original",
dist=True,
dist_mode="to_original")
dist_context._backup(serial=True, dist=True)
dist_context._restore(serial=True, dist=True, dist_mode="to_default")
dist_context._backup(serial=True, dist=True)
dist_context._restore(serial=True, dist=True, dist_mode="to_nothing")
if __name__ == "__main__":
unittest.main()
...@@ -94,7 +94,8 @@ class TestDistSlice(unittest.TestCase): ...@@ -94,7 +94,8 @@ class TestDistSlice(unittest.TestCase):
ops = dist_main_prog.global_block().ops ops = dist_main_prog.global_block().ops
for op in ops: for op in ops:
op_dist_attr = dist_context.get_op_dist_attr_for_program(op) op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_type == "slice" # We amend this impl_type after completion
assert op_dist_attr.impl_type == "default"
for out in op.output_arg_names: for out in op.output_arg_names:
var_dims_mapping = op_dist_attr.get_output_dims_mapping(out) var_dims_mapping = op_dist_attr.get_output_dims_mapping(out)
ref_dims_mapping = [-1 for i in range(len(var_dims_mapping))] ref_dims_mapping = [-1 for i in range(len(var_dims_mapping))]
......
...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.partitioner import Partitioner ...@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import make_data_unshard from paddle.distributed.auto_parallel.utils import make_data_unshard
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context
from paddle.distributed.auto_parallel.operators import find_best_compatible_distributed_operator_impl from paddle.distributed.auto_parallel.operators import find_compatible_distributed_operator_impls
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
......
...@@ -28,7 +28,7 @@ from paddle.distributed.auto_parallel.partitioner import Partitioner ...@@ -28,7 +28,7 @@ from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import make_data_unshard from paddle.distributed.auto_parallel.utils import make_data_unshard
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context
from paddle.distributed.auto_parallel.operators import find_best_compatible_distributed_operator_impl from paddle.distributed.auto_parallel.operators import find_compatible_distributed_operator_impls
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册