diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index caf220646bb6098d0577e89bb7173b24b5d4b40a..573f23fdca519ae1da10d62ef7eb2da6238805f3 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -55,6 +55,7 @@ class DistributedContext: self._is_initialized_for_program = False self._dist_tensors_for_program = {} self._dist_ops_for_program = {} + self._block_state = BlockState() # Graph related data members self._is_initialized_for_graph = False self._serial_graph = None @@ -102,6 +103,10 @@ class DistributedContext: def dist_op_context(self): return self._dist_op_context + @property + def block_state(self): + return self._block_state + @property def dist_main_programs(self): return self._dist_main_programs @@ -512,66 +517,83 @@ class DistributedOperatorContext: def __init__(self): self._dst_main_program = None + self._main_block = None self._dst_startup_program = None - self._varname_mapping = None - self._rank_id = None + self._startup_block = None self._cur_src_op = None self._cur_dist_attr = None self.grad_op_id_to_op_id = {} + self._work_block = None self.already_init_sync_vars = set() + self.varname_mapping = None + self.rank_id = None def __deepcopy__(self, memo): cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k == "_dst_main_program" or k == "_dst_startup_program" or k == "_cur_src_op": + if k in [ + "_dst_main_program", "_dst_startup_program", "_cur_src_op", + "_work_block", "_main_block", "_startup_block" + ]: setattr(result, k, v) else: setattr(result, k, copy.deepcopy(v, memo)) return result - def set_dst_main_program(self, prog): - self._dst_main_program = prog - - def get_dst_main_program(self): + @property + def dst_main_program(self): return self._dst_main_program - def set_dst_startup_program(self, prog): - self._dst_startup_program = prog + @dst_main_program.setter + def dst_main_program(self, prog): + self._dst_main_program = prog + self._main_block = prog.blocks[0] - def get_dst_startup_program(self): - return self._dst_startup_program + @property + def main_block(self): + return self._main_block - def set_varname_mapping(self, mapping): - self._varname_mapping = mapping + @property + def dst_startup_program(self): + return self._dst_startup_program - def get_varname_mapping(self): - return self._varname_mapping + @dst_startup_program.setter + def dst_startup_program(self, prog): + self._dst_startup_program = prog + self._startup_block = prog.blocks[0] - def set_rank_id(self, rank_id): - self._rank_id = rank_id + @property + def startup_block(self): + return self._startup_block - def get_rank_id(self): - return self._rank_id + @property + def work_block(self): + assert self._work_block is not None + return self._work_block - def set_cur_src_op(self, cur_src_op): - self._cur_src_op = cur_src_op + @work_block.setter + def work_block(self, block): + assert block is not None + self._work_block = block - def get_cur_src_op(self): + @property + def cur_src_op(self): + assert self._cur_src_op is not None return self._cur_src_op def prepare_context(self, src_op): - self.set_cur_src_op(src_op) + self._cur_src_op = src_op # build input varname mapping kinputs = {} for input_name in src_op.desc.input_names(): varnames = [] for varname in src_op.desc.input(input_name): - assert varname in self._varname_mapping - varnames.append(self._varname_mapping[varname]) + assert varname in self.varname_mapping + varnames.append(self.varname_mapping[varname]) kinputs[input_name] = varnames # build output varname mapping @@ -579,8 +601,52 @@ class DistributedOperatorContext: for output_name in src_op.desc.output_names(): varnames = [] for varname in src_op.desc.output(output_name): - assert varname in self._varname_mapping - varnames.append(self._varname_mapping[varname]) + assert varname in self.varname_mapping + varnames.append(self.varname_mapping[varname]) koutputs[output_name] = varnames return kinputs, koutputs + + +class BlockState(object): + def __init__(self): + self.nblock = 0 + self.forward_indices = [] + self.backward_indices = [] + self.backward_to_forward_index_map = {} + + def parse_forward_blocks(self, program): + + while program.current_block_idx != 0: + program._rollback() + + assert program.current_block_idx == 0 + + for idx, block in enumerate(program.blocks): + + assert idx == block.idx, "index doesn't match" + assert block.forward_block_idx == -1, "forward_block_idx of forward block [{}] is not [{}]".format( + idx, block.forward_block_idx) + self.forward_indices.append(idx) + self.nblock += 1 + + assert self.nblock >= 1 + + def parse_backward_blocks(self, program): + + assert 0 in self.forward_indices, "forward block idx are{}".format( + self.forward_indices) + self.backward_to_forward_index_map[0] = 0 + + for idx, block in enumerate(program.blocks): + + if idx < len(self.forward_indices): + continue + + assert idx == block.idx, "index doesn't match" + assert block.forward_block_idx in self.forward_indices + self.backward_indices.append(idx) + self.backward_to_forward_index_map[idx] = block.forward_block_idx + self.nblock += 1 + + assert self.nblock == len(program.blocks) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py b/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py index 2870acfd367cab5236f8544c447bdd269b8e654b..b887de577b0a21818aa1165c9015fb33a13da037 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py @@ -76,9 +76,9 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): # by now the backward function only insert the gradient allreduce for dist op itself dist_op_context = ctx.dist_op_context - main_block = dist_op_context.get_dst_main_program().global_block() - backward_op = dist_op_context.get_cur_src_op() - rank_id = dist_op_context.get_rank_id() + main_block = dist_op_context.main_block + backward_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id dist_attr = ctx.get_op_dist_attr_for_program(backward_op) assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(backward_op)) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index 48f9b5a78dd8a371962ed4b72babe01dcc1ac5d4..4e977007261a73e9b24a051f84e6e30f2bf9d860 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -32,6 +32,8 @@ from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, from ..process_group import new_process_group from ..utils import _get_comm_group, _get_corresponding_rank +__op_not_need_param_init__ = ["while", "cond"] + class DistributedDefault(DistributedOperatorImplContainer): def __init__(self, op_type): @@ -195,10 +197,10 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): def forward(ctx, *args, **kwargs): dist_op_context = ctx.dist_op_context - main_block = dist_op_context.get_dst_main_program().global_block() - startup_block = dist_op_context.get_dst_startup_program().global_block() - src_op = dist_op_context.get_cur_src_op() - rank_id = dist_op_context.get_rank_id() + main_block = dist_op_context.work_block + startup_block = dist_op_context.startup_block + src_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id # check validation of inputs / outputs for input_name in src_op.desc.input_names(): @@ -227,6 +229,9 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): main_block._sync_with_cpp() # param initialization sync + if src_op.type in __op_not_need_param_init__: + return + for varname in dist_op_desc.input_arg_names(): if startup_block.has_var(varname) and startup_block.var( varname @@ -278,12 +283,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): # by now the backward function only insert the gradient allreduce for dist op itself dist_op_context = ctx.dist_op_context - main_block = dist_op_context.get_dst_main_program().global_block() - backward_op = dist_op_context.get_cur_src_op() + main_block = dist_op_context.work_block + backward_op = dist_op_context.cur_src_op dist_attr = ctx.get_op_dist_attr_for_program(backward_op) assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(backward_op)) - rank_id = dist_op_context.get_rank_id() + rank_id = dist_op_context.rank_id # check validation of inputs / outputs for input_name in backward_op.desc.input_names(): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index eac4776f8f3bcdbffc85725a2280b30c6bcff060..94eb0d2d469f0595fdc8cb31821d6cded9ad064a 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -128,10 +128,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): """ dist_op_context = ctx.dist_op_context - main_block = dist_op_context.get_dst_main_program().global_block() - startup_block = dist_op_context.get_dst_startup_program().global_block() - src_op = dist_op_context.get_cur_src_op() - rank_id = dist_op_context.get_rank_id() + main_block = dist_op_context.work_block + startup_block = dist_op_context.startup_block + src_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(src_op)) @@ -311,9 +311,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): # by now the backward function only insert the gradient allreduce for dist op itself dist_op_context = ctx.dist_op_context - main_block = dist_op_context.get_dst_main_program().global_block() - backward_op = dist_op_context.get_cur_src_op() - rank_id = dist_op_context.get_rank_id() + main_block = dist_op_context.work_block + backward_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id dist_attr = ctx.get_op_dist_attr_for_program(backward_op) assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(backward_op)) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index cb59a6f25c48769a639094f0a14ac12b63036657..9eb24a65e608c22573342f32dfd0dc96a601e3ac 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -223,9 +223,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): # by now the backward function only insert the gradient allreduce for dist op itself dist_op_context = ctx.dist_op_context - main_block = dist_op_context.get_dst_main_program().global_block() - backward_op = dist_op_context.get_cur_src_op() - rank_id = dist_op_context.get_rank_id() + main_block = dist_op_context.work_block + backward_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id dist_attr = ctx.get_op_dist_attr_for_program(backward_op) assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(backward_op)) @@ -257,7 +257,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): kwargs['Y@GRAD']) X_var = main_block.var(kwargs['X'][0]) - Y_var = main_block.var(kwargs['Y'][0]) + Y_var = main_block._var_recursive(kwargs['Y'][0]) Out_grad = main_block.var(kwargs['Out@GRAD'][0]) Y_grad = main_block.var(kwargs['Y@GRAD'][0]) @@ -433,7 +433,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id): - assert Weight_var.name not in dist_op_context.already_init_sync_vars + assert Weight_var.name not in dist_op_context.already_init_sync_vars, "{} is in {}.".format( + Weight_var.name, dist_op_context.already_init_sync_vars) assert startup_block.has_var(Weight_var.name) dist_op_context.already_init_sync_vars.add(Weight_var.name) param = startup_block.var(Weight_var.name) @@ -528,10 +529,10 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): """ dist_op_context = ctx.dist_op_context - main_block = dist_op_context.get_dst_main_program().global_block() - startup_block = dist_op_context.get_dst_startup_program().global_block() - src_op = dist_op_context.get_cur_src_op() - rank_id = dist_op_context.get_rank_id() + main_block = dist_op_context.work_block + startup_block = dist_op_context.startup_block + src_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(src_op)) @@ -753,10 +754,10 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): """ dist_op_context = ctx.dist_op_context - main_block = dist_op_context.get_dst_main_program().global_block() - startup_block = dist_op_context.get_dst_startup_program().global_block() - src_op = dist_op_context.get_cur_src_op() - rank_id = dist_op_context.get_rank_id() + main_block = dist_op_context.work_block + startup_block = dist_op_context.startup_block + src_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(src_op)) @@ -1042,10 +1043,10 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): """ dist_op_context = ctx.dist_op_context - main_block = dist_op_context.get_dst_main_program().global_block() - startup_block = dist_op_context.get_dst_startup_program().global_block() - src_op = dist_op_context.get_cur_src_op() - rank_id = dist_op_context.get_rank_id() + main_block = dist_op_context.work_block + startup_block = dist_op_context.startup_block + src_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(src_op)) @@ -1071,7 +1072,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): output_name) X_var = main_block.var(kwargs['X'][0]) - Weight_var = main_block.var(kwargs['Y'][0]) + Weight_var = main_block._var_recursive(kwargs['Y'][0]) Out_var = main_block.var(kwargs['Out'][0]) # TODO infer logic comm presentation @@ -1261,10 +1262,10 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): """ dist_op_context = ctx.dist_op_context - main_block = dist_op_context.get_dst_main_program().global_block() - startup_block = dist_op_context.get_dst_startup_program().global_block() - src_op = dist_op_context.get_cur_src_op() - rank_id = dist_op_context.get_rank_id() + main_block = dist_op_context.work_block + startup_block = dist_op_context.startup_block + src_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(src_op)) @@ -1290,7 +1291,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): output_name) X_var = main_block.var(kwargs['X'][0]) - Weight_var = main_block.var(kwargs['Y'][0]) + Weight_var = main_block._var_recursive(kwargs['Y'][0]) Out_var = main_block.var(kwargs['Out'][0]) # TODO infer logic comm presentation diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py index 93b0d91b7836d64ae6e1dc9b17161746bc6b8444..a72e304bb5b911eb89fd3e401f9a4e9abf58ceb2 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py @@ -130,9 +130,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): """ dist_op_context = ctx.dist_op_context - main_block = dist_op_context.get_dst_main_program().global_block() - src_op = dist_op_context.get_cur_src_op() - rank_id = dist_op_context.get_rank_id() + main_block = dist_op_context.work_block + src_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(src_op)) @@ -287,9 +287,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): """ dist_op_context = ctx.dist_op_context - main_block = dist_op_context.get_dst_main_program().global_block() - src_op = dist_op_context.get_cur_src_op() - rank_id = dist_op_context.get_rank_id() + main_block = dist_op_context.work_block + src_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(src_op)) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py b/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py index f216fce16f30d0d581248402740b27da41725904..4ea2e0a88471601a5c8051c4f58a41c2509bc033 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py @@ -65,9 +65,9 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl): # the backward function only filte the gradient with current rank id dist_op_context = ctx.dist_op_context - main_block = dist_op_context.get_dst_main_program().global_block() - backward_op = dist_op_context.get_cur_src_op() - rank_id = dist_op_context.get_rank_id() + main_block = dist_op_context.main_block + backward_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id dist_attr = ctx.get_op_dist_attr_for_program(backward_op) assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(backward_op)) diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 6278f0a2424a0fa89b5ae7ab2350aeec63a600a7..0f35ccd915f2ab394c0e7316196b9a03b43b9968 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -132,7 +132,7 @@ class AutoParallelizer: distop_context=self._dist_context.dist_op_context) self._completer = Completer(self._dist_context) self._completer.complete_backward_annotation(main_program) - + self._dist_context.block_state.parse_backward_blocks(main_program) return params_grads def _apply_optimize(self, main_program, startup_program, params_grads): @@ -174,6 +174,7 @@ class AutoParallelizer: serial_main_program = self._main_program.clone() serial_startup_program = self._startup_program.clone() serial_loss = serial_main_program.global_block().var(self._loss.name) + # generating serial if dist_context is None: # Annotation completion @@ -186,6 +187,9 @@ class AutoParallelizer: completed_main_program = serial_main_program self._dist_context = copy.deepcopy(dist_context) + # parse forward sub block + self._dist_context.block_state.parse_forward_blocks(serial_main_program) + # serial backward pass params_grads = self._generate_backward( completed_main_program, serial_startup_program, serial_loss, diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index e789d82632e073544b7efaac96f397bb9df9276c..2f88407c093a534d1d67133aece636127ff29626 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -29,6 +29,9 @@ from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op from .operators.common import BACKWARD_ONLY_DIST_OPS __varname_not_in_block__ = ["lod_tensor_blocking_queue_0"] +__not_shape_var_type__ = [ + core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES +] class Partitioner(object): @@ -75,8 +78,8 @@ class Partitioner(object): # init distop helper dist_op_context = self._dist_context.dist_op_context - dist_op_context.set_varname_mapping(self._serial2dist_varname_mapping) - dist_op_context.set_rank_id(self._rank_id) + dist_op_context.varname_mapping = self._serial2dist_varname_mapping + dist_op_context.rank_id = self._rank_id # partition startup program if serial_startup_program == None: @@ -84,7 +87,7 @@ class Partitioner(object): else: partitioned_startup_prog = self.partition_startup_program( serial_main_program, serial_startup_program) - dist_op_context.set_dst_startup_program(partitioned_startup_prog) + dist_op_context.dst_startup_program = partitioned_startup_prog # partition main program partitioned_main_prog, partitioned_params_grads = self.partition_main_program( @@ -157,15 +160,45 @@ class Partitioner(object): 2. replace local op with corresponding dist op """ - dist_op_context = self._dist_context.dist_op_context partitioned_main_prog = fluid.Program() - dist_op_context.set_dst_main_program(partitioned_main_prog) - target_block = partitioned_main_prog.global_block() - ref_block = serial_main_program.global_block() - serial_ops = serial_main_program.global_block().ops + dist_op_context = self._dist_context.dist_op_context + dist_op_context.dst_main_program = partitioned_main_prog + + for idx in range(self._dist_context.block_state.nblock): + ref_block = serial_main_program.blocks[idx] + + if idx == 0: + target_block = partitioned_main_prog.blocks[0] + else: + target_block = partitioned_main_prog._create_block( + parent_idx=ref_block.parent_idx) + assert ref_block.idx == target_block.idx + target_block._set_forward_block_idx(ref_block.forward_block_idx) + dist_op_context.work_block = target_block + self.partition_block(ref_block, target_block) + + partitioned_main_prog.current_block_idx = 0 + + partitioned_params_and_grads = [] + for p, g in params_and_grads: + assert p.name in self._serial2dist_varname_mapping + dist_p = self._get_dist_var_by_serial_var(p, partitioned_main_prog) + if g is None: + dist_g = None + else: + assert g.name in self._serial2dist_varname_mapping + dist_g = self._get_dist_var_by_serial_var(g, + partitioned_main_prog) + partitioned_params_and_grads.append((dist_p, dist_g)) + + return partitioned_main_prog, partitioned_params_and_grads + + def partition_block(self, ref_block, target_block): + + dist_op_context = self._dist_context.dist_op_context + serial_ops = ref_block.ops # init mapping - first_backward_op_idx = -1 forward_op_id2forward_op = {} for idx in range(len(serial_ops)): if is_forward_op(serial_ops[idx]): @@ -218,23 +251,6 @@ class Partitioner(object): "partitioner only support forward op and backward op, but got {}". format(str(op))) - partitioned_params_and_grads = [] - for p, g in params_and_grads: - assert p.name in self._serial2dist_varname_mapping - dist_p_name = self._serial2dist_varname_mapping[p.name] - assert target_block.has_var(dist_p_name) - dist_p = target_block.var(dist_p_name) - if g is None: - dist_g = None - else: - assert g.name in self._serial2dist_varname_mapping - dist_g_name = self._serial2dist_varname_mapping[g.name] - assert target_block.has_var(dist_g_name) - dist_g = target_block.var(dist_g_name) - partitioned_params_and_grads.append((dist_p, dist_g)) - - return partitioned_main_prog, partitioned_params_and_grads - def _is_valid_annotated_program(self, program): # TODO (ZJ-LIANG) should check all block @@ -245,7 +261,7 @@ class Partitioner(object): ] var_dist_attrs = [ self._dist_context.get_tensor_dist_attr_for_program(var) - for var in vars_ + for var in vars_ if (var.type not in __not_shape_var_type__) ] all_ops_annotated = all(dist_attr is not None @@ -255,6 +271,14 @@ class Partitioner(object): return all_ops_annotated and all_vars_annotated + def _get_dist_var_by_serial_var(self, serial_var, partitioned_main_prog): + + block_idx = serial_var.block.idx + target_block = partitioned_main_prog.blocks[block_idx] + dist_var_name = self._serial2dist_varname_mapping[serial_var.name] + assert target_block.has_var(dist_var_name) + return target_block.var(dist_var_name) + def _get_dist_shape(var, dist_attr): @@ -341,7 +365,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname, """ src_var = src_block.var(src_varname) - if src_var.type == core.VarDesc.VarType.READER: + if src_var.type in __not_shape_var_type__: dst_block.create_var( type=src_var.type, name=dst_varname, diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_auto_parallel_while_op.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_auto_parallel_while_op.py new file mode 100644 index 0000000000000000000000000000000000000000..1cd8f8f3e7083d61bd4a30ca114d0ac2a099ba47 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_auto_parallel_while_op.py @@ -0,0 +1,440 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import numpy as np +import paddle.nn as nn +import paddle.utils as utils +import paddle.fluid as fluid +import paddle.static as static +import paddle.nn.functional as F +import paddle.distributed.auto_parallel as auto + +from paddle.distributed import fleet + +from paddle.distributed.auto_parallel.partitioner import Partitioner +from paddle.distributed.auto_parallel.utils import make_data_unshard +from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute +from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context +from paddle.distributed.auto_parallel.operators import find_best_compatible_distributed_operator_impl + +paddle.enable_static() + +batch_size = 4 +epoch_num = 10 +hidden_size = 1024 +sequence_len = 512 +_g_process_mesh = auto.ProcessMesh([0, 1]) + + +def get_random_inputs_and_labels(input_shape, label_shape): + input = np.random.random(size=input_shape).astype('float32') + label = np.random.random(size=label_shape).astype('float32') + return input, label + + +def batch_generator_creator(): + def __reader__(): + for _ in range(batch_size): + batch_input, batch_label = get_random_inputs_and_labels( + [batch_size, sequence_len, hidden_size], + [batch_size, sequence_len, 1]) + yield batch_input, batch_label + + return __reader__ + + +class MLPLayer(nn.Layer): + def __init__(self, + hidden_size=1024, + intermediate_size=4 * 1024, + dropout_ratio=0.1, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + param_initializer = nn.initializer.Normal( + mean=0.0, std=initializer_range) + + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + self.linear0 = nn.Linear( + d_model, + dim_feedforward, + weight_attr=paddle.ParamAttr(initializer=param_initializer), + bias_attr=None) + self.linear1 = nn.Linear( + dim_feedforward, + d_model, + weight_attr=paddle.ParamAttr(initializer=param_initializer), + bias_attr=None) + + def forward(self, input): + + auto.shard_tensor( + self.norm.weight, + dist_attr={"process_mesh": _g_process_mesh, + "dims_mapping": [-1]}) + auto.shard_tensor( + self.norm.bias, + dist_attr={"process_mesh": _g_process_mesh, + "dims_mapping": [-1]}) + auto.shard_tensor( + self.linear0.weight, + dist_attr={ + "process_mesh": _g_process_mesh, + "dims_mapping": [-1, 0] + }) + auto.shard_tensor( + self.linear0.bias, + dist_attr={"process_mesh": _g_process_mesh, + "dims_mapping": [0]}) + auto.shard_tensor( + self.linear1.weight, + dist_attr={ + "process_mesh": _g_process_mesh, + "dims_mapping": [0, -1] + }) + auto.shard_tensor( + self.linear1.bias, + dist_attr={"process_mesh": _g_process_mesh, + "dims_mapping": [-1]}) + + out = self.norm(input) + auto.shard_tensor( + out, + dist_attr={ + "process_mesh": _g_process_mesh, + "dims_mapping": [-1, -1, -1] + }) + out = self.linear0(out) + auto.shard_tensor( + out, + dist_attr={ + "process_mesh": _g_process_mesh, + "dims_mapping": [-1, -1, 0] + }) + out = F.gelu(out, approximate=True) + auto.shard_tensor( + out, + dist_attr={ + "process_mesh": _g_process_mesh, + "dims_mapping": [-1, -1, 0] + }) + out = self.linear1(out) + auto.shard_tensor( + out, + dist_attr={ + "process_mesh": _g_process_mesh, + "dims_mapping": [-1, -1, -1] + }) + + return out + + +def get_program(): + dist_strategy = fleet.DistributedStrategy() + dist_strategy.semi_auto = True + # fleet.init(is_collective=True, strategy=dist_strategy) + + train_program = static.Program() + start_program = static.Program() + with fluid.program_guard(train_program, start_program): + + # 循环计数器 + i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0) + auto.shard_tensor( + i, + dist_attr={"process_mesh": _g_process_mesh, + "dims_mapping": [-1]}) + + # 循环次数 + loop_len = fluid.layers.fill_constant( + shape=[1], dtype='int64', value=epoch_num) + auto.shard_tensor( + loop_len, + dist_attr={"process_mesh": _g_process_mesh, + "dims_mapping": [-1]}) + + # input + input = static.data( + name="input", + shape=[batch_size, sequence_len, hidden_size], + dtype='float32') + label = static.data( + name="label", shape=[batch_size, sequence_len, 1], dtype='float32') + data_holder = [input, label] + # dataloader + dataloader = paddle.io.DataLoader.from_generator( + feed_list=data_holder, capacity=4 * batch_size, iterable=False) + dataloader.set_batch_generator( + batch_generator_creator(), places=paddle.static.cuda_places()) + # data dist_attr + auto.shard_tensor( + input, + dist_attr={ + "process_mesh": _g_process_mesh, + "dims_mapping": [-1, -1, -1] + }) + auto.shard_tensor( + label, + dist_attr={ + "process_mesh": _g_process_mesh, + "dims_mapping": [-1, -1, -1] + }) + + mlp_start = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + pred = mlp_start(input) + + input_array = fluid.layers.array_write(pred, i) + auto.shard_tensor( + input_array, + dist_attr={ + "process_mesh": _g_process_mesh, + "dims_mapping": [-1, -1, -1] + }) + + cond = fluid.layers.less_than(x=i, y=loop_len) + auto.shard_tensor( + cond, + dist_attr={"process_mesh": _g_process_mesh, + "dims_mapping": [-1]}) + + while_op = fluid.layers.While(cond=cond) + with while_op.block(): + + pre_input = fluid.layers.array_read(array=input_array, i=i) + auto.shard_tensor( + pre_input, + dist_attr={ + "process_mesh": _g_process_mesh, + "dims_mapping": [-1, -1, -1] + }) + + mlp_while = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + cur_pred = mlp_while(pre_input) + + # 更新循环条件 + i = fluid.layers.increment(x=i, value=1, in_place=True) + fluid.layers.array_write(cur_pred, array=input_array, i=i) + fluid.layers.less_than(x=i, y=loop_len, cond=cond) + + end_pred = fluid.layers.array_read(array=input_array, i=i) + auto.shard_tensor( + end_pred, + dist_attr={ + "process_mesh": _g_process_mesh, + "dims_mapping": [-1, -1, -1] + }) + + mlp_end = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + pred = mlp_end(end_pred) + + error_cost = paddle.nn.functional.square_error_cost(pred, label) + auto.shard_tensor( + error_cost, + dist_attr={ + "process_mesh": _g_process_mesh, + "dims_mapping": [-1, -1, -1] + }) + + loss = paddle.mean(error_cost) + auto.shard_tensor( + loss, + dist_attr={"process_mesh": _g_process_mesh, + "dims_mapping": [-1]}) + + return train_program, start_program, dataloader, i, loss + + +def completion(train_program, start_program, dist_context): + blocks = train_program.blocks + # completion tensors + for block in blocks: + for op in block.ops: + if op.type == "layer_norm": + for out_name in op.output_arg_names: + out_var = block.vars[out_name] + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( + out_var) + if tensor_dist_attr: + continue + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.process_mesh = _g_process_mesh + tensor_dist_attr.dims_mapping = [-1] + dist_context.set_tensor_dist_attr_for_program( + out_var, tensor_dist_attr) + + elif op.type == "elementwise_sub": + for out_name in op.output_arg_names: + out_var = block.vars[out_name] + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.process_mesh = _g_process_mesh + tensor_dist_attr.dims_mapping = [-1, -1, -1] + dist_context.set_tensor_dist_attr_for_program( + out_var, tensor_dist_attr) + + elif op.type == "matmul_v2": + col = False + for in_name in op.input_arg_names: + if ".w_" not in in_name: + continue + if in_name not in block.vars: + in_var = blocks[0].vars[in_name] + else: + in_var = block.vars[in_name] + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( + in_var) + assert tensor_dist_attr is not None + if tensor_dist_attr.dims_mapping == [-1, 0]: + col = True + for out_name in op.output_arg_names: + out_var = block.vars[out_name] + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( + out_var) + if tensor_dist_attr: + continue + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.process_mesh = _g_process_mesh + if col: + tensor_dist_attr.dims_mapping = [-1, -1, 0] + else: + tensor_dist_attr.dims_mapping = [-1, -1, -1] + dist_context.set_tensor_dist_attr_for_program( + out_var, tensor_dist_attr) + elif op.type == "while": + out_name = op.desc.output("StepScopes")[0] + out_var = block.vars[out_name] + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.process_mesh = _g_process_mesh + tensor_dist_attr.dims_mapping = [-1] + dist_context.set_tensor_dist_attr_for_program(out_var, + tensor_dist_attr) + + # completion ops + for block in blocks: + for op in block.ops: + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.process_mesh = _g_process_mesh + if op.type == "create_by_read" or op.type == "create_double_buffer_reader": + for in_name in op.input_arg_names: + op_dist_attr.set_input_dims_mapping(in_name, []) + for out_name in op.output_arg_names: + op_dist_attr.set_output_dims_mapping(out_name, []) + elif op.type == "read": + for in_name in op.input_arg_names: + op_dist_attr.set_output_dims_mapping(in_name, []) + for out_name in op.output_arg_names: + out_var = block.vars[out_name] + out_dist_attr = dist_context.get_tensor_dist_attr_for_program( + out_var) + op_dist_attr.set_output_dist_attr(out_name, out_dist_attr) + elif op.type == "while": + for in_name in op.input_arg_names: + in_var = block.vars[in_name] + in_dist_attr = dist_context.get_tensor_dist_attr_for_program( + in_var) + op_dist_attr.set_input_dist_attr(in_name, in_dist_attr) + for out_name in op.output_arg_names: + if out_name == op.desc.output("StepScopes")[0]: + op_dist_attr.set_output_dims_mapping(out_name, []) + else: + out_var = block.vars[out_name] + out_dist_attr = dist_context.get_tensor_dist_attr_for_program( + out_var) + op_dist_attr.set_output_dist_attr(out_name, + out_dist_attr) + else: + for in_name in op.input_arg_names: + if in_name == "lod_tensor_blocking_queue_0": + continue + if in_name not in block.vars: + in_var = blocks[0].vars[in_name] + else: + in_var = block.vars[in_name] + in_dist_attr = dist_context.get_tensor_dist_attr_for_program( + in_var) + op_dist_attr.set_input_dist_attr(in_name, in_dist_attr) + for out_name in op.output_arg_names: + if out_name not in block.vars: + out_var = blocks[0].vars[out_name] + else: + out_var = block.vars[out_name] + out_dist_attr = dist_context.get_tensor_dist_attr_for_program( + out_var) + op_dist_attr.set_output_dist_attr(out_name, out_dist_attr) + + if op.type == "matmul_v2": + op_dist_attr.impl_type = "matmul_v2" + for in_name in op_dist_attr.inputs_dist_attrs.keys(): + in_dist_attr = op_dist_attr.inputs_dist_attrs[in_name] + if ".w_" in in_name and in_dist_attr.dims_mapping[-1] == 0: + op_dist_attr.impl_idx = 0 + else: + op_dist_attr.impl_idx = 1 + else: + op_dist_attr.impl_type = "default" + op_dist_attr.impl_idx = 0 + + dist_context.set_op_dist_attr_for_program(op, op_dist_attr) + make_data_unshard(train_program, start_program, dist_context) + + return train_program, start_program + + +def partition(train_program, start_program, dist_context): + + # optimizer = paddle.optimizer.SGD(learning_rate=0.00001) + rank = paddle.distributed.get_rank() + partitioner = Partitioner(dist_context, rank) + dist_main_prog, dist_startup_prog, _ = partitioner.partition( + train_program, start_program, []) + + return dist_main_prog, dist_startup_prog + + +class TestMLP(unittest.TestCase): + def test_partitioner(self): + + train_program, start_program, dataloader, i, loss = get_program() + dist_context = get_default_distributed_context() + train_program, start_program = completion(train_program, start_program, + dist_context) + dist_context.block_state.parse_forward_blocks(train_program) + + dist_main_prog, dist_startup_prog = partition( + train_program, start_program, dist_context) + global_block_ops = dist_main_prog.blocks[0].ops + global_block_ops = [op.type for op in global_block_ops] + sub_block_ops = dist_main_prog.blocks[1].ops + sub_block_ops = [op.type for op in sub_block_ops] + + self.assertTrue("c_allreduce_sum" in global_block_ops) + self.assertTrue("c_allreduce_sum" in sub_block_ops) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py index 52397f51321f585784b52c4a39bd707cf97f7dc4..96ab0aecb75850de51e58e6d6a26271e54f800b4 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py @@ -158,6 +158,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) + dist_context.block_state.parse_forward_blocks(complete_train_program) params_grads = parallelizer._generate_backward( complete_train_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py index 27de9f325063b05e8dd17b79c501d944a6e42d2b..29575dc76c2a1c6bdcdca4a42671f84196fe0a89 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py @@ -47,9 +47,7 @@ def get_dist_prog(train_program, complete_train_program = completer.complete_forward_annotation( train_program ) if complete_train_program is None else complete_train_program - - # parallelizer._apply_serial_forward_pass(complete_train_program, - # startup_program) + dist_context.block_state.parse_forward_blocks(complete_train_program) params_grads = parallelizer._generate_backward( complete_train_program, @@ -95,9 +93,9 @@ class TestDistributedTensor(unittest.TestCase): rank_id = 1 train_program = paddle.static.Program() startup_program = paddle.static.Program() - dist_main_prog, dist_startup_prog, _ = get_dist_prog( - train_program, startup_program, dist_context, rank_id, - complete_train_program) + dist_context = DistributedContext() + dist_main_prog, dist_startup_prog, complete_train_program = get_dist_prog( + train_program, startup_program, dist_context, rank_id, None) dist_context.dist_main_programs[rank_id] = dist_main_prog dist_context.dist_startup_programs[rank_id] = dist_startup_prog name = "layer_norm_1.tmp_2" diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py index 8869fd6a59e3772507aa6413afd7c872bab7a533..36a34815b681aa2de543061d62ea12493830d714 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py @@ -486,7 +486,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) - + dist_context.block_state.parse_forward_blocks(complete_train_program) params_grads = parallelizer._generate_backward( complete_train_program, startup_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py index deff2144411fccbff90a22f6639bc252da866d82..ef8780a020f33bb056fd2d596538fe44a5600492 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py @@ -53,6 +53,7 @@ def get_programs(annotated_func): completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) + dist_context.block_state.parse_forward_blocks(complete_train_program) rank_id = 3 dist_strategy = fleet.DistributedStrategy() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py index 01e62d886e2b7cb9fd7f71bae3b775e0698265ab..d0bed73f1b8c4e0585b096e3e1a21d49aee5a698 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py @@ -885,6 +885,7 @@ class TestGPTPartitioner(unittest.TestCase): completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) + dist_context.block_state.parse_forward_blocks(complete_train_program) # serial backward pass params_grads = parallelizer._generate_backward( diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py index 1d8938785924cfadfdb232aeeb42b7af045af09a..1278ed68d959e4f076fec2f6077c47437a12c300 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py @@ -160,7 +160,7 @@ def get_dist_prog(train_program, completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) - + dist_context.block_state.parse_forward_blocks(complete_train_program) if change_process_mesh: global PP_MESH_1 dist_context.get_tensor_dist_attr_for_program( diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py index 5a79d1f9514ab2c8ce1f6de7956653df463a1f9d..e84cb68f437caa848e43921fda19ccc4b722a821 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py @@ -120,7 +120,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) - + dist_context.block_state.parse_forward_blocks(complete_train_program) params_grads = parallelizer._generate_backward( complete_train_program, startup_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py index 6696a9d3006d2bdec61b14fc49a639060d5fa4cd..0636c083e54e00c6386fbbf7a4d93da222219287 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py @@ -136,7 +136,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) - + dist_context.block_state.parse_forward_blocks(complete_train_program) params_grads = parallelizer._generate_backward( complete_train_program, startup_program, @@ -269,6 +269,7 @@ class TestMLPReshard(unittest.TestCase): completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) + dist_context.block_state.parse_forward_blocks(complete_train_program) partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition( complete_train_program, startup_program, []) reshard(partitioned_main_prog, partitioned_startup_prog, rank_id,