From c4fdb057dc8267cf30798b2fd73b10360c1ed93f Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 24 Dec 2021 19:10:57 +0800 Subject: [PATCH] [Auto Paralle] partitioner refactor (#37853) --- .../distributed/auto_parallel/dist_context.py | 26 +- .../auto_parallel/operators/common.py | 45 + .../auto_parallel/operators/dist_default.py | 26 +- .../auto_parallel/operators/dist_embedding.py | 85 +- .../auto_parallel/operators/dist_matmul.py | 152 +++- .../auto_parallel/operators/dist_reshape.py | 4 +- .../auto_parallel/operators/dist_softmax.py | 2 +- .../auto_parallel/operators/dist_transpose.py | 2 +- .../distributed/auto_parallel/parallelizer.py | 130 ++- .../distributed/auto_parallel/partitioner.py | 781 ++++-------------- .../distributed/auto_parallel/planner.py | 21 +- .../paddle/distributed/auto_parallel/utils.py | 82 +- .../test_auto_parallel_cost_model.py | 37 +- .../unittests/test_auto_parallel_mapper.py | 33 +- .../test_auto_parallel_partitioner.py | 6 +- .../test_auto_parallel_partitioner_gpt.py | 68 +- .../unittests/test_auto_parallel_reshard.py | 37 +- .../test_auto_parallel_reshard_dpmppp.py | 37 +- .../test_auto_parallel_reshard_mppp.py | 46 +- 19 files changed, 805 insertions(+), 815 deletions(-) mode change 100755 => 100644 python/paddle/distributed/auto_parallel/dist_context.py mode change 100755 => 100644 python/paddle/distributed/auto_parallel/operators/dist_default.py mode change 100755 => 100644 python/paddle/distributed/auto_parallel/operators/dist_embedding.py mode change 100755 => 100644 python/paddle/distributed/auto_parallel/partitioner.py mode change 100644 => 100755 python/paddle/distributed/auto_parallel/planner.py mode change 100755 => 100644 python/paddle/distributed/auto_parallel/utils.py mode change 100755 => 100644 python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py mode change 100755 => 100644 python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py old mode 100755 new mode 100644 index 347d02dacf4..3ec63fa116c --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -404,7 +404,7 @@ class DistributedOperatorContext: def get_cur_src_op(self): return self._cur_src_op - def prepare_forward_context(self, src_op): + def prepare_context(self, src_op): self.set_cur_src_op(src_op) @@ -413,6 +413,7 @@ class DistributedOperatorContext: for input_name in src_op.desc.input_names(): varnames = [] for varname in src_op.desc.input(input_name): + assert varname in self._varname_mapping varnames.append(self._varname_mapping[varname]) kinputs[input_name] = varnames @@ -421,29 +422,8 @@ class DistributedOperatorContext: for output_name in src_op.desc.output_names(): varnames = [] for varname in src_op.desc.output(output_name): + assert varname in self._varname_mapping varnames.append(self._varname_mapping[varname]) koutputs[output_name] = varnames return kinputs, koutputs - - def prepare_backward_context(self, backward_op): - - self.set_cur_src_op(backward_op) - - # build input varname mapping - kinputs = {} - for input_name in backward_op.desc.input_names(): - varnames = [] - for varname in backward_op.desc.input(input_name): - varnames.append(varname) - kinputs[input_name] = varnames - - # build output varname mapping - koutputs = {} - for output_name in backward_op.desc.output_names(): - varnames = [] - for varname in backward_op.desc.output(output_name): - varnames.append(varname) - koutputs[output_name] = varnames - - return kinputs, koutputs diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index 3ebda4694c6..0e0b2eae9c7 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License +from ..dist_attribute import OperatorDistributedAttribute + _g_distributed_operator_impl_registries = {} @@ -138,3 +140,46 @@ def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr): exact_shape.append(new_shape) return exact_shape + + +def set_comm_op_dist_attr_for_program(new_op, process_mesh, tensor_dist_attr, + ctx): + assert process_mesh is not None + assert tensor_dist_attr is not None + + new_op_dist_attr = OperatorDistributedAttribute() + new_op_dist_attr.process_mesh = process_mesh + for input_varname in new_op.desc.input_arg_names(): + new_op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr) + for output_varname in new_op.desc.output_arg_names(): + new_op_dist_attr.set_output_dist_attr(output_varname, tensor_dist_attr) + ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr) + + +def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx): + + ref_dist_attr = ctx.get_op_dist_attr_for_program(ref_op) + new_op_dist_attr = OperatorDistributedAttribute() + new_op_dist_attr.process_mesh = ref_dist_attr.process_mesh + + for input_name in ref_op.input_names: + assert input_name in new_op.input_names + assert len(ref_op.input(input_name)) == 1 + assert len(new_op.input(input_name)) == 1 + + ref_tensor_dist_attr = ref_dist_attr.get_input_dist_attr( + ref_op.input(input_name)[0]) + new_op_dist_attr.set_input_dist_attr( + new_op.input(input_name)[0], ref_tensor_dist_attr) + + for output_name in ref_op.output_names: + assert output_name in new_op.output_names + assert len(ref_op.output(output_name)) == 1 + assert len(new_op.output(output_name)) == 1 + + ref_tensor_dist_attr = ref_dist_attr.get_output_dist_attr( + ref_op.output(output_name)[0]) + new_op_dist_attr.set_output_dist_attr( + new_op.output(output_name)[0], ref_tensor_dist_attr) + + ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py old mode 100755 new mode 100644 index 05af1b402b4..72e750e5a43 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -66,7 +66,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): main_block = dist_op_context.get_dst_main_program().global_block() startup_block = dist_op_context.get_dst_startup_program().global_block() src_op = dist_op_context.get_cur_src_op() - varname_mapping = dist_op_context.get_varname_mapping() rank_id = dist_op_context.get_rank_id() # check validation of inputs / outputs @@ -153,6 +152,31 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): str(backward_op)) rank_id = dist_op_context.get_rank_id() + # check validation of inputs / outputs + for input_name in backward_op.desc.input_names(): + assert input_name in kwargs, "input [{}] is not given".format( + input_name) + assert len(kwargs[input_name]) == len( + backward_op.desc.input(input_name) + ), "number of tensor for input [{}] is not match".format(input_name) + for output_name in backward_op.desc.output_names(): + assert output_name in kwargs, "input [{}] is not given".format( + output_name) + assert len(kwargs[output_name]) == len( + backward_op.desc.output(output_name) + ), "number of tensor for input [{}] is not match".format( + output_name) + + # replicate op in dist program + dist_op_desc = main_block.desc.append_op() + dist_op_desc.copy_from(backward_op.desc) + for input_name in backward_op.desc.input_names(): + dist_op_desc.set_input(input_name, kwargs[input_name]) + for output_name in backward_op.desc.output_names(): + dist_op_desc.set_output(output_name, kwargs[output_name]) + + main_block._sync_with_cpp() + # check if need gradient allreduce # if there is a non-gradient & non-parameter input and its batch dimension is splited, # we need insert gradient allreduce for the gradient of parameter in its output diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py old mode 100755 new mode 100644 index 20722cdf605..18d976e965a --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -16,14 +16,14 @@ from .common import infer_shape from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl from .common import register_distributed_operator_impl_container -from .common import register_distributed_operator_impl +from .common import register_distributed_operator_impl, set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program from ..utils import is_dim_shard from ..utils import is_dim_replicate from ..utils import is_valid_list_index from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping -from ..dist_attribute import OperatorDistributedAttribute +from ..dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute from paddle.fluid import core, unique_name from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import Program, Parameter, Variable, program_guard @@ -329,9 +329,6 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id) - # check if need gradient allreduce - need_gradient_allreduce = False - assert 'Ids' in kwargs, "input [{}] is not given".format('Ids') assert 'W' in kwargs, "input [{}] is not given".format('W') assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out') @@ -355,6 +352,84 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): kwargs['W@GRAD']) Ids_var = main_block.var(kwargs['Ids'][0]) + Weight_var = main_block.var(kwargs['W'][0]) + Out_grad = main_block.var(kwargs['Out@GRAD'][0]) + Weight_grad = main_block.var(kwargs['W@GRAD'][0]) + + embedding_row_dim_mapping = dist_attr.get_input_dims_mapping( + Weight_var.name)[0] + assert embedding_row_dim_mapping >= 0, "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format( + embedding_row_dim_mapping) + process_mesh_shape = dist_attr.process_mesh.topology + process_mesh_group = dist_attr.process_mesh.processes + + # A generalized method to caculate embedding offset using cartisian product + relative_idx = _get_idx_in_axis(process_mesh_group, process_mesh_shape, + embedding_row_dim_mapping, rank_id) + per_part_size = Weight_var.shape[0] + relative_idx = relative_idx * per_part_size + + check_variable_and_dtype( + Out_grad, 'tensor', + ['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity') + + intermediate_var_0 = main_block.create_var( + name=unique_name.generate_with_ignorable_key(".".join( + ["c_embedding", '@tmp_0@GRAD'])), + dtype=Out_grad.dtype, + shape=Out_grad.shape, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=Out_grad.stop_gradient) + + # copy X_var's dist_attr to intermediate_var_0's dist_attr + out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name) + assert out_grad_dist_attr is not None + ctx.set_tensor_dist_attr_for_program(intermediate_var_0, + out_grad_dist_attr) + + group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, + embedding_row_dim_mapping, rank_id) + group = new_process_group(group_ranks) + + c_identity_op = main_block.append_op( + type='c_identity', + inputs={'X': [Out_grad]}, + outputs={'Out': intermediate_var_0}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'use_model_parallel': True, + OP_ROLE_KEY: OpRole.Backward, + }) + check_variable_and_dtype(intermediate_var_0, 'x', + ['float16', 'float32', 'float64'], 'linear') + check_dtype(intermediate_var_0.dtype, 'dtype', + ['float16', 'float32', 'float64'], 'linear') + + set_comm_op_dist_attr_for_program(c_identity_op, dist_attr.process_mesh, + out_grad_dist_attr, ctx) + + main_block._sync_with_cpp() + c_embedding_grad_op_desc = main_block.desc.append_op() + c_embedding_grad_op_desc.set_type("c_embedding_grad") + c_embedding_grad_op_desc.set_input('Ids', [Ids_var.name]) + c_embedding_grad_op_desc.set_input('W', [Weight_var.name]) + c_embedding_grad_op_desc.set_input('Out@GRAD', + [intermediate_var_0.name]) + c_embedding_grad_op_desc.set_output('W@GRAD', [Weight_grad.name]) + c_embedding_grad_op_desc._set_attr('start_index', relative_idx) + c_embedding_grad_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward) + main_block._sync_with_cpp() + + c_embedding_grad_op = main_block.ops[-1] + assert c_embedding_grad_op.type == "c_embedding_grad" + naive_copy_op_dist_attr_for_program(c_embedding_grad_op, backward_op, + ctx) + + # check if need gradient allreduce + need_gradient_allreduce = False + process_mesh = dist_attr.process_mesh var_dim_mapping = dist_attr.get_input_dims_mapping(Ids_var.name) mesh_shape = process_mesh.topology diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index 3a4d8412bf8..aeaf9eb76b1 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License +import copy from .common import infer_shape from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl +from .common import set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program from ..utils import is_dim_shard from ..utils import is_dim_replicate from ..utils import is_valid_list_index @@ -33,6 +35,20 @@ from ..process_group import new_process_group from ..utils import _get_comm_group, _get_corresponding_rank +def copy_op_with_new_input_output(block, src_op, **kwargs): + dist_op_desc = block.desc.append_op() + dist_op_desc.copy_from(src_op.desc) + for input_name in src_op.desc.input_names(): + assert input_name in kwargs + dist_op_desc.set_input(input_name, kwargs[input_name]) + for output_name in src_op.desc.output_names(): + assert input_name in kwargs + dist_op_desc.set_output(output_name, kwargs[output_name]) + + block._sync_with_cpp() + return dist_op_desc + + def _update_dims_mapping_for_matmul(dist_op): changed = False op_desc = dist_op.serial_op.desc @@ -141,15 +157,11 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): if rank_id not in dist_attr.process_mesh.processes: rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id) - # check if need gradient allreduce - need_gradient_allreduce = False - assert 'Y' in kwargs, "input [{}] is not given".format('Y') assert 'X' in kwargs, "input [{}] is not given".format('X') assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out@GRAD') assert 'Y@GRAD' in kwargs, "output [{}] is not given".format('Y@GRAD') assert 'X@GRAD' in kwargs, "output [{}] is not given".format('X@GRAD') - assert len( kwargs['Y'] ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( @@ -166,15 +178,138 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): kwargs['Y@GRAD'] ) == 1, "row_parallel_embedding output Ids take 1 variable but got {}".format( kwargs['Y@GRAD']) - assert len( - kwargs['X@GRAD'] - ) == 1, "row_parallel_embedding output Ids take 1 variable but got {}".format( - kwargs['X@GRAD']) X_var = main_block.var(kwargs['X'][0]) + Y_var = main_block.var(kwargs['Y'][0]) + Out_grad = main_block.var(kwargs['Out@GRAD'][0]) + Y_grad = main_block.var(kwargs['Y@GRAD'][0]) + assert not X_var.is_parameter, "left operand(X) [{}] of dist matmul should not be parameter".format( X_var.name) + Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name) + process_mesh_shape = dist_attr.process_mesh.topology + process_mesh_group = dist_attr.process_mesh.processes + assert len( + Y_var_dim_mapping + ) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format( + Y_var.name, Y_var_dim_mapping) + Y_var_partitioned = False + for dim in Y_var_dim_mapping: + if dim >= 0 and process_mesh_shape[dim] > 0: + Y_var_partitioned = True + break + + if Y_var.is_parameter and Y_var_partitioned: + + if Y_var_dim_mapping[0] >= 0: + # row parallel: c_identity + matmul + assert Y_var_dim_mapping[1] < 0 + parallel_axis = Y_var_dim_mapping[0] + + check_variable_and_dtype( + Out_grad, 'tensor', + ['float16', 'float32', 'float64', 'int32', 'int64'], + '_c_identity') + + intermediate_var_0 = main_block.create_var( + name=unique_name.generate_with_ignorable_key(".".join( + ["c_identity", 'tmp'])) + "@GRAD", + dtype=Out_grad.dtype, + shape=Out_grad.shape, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=Out_grad.stop_gradient) + + # copy X_var's dist_attr to intermediate_var_0's dist_attr + out_grad_dist_attr = dist_attr.get_input_dist_attr(Out_grad.name) + assert out_grad_dist_attr is not None + ctx.set_tensor_dist_attr_for_program(intermediate_var_0, + out_grad_dist_attr) + + group_ranks = _get_comm_group( + process_mesh_group, process_mesh_shape, parallel_axis, rank_id) + group = new_process_group(group_ranks) + c_identity_op = main_block.append_op( + type='c_identity', + inputs={'X': [Out_grad]}, + outputs={'Out': intermediate_var_0}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'use_model_parallel': True, + OP_ROLE_KEY: OpRole.Backward, + }) + check_variable_and_dtype(intermediate_var_0, 'x', + ['float16', 'float32', 'float64'], + 'linear') + check_dtype(intermediate_var_0.dtype, 'dtype', + ['float16', 'float32', 'float64'], 'linear') + set_comm_op_dist_attr_for_program( + c_identity_op, dist_attr.process_mesh, out_grad_dist_attr, ctx) + + new_kwargs = copy.deepcopy(kwargs) + new_kwargs['Out@GRAD'] = [intermediate_var_0.name] + matmul_op_desc = copy_op_with_new_input_output( + main_block, backward_op, **new_kwargs) + else: + # col parallel: matmul + allreduce + assert Y_var_dim_mapping[0] < 0 + parallel_axis = Y_var_dim_mapping[1] + new_kwargs = copy.deepcopy(kwargs) + + # NOTE (JZ-LIANG) should allow left operand be empty for matmul grad + has_x_grad = len(kwargs['X@GRAD']) > 0 + if has_x_grad: + assert len(kwargs['X@GRAD']) == 1 + X_grad = main_block.var(kwargs['X@GRAD'][0]) + intermediate_var_0 = main_block.create_var( + name=unique_name.generate_with_ignorable_key(".".join( + ["c_identity", 'tmp'])) + "@GRAD", + dtype=X_grad.dtype, + shape=X_grad.shape, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=X_grad.stop_gradient) + + X_grad_dist_attr = dist_attr.get_output_dist_attr(X_grad.name) + assert X_grad_dist_attr is not None + ctx.set_tensor_dist_attr_for_program(intermediate_var_0, + X_grad_dist_attr) + new_kwargs['X@GRAD'] = [intermediate_var_0.name] + + matmul_op_desc = copy_op_with_new_input_output( + main_block, backward_op, **new_kwargs) + + # NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad + if has_x_grad: + group_ranks = _get_comm_group(process_mesh_group, + process_mesh_shape, parallel_axis, + rank_id) + group = new_process_group(group_ranks) + c_allreduce_sum_op = main_block.append_op( + type='c_allreduce_sum', + inputs={'X': [intermediate_var_0.name]}, + outputs={'Out': kwargs['X@GRAD']}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'use_model_parallel': True, + OP_ROLE_KEY: OpRole.Backward + }) + set_comm_op_dist_attr_for_program(c_allreduce_sum_op, + dist_attr.process_mesh, + X_grad_dist_attr, ctx) + else: + # replicate + matmul_op_desc = copy_op_with_new_input_output(main_block, backward_op, + **kwargs) + + main_block._sync_with_cpp() + + # check if need gradient allreduce + need_gradient_allreduce = False + process_mesh = dist_attr.process_mesh var_dim_mapping = dist_attr.get_input_dims_mapping(X_var.name) mesh_shape = process_mesh.topology @@ -187,7 +322,6 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): dp_degree = len(group_ranks) dp_group = new_process_group(group_ranks) - Y_var = main_block.var(kwargs['Y'][0]) if need_gradient_allreduce and Y_var.is_parameter: Y_Grad_var = main_block.var(kwargs['Y@GRAD'][0]) allreduce_op = main_block.append_op( diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py index d72d13803ff..aba9704ad54 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py @@ -43,7 +43,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): super(DistributedReshapeImpl0, self).__init__() self._name = name self._forward_implemented = True - self._backward_implemented = True + self._backward_implemented = False def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc @@ -200,7 +200,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): super(DistributedReshapeImpl1, self).__init__() self._name = name self._forward_implemented = True - self._backward_implemented = True + self._backward_implemented = False def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc diff --git a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py index de2d0ba62e6..e4624b51222 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py @@ -39,7 +39,7 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): super(DistributedSoftmaxImpl, self).__init__() self._name = name self._forward_implemented = False - self._backward_implemented = True + self._backward_implemented = False def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc diff --git a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py index 98c46810518..8b40524e473 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py @@ -39,7 +39,7 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): super(DistributedTranspose2Impl, self).__init__() self._name = name self._forward_implemented = False - self._backward_implemented = True + self._backward_implemented = False def is_input_compatible(self, dist_op): return True diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index f6ddf2b9b73..1b4fcd69830 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -22,15 +22,15 @@ import subprocess import logging import pickle import time - import paddle from paddle.distributed.utils import get_logger from paddle.distributed.fleet import cloud_utils import paddle.fluid.core as core +from paddle.fluid import program_guard from .dist_context import DistributedContext from .dist_context import get_default_distributed_context from .dist_context import set_default_distributed_context -from .completion import complete_annotation, complete_backward_annotation +from .completion import complete_annotation, complete_backward_annotation, complete_update_annotation from .partitioner import Partitioner from .process_group import get_all_process_groups from .process_group import get_process_group @@ -79,6 +79,7 @@ class AutoParallelizer: self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING") self._need_rank_mapping = True if self._need_rank_mapping and \ self._need_rank_mapping.lower() == 'true' else False + self._pass_context = None def _remove_distributed_attrs(self, main_program): suffix = core.kAutoParallelSuffix() @@ -90,28 +91,112 @@ class AutoParallelizer: if suffix in attr_name: op._remove_attr(attr_name) + def _apply_serial_forward_pass(self, main_program, startup_program): + + # apply amp forward pass + if self._dist_strategy.amp: + auto_parallel_amp_pass = new_pass("auto_parallel_amp_pass", + self._dist_strategy.amp_configs) + auto_parallel_amp_pass.apply_forward(main_program, startup_program, + self._pass_context) + + # apply recompute forward pass + if self._dist_strategy.recompute: + auto_parallel_recompute_pass = new_pass( + "auto_parallel_recompute_pass", + self._dist_strategy.recompute_configs) + auto_parallel_recompute_pass.apply_forward( + main_program, startup_program, self._pass_context) + + def _generate_backward(self, main_program, startup_program, loss, + parameter_list, no_grad_set, callbacks): + + # apply recompute backward pass + if self._dist_strategy.recompute: + assert auto_parallel_recompute_pass + auto_parallel_recompute_pass.apply_forward( + main_program, startup_program, parameter_list, no_grad_set, + self._pass_context) + else: + from paddle.fluid.backward import append_backward + with program_guard(main_program, startup_program): + params_grads = append_backward( + loss, + parameter_list, + no_grad_set, + callbacks, + distop_context=self._dist_context.dist_op_context) + complete_backward_annotation( + main_program, dist_context=self._dist_context) + + # apply amp forward pass + if self._dist_strategy.amp: + assert auto_parallel_amp_pass + auto_parallel_amp_pass.apply_backward(main_program, startup_program, + self._pass_context) + + return params_grads + + def _apply_optimize(self, main_program, startup_program, params_grads): + + if self._dist_strategy.sharding: + auto_parallel_sharding_pass = new_pass( + "auto_parallel_sharding_pass", self._dist_strategy) + params_grads = auto_parallel_sharding_pass.apply( + main_program, startup_program, params_grads, self._pass_context) + + if self._dist_strategy.gradient_merge: + auto_parallel_gradient_merge_pass = new_pass( + "auto_parallel_gradient_merge_pass", + self._dist_strategy.gradient_merge_configs) + auto_parallel_gradient_merge_pass.apply( + main_program, startup_program, params_grads, self._pass_context) + + else: + with program_guard(main_program, startup_program): + optimizer = copy.deepcopy(self._optimizer) + optimize_ops = optimizer.apply_gradients(params_grads) + + # update completion + complete_update_annotation( + main_program, dist_context=self._dist_context) + + return optimize_ops + def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False): completed_main_program = None + serial_main_program = self._main_program.clone() + serial_startup_program = self._startup_program.clone() + serial_loss = serial_main_program.global_block().var(self._loss.name) + # generating serial if dist_context is None: # Annotation completion self._dist_context = DistributedContext() _logger.info("Start annotation dist attr.") - completed_main_program = complete_annotation(self._main_program, + completed_main_program = complete_annotation(serial_main_program, self._dist_context) else: - completed_main_program = self._main_program + completed_main_program = serial_main_program self._dist_context = copy.deepcopy(dist_context) - # Logical partition - partitioner = Partitioner(self._dist_strategy, self._dist_context, rank) - dist_main_prog, dist_startup_prog = partitioner.transpile_forward( - completed_main_program, self._startup_program) - dist_params_grads = partitioner.apply_backward( - self._loss, completed_main_program, self._startup_program, - dist_main_prog, dist_startup_prog) - dist_optimize_ops = partitioner.apply_optimize( - copy.deepcopy(self._optimizer), dist_params_grads, dist_main_prog, - dist_startup_prog) + # serial forward pass + self._apply_serial_forward_pass(completed_main_program, + serial_startup_program) + # serial backward pass + params_grads = self._generate_backward( + completed_main_program, serial_startup_program, serial_loss, + self._parameter_list, self._no_grad_set, self._callbacks) + + # Logical partition + rank = paddle.distributed.get_rank() + partitioner = Partitioner(self._dist_context, rank) + dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( + completed_main_program, serial_startup_program, params_grads) + + # TODO refactor the placement of optimizer + # generate optimize program + dist_optimize_ops = self._apply_optimize( + dist_main_prog, dist_startup_prog, dist_params_grads) set_grad_var_shape(dist_main_prog, self._dist_context) @@ -133,13 +218,15 @@ class AutoParallelizer: loss, startup_program, parameter_list=None, - no_grad_set=None): + no_grad_set=None, + callbacks=None): assert startup_program is not None self._loss = loss self._startup_program = startup_program self._main_program = loss.block.program self._parameter_list = parameter_list self._no_grad_set = no_grad_set + self._callbacks = callbacks if self._enable_auto_mapping and self._need_rank_mapping: # Do the mapping pass before parallelization @@ -156,6 +243,7 @@ class AutoParallelizer: self._optimizer, self._cluster) planner = Planner( serial_program_info, + self, algorithm_config={"name": "mcmc", "max_search_times": 5}) dist_context, _ = planner.search() @@ -262,6 +350,7 @@ class AutoParallelizer: cluster=self._cluster) planner = Planner( serial_program_info, + self, algorithm_config={ "name": "mcmc", "max_search_times": 5 @@ -303,3 +392,14 @@ class AutoParallelizer: self._remove_distributed_attrs(dist_main_prog) return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog + + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k == "_main_program" or k == "_startup_program" or k == "_dist_context" or k == "_fleet" or k == "_loss": + setattr(result, k, v) + else: + setattr(result, k, copy.deepcopy(v, memo)) + return result diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py old mode 100755 new mode 100644 index 9af194e810f..e4d913cb9cd --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -20,18 +20,11 @@ from paddle.fluid import core from paddle.fluid import framework as framework from paddle.fluid import core, unique_name from paddle.fluid.framework import Program, Parameter, Variable, program_guard -from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype -from paddle.fluid.backward import append_backward, _some_in_set_, _append_grad_suffix_ from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container -from paddle.fluid.clip import GradientClipBase, GradientClipByNorm, error_clip_callback, append_gradient_clip_ops, ClipGradByGlobalNorm -from paddle.distributed.fleet.base.distributed_strategy import DistributedStrategy from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext -from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op, is_optimizer_op -from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from .dist_attribute import OperatorDistributedAttribute from .process_group import new_process_group -from .utils import print_program_with_dist_attr -from paddle.distributed.auto_parallel.completion import complete_backward_annotation, complete_update_annotation +from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op __varname_not_in_block__ = ["lod_tensor_blocking_queue_0"] @@ -48,331 +41,152 @@ class Partitioner(object): 2. partition var: if a var is sharded, modify the shape of var according to its shard annotation Partitioner is supposed to be call by the auto parallel framework, and not supposed to be directly called by user. - - Example: - .... - import paddle.distributed.auto_parallel as auto - from paddle.fluid.distributed_attribute import get_default_distributed_context - from paddle.distributed import fleet - from paddle.distributed.auto_parallel.partitioner import Partitioner - - # create serial program with forward only - with static.program_guard(serial_main_program, serial_start_program): - model = create_model(config) - tokens = static.data(name="tokens", shape=[batch_size, sequence_len], dtype='int64') - labels = static.data(name="labels", shape=[batch_size, sequence_len], dtype='int64') - loss_mask = static.data(name="loss_mask", shape=[batch_size, sequence_len], dtype='int64') - preds = model(tokens) - loss = criterion(preds, labels, loss_mask) - - # auto completion - auto.ProcessMesh(shape=[2, 4], process_group=[0, 1, 2, 3, 4, 5, 6, 7]) - annotated_main_program = auto.complete_annotation(serial_main_program) - dist_context = get_default_distributed_context() - - # distributed strategy & rank info - rank_id = paddle.distributed.get_rank() - dist_strategy = fleet.DistributedStrategy() - - # create partitioner - Partitioner = Partitioner(dist_strategy, dist_context, rank_id) - - # create dist program with forward only - # for distributed inference, using partitioned_main_prog from here - partitioned_main_prog, partitioned_startup_prog = Partitioner.transpile_forward(complete_train_program, start_program) - - # create dist program with forward/backward/update - # for distributed training, using partitioned_main_prog from here - dist_params_grads = Partitioner.apply_backward(loss, complete_train_program, start_program, partitioned_main_prog, partitioned_startup_prog) - optimizer = paddle.fluid.optimizer.AdamOptimizer( - learning_rate=0.00001, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - grad_clip=None) - opt_ops = Partitioner.apply_optimize(optimizer, dist_params_grads, partitioned_main_prog, partitioned_startup_prog) """ - def __init__(self, dist_strategy, dist_context, rank_id=0): + def __init__(self, dist_context, rank_id=0): """ Args: - dist_strategy (paddle.fleet.distributed_strategy): used to determine the user defined distributed strategy. dist_context (paddle.fluid.DistributedContext): used to access the distributed_attr of var & op, every Partitioner object could maintain its own DistributedContext member, and partition program base on that shard scenario. rank_id (int): global rank id to which the partitioned distributed program belong. """ - - if not isinstance(dist_strategy, DistributedStrategy): - raise TypeError( - "dist_strategy be paddle.fleet.base.DistributedStrategy, got %s here" - % type(dist_strategy)) - if not isinstance(dist_context, DistributedContext): raise TypeError( "dist_context be paddle.fluid.DistributedContext, got %s here" % type(dist_context)) - self._dist_strategy = dist_strategy self._dist_context = dist_context self._rank_id = rank_id self._serial2dist_varname_mapping = {} self._dist_varname_suffix = "" - # TODO if there is some dist op that is not compatible - # with auto_backward in forward, the following flag - # should be set to False - self._compatible_with_auto_backward = True - - def transpile_forward(self, serial_main_program, serial_startup_program): - """ - take serial forward programs with shard annotation, create a new distributed forward programs based on the serial ones. - instead of modify the input programs inplace, this function will preserve the inputs and create new program for output. - - beside replace the serial op with its dist op, if user has defined other strategy in fleet.distributed_strategy, and if - those strategy need to transpile (modify) the forward network program, those forward program modification should also be done within this - function in auto parallel scenario, in order to facilitate distributed inference/evaluation which need to DECOUPLE strategy specific forward transpilation with fleet.distributed_optimizer.minimize(). - - by now the fleet.distributed_strategy that need transpile forward program are following: - 1. (optimizer) sharding - - Args: - main_program (paddle.fluid.framework.program): serial main program with forward network only - startup_program (paddle.fluid.framework.program): serial startup program with forward network only - - return: - main_program (paddle.fluid.framework.program): distributed main program with forward network only - startup_program (paddle.fluid.framework.program): distributed startup program with forward network only - """ - - dist_main_program, dist_startup_program = self.transpile_forward_impl( - serial_main_program, serial_startup_program) - return dist_main_program, dist_startup_program - - def apply_backward(self, - serial_loss, - serial_main_program, - serial_startup_program, - dist_main_program, - dist_startup_program, - parameter_list=None, - no_grad_set=None, - callbacks=None): - """ - A complete training neural network is made up of forward and backward propagation. - This function is to generate the dist backward program for the distributed forward program. - - By now, the current automatical backward mechanism in paddle framework might NOT handle the backward generation for - some dist ops correctly, some so we now have two ways to genenate the backward program: - 1. dist_forward_program --> auto_backward --> dist_backward_program (if auto_backward could handle all dist op) - 2. serial_forward_program --> auto_backward --> serial_backward_program --> dist_op_backward_transpile --> dist_backward_program (if auto_backward could not handle all dist op) - - the backprogram is append the input dist program inplaced. - - Args: - serial_loss (Variable) the loss in serial program that to be minimized - serial_main_program (paddle.fluid.framework.program): serial main program with forward network only - serial_startup_program (paddle.fluid.framework.program): serial startup program with forward network only - dist_main_program (paddle.fluid.framework.program): dist main program with forward network only - dist_startup_program (paddle.fluid.framework.program): dist startup program with forward network only - parameter_list (Iterable, optional): Iterable of ``Variable`` or ``Variable.name`` to update - to minimize ``loss``. The default value is None, at this time all parameters - will be updated. - no_grad_set (set, optional): Set of ``Variable`` or ``Variable.name`` that don't need - to be updated. The default value is None. - callbacks (list, optional): list of callable objects to run when appending backward - operator for one parameter. The default value is None. - - return: - params_grads (list) list of tuple that contain param and its grad variable - """ - params_grads = self.apply_backward_impl( - serial_loss, serial_main_program, serial_startup_program, - dist_main_program, dist_startup_program) - return params_grads - - def apply_optimize(self, user_define_optimizer, params_grads, - dist_main_program, dist_startup_program): - """ - append update related ops to the program: clip, weight decay, ops - filter optimize op if sharding is enable - naive gradient synchronization before update - - Args: - user_define_optimizer (paddle.fluid.optimizer): - params_grads (list) list of tuple that contain param and its grad variable - dist_main_program (paddle.fluid.framework.program): dist main program with forward & backward network - dist_startup_program (paddle.fluid.framework.program): dist startup program with forward & backward network - """ - - optimize_ops = self.apply_optimize_impl(user_define_optimizer, - params_grads, dist_main_program, - dist_startup_program) + def partition(self, serial_main_program, serial_startup_program, + params_grads): - return optimize_ops - - def transpile_forward_impl(self, main_program, startup_program): - - if not isinstance(main_program, (Program)): + if not isinstance(serial_main_program, (Program)): raise TypeError( - "dist_strategy be paddle.fluid.framework.program, got %s here" % - type(main_program)) - - if not isinstance(startup_program, (Program)): - raise TypeError( - "dist_context be paddle.fluid.framework.program, got %s here" % - type(startup_program)) + "main_program be paddle.fluid.framework.program, got %s here" % + type(serial_main_program)) # check if shard annotated serial program valid - if not self._is_valid_annotated_program(main_program): + if not self._is_valid_annotated_program(serial_main_program): raise RuntimeError( "Not all vars or ops are annotated in main program !") - # dist op & partition vars - new_main_prog, new_startup_program = self._dist_var_op_forward_transpile( - main_program, startup_program) - - # Sharding - if self._dist_strategy.sharding: - new_main_prog, new_startup_program = self._sharding_forward_transpile( - new_main_prog, new_startup_program) - - return new_main_prog, new_startup_program - - def apply_backward_impl(self, - serial_loss, - serial_main_program, - serial_startup_program, - dist_main_program, - dist_startup_program, - parameter_list=None, - no_grad_set=None, - callbacks=None): - """ - """ + # init distop helper + dist_op_context = self._dist_context.dist_op_context + dist_op_context.set_varname_mapping(self._serial2dist_varname_mapping) + dist_op_context.set_rank_id(self._rank_id) - params_grads = self._dist_var_op_backward_transpile( - serial_loss, serial_main_program, serial_startup_program, - dist_main_program, dist_startup_program) - # Sharding - if self._dist_strategy.sharding: - self._sharding_backward_transpile(new_main_prog, - new_startup_program) + # partition startup program + if serial_startup_program == None: + partitioned_startup_prog = None + else: + partitioned_startup_prog = self.partition_startup_program( + serial_main_program, serial_startup_program) + dist_op_context.set_dst_startup_program(partitioned_startup_prog) - return params_grads + # partition main program + partitioned_main_prog, partitioned_params_grads = self.partition_main_program( + serial_main_program, params_grads) - def apply_optimize_impl(self, user_define_optimizer, params_grads, - dist_main_program, dist_startup_program): - """ - append update related ops to the program: clip, weight decay, ops - filter optimize op if sharding is enable - naive gradient synchronization before update + return partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads - Args: - user_define_optimizer (paddle.fluid.optimizer): - params_grads (list) list of tuple that contain param and its grad variable - dist_main_program (paddle.fluid.framework.program): dist main program with forward & backward network - dist_startup_program (paddle.fluid.framework.program): dist startup program with forward & backward network - """ + def partition_startup_program(self, serial_main_program, + serial_startup_program): - if self._dist_strategy.sharding: - params_grads = sharding_optimize_transpile( - params_grads, dist_main_program, dist_startup_program) - - optimize_ops = self._optimize_transpile(user_define_optimizer, - params_grads, dist_main_program, - dist_startup_program) + if not isinstance(serial_startup_program, (Program)): + raise TypeError( + "dist_context be paddle.fluid.framework.program, got %s here" % + type(serial_startup_program)) - return optimize_ops + partitioned_startup_prog = fluid.Program() + ref_block = serial_main_program.global_block() + target_block = partitioned_startup_prog.global_block() + param2shape = {} + temp_varname_map = {} - def _dist_var_op_forward_transpile(self, - serial_main_program, - serial_startup_program=None): + # tensors + for var in serial_startup_program.list_vars(): + if isinstance(var, Parameter): + # TODO if var not belong to this rank, should be filtered + serial_main_var = ref_block.var(var.name) + dist_attr = self._dist_context.get_tensor_dist_attr_for_program( + serial_main_var) + target_shape = _get_dist_shape(serial_main_var, dist_attr) + new_name = var.name + self._dist_varname_suffix + temp_varname_map[var.name] = new_name + _partition_parameter(self._dist_context, serial_main_var, + target_block, new_name, target_shape) + param2shape[new_name] = target_shape + + # ops + for op in serial_startup_program.global_block().ops: + # TODO if var not belong to this rank, should be filtered + output_vars = op.desc.output_arg_names() + assert len( + output_vars + ) == 1, "initializer should output only ONE variable, but got [{}]".format( + str(op.desc)) + assert temp_varname_map[output_vars[ + 0]] in param2shape, "try to initialize [{}] which is not a Parameter".format( + output_vars[0]) + new_op_desc = target_block.desc.append_op() + new_op_desc.copy_from(op.desc) + new_op_desc._rename_output(output_vars[0], + temp_varname_map[output_vars[0]]) + new_op_desc._set_attr("shape", + param2shape[temp_varname_map[output_vars[0]]]) + target_block._sync_with_cpp() + + # set distribute atrribute + new_op = target_block.ops[-1] + assert new_op.type == new_op_desc.type() + assert new_op.desc == new_op_desc + output_var = target_block.var(output_vars[0]) + output_var_attr = self._dist_context.get_tensor_dist_attr_for_program( + output_var) + op_attr = OperatorDistributedAttribute() + op_attr.process_mesh = output_var_attr.process_mesh + op_attr.set_output_dims_mapping(output_var.name, + output_var_attr.dims_mapping) + op_attr.set_input_dims_mapping(output_var.name, + output_var_attr.dims_mapping) + self._dist_context.set_op_dist_attr_for_program(new_op, op_attr) + + return partitioned_startup_prog + + def partition_main_program(self, serial_main_program, params_and_grads): """ 1. partition variables 2. replace local op with corresponding dist op """ + dist_op_context = self._dist_context.dist_op_context partitioned_main_prog = fluid.Program() - partitioned_global_block = partitioned_main_prog.global_block() - serial_main_block = serial_main_program.global_block() + dist_op_context.set_dst_main_program(partitioned_main_prog) + target_block = partitioned_main_prog.global_block() + ref_block = serial_main_program.global_block() serial_ops = serial_main_program.global_block().ops - # transpile startup program - if serial_startup_program == None: - partitioned_startup_prog = None - else: - partitioned_startup_prog = fluid.Program() - # create parameter - partitioned_startup_global_block = partitioned_startup_prog.global_block( - ) - param2shape = {} - temp_varname_map = {} - for var in serial_startup_program.list_vars(): - if isinstance(var, Parameter): - # TODO if var not belong to this rank, should be filtered - serial_main_var = serial_main_block.var(var.name) - dist_attr = self._dist_context.get_tensor_dist_attr_for_program( - serial_main_var) - target_shape = _get_dist_shape(serial_main_var, dist_attr) - new_name = var.name + self._dist_varname_suffix - temp_varname_map[var.name] = new_name - _partition_parameter(self._dist_context, serial_main_var, - partitioned_startup_global_block, - new_name, target_shape) - param2shape[new_name] = target_shape - - # copy initializer - for op in serial_startup_program.global_block().ops: - # TODO if var not belong to this rank, should be filtered - output_vars = op.desc.output_arg_names() - assert len( - output_vars - ) == 1, "initializer should output only ONE variable, but got [{}]".format( - str(op.desc)) - assert temp_varname_map[output_vars[ - 0]] in param2shape, "try to initialize [{}] which is not a Parameter".format( - output_vars[0]) - new_op_desc = partitioned_startup_global_block.desc.append_op() - new_op_desc.copy_from(op.desc) - new_op_desc._rename_output(output_vars[0], - temp_varname_map[output_vars[0]]) - new_op_desc._set_attr( - "shape", param2shape[temp_varname_map[output_vars[0]]]) - partitioned_startup_global_block._sync_with_cpp() - - # set distribute atrribute - new_op = partitioned_startup_global_block.ops[-1] - assert new_op.type == new_op_desc.type() - assert new_op.desc == new_op_desc - output_var = partitioned_startup_global_block.var(output_vars[ - 0]) - output_var_attr = self._dist_context.get_tensor_dist_attr_for_program( - output_var) - op_attr = OperatorDistributedAttribute() - op_attr.process_mesh = output_var_attr.process_mesh - op_attr.set_output_dims_mapping(output_var.name, - output_var_attr.dims_mapping) - op_attr.set_input_dims_mapping(output_var.name, - output_var_attr.dims_mapping) - self._dist_context.set_op_dist_attr_for_program(new_op, op_attr) - - # TODO move helper init to a comm place - dist_op_context = self._dist_context.dist_op_context - dist_op_context.set_dst_main_program(partitioned_main_prog) - dist_op_context.set_dst_startup_program(partitioned_startup_prog) - dist_op_context.set_varname_mapping(self._serial2dist_varname_mapping) - dist_op_context.set_rank_id(self._rank_id) + # init mapping + first_backward_op_idx = -1 + forward_op_id2forward_op = {} + for idx in range(len(serial_ops)): + if is_forward_op(serial_ops[idx]): + forward_op_id2forward_op[serial_ops[idx].desc.id( + )] = serial_ops[idx] - # transpile main program + # partiiton for op in serial_ops: # partititon input variables for serial_input_varname in op.desc.input_arg_names(): if serial_input_varname not in self._serial2dist_varname_mapping: new_varname = serial_input_varname + self._dist_varname_suffix - if serial_main_block.has_var(serial_input_varname): - _partition_var(self._dist_context, serial_main_block, - partitioned_global_block, - serial_input_varname, new_varname) + if ref_block.has_var(serial_input_varname): + _partition_var(self._dist_context, ref_block, + target_block, serial_input_varname, + new_varname) else: assert serial_input_varname in __varname_not_in_block__ @@ -383,145 +197,47 @@ class Partitioner(object): for serial_output_varname in op.desc.output_arg_names(): if serial_output_varname not in self._serial2dist_varname_mapping: new_varname = serial_output_varname + self._dist_varname_suffix - _partition_var(self._dist_context, serial_main_block, - partitioned_global_block, + _partition_var(self._dist_context, ref_block, target_block, serial_output_varname, new_varname) self._serial2dist_varname_mapping[ serial_output_varname] = new_varname # partition op - kinputs, koutputs = dist_op_context.prepare_forward_context(op) - dist_attr = self._dist_context.get_op_dist_attr_for_program(op) - if _is_dist_op_forward_implement(self._dist_context, op): - dist_ops = get_distributed_operator_impl_container(op.type) - dist_op_impl = dist_ops.get_impl(dist_attr.impl_idx) - dist_op_impl.forward(self._dist_context, **kinputs, **koutputs) - + if is_forward_op(op): + kinputs, koutputs = dist_op_context.prepare_context(op) + dist_op_forward_impl = _get_dist_op_forward_implement( + op, self._dist_context) + dist_op_forward_impl.forward(self._dist_context, **kinputs, + **koutputs) + + elif is_backward_op(op): + print(str(op)) + kinputs, koutputs = dist_op_context.prepare_context(op) + dist_op_backward_impl = _get_dist_op_backward_implement( + op, self._dist_context, forward_op_id2forward_op) + dist_op_backward_impl.backward(self._dist_context, **kinputs, + **koutputs) else: - # replicate op - dist_ops = get_distributed_operator_impl_container("default") - dist_op_impl = dist_ops.get_impl(0) - dist_op_impl.forward(self._dist_context, **kinputs, **koutputs) - - return partitioned_main_prog, partitioned_startup_prog - - def _dist_var_op_backward_transpile(self, - serial_loss, - serial_main_program, - serial_startup_program, - dist_main_program, - dist_startup_program, - parameter_list=None, - no_grad_set=None, - callbacks=None): - """ - so far, the auto_backward case only guarantee the correcotness of backward ops for curtain Dist ops: - 1. NV-Megatron-like parallel embedding - 2. NV-Megatron-like row parallel linear - 3. NV-Megatron-like col parallel linear - """ - - if self._compatible_with_auto_backward: - assert isinstance( - serial_loss, Variable), "The target loss should be an Variable." - dist_loss = self._serial_varname2dist_var(serial_loss.name, - dist_main_program) - - assert len(dist_loss.shape) == 1 and dist_loss.shape[0] == 1, \ - "The dist loss.shape should be (1L,), but the current dist loss.shape is {}. " \ - "Maybe that you should call fluid.layers.mean to process the current loss.".format( - dist_loss.shape) - - # update parameter list - if parameter_list: - parameter_list = [ - self._serial_varname2dist_var(param.name, dist_main_program) - for param in parameter_list - ] - - # update parameter no_grad_set - if no_grad_set: - no_grad_set = [ - self._serial_varname2dist_var(param.name, dist_main_program) - for param in no_grad_set - ] - - dist_op_context = self._dist_context.dist_op_context - params_and_grads = _auto_backward( - dist_loss, - dist_startup_program, - parameter_list=parameter_list, - no_grad_set=no_grad_set, - callbacks=callbacks, - distop_context=dist_op_context) - - # backward completion - complete_backward_annotation( - dist_main_program, dist_context=self._dist_context) - - # transpiler backward for dist op - # get backward ops - ops = dist_main_program.global_block().ops - first_backward_op_idx = -1 - forward_op_id2forward_op = {} - for idx in range(len(ops)): - if is_forward_op(ops[idx]): - forward_op_id2forward_op[ops[idx].desc.id()] = ops[idx] - - if int(ops[idx].attr('op_role')) == int(OpRole.Backward): - first_backward_op_idx = idx - break - assert first_backward_op_idx >= 0, "not found backward ops in program" - assert len(forward_op_id2forward_op - ) > 0, "not found forward ops in program" - - backward_ops = ops[first_backward_op_idx:] - for backward_op in backward_ops: - # if the backward op has a corresponding forward op - if backward_op.desc.id() in dist_op_context.gradopidx2opidx: - forward_op_id = dist_op_context.gradopidx2opidx[ - backward_op.desc.id()] - forward_op = forward_op_id2forward_op[forward_op_id] - # TODO backward attr should has _impl_idx - forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( - forward_op) - # TODO use the backward op itself to find the dist op - dist_ops = get_distributed_operator_impl_container( - forward_op.type) - kinputs, koutputs = dist_op_context.prepare_backward_context( - backward_op) - - # TODO use backward op itself to determine impl idx - if _is_dist_op_backward_implement(self._dist_context, - forward_op): - dist_op_impl = dist_ops.get_impl( - forward_op_dist_attr.impl_idx) - dist_op_impl.backward(self._dist_context, **kinputs, - **koutputs) - else: - # replicate op - dist_ops = get_distributed_operator_impl_container( - "default") - dist_op_impl = dist_ops.get_impl(0) - dist_op_impl.backward(self._dist_context, **kinputs, - **koutputs) - - return params_and_grads - # replace dist grad ops - else: - raise RuntimeError("transpile NOT implemented !") - - def _optimize_transpile(self, user_define_optimizer, params_grads, - main_program, startup_program): - - with program_guard(main_program, startup_program): - optimize_ops = user_define_optimizer.apply_gradients(params_grads) - - # update completion - complete_update_annotation( - main_program, dist_context=self._dist_context) + raise NotImplementedError( + "partitioner only support forward op and backward op, but got {}". + format(str(op))) + + partitioned_params_and_grads = [] + for p, g in params_and_grads: + assert p.name in self._serial2dist_varname_mapping + dist_p_name = self._serial2dist_varname_mapping[p.name] + assert target_block.has_var(dist_p_name) + dist_p = target_block.var(dist_p_name) + if g is None: + dist_g = None + else: + assert g.name in self._serial2dist_varname_mapping + dist_g_name = self._serial2dist_varname_mapping[g.name] + assert target_block.has_var(dist_g_name) + dist_g = target_block.var(dist_g_name) + partitioned_params_and_grads.append((dist_p, dist_g)) - return optimize_ops + return partitioned_main_prog, partitioned_params_and_grads def _is_valid_annotated_program(self, program): @@ -543,154 +259,6 @@ class Partitioner(object): return all_ops_annotated and all_vars_annotated - def _serial_varname2dist_var(self, serial_varname, dist_program): - assert serial_varname in self._serial2dist_varname_mapping, "The serial var [{}] is not found in var name mapping".format( - serial_varname) - dist_varname = self._serial2dist_varname_mapping[serial_varname] - - assert dist_program.global_block().has_var( - dist_varname - ), "The dist var [{}] is not found in dist program".format(dist_varname) - dist_var = dist_program.global_block().var(dist_varname) - - return dist_var - - def _is_var_distributed(self, var): - - dist_attr = self._dist_context.get_tensor_dist_attr_for_program(var) - assert dist_attr is not None, "dist_attr of var [{}] is None".format( - var.name) - return _is_distributed(dist_attr) - - def _sharding_forward_transpile(self, main_prog, startup_program): - """ - this transpile conduct the modification in forward program need by sharding strategy - which majorly include: - 1. partition the parameter - 2. insert broadcast op - 3. insert sync op - - NOTE the transpile modification is inplace on the input program - """ - - raise NotImplementedError( - "Sharding is NOT support in AutoParallel yet!") - - def _sharding_backward_transpile(self, main_prog, startup_program): - """ - this transpile conduct the modification in backward program need by sharding strategy - which majorly include: - 1. partition the gradient - 2. insert broadcast op - 3. insert sync op - - NOTE the transpile modification is inplace on the input program - """ - - raise NotImplementedError( - "Sharding is NOT support in AutoParallel yet!") - - def _sharding_optimize_transpile(self, params_grads, dist_main_program, - dist_startup_program): - """ - shard params_grads - append the broadcast to sync parameters - """ - raise RuntimeError("sharding transpile is NOT implemented !") - - -def _get_no_grad_set_name(no_grad_set): - no_grad_set_name = set() - if no_grad_set is not None: - if isinstance(no_grad_set, (set, list, tuple)): - for i, no_grad_var in enumerate(no_grad_set): - if isinstance(no_grad_var, framework.Variable): - no_grad_set_name.add(no_grad_var.name) - elif isinstance(no_grad_var, six.string_types): - no_grad_set_name.add(no_grad_var) - else: - raise TypeError( - "The type of no_grad_set's member must be paddle.fluid.Variable or str, but received %s." - % (type(no_grad_var))) - else: - raise TypeError( - "The type of no_grad_set should be set or list or tuple, but received {}". - format(type(no_grad_set))) - return no_grad_set_name - - -def _get_no_grad_set(loss, no_grad_set=None): - no_grad_set = _get_no_grad_set_name(no_grad_set) - parameters = loss.block.program.global_block().all_parameters() - param_no_trainable = set( - [param.name for param in parameters if param.trainable is False]) - # If the parameter is no trainable, it should not have a gradient. - no_grad_set.update(param_no_trainable) - - return no_grad_set - - -def _is_dist_op_forward_implement(dist_context, op): - dist_attr = dist_context.get_op_dist_attr_for_program(op) - dist_ops = get_distributed_operator_impl_container(op.type) - - return dist_ops and dist_attr.impl_idx >= 0 and dist_ops.get_impl( \ - dist_attr.impl_idx)._forward_implemented - - -def _is_dist_op_backward_implement(dist_context, op): - dist_attr = dist_context.get_op_dist_attr_for_program(op) - dist_ops = get_distributed_operator_impl_container(op.type) - - return dist_ops and dist_attr.impl_idx >= 0 and dist_ops.get_impl( \ - dist_attr.impl_idx)._backward_implemented - - -def _auto_backward(loss, - startup_program=None, - parameter_list=None, - no_grad_set=None, - callbacks=None, - distop_context=None): - """ - modification is inplaced - """ - act_no_grad_set = _get_no_grad_set(loss, no_grad_set) - assert isinstance(loss, Variable), "The target loss should be an Variable." - - if callbacks is None: - callbacks = [error_clip_callback] - else: - assert (isinstance(callbacks, list)) - - assert len(loss.shape) == 1 and loss.shape[0] == 1, \ - "The loss.shape should be (1L,), but the current loss.shape is {}. " \ - "Maybe that you should call fluid.layers.mean to process the current loss.".format( - loss.shape) - - program = loss.block.program - - with program_guard(program, startup_program): - params_grads = append_backward( - loss, - parameter_list, - act_no_grad_set, - callbacks, - distop_context=distop_context) - - return params_grads - - -def _is_distributed(dist_attr): - - mapping = dist_attr.dims_mapping - mesh = dist_attr.process_mesh.topology - for idx in range(len(mapping)): - if mapping[idx] >= 0 and mesh[mapping[idx]] > 1: - return True - - return False - def _get_dist_shape(var, dist_attr): @@ -795,52 +363,33 @@ def _partition_var(dist_context, src_block, dst_block, src_varname, dst_varname, target_shape) -def _insert_src_op(src_op, dst_block, varname_mapping): - - new_op_desc = dst_block.desc.append_op() - new_op_desc.copy_from(src_op.desc) - for local_varname in src_op.desc.input_arg_names(): - new_op_desc._rename_input(local_varname, varname_mapping[local_varname]) - for local_varname in src_op.desc.output_arg_names(): - new_op_desc._rename_output(local_varname, - varname_mapping[local_varname]) - dst_block._sync_with_cpp() - - -def _insert_dist_op(src_op, dst_block, varname_mapping, dist_context, rank_id): - - # build input varname mapping - input_mapping = {} - for input_name in src_op.desc.input_names(): - varnames = [] - for varname in src_op.desc.input(input_name): - varnames.append(varname_mapping[varname]) - input_mapping[input_name] = varnames - - # build output varname mapping - output_mapping = {} - for output_name in src_op.desc.output_names(): - varnames = [] - for varname in src_op.desc.output(output_name): - varnames.append(varname_mapping[varname]) - output_mapping[output_name] = varnames - - # append dist op - dist_attr = dist_context.get_op_dist_attr_for_program(src_op) - dist_ops = get_distributed_operator_impl_container(src_op.type) - append_op_handle = dist_ops.get_impl(dist_attr.impl_idx).forward(src_op) - append_op_handle( - dst_block, - src_op, - dist_attr, - input_mapping, - output_mapping, - rank_id=rank_id) - - -def is_forward_op(op): - role1 = int(core.op_proto_and_checker_maker.OpRole.Forward) | int( - core.op_proto_and_checker_maker.OpRole.Loss) - role2 = int(core.op_proto_and_checker_maker.OpRole.Forward) - op_role = int(op.attr('op_role')) - return op_role == role2 or op_role == role1 +def _get_dist_op_backward_implement(backward_op, dist_context, + forward_op_id2forward_op): + dist_op_context = dist_context.dist_op_context + if backward_op.desc.id() in dist_op_context.gradopidx2opidx: + forward_op_id = dist_op_context.gradopidx2opidx[backward_op.desc.id()] + forward_op = forward_op_id2forward_op[forward_op_id] + forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( + forward_op) + dist_ops = get_distributed_operator_impl_container(forward_op.type) + + # TODO backward should have its own impl_idx + if dist_ops and forward_op_dist_attr.impl_idx >= 0 and dist_ops.get_impl( \ + forward_op_dist_attr.impl_idx)._backward_implemented: + return dist_ops.get_impl(forward_op_dist_attr.impl_idx) + + dist_ops = get_distributed_operator_impl_container("default") + return dist_ops.get_impl(0) + + +def _get_dist_op_forward_implement(forward_op, dist_context): + dist_attr = dist_context.get_op_dist_attr_for_program(forward_op) + dist_ops = get_distributed_operator_impl_container(forward_op.type) + + if dist_ops and dist_attr.impl_idx >= 0 and dist_ops.get_impl( + dist_attr.impl_idx)._forward_implemented: + return dist_ops.get_impl(dist_attr.impl_idx) + + else: + dist_ops = get_distributed_operator_impl_container("default") + return dist_ops.get_impl(0) diff --git a/python/paddle/distributed/auto_parallel/planner.py b/python/paddle/distributed/auto_parallel/planner.py old mode 100644 new mode 100755 index 7c4ce0b2435..1dfefb41c80 --- a/python/paddle/distributed/auto_parallel/planner.py +++ b/python/paddle/distributed/auto_parallel/planner.py @@ -386,15 +386,20 @@ class SearchAlgorithm: class MCMC(SearchAlgorithm): - def __init__(self, serial_program_info, max_search_times=5): + def __init__(self, serial_program_info, parallelizer, max_search_times=5): super(MCMC, self).__init__("mcmc") self._serial_program_info = serial_program_info self._max_search_times = max_search_times + self._parallelizer = parallelizer @property def serial_program_info(self): return self._serial_program_info + @property + def parallelizer(self): + return self._parallelizer + @property def max_search_times(self): return self._max_search_times @@ -483,7 +488,7 @@ class MCMC(SearchAlgorithm): cost = None # get all distributed programs all_dist_main_program = get_all_distributed_main_program( - self.serial_program_info, dist_context) + self.serial_program_info, dist_context, self.parallelizer) pipeline_config = [ process_mesh.processes for process_mesh in pipeline_process_meshes ] if pipeline_process_meshes is not None else None @@ -829,8 +834,10 @@ class MCMC(SearchAlgorithm): class Planner: - def __init__(self, serial_program_info, algorithm_config=None): + def __init__(self, serial_program_info, parallelizer, + algorithm_config=None): self._serial_program_info = serial_program_info + self._parallelizer = parallelizer self._algorithm_config = algorithm_config self._algorithm_searcher = self.create_algorithm_searcher( algorithm_config) @@ -847,6 +854,10 @@ class Planner: def algorithm_searcher(self): return self._algorithm_searcher + @property + def parallelizer(self): + return self._parallelizer + def create_algorithm_searcher(self, algorithm_config): name = algorithm_config.get("name", None) assert name is not None, "Invalid algorithm config." @@ -856,9 +867,9 @@ class Planner: # NOTE: Only GPU clusters are supported now. max_search_times = algorithm_config.get("max_search_times", None) algorithm_searcher = MCMC( - self.serial_program_info, + self.serial_program_info, self.parallelizer, max_search_times) if max_search_times is not None else MCMC( - self.serial_program_info) + self.serial_program_info, self.parallelizer) else: raise NotImplementedError( "Other search algorithms have not been supported now.") diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py old mode 100755 new mode 100644 index 3b392d4e088..8f7ac360401 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -993,7 +993,9 @@ def set_grad_var_shape(program, dist_context): block = program.global_block() vars = block.vars for op in block.ops: - if op.type == "sum": + if op.type in [ + "sum", "check_finite_and_unscale", "update_loss_scaling" + ]: continue if int(op.attr('op_role')) == int(OpRole.Backward): op_dist_attr = dist_context.get_op_dist_attr_for_program(op) @@ -1004,15 +1006,24 @@ def set_grad_var_shape(program, dist_context): forward_var_name = var_name[:var_name.find("@GRAD")] if op.type == "c_allreduce_sum" or op.type == "c_identity" or op.type == "scale": forward_var_name = op.input_arg_names[0] + elif op.type == "matmul_v2_grad": + forward_var_name = None + for output_name in op.output_names: + if var_name in op.output(output_name): + assert "@GRAD" in output_name + input_name = output_name[:output_name.find("@GRAD")] + assert len(op.input(input_name)) == 1 + forward_var_name = op.input(input_name)[0] + assert forward_var_name is not None need_set_shape_list = [ "reshape2_grad", "softmax_with_cross_entropy_grad", "transpose2_grad", "softmax_grad", "cross_entropy_grad2", - "dropout_grad", "unsqueeze2_grad" + "dropout_grad" ] forward_list = [ "reshape2", "softmax_with_cross_entropy", "transpose2", - "softmax", "cross_entropy2", "dropout", "unsqueeze2" + "softmax", "cross_entropy2", "dropout" ] if op.type in need_set_shape_list: for forward_op in block.ops: @@ -1041,6 +1052,23 @@ def set_grad_var_shape(program, dist_context): grad_var.desc.set_shape(ref_shape) +OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() +OpRole = core.op_proto_and_checker_maker.OpRole + + +def is_forward_op(op): + ref_role1 = int(core.op_proto_and_checker_maker.OpRole.Forward) + ref_role2 = int(core.op_proto_and_checker_maker.OpRole.Loss) + op_role = int(op.attr('op_role')) + return OP_ROLE_KEY in op.attr_names and (op_role == ref_role1 or + op_role == ref_role2) + + +def is_backward_op(op): + return OP_ROLE_KEY in op.attr_names and \ + int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward) + + def update_op_dims_mapping_by_default_dist_impl(dist_op): changed = False op_dist_attr = dist_op.dist_attr @@ -1177,57 +1205,25 @@ def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op): return changed -def get_all_distributed_main_program(serial_program_info, dist_context): +def get_all_distributed_main_program(serial_program_info, dist_context, + parallelizer): "Get all distributed main programs by dist_context." - from .dist_context import DistributedOperatorContext + from .dist_context import DistributedOperatorContext, DistributedContext cluster = serial_program_info.cluster + copied_parallelizer = copy.deepcopy(parallelizer) all_dist_main_program = [] ranks = paddle.distributed.get_world_size() if cluster is None else len( cluster.get_all_devices("GPU")) for rank_id in range(ranks): used_dist_context = copy.deepcopy(dist_context) used_dist_context._dist_op_context = DistributedOperatorContext() - dist_main_program, dist_startup_program = get_specified_distributed_main_program( - serial_program_info, used_dist_context, rank_id) + _, _, dist_startup_program, dist_main_program, _ = copied_parallelizer._get_dist_program( + rank_id, used_dist_context) all_dist_main_program.append(dist_main_program) return all_dist_main_program -def get_specified_distributed_main_program(serial_program_info, dist_context, - rank_id): - "Get distributed main program by the given dist_context and rank_id." - from .partitioner import Partitioner - from .reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER - from .process_group import _g_process_group_map, ProcessGroup - - dist_strategy = paddle.distributed.fleet.DistributedStrategy() - train_program = serial_program_info.train_program - startup_program = serial_program_info.startup_program - loss = serial_program_info.loss - optimizer = serial_program_info.optimizer - - partitioner = Partitioner(dist_strategy, dist_context, rank_id) - dist_main_program, dist_startup_program = partitioner.transpile_forward( - train_program, startup_program) - dist_params_grads = partitioner.apply_backward( - loss, train_program, startup_program, dist_main_program, - dist_startup_program) - opt_ops = partitioner.apply_optimize( - copy.deepcopy(optimizer), dist_params_grads, dist_main_program, - dist_startup_program) - set_grad_var_shape(dist_main_program, dist_context) - make_data_unshard(dist_main_program, dist_startup_program, dist_context) - reshard(dist_main_program, dist_startup_program, rank_id, dist_context) - HAS_SENT.clear() - HAS_RECV.clear() - HAS_ALLGATHER.clear() - - _g_process_group_map.clear() - _g_process_group_map[0] = ProcessGroup(0, []) - return dist_main_program, dist_startup_program - - class SerialProgramInfo: def __init__(self, train_program, @@ -1286,7 +1282,6 @@ def get_standalone_cost_data(distributed_programs): shape = list(map(lambda x: int(x.strip()), shape)) dtype_factor = 1 total_static_input_size += reduce(lambda x, y: x * y, shape) - # print(arg_name_lower) if op.type == "c_embedding": arg_name_lower = "w" if arg_name_lower == "weight" else "ids" for arg_name in op.input_names: @@ -1301,7 +1296,8 @@ def get_standalone_cost_data(distributed_programs): actual_runtime = total_actual_input_size / total_static_input_size * runtime return actual_runtime - cost_model = paddle.cost_model.CostModel() + import paddle.cost_model as cm + cost_model = cm.CostModel() cost_model.static_cost_data() DEFAULT_MULTIPLE = 2 OP_NAME_MAPPING = { diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py index d58c79dd72c..ab91c3fe7c4 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py @@ -26,7 +26,7 @@ import paddle.distributed.auto_parallel as auto from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed import fleet from paddle.distributed.auto_parallel.partitioner import Partitioner -from paddle.distributed.auto_parallel.completion import complete_backward_annotation +from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.cost_model import estimate_cost import paddle.fluid.core as core @@ -148,22 +148,33 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): loss, train_program, startup_program = mlp_forward(train_program, startup_program) - dist_strategy = fleet.DistributedStrategy() + fleet._user_defined_strategy = fleet.DistributedStrategy() + fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer() + parallelizer = AutoParallelizer(fleet) + parallelizer._dist_context = dist_context - # auto completion + # serial forward & backward completion complete_train_program = auto.complete_annotation(train_program, dist_context) - partitioner = Partitioner(dist_strategy, dist_context, rank_id) + + parallelizer._apply_serial_forward_pass(complete_train_program, + startup_program) + + params_grads = parallelizer._generate_backward( + complete_train_program, + startup_program, + loss, + parameter_list=None, + no_grad_set=None, + callbacks=None) + # logical partition - auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward( - complete_train_program, startup_program) - dist_params_grads = partitioner.apply_backward( - loss, complete_train_program, startup_program, auto_parallel_main_prog, - auto_parallel_startup_prog) - optimizer = paddle.fluid.optimizer.AdamOptimizer() - opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads, - auto_parallel_main_prog, - auto_parallel_startup_prog) + partitioner = Partitioner(dist_context, rank_id) + auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition( + complete_train_program, startup_program, params_grads) + + partitioned_optimize_ops = parallelizer._apply_optimize( + auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) return auto_parallel_main_prog, auto_parallel_startup_prog diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py index 4fd64dc252b..9fe5a52cf08 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py @@ -36,6 +36,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer from paddle.distributed import fleet import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.reshard import reshard @@ -469,21 +470,31 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): loss, train_program, startup_program = mlp_forward(train_program, startup_program) - dist_strategy = fleet.DistributedStrategy() + fleet._user_defined_strategy = fleet.DistributedStrategy() + fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer() + parallelizer = AutoParallelizer(fleet) + parallelizer._dist_context = dist_context # auto completion complete_train_program = auto.complete_annotation(train_program, dist_context) - partitioner = Partitioner(dist_strategy, dist_context, rank_id) - # logical partition - dist_train_program, dist_startup_prog = partitioner.transpile_forward( - complete_train_program, startup_program) - dist_params_grads = partitioner.apply_backward( - loss, complete_train_program, startup_program, dist_train_program, - dist_startup_prog) - optimizer = paddle.fluid.optimizer.AdamOptimizer() - opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads, - dist_train_program, dist_startup_prog) + parallelizer._apply_serial_forward_pass(complete_train_program, + startup_program) + params_grads = parallelizer._generate_backward( + complete_train_program, + startup_program, + loss, + parameter_list=None, + no_grad_set=None, + callbacks=None) + + partitioner = Partitioner(dist_context, rank_id) + dist_train_program, dist_startup_prog, dist_params_grads = partitioner.partition( + complete_train_program, startup_program, params_grads) + + partitioned_optimize_ops = parallelizer._apply_optimize( + dist_train_program, dist_startup_prog, dist_params_grads) + reshard(dist_train_program, dist_startup_prog, rank_id, dist_context) return dist_train_program, dist_startup_prog diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py old mode 100755 new mode 100644 index 3a23f9b2611..21cf8a904b6 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py @@ -54,9 +54,9 @@ def get_programs(annotated_func): rank_id = 3 dist_strategy = fleet.DistributedStrategy() - partitioner = Partitioner(dist_strategy, dist_context, rank_id) - test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog = partitioner.transpile_forward( - complete_train_program, start_program) + partitioner = Partitioner(dist_context, rank_id) + test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog, _ = partitioner.partition( + complete_train_program, start_program, []) return complete_train_program, start_program, test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog, dist_context diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py old mode 100755 new mode 100644 index 7fcb18db128..3270cfc3c8a --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py @@ -35,6 +35,7 @@ from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_pr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.partitioner import Partitioner +from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.utils import _get_comm_group from paddle.distributed.auto_parallel.process_group import new_process_group @@ -790,9 +791,9 @@ class GPTPretrainingCriterion(nn.Layer): return loss -def gpt_pretrain_forward(train_program, start_program): +def gpt_pretrain_forward(train_program, startup_program): with static.program_guard(train_program, - start_program), utils.unique_name.guard(): + startup_program), utils.unique_name.guard(): batch_size = 16 sequence_len = 512 input_ids = static.data( @@ -848,7 +849,19 @@ def gpt_pretrain_forward(train_program, start_program): loss = criterion(preds, labels, loss_mask) - return train_program, start_program, loss + return train_program, startup_program, loss + + +class FakeStrategy(object): + def __init__(self): + self.amp = False + self.recompute = False + + +class FakeFleet(object): + def __init__(self): + self.user_defined_optimizer = None + self._user_defined_strategy = FakeStrategy() class TestGPTPartitioner(unittest.TestCase): @@ -861,38 +874,41 @@ class TestGPTPartitioner(unittest.TestCase): mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) train_program = static.Program() - start_program = static.Program() - dist_context = DistributedContext() + startup_program = static.Program() + parallelizer = AutoParallelizer(FakeFleet()) + dist_context = parallelizer._dist_context + dist_context.process_mesh = _global_process_mesh - train_program, start_program, loss = gpt_pretrain_forward(train_program, - start_program) + train_program, startup_program, loss = gpt_pretrain_forward( + train_program, startup_program) complete_train_program = auto.complete_annotation(train_program, dist_context) + + # serial forward pass + parallelizer._apply_serial_forward_pass(complete_train_program, + startup_program) + + # serial backward pass + params_grads = parallelizer._generate_backward( + complete_train_program, + startup_program, + loss, + parameter_list=None, + no_grad_set=None, + callbacks=None) + rank_id = 3 - dist_strategy = fleet.DistributedStrategy() - partitioner = Partitioner(dist_strategy, dist_context, rank_id) - auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward( - complete_train_program, start_program) - dist_params_grads = partitioner.apply_backward( - loss, complete_train_program, start_program, - auto_parallel_main_prog, auto_parallel_startup_prog) + partitioner = Partitioner(dist_context, rank_id) + auto_parallel_main_prog, auto_parallel_startup_prog, params_grads = partitioner.partition( + complete_train_program, startup_program, params_grads) with open("./test_auto_parallel_partitioner_serial_main_new.txt", "w") as fw: fw.write(str(train_program)) with open("./test_auto_parallel_partitioner_serial_startup_new.txt", "w") as fw: - fw.write(str(start_program)) - - optimizer = paddle.fluid.optimizer.AdamOptimizer( - learning_rate=0.00001, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - grad_clip=None) - opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads, - auto_parallel_main_prog, - auto_parallel_startup_prog) + fw.write(str(startup_program)) + from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context set_default_distributed_context(dist_context) with open("./test_auto_parallel_partitioner_main_new.txt1", "w") as fw: @@ -927,7 +943,7 @@ class TestGPTPartitioner(unittest.TestCase): complete_train_program, weights, 0, 1)) all_params = sorted( - [param.name for param in start_program.all_parameters()]) + [param.name for param in startup_program.all_parameters()]) allreduce_grads = [ 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', 'layer_norm_6.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2', diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py index 0439b9a287c..0631cc74a32 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py @@ -24,6 +24,7 @@ import paddle.utils as utils import paddle.distributed.auto_parallel as auto from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed import fleet +from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.process_group import _g_process_group_map @@ -145,22 +146,34 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): loss, train_program, startup_program = mlp_forward(train_program, startup_program) - # auto completion + fleet._user_defined_strategy = fleet.DistributedStrategy() + fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer() + parallelizer = AutoParallelizer(fleet) + parallelizer._dist_context = dist_context + + # serial forward & backward completion complete_train_program = auto.complete_annotation(train_program, dist_context) - dist_strategy = fleet.DistributedStrategy() - partitioner = Partitioner(dist_strategy, dist_context, rank_id) + parallelizer._apply_serial_forward_pass(complete_train_program, + startup_program) + + params_grads = parallelizer._generate_backward( + complete_train_program, + startup_program, + loss, + parameter_list=None, + no_grad_set=None, + callbacks=None) + # logical partition - auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward( - complete_train_program, startup_program) - dist_params_grads = partitioner.apply_backward( - loss, complete_train_program, startup_program, auto_parallel_main_prog, - auto_parallel_startup_prog) - optimizer = paddle.fluid.optimizer.AdamOptimizer() - opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads, - auto_parallel_main_prog, - auto_parallel_startup_prog) + partitioner = Partitioner(dist_context, rank_id) + auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition( + complete_train_program, startup_program, params_grads) + + partitioned_optimize_ops = parallelizer._apply_optimize( + auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) + return auto_parallel_main_prog, auto_parallel_startup_prog diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py index 4bd03a3e1bd..0e098664f7e 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py @@ -24,6 +24,7 @@ import paddle.utils as utils import paddle.distributed.auto_parallel as auto from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed import fleet +from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr @@ -109,22 +110,34 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): loss, train_program, startup_program = mlp_forward(train_program, startup_program) - # auto completion + fleet._user_defined_strategy = fleet.DistributedStrategy() + fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer() + parallelizer = AutoParallelizer(fleet) + parallelizer._dist_context = dist_context + + # serial forward & backward completion complete_train_program = auto.complete_annotation(train_program, dist_context) - dist_strategy = fleet.DistributedStrategy() - partitioner = Partitioner(dist_strategy, dist_context, rank_id) + parallelizer._apply_serial_forward_pass(complete_train_program, + startup_program) + + params_grads = parallelizer._generate_backward( + complete_train_program, + startup_program, + loss, + parameter_list=None, + no_grad_set=None, + callbacks=None) + # logical partition - auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward( - complete_train_program, startup_program) - dist_params_grads = partitioner.apply_backward( - loss, complete_train_program, startup_program, auto_parallel_main_prog, - auto_parallel_startup_prog) - optimizer = paddle.fluid.optimizer.AdamOptimizer() - opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads, - auto_parallel_main_prog, - auto_parallel_startup_prog) + partitioner = Partitioner(dist_context, rank_id) + auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition( + complete_train_program, startup_program, params_grads) + + partitioned_optimize_ops = parallelizer._apply_optimize( + auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) + return auto_parallel_main_prog, auto_parallel_startup_prog diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py index ae79712dc79..c6b1be65207 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py @@ -24,6 +24,7 @@ import paddle.utils as utils import paddle.distributed.auto_parallel as auto from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed import fleet +from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr @@ -125,22 +126,32 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): loss, train_program, startup_program = mlp_forward(train_program, startup_program) - # auto completion + fleet._user_defined_strategy = fleet.DistributedStrategy() + fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer() + parallelizer = AutoParallelizer(fleet) + parallelizer._dist_context = dist_context + + # serial forward & backward completion complete_train_program = auto.complete_annotation(train_program, dist_context) + parallelizer._apply_serial_forward_pass(complete_train_program, + startup_program) + + params_grads = parallelizer._generate_backward( + complete_train_program, + startup_program, + loss, + parameter_list=None, + no_grad_set=None, + callbacks=None) - dist_strategy = fleet.DistributedStrategy() - partitioner = Partitioner(dist_strategy, dist_context, rank_id) # logical partition - auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward( - complete_train_program, startup_program) - dist_params_grads = partitioner.apply_backward( - loss, complete_train_program, startup_program, auto_parallel_main_prog, - auto_parallel_startup_prog) - optimizer = paddle.fluid.optimizer.AdamOptimizer() - opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads, - auto_parallel_main_prog, - auto_parallel_startup_prog) + partitioner = Partitioner(dist_context, rank_id) + auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition( + complete_train_program, startup_program, params_grads) + + partitioned_optimize_ops = parallelizer._apply_optimize( + auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) return auto_parallel_main_prog, auto_parallel_startup_prog @@ -253,14 +264,15 @@ class TestMLPReshard(unittest.TestCase): rank_id = 0 dist_context = DistributedContext() dist_strategy = fleet.DistributedStrategy() - partitioner = Partitioner(dist_strategy, dist_context, rank_id) + partitioner = Partitioner(dist_context, rank_id) complete_train_program = auto.complete_annotation(train_program, dist_context) - auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward( - complete_train_program, startup_program) - reshard(auto_parallel_main_prog, startup_program, rank_id, dist_context) + partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition( + complete_train_program, startup_program, []) + reshard(partitioned_main_prog, partitioned_startup_prog, rank_id, + dist_context) # the x should not be slice - self.assertTrue(check_allgather(auto_parallel_main_prog)) + self.assertTrue(check_allgather(partitioned_main_prog)) if __name__ == "__main__": -- GitLab