diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index e303ce1216822b26bb58813c37239ae3e3fec043..408a1fdaafeefefb5065de53093da0da7a92587c 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -648,6 +648,9 @@ class Completer: self._dist_context.copy_dist_attr_from_graph_to_program() self._dist_context.clear_dist_info_for_graph() + # NOTE:[HighOrderGrad] update vars and ops distributed attribute in high order gradient + self.complete_high_order_grad_annotation(serial_main_program) + # Do the validation check and amend some completion self._dist_context.amend_dist_attr_for_program() @@ -655,6 +658,164 @@ class Completer: return serial_main_program + def complete_high_order_grad_annotation(self, serial_main_program): + """ + NOTE: + [HighOrderGrad] Complete the annotation of vars and ops only for high order gradient. + This function is temporary to support high order gradient, and will be removed in the future. + """ + + def _is_grad_var_name(name): + if "@GRAD" in name: + return True + return False + + def _get_op_by_id(ops, id): + for op in ops: + if op.desc.id() == id: + return op + return None + + ops = list(serial_main_program.global_block().ops) + vars = serial_main_program.global_block().vars + dist_op_context = self._dist_context.dist_op_context + grad_var_to_var = dist_op_context.grad_var_to_var + + appended_grad_times = 0 + for idx in range(0, len(ops)): + op = ops[idx] + if int(op.attr('op_role')) == int( + core.op_proto_and_checker_maker.OpRole.Forward): + continue + + if int(op.attr('op_role')) == int( + core.op_proto_and_checker_maker.OpRole.Backward) and int( + ops[idx - 1].attr('op_role')) == int( + core.op_proto_and_checker_maker.OpRole.Forward): + appended_grad_times += 1 + + # complete the annotation of grad op (xxx_grad op or sum op) + # xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id + grad_op = ops[idx] + if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id: + # TODO support the case where one forward op corresponding to multiple xxx_grad op + forward_op = _get_op_by_id( + ops, dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()]) + assert forward_op is not None + + fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( + forward_op) + fwd_op_process_mesh = fwd_op_dist_attr.process_mesh + grad_op_dist_attr = OperatorDistributedAttribute() + grad_op_dist_attr.process_mesh = fwd_op_process_mesh + + for input_name in grad_op.input_arg_names: + if input_name not in forward_op.input_arg_names and input_name not in forward_op.output_arg_names: + if input_name in grad_var_to_var[appended_grad_times]: + fwd_name = grad_var_to_var[appended_grad_times][ + input_name] + ref_dims_mapping = fwd_op_dist_attr.get_output_dims_mapping( + fwd_name) + else: + input_var = vars[input_name] + ref_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program( + input_var).dims_mapping + else: + if fwd_op_dist_attr.get_input_dims_mapping(input_name): + ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping( + input_name) + else: + ref_dims_mapping = fwd_op_dist_attr.get_output_dims_mapping( + input_name) + assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format( + input_name) + grad_op_dist_attr.set_input_dims_mapping(input_name, + ref_dims_mapping) + + for output_name in grad_op.output_arg_names: + assert output_name in grad_var_to_var[appended_grad_times] + fwd_name = grad_var_to_var[appended_grad_times][output_name] + ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping( + fwd_name) + # var + output_var = vars[output_name] + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.dims_mapping = ref_dims_mapping + tensor_dist_attr.process_mesh = fwd_op_process_mesh + self._dist_context.set_tensor_dist_attr_for_program( + output_var, tensor_dist_attr) + # op + grad_op_dist_attr.set_output_dims_mapping(output_name, + ref_dims_mapping) + + self._dist_context.set_op_dist_attr_for_program( + grad_op, grad_op_dist_attr) + + # grad ops that have not a corresponding mapping in grad_op_id_to_op_id + else: + + if grad_op.type == 'sum': + assert all(map(_is_grad_var_name, grad_op.input_arg_names)) + output_name = grad_op.output_arg_names[0] + assert output_name in grad_var_to_var[appended_grad_times], \ + "sum op's output '{}' has no corresponding var".format( + output_name) + ref_fwd_var_name = grad_var_to_var[appended_grad_times][ + output_name] + ref_fwd_var = vars[ref_fwd_var_name] + ref_fwd_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( + ref_fwd_var) + ref_fwd_dims_mapping = ref_fwd_dist_attr.dims_mapping + ref_fwd_process_mesh = ref_fwd_dist_attr.process_mesh + # output + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.dims_mapping = ref_fwd_dims_mapping + tensor_dist_attr.process_mesh = ref_fwd_process_mesh + output_var = vars[output_name] + self._dist_context.set_tensor_dist_attr_for_program( + output_var, tensor_dist_attr) + # op + grad_op_dist_attr = OperatorDistributedAttribute() + grad_op_dist_attr.process_mesh = ref_fwd_process_mesh + for var_name in grad_op.input_arg_names: + grad_op_dist_attr.set_input_dims_mapping( + var_name, ref_fwd_dims_mapping) + grad_op_dist_attr.set_output_dims_mapping( + output_name, ref_fwd_dims_mapping) + + elif grad_op.type == 'fill_zeros_like': + ref_var_name = grad_op.input_arg_names[0] + ref_var = vars[ref_var_name] + ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( + ref_var) + ref_dims_mapping = ref_dist_attr.dims_mapping + ref_process_mesh = ref_dist_attr.process_mesh + # output + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.dims_mapping = ref_dims_mapping + tensor_dist_attr.process_mesh = ref_process_mesh + output_var_name = grad_op.output_arg_names[0] + output_var = vars[output_var_name] + self._dist_context.set_tensor_dist_attr_for_program( + output_var, tensor_dist_attr) + # op + grad_op_dist_attr = OperatorDistributedAttribute() + grad_op_dist_attr.process_mesh = ref_process_mesh + grad_op_dist_attr.set_input_dims_mapping(ref_var_name, + ref_dims_mapping) + grad_op_dist_attr.set_output_dims_mapping(output_var_name, + ref_dims_mapping) + + elif grad_op.type in ['shape', 'fill_constant']: + continue + + else: + raise ValueError("got unexpect op [{}]".format( + str(grad_op.type))) + + self._dist_context.set_op_dist_attr_for_program( + grad_op, grad_op_dist_attr) + def complete_backward_annotation(self, serial_main_program): """Complete the annotation of vars and ops in the backward phase for parallel program.""" @@ -689,6 +850,8 @@ class Completer: ops = list(serial_main_program.global_block().ops) vars = serial_main_program.global_block().vars dist_op_context = self._dist_context.dist_op_context + grad_var_to_var = dist_op_context.grad_var_to_var[len( + dist_op_context.grad_var_to_var)] for idx in range(first_backward_op_idx, len(ops)): @@ -765,102 +928,111 @@ class Completer: grad_op, grad_op_dist_attr) continue - # op dist attr - forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( + fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( forward_op) - forward_op_process_mesh = forward_op_dist_attr.process_mesh + fwd_op_process_mesh = fwd_op_dist_attr.process_mesh grad_op_dist_attr = OperatorDistributedAttribute() - grad_op_dist_attr.process_mesh = forward_op_process_mesh + grad_op_dist_attr.process_mesh = fwd_op_process_mesh - # var for input_name in grad_op.input_arg_names: - input_var = vars[input_name] - ref_dims_mapping = None - if "@GRAD" in input_name: - forward_name = _get_forward_varname_from_grad_varname( - input_name) - ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping( - forward_name) + if input_name not in forward_op.input_arg_names and input_name not in forward_op.output_arg_names: + if input_name in grad_var_to_var: + fwd_name = grad_var_to_var[input_name] + ref_dims_mapping = fwd_op_dist_attr.get_output_dims_mapping( + fwd_name) + else: + input_var = vars[input_name] + ref_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program( + input_var).dims_mapping else: - if forward_op_dist_attr.get_input_dims_mapping( - input_name): - ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping( + if fwd_op_dist_attr.get_input_dims_mapping(input_name): + ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping( input_name) else: - ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping( + ref_dims_mapping = fwd_op_dist_attr.get_output_dims_mapping( input_name) - assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format( - input_var.name) + input_name) grad_op_dist_attr.set_input_dims_mapping(input_name, ref_dims_mapping) - for output_name in grad_op.desc.output_names(): - assert len(grad_op.desc.output(output_name)) in [0, 1] - if _is_grad_var_name(output_name): - input_name = _get_forward_varname_from_grad_varname( - output_name) - else: - assert grad_op.type in [ - "cast", "c_identity", "c_allreduce_sum" - ] - input_name = "X" - assert input_name in forward_op.desc.input_names( - ), "var [{}] in op [{}]'s output but could not find [{}] in its forward op".format( - output_name, grad_op.type, input_name) - if len(grad_op.desc.output(output_name)) == 1: - # tensor dist attr - output_var = vars[grad_op.desc.output(output_name)[0]] - forward_name = _get_forward_varname_from_grad_varname( - output_var.name) - ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping( - forward_name) - - output_var_dist_attr = TensorDistributedAttribute() - output_var_dist_attr.dims_mapping = ref_dims_mapping - output_var_dist_attr.process_mesh = forward_op_process_mesh - self._dist_context.set_tensor_dist_attr_for_program( - output_var, output_var_dist_attr) - - grad_op_dist_attr.set_output_dims_mapping( - output_var.name, ref_dims_mapping) + for output_name in grad_op.output_arg_names: + assert output_name in grad_var_to_var + fwd_name = grad_var_to_var[output_name] + ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping( + fwd_name) + # var + output_var = vars[output_name] + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.dims_mapping = ref_dims_mapping + tensor_dist_attr.process_mesh = fwd_op_process_mesh + self._dist_context.set_tensor_dist_attr_for_program( + output_var, tensor_dist_attr) + # op + grad_op_dist_attr.set_output_dims_mapping(output_name, + ref_dims_mapping) self._dist_context.set_op_dist_attr_for_program( grad_op, grad_op_dist_attr) - # only sum op for merge mutiple version grad has no a corresponding mapping in grad_op_id_to_op_id + # grad ops that have not a corresponding mapping in grad_op_id_to_op_id else: - assert grad_op.type == "sum", "got unexpect op [{}]".format( - str(grad_op.type)) - assert all(map(_is_grad_var_name, grad_op.input_arg_names)) - assert len(grad_op.output_arg_names) == 1 - - ref_forward_var_name = _get_forward_varname_from_grad_varname( - grad_op.output_arg_names[0]) - forward_var = vars[ref_forward_var_name] - ref_forward_var_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program( - forward_var).dims_mapping - ref_forward_var_process_mesh = self._dist_context.get_tensor_dist_attr_for_program( - forward_var).process_mesh + if grad_op.type == 'sum': + assert all(map(_is_grad_var_name, grad_op.input_arg_names)) + output_name = grad_op.output_arg_names[0] + assert output_name in grad_var_to_var, "sum op's output '{}' has no corresponding var".format( + output_name) + ref_fwd_var_name = grad_var_to_var[output_name] + ref_fwd_var = vars[ref_fwd_var_name] + ref_fwd_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( + ref_fwd_var) + ref_fwd_dims_mapping = ref_fwd_dist_attr.dims_mapping + ref_fwd_process_mesh = ref_fwd_dist_attr.process_mesh + + # output + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.dims_mapping = ref_fwd_dims_mapping + tensor_dist_attr.process_mesh = ref_fwd_process_mesh + output_var = vars[output_name] + self._dist_context.set_tensor_dist_attr_for_program( + output_var, tensor_dist_attr) - # output - tensor_dist_attr = TensorDistributedAttribute() - tensor_dist_attr.dims_mapping = ref_forward_var_dims_mapping - tensor_dist_attr.process_mesh = ref_forward_var_process_mesh - self._dist_context.set_tensor_dist_attr_for_program( - vars[grad_op.output_arg_names[0]], tensor_dist_attr) + # op + grad_op_dist_attr = OperatorDistributedAttribute() + grad_op_dist_attr.process_mesh = ref_fwd_process_mesh + for var_name in grad_op.input_arg_names: + grad_op_dist_attr.set_input_dims_mapping( + var_name, ref_fwd_dims_mapping) + grad_op_dist_attr.set_output_dims_mapping( + output_name, ref_fwd_dims_mapping) + + elif grad_op.type == 'fill_zeros_like': + ref_var_name = grad_op.input_arg_names[0] + ref_var = vars[ref_var_name] + ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( + ref_var) + ref_dims_mapping = ref_dist_attr.dims_mapping + ref_process_mesh = ref_dist_attr.process_mesh + # output + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.dims_mapping = ref_dims_mapping + tensor_dist_attr.process_mesh = ref_process_mesh + output_var_name = grad_op.output_arg_names[0] + output_var = vars[output_var_name] + self._dist_context.set_tensor_dist_attr_for_program( + output_var, tensor_dist_attr) + # op + grad_op_dist_attr = OperatorDistributedAttribute() + grad_op_dist_attr.process_mesh = ref_process_mesh + grad_op_dist_attr.set_input_dims_mapping(ref_var_name, + ref_dims_mapping) + grad_op_dist_attr.set_output_dims_mapping(output_var_name, + ref_dims_mapping) + + else: + raise ValueError("got unexpect op [{}]".format( + str(grad_op.type))) - # op - grad_op_dist_attr = OperatorDistributedAttribute() - grad_op_dist_attr.process_mesh = ref_forward_var_process_mesh - for var_name in grad_op.input_arg_names: - assert _get_forward_varname_from_grad_varname( - var_name) == ref_forward_var_name - grad_op_dist_attr.set_input_dims_mapping( - var_name, ref_forward_var_dims_mapping) - - grad_op_dist_attr.set_output_dims_mapping( - grad_op.output_arg_names[0], ref_forward_var_dims_mapping) self._dist_context.set_op_dist_attr_for_program( grad_op, grad_op_dist_attr) diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index 2807c46540ab1e52f7490c850faa34eac00c04db..7e245358d4bccaad4b6ffeb0648350459d6212e9 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -120,6 +120,11 @@ class DistributedContext: def dist_startup_programs(self): return self._dist_startup_programs + @property + def is_annotation(self): + return len(self._dist_tensors_for_program) or len( + self._dist_ops_for_program) + def add_process_mesh(self, process_mesh): assert isinstance(process_mesh, ProcessMesh), \ 'The type of dim_mapping must be ProcessMesh.' @@ -577,6 +582,7 @@ class DistributedOperatorContext: self._cur_src_op = None self._cur_dist_attr = None self.grad_op_id_to_op_id = {} + self.grad_var_to_var = defaultdict(dict) self._work_block = None self.already_init_sync_vars = set() self.varname_mapping = None diff --git a/python/paddle/distributed/auto_parallel/dist_loader.py b/python/paddle/distributed/auto_parallel/dist_loader.py index 9449b52952cd844439d4a3254820cb9ca80a5a8a..cc08bc1a901b78e2a487aae86eaffff710c2ac95 100644 --- a/python/paddle/distributed/auto_parallel/dist_loader.py +++ b/python/paddle/distributed/auto_parallel/dist_loader.py @@ -16,6 +16,7 @@ import abc import numpy as np import paddle from .utils import to_list +from paddle.fluid.layers.utils import flatten from paddle.io import DataLoader, DistributedBatchSampler @@ -56,16 +57,17 @@ class NonIterableGeneratorLoader(DistributedDataLoader): data_parallel_world_size=None, data_parallel_rank=None, drop_last=False, - inputs=[]): + sample_generator=True): self.feed_list = feed_list self.places = places self.steps_per_epoch = steps_per_epoch + self._sample_generator = sample_generator + super(NonIterableGeneratorLoader, self).__init__( dataset, batch_size, epochs, data_parallel_world_size, data_parallel_rank, drop_last) self._inner_dataloader = self._create_inner_dataloader() self._steps = self._infer_steps() - self._inputs = inputs def __iter__(self): self._cur_step = 0 @@ -91,27 +93,28 @@ class NonIterableGeneratorLoader(DistributedDataLoader): return steps_per_epoch def _create_inner_dataloader(self): - def data_generator(): + def sample_data_generator(): batch_data = None for step, data in enumerate(self.dataset): - if not isinstance(data, list): - data = to_list(data) - - if self.batch_size == 1: - yield data + data = flatten(data) + if batch_data is None: + batch_data = [[] for i in range(len(data))] + for idx in range(len(data)): + batch_data[idx].append(data[idx]) + if (step + 1) % self.batch_size == 0: + yield batch_data batch_data = None - else: - if batch_data is None: - batch_data = [[] for i in range(len(data))] - - for idx in range(len(data)): - batch_data[idx].append(data[idx]) - if (step + 1) % self.batch_size == 0: - yield batch_data - batch_data = None + def batch_data_generator(): + for data in self.dataset: + data = flatten(data) + yield data dataloader = paddle.fluid.io.DataLoader.from_generator( feed_list=self.feed_list, capacity=70, iterable=False) - dataloader.set_batch_generator(data_generator, self.places) + if self._sample_generator: + dataloader.set_batch_generator(sample_data_generator, self.places) + else: + dataloader.set_batch_generator(batch_data_generator, self.places) + return dataloader diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index a5fec789dfb37177b81a7bd04c973f8ffea39865..2cd841ef80979bb89b90460fbb106f464d74145f 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -17,18 +17,22 @@ import logging from collections import defaultdict import paddle +import paddle.distributed.auto_parallel as auto + from paddle import fluid from paddle.io import Dataset from paddle.metric import Metric from paddle.static import InputSpec from paddle.fluid import core from paddle.fluid import program_guard +from paddle.fluid.layers.utils import flatten +from paddle.fluid.executor import global_scope from paddle.fluid.backward import append_backward from paddle.fluid.framework import Operator from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.dygraph.parallel import ParallelEnv -from paddle.distributed.passes import new_pass, PassContext from paddle.distributed.utils import get_logger +from paddle.distributed.passes import new_pass, PassContext from .mapper import mapping from .cluster import Cluster @@ -61,6 +65,12 @@ class Engine: self.strategy = strategy self._executor = None + self._cur_rank = paddle.distributed.get_rank() + self._nranks = paddle.distributed.get_world_size() + self._saver = DistributedSaver() + self._logger = get_logger(logging.INFO) + + self._default_strategy = None self._orig_main_prog = fluid.default_main_program() self._orig_startup_prog = fluid.default_startup_program() self._orig_dist_context = get_default_distributed_context() @@ -70,9 +80,6 @@ class Engine: self._dist_startup_progs = defaultdict(dict) # dist startup programs self._dist_contexts = {} self._pass_contexts = {} - self._cur_rank = paddle.distributed.get_rank() - self._logger = get_logger(logging.INFO) - self._saver = DistributedSaver() self._feed_vars = {} self._fetch_vars = {} @@ -86,13 +93,11 @@ class Engine: # TODO: check loss type self._loss = loss self._metrics = to_list(metrics) - for m in ['train', 'predict']: - self.mode = m - self._build(m) # build forward program - self._plan(m) # completion & planner - self._parallel(m, all_ranks) # parallel - self._initialize(m) # init comm and startup program - self.mode = mode + self._mode = mode + self._build(mode) # build forward program + self._plan(mode) # completion & planner + self._parallel(mode, all_ranks) # parallel + self._initialize(mode) # init comm and startup program def _build(self, mode): serial_main_prog = self._serial_main_progs.get(mode, None) @@ -112,10 +117,16 @@ class Engine: if mode != "predict" and self._loss: losses = to_list(self._loss(*(outputs + labels))) + default_ctx = get_default_distributed_context() + if not default_ctx.is_annotation or self._default_strategy: + inputs = [self._set_data_parallel(var) for var in inputs] + labels = [self._set_data_parallel(var) for var in labels] + + # print(serial_main_prog) self._feed_vars[mode] = {"inputs": inputs, "labels": labels} self._fetch_vars[mode] = { - "outputs": outputs, + "outputs": flatten(outputs), "loss": losses, "metrics": metrics } @@ -128,6 +139,12 @@ class Engine: self._pass_contexts[mode] = PassContext() def _plan(self, mode): + + # NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need + # dependency of backward-forward ops in forward completition. + defualt_ctx = get_default_distributed_context() + self._dist_contexts[mode]._dist_op_context = defualt_ctx.dist_op_context + # Complete the distributed annotation serial_main_prog = self._serial_main_progs[mode] self._completer = Completer(self._dist_contexts[mode]) @@ -147,13 +164,14 @@ class Engine: self._parallel_program(mode, rank) def _initialize(self, mode): - # Traverse different rank programs and traverse each op of them, - # instantiate communication by process_mapping. - all_process_groups = get_all_process_groups() - for process_group in all_process_groups: - if self._cur_rank not in process_group.ranks: - continue - process_group.instantiate() + if self._nranks > 1: + # Traverse different rank programs and traverse each op of them, + # instantiate communication by process_mapping. + all_process_groups = get_all_process_groups() + for process_group in all_process_groups: + if self._cur_rank not in process_group.ranks: + continue + process_group.instantiate() # initialize self._place = _get_device() @@ -161,8 +179,16 @@ class Engine: self._place = fluid.CUDAPlace(ParallelEnv().dev_id) if self._executor is None: self._executor = paddle.static.Executor(self._place) - dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank] - self._executor.run(dist_startup_prog) + uninitialized = [] + dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank] + for var in dist_startup_prog.list_vars(): + scope_var = global_scope().find_var(var.name) + if scope_var and scope_var.get_tensor()._is_initialized(): + continue + uninitialized.append(var) + if uninitialized: + prune_startup_prog = dist_startup_prog._prune(uninitialized) + self._executor.run(prune_startup_prog) def _parallel_program(self, mode, rank): serial_main_program = self._serial_main_progs[mode] @@ -246,12 +272,13 @@ class Engine: if config["use_pure_fp16"]: config["base_opt"] = self._optimizer auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) - auto_parallel_fp16_pass.apply( - [main_program], [startup_program], self._pass_context) + auto_parallel_fp16_pass.apply([main_program], + [startup_program], + self._pass_contexts[self.mode]) else: auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) auto_parallel_amp_pass.apply([main_program], [startup_program], - self._pass_context) + self._pass_contexts[self.mode]) # apply recompute pass if self.strategy.recompute: @@ -288,18 +315,26 @@ class Engine: [main_program], [startup_program], self._pass_contexts[self.mode]) - def fit(self, train_data, batch_size=1, epochs=1, steps_per_epoch=None): + def fit(self, + train_data, + batch_size=1, + epochs=1, + steps_per_epoch=None, + use_program_cache=False, + return_numpy=True, + sample_generator=True): # TODO: callbacks # TODO: evaluate after training self.mode = 'train' - assert isinstance(train_data, Dataset) - train_dataloader = self._create_dataloader(train_data, batch_size, - epochs, steps_per_epoch) + assert self.mode in self._dist_main_progs, "train model is not ready, please call `engine.prepare(mode='train')` first." + train_dataloader = self._create_dataloader( + train_data, batch_size, epochs, steps_per_epoch, sample_generator) outputs = [] for epoch in range(epochs): for step, data in enumerate(train_dataloader): - logs, loss = self._train_step(data) + logs, loss = self._train_step(data, use_program_cache, + return_numpy) outputs.append(loss) train_logs = { "train_" + name: val @@ -308,14 +343,35 @@ class Engine: self._logger.info(train_logs) return outputs + def evaluate(self, + eval_data, + batch_size=1, + use_program_cache=False, + return_numpy=True, + sample_generator=True): + self.mode = 'eval' + assert self.mode in self._dist_main_progs, "eval model is not ready, please call `engine.prepare(mode='eval')` first." + eval_dataloader = self._create_dataloader( + eval_data, batch_size, sample_generator=sample_generator) + + outputs = [] + for step, data in enumerate(eval_dataloader): + logs, outs = self._eval_step(data, use_program_cache, return_numpy) + outputs.append(outs) + predict_logs = {"eval_" + name: val for name, val in logs.items()} + self._logger.info(predict_logs) + return outputs + def predict(self, test_data, batch_size=1, use_program_cache=False, - return_numpy=True): + return_numpy=True, + sample_generator=True): self.mode = 'predict' - # TODO: need check dataset - test_dataloader = self._create_dataloader(test_data, batch_size) + assert self.mode in self._dist_main_progs, "predict model is not ready, please call `engine.prepare(mode='predict')` first." + test_dataloader = self._create_dataloader( + test_data, batch_size, sample_generator=sample_generator) outputs = [] for step, data in enumerate(test_dataloader): @@ -329,19 +385,39 @@ class Engine: self._logger.info(predict_logs) return outputs - def _train_step(self, data): + def _train_step(self, data, use_program_cache=False, return_numpy=True): logs = {} dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] fetch_var = self._fetch_vars[self.mode]["loss"][0] if fetch_var.name not in dist_main_prog.global_block().vars: - loss = self._executor.run(dist_main_prog) + loss = self._executor.run(dist_main_prog, + use_program_cache=use_program_cache) logs["loss"] = None else: loss = self._executor.run(dist_main_prog, - fetch_list=to_list(fetch_var)) + fetch_list=to_list(fetch_var), + use_program_cache=use_program_cache, + return_numpy=return_numpy) logs["loss"] = loss return logs, loss + def _eval_step(self, data, use_program_cache=False, return_numpy=True): + logs = {} + dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] + fetch_var = self._fetch_vars[self.mode]["loss"][0] + + if fetch_var.name not in dist_main_prog.global_block().vars: + outs = self._executor.run(dist_main_prog, + use_program_cache=use_program_cache) + logs["loss"] = outs + else: + outs = self._executor.run(dist_main_prog, + fetch_list=fetch_var, + use_program_cache=use_program_cache, + return_numpy=return_numpy) + logs["loss"] = outs + return logs, outs + def _predict_step(self, data, use_program_cache=False, return_numpy=True): logs = {} dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] @@ -366,7 +442,8 @@ class Engine: dataset, batch_size, epochs=1, - steps_per_epoch=None): + steps_per_epoch=None, + sample_generator=True): feed_list = self._feed_vars[self.mode]["inputs"] + self._feed_vars[ self.mode]["labels"] dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] @@ -376,9 +453,12 @@ class Engine: serial_main_prog = self._serial_main_progs[self.mode] serial_main_block = serial_main_prog.global_block() op_size = len(dist_main_block.ops) + if dist_main_block.ops[0].type == 'create_py_reader': + op_size -= 3 + for _ in range(3): + dist_main_block._remove_op(0, sync=False) places = paddle.static.cuda_places() with fluid.program_guard(dist_main_prog, dist_startup_prog): - inputs = self._feed_vars[self.mode]["inputs"] dataloader = NonIterableGeneratorLoader( dataset, feed_list, @@ -386,7 +466,7 @@ class Engine: batch_size, epochs, steps_per_epoch, - inputs=inputs) + sample_generator=sample_generator) new_op_size = len(dist_main_block.ops) for _ in range(new_op_size - 1, op_size - 1, -1): op = dist_main_block.ops[new_op_size - 1] @@ -396,7 +476,7 @@ class Engine: dist_main_block, new_op_desc, type=new_op_desc.type()) dist_main_block.ops.insert(0, new_op) for in_name in new_op.input_arg_names: - if in_name == "lod_tensor_blocking_queue_0": + if "lod_tensor_blocking_queue" in in_name: continue if in_name not in dist_main_block.vars: in_var = serial_main_block._var_recursive(in_name) @@ -424,6 +504,27 @@ class Engine: .format(i, spec)) return specs + def _set_data_parallel(self, var): + if self._nranks == 1: + self._default_strategy = 'serial' + auto.shard_tensor( + var, + dist_attr={ + "process_mesh": [0], + "dims_mapping": [-1 for _ in range(len(var.shape))] + }) + else: + self._default_strategy = 'dp' + auto.shard_tensor( + var, + dist_attr={ + "process_mesh": list(range(self._nranks)), + "dims_mapping": + [0] + [-1 for _ in range(len(var.shape) - 1)] + }) + + return var + def save(self, path, training=True, mode=None): if not mode: mode = self.mode @@ -459,3 +560,35 @@ class Engine: dist_context = self._dist_contexts[mode] self._saver.load(path, dist_main_prog, dist_context, strict, load_optimizer) + + @property + def mode(self): + return self._mode + + @mode.setter + def mode(self, mode): + self._mode = mode + + @property + def metrics(self): + return self._metrics + + @property + def main_program(self): + return self._dist_main_progs[self.mode][self._cur_rank] + + @property + def startup_program(self): + return self._dist_startup_progs[self.mode][self._cur_rank] + + @property + def dist_context(self): + return self._dist_contexts[self.mode] + + @property + def serial_main_program(self): + return self._serial_main_progs[self.mode] + + @property + def serial_startup_program(self): + return self._serial_startup_progs[self.mode] diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index 9fb200f4d2db9353ba1f5419bf06aca7488640d3..4795050d15dcc0b60328e0a5be97bac46cfdea88 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -53,6 +53,10 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr + input_names = op_desc.input_names() + xshape_arg_names = [] + if "XShape" in input_names: + xshape_arg_names = op_desc.input("XShape") for arg_name in op_desc.input_arg_names(): serial_tensor = dist_op.get_serial_input(arg_name) dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) @@ -63,10 +67,18 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): # continue # if len(dims_mapping) < 1: # continue - if len(dims_mapping) > 1: - for mapping in dims_mapping[1:]: - if mapping != -1: - return False + if arg_name not in xshape_arg_names: + if len(dims_mapping) > 1: + for mapping in dims_mapping[1:]: + if mapping != -1: + return False + else: + if dims_mapping[0] != -1: + return False + if len(dims_mapping) > 2: + for mapping in dims_mapping[2:]: + if mapping != -1: + return False return True def is_output_compatible(self, dist_op): @@ -105,17 +117,31 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): op_dist_attr = dist_op.dist_attr batch_dim_mappings = [] # Check input compatibility + input_names = op_desc.input_names() + xshape_arg_names = [] + if "XShape" in input_names: + xshape_arg_names = op_desc.input("XShape") for arg_name in op_desc.input_arg_names(): serial_tensor = dist_op.get_serial_input(arg_name) if serial_tensor.is_parameter: continue dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) - if len(dims_mapping) > 1: - for mapping in dims_mapping[1:]: - if mapping != -1: - return False - if len(dims_mapping) >= 1: - batch_dim_mappings.append(dims_mapping[0]) + if arg_name not in xshape_arg_names: + if len(dims_mapping) > 1: + for mapping in dims_mapping[1:]: + if mapping != -1: + return False + if len(dims_mapping) >= 1: + batch_dim_mappings.append(dims_mapping[0]) + else: + if dims_mapping[0] != -1: + return False + if len(dims_mapping) > 2: + for mapping in dims_mapping[2:]: + if mapping != -1: + return False + if len(dims_mapping) >= 2: + batch_dim_mappings.append(dims_mapping[1]) # Check output compatibility output_names = op_desc.output_names() @@ -160,24 +186,39 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): or op_desc.type() == "slice" \ or op_desc.type() == "while": return False + + input_names = op_desc.input_names() + input_xshape_arg_names = [] + if "XShape" in input_names: + input_xshape_arg_names = op_desc.input("XShape") + output_names = op_desc.output_names() - xshape_arg_names = [] + output_xshape_arg_names = [] if "XShape" in output_names: - xshape_arg_names = op_desc.output("XShape") + output_xshape_arg_names = op_desc.output("XShape") + batch_dim_mappings = [] for arg_name in op_desc.input_arg_names(): serial_tensor = dist_op.get_serial_input(arg_name) if serial_tensor.is_parameter: continue dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) - if len(dims_mapping) >= 1: - batch_dim_mappings.append(dims_mapping[0]) + if arg_name not in input_xshape_arg_names: + if len(dims_mapping) >= 1: + batch_dim_mappings.append(dims_mapping[0]) + else: + batch_dim_mappings.append(dims_mapping[1]) for arg_name in op_desc.output_arg_names(): + if op_desc.type() == "fill_zeros_like": + input_tensor = dist_op.get_serial_input(op_desc.input_arg_names( + )[0]) + if input_tensor.is_parameter: + continue serial_tensor = dist_op.get_serial_output(arg_name) if serial_tensor.is_parameter: continue dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) - if arg_name not in xshape_arg_names: + if arg_name not in output_xshape_arg_names: if len(dims_mapping) >= 1: batch_dim_mappings.append(dims_mapping[0]) else: @@ -194,16 +235,27 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): if serial_tensor.is_parameter: continue dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) - if len(dims_mapping - ) >= 1 and compatible_dim_mapping != dims_mapping[0]: - dims_mapping[0] = compatible_dim_mapping - changed = True + if arg_name not in input_xshape_arg_names: + if len(dims_mapping) >= 1 and \ + compatible_dim_mapping != dims_mapping[0]: + dims_mapping[0] = compatible_dim_mapping + changed = True + else: + if len(dims_mapping) >= 2 and \ + compatible_dim_mapping != dims_mapping[1]: + dims_mapping[1] = compatible_dim_mapping + changed = True for arg_name in op_desc.output_arg_names(): + if op_desc.type() == "fill_zeros_like": + input_tensor = dist_op.get_serial_input(op_desc.input_arg_names( + )[0]) + if input_tensor.is_parameter: + continue serial_tensor = dist_op.get_serial_output(arg_name) if serial_tensor.is_parameter: continue dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) - if arg_name not in xshape_arg_names: + if arg_name not in output_xshape_arg_names: if len(dims_mapping ) >= 1 and compatible_dim_mapping != dims_mapping[0]: dims_mapping[0] = compatible_dim_mapping @@ -371,30 +423,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): if need_gradient_allreduce: allreduce_vars = [] - for input_name in backward_op.desc.input_names(): - for varname in backward_op.desc.input(input_name): - if "@GRAD" not in varname and is_parameter_related( - varname, main_block): - # NOTE: When amp and recompute pass are effective at the same time, - # if a parameter is casted and recomputed, the 'parameter@GARD' can not - # be found in the grad_op's output. - if "subprog_" in varname: - varname = varname[:varname.index(".subprog_")] - - assert len( - backward_op.desc.input(input_name) - ) == 1, "parameter input to grad op should be length 1, but got [{}]".format( - backward_op.desc.input(input_name)) - - assert varname + "@GRAD" in backward_op.desc.output_arg_names( - ), "parameter's grad [{}] not found in the grad op's output".format( - varname + "@GRAD") - assert len( - backward_op.desc.output(input_name + "@GRAD") - ) == 1, "parameter grad of grad op should be length 1, but got [{}]".format( - backward_op.desc.output(input_name + "@GRAD")) - allreduce_vars.append( - backward_op.desc.output(input_name + "@GRAD")[0]) + for output_name in backward_op.desc.output_names(): + for varname in backward_op.desc.output(output_name): + if varname in kwargs["grad_var_to_var"]: + fwd_name = kwargs["grad_var_to_var"][varname] + if fwd_name not in main_block.vars: + continue + if is_parameter_related(fwd_name, main_block): + allreduce_vars.append(varname) if len(allreduce_vars) > 0: diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index c03ef9c06d80fd6a9f49c4bcbd03864c62d4b949..fe091cd08b72b111707429f9f8439fa3d008f32d 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -25,7 +25,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext, Di from .dist_attribute import OperatorDistributedAttribute from .process_group import new_process_group from .utils import set_dist_op_desc_original_id -from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op +from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_loss_op from .operators.common import BACKWARD_ONLY_DIST_OPS __varname_not_in_block__ = ["lod_tensor_blocking_queue_0"] @@ -198,15 +198,29 @@ class Partitioner(object): dist_op_context = self._dist_context.dist_op_context serial_ops = ref_block.ops + last_fwd_op_idx = -1 + for idx, op in enumerate(ref_block.ops): + if is_loss_op(op): + last_fwd_op_idx = idx + break + + if last_fwd_op_idx == -1: + last_fwd_op_idx = len(ref_block.ops) + # init mapping forward_op_id2forward_op = {} for idx in range(len(serial_ops)): - if is_forward_op(serial_ops[idx]): + if idx <= last_fwd_op_idx: forward_op_id2forward_op[serial_ops[idx].desc.id( )] = serial_ops[idx] + appended_grad_times = 0 # partiiton - for op in serial_ops: + for idx, op in enumerate(serial_ops): + + if is_backward_op(op) and (is_forward_op(serial_ops[idx - 1]) or + is_loss_op(serial_ops[idx - 1])): + appended_grad_times += 1 # partititon input variables for serial_input_varname in op.desc.input_arg_names(): @@ -244,8 +258,11 @@ class Partitioner(object): kinputs, koutputs = dist_op_context.prepare_context(op) dist_op_backward_impl = _get_dist_op_backward_implement( op, self._dist_context, forward_op_id2forward_op) - dist_op_backward_impl.backward(self._dist_context, **kinputs, - **koutputs) + grad_var_to_var = self._dist_context.dist_op_context.grad_var_to_var[ + appended_grad_times] + dist_op_backward_impl.backward( + self._dist_context, **kinputs, **koutputs, + **{"grad_var_to_var": grad_var_to_var}) else: raise NotImplementedError( "partitioner only support forward op and backward op, but got {}". diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index fc85cd04d4010ed826ea198f0c4b44a7c461ea86..9c40034498dbc504cd106fef35a886fd1054990a 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -996,69 +996,87 @@ def set_grad_var_shape(program, dist_context): block = program.global_block() vars = block.vars - for op in block.ops: + appended_grad_times = 0 + grad_var_to_var = dist_context.dist_op_context.grad_var_to_var + + for idx, op in enumerate(block.ops): + + if int(op.attr('op_role')) != int(OpRole.Backward): + continue + + if int(block.ops[idx-1].attr('op_role')) == int(OpRole.Forward) or \ + int(block.ops[idx-1].attr('op_role')) == 257: + appended_grad_times += 1 if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: break - if op.type in ["sum", "concat"]: + if op.type in ["sum", "concat", "shape"]: continue - if int(op.attr('op_role')) == int(OpRole.Backward): - op_dist_attr = dist_context.get_op_dist_attr_for_program(op) - assert op_dist_attr is not None - for var_name in op.output_arg_names: - if "@GRAD" not in var_name: - continue + op_dist_attr = dist_context.get_op_dist_attr_for_program(op) + assert op_dist_attr is not None + + for var_name in op.output_arg_names: + + if "@GRAD" not in var_name: + continue + if var_name in grad_var_to_var[appended_grad_times]: + forward_var_name = grad_var_to_var[appended_grad_times][ + var_name] + else: forward_var_name = var_name[:var_name.find("@GRAD")] - if op.type in [ - "c_allreduce_sum", "c_identity", "scale", "cast" - ]: - forward_var_name = op.input_arg_names[0] - elif op.type == "matmul_v2_grad": - forward_var_name = None - for output_name in op.output_names: - if var_name in op.output(output_name): - assert "@GRAD" in output_name - input_name = output_name[:output_name.find("@GRAD")] - assert len(op.input(input_name)) == 1 - forward_var_name = op.input(input_name)[0] - assert forward_var_name is not None - - need_set_shape_list = [ - "reshape2_grad", "softmax_with_cross_entropy_grad", - "transpose2_grad", "softmax_grad", "cross_entropy_grad2", - "dropout_grad" - ] - forward_list = [ - "reshape2", "softmax_with_cross_entropy", "transpose2", - "softmax", "cross_entropy2", "dropout" - ] - if op.type in need_set_shape_list: - for forward_op in block.ops: - assert int(forward_op.attr('op_role')) != int( - OpRole.Backward) - idx = need_set_shape_list.index(op.type) - forward_op_name = forward_list[idx] - if forward_op.type == forward_op_name and forward_var_name in forward_op.input_arg_names: - op_dist_attr = dist_context.get_op_dist_attr_for_program( - forward_op) - break - - forward_input_dist_attr = op_dist_attr.get_input_dist_attr( - forward_var_name) - assert forward_input_dist_attr is not None, f"{forward_var_name, str(op)}" - forward_var = vars[forward_var_name] - forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program( - forward_var) - assert forward_var_dist_attr is not None - grad_var = vars[var_name] - ref_shape = infer_shape(block, forward_var, - forward_var_dist_attr, - forward_input_dist_attr) - - if list(grad_var.shape) != ref_shape: - grad_var.desc.set_shape(ref_shape) + + if op.type in [ + "c_allreduce_sum", "c_identity", "scale", "cast", + "fill_zeros_like" + ]: + forward_var_name = op.input_arg_names[0] + elif op.type == "matmul_v2_grad": + forward_var_name = None + for output_name in op.output_names: + if var_name in op.output(output_name): + assert "@GRAD" in output_name + input_name = output_name[:output_name.find("@GRAD")] + assert len(op.input(input_name)) == 1 + forward_var_name = op.input(input_name)[0] + assert forward_var_name is not None + + need_set_shape_list = [ + "reshape2_grad", "softmax_with_cross_entropy_grad", + "transpose2_grad", "softmax_grad", "cross_entropy_grad2", + "dropout_grad", "tanh_grad", "slice", "assign", + "matmul_v2_triple_grad", "elementwise_add_triple_grad", + "fill_constant", "sqrt_grad" + ] + forward_list = [ + "reshape2", "softmax_with_cross_entropy", "transpose2", + "softmax", "cross_entropy2", "dropout", "tanh", + ["slice_grad", "c_allgather"], "assign", "matmul_v2_grad_grad", + "elementwise_add_grad_grad", "shape", "sqrt" + ] + if op.type in need_set_shape_list: + for forward_op in block.ops: + idx = need_set_shape_list.index(op.type) + forward_op_name = forward_list[idx] + if forward_op.type in forward_op_name and forward_var_name in forward_op.input_arg_names: + op_dist_attr = dist_context.get_op_dist_attr_for_program( + forward_op) + break + + forward_input_dist_attr = op_dist_attr.get_input_dist_attr( + forward_var_name) + assert forward_input_dist_attr is not None, f"{forward_var_name, str(op)}" + forward_var = vars[forward_var_name] + forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program( + forward_var) + assert forward_var_dist_attr is not None + grad_var = vars[var_name] + ref_shape = infer_shape(block, forward_var, forward_var_dist_attr, + forward_input_dist_attr) + + if list(grad_var.shape) != ref_shape: + grad_var.desc.set_shape(ref_shape) OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 5fdbbb4d7ed18981364864cb7b721d4cf96d6faa..bc53c130286aa96e68a1485ef4c203e45e27d878 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -478,12 +478,16 @@ def _accumulate_gradients_by_add_ops_(var_name, renamed_vars[var_name] = [var_name] -def _addup_repetitive_outputs_(op_descs, block_idx): +def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None): """ In backward part, an variable may be the output of more than one ops. And one op may yield its multiple outputs to the same variable. In these cases, the variable should be the accumulation of all the outputs. `sum_op`s are added to implement the accumulate. + + Args: + grad_var_to_var(dict): used to build the mapping between grad var name and forward var name. + Only for auto parallel. """ _MAX_ADD_NUM_ = framework._global_flags()['FLAGS_max_inplace_grad_add'] #pending_sum_ops = [] @@ -531,6 +535,13 @@ def _addup_repetitive_outputs_(op_descs, block_idx): new_name = var_name + "@RENAME@block" + str(block_idx) + "@" + \ str(var_rename_count[var_name]) var_rename_count[var_name] += 1 + # Build the mapping between the new_name and var_name (Only for auto parallel) + if grad_var_to_var is not None: + if var_name in grad_var_to_var: + grad_var_to_var[new_name] = grad_var_to_var[ + var_name] + else: + grad_var_to_var[new_name] = var_name # rename original var_name renamed_vars[var_name][0] = new_name # before change: _rename_arg_(op_descs, var_name, @@ -557,6 +568,13 @@ def _addup_repetitive_outputs_(op_descs, block_idx): new_name = var_name + "@RENAME@block" + str(block_idx) + "@" + \ str(var_rename_count[var_name]) var_rename_count[var_name] += 1 + # Build the mapping between the new_name and var_name (Only for auto parallel) + if grad_var_to_var is not None: + if var_name in grad_var_to_var: + grad_var_to_var[new_name] = grad_var_to_var[ + var_name] + else: + grad_var_to_var[new_name] = var_name arg_names[arg_idx] = new_name op_desc.set_output(param_name, arg_names) renamed_vars[var_name].append(new_name) @@ -1081,6 +1099,16 @@ def _append_backward_ops_(block, rename_var_map(dict): used to associate target_grad var name with first grad_op input name. Only used in for high order gradient. """ + + # Build the mapping between the forward op and backward op (Only for auto parallel) + def update_distop_context(distop_context, op_grad_to_var, + appending_grad_times): + distop_context.grad_var_to_var[appending_grad_times].update( + op_grad_to_var) + for op_desc in grad_op_desc: + assert op_desc.id() not in distop_context.grad_op_id_to_op_id + distop_context.grad_op_id_to_op_id[op_desc.id()] = op.desc.id() + if callbacks is not None: assert (isinstance(callbacks, (list, tuple))) for cb in callbacks: @@ -1118,11 +1146,18 @@ def _append_backward_ops_(block, # Getting op's corresponding grad_op grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list) + # Build the mapping between the forward op and backward op (Only for auto parallel) if distop_context is not None: - for op_desc in grad_op_desc: - assert op_desc.id() not in distop_context.grad_op_id_to_op_id - distop_context.grad_op_id_to_op_id[op_desc.id()] = op.desc.id() + update_distop_context(distop_context, op_grad_to_var, + program._appending_grad_times) + else: + default_ctx = getattr(paddle.distributed.auto_parallel.dist_context, + '_g_default_distributed_context', None) + if default_ctx is not None: + distop_context = default_ctx.dist_op_context + update_distop_context(distop_context, op_grad_to_var, + program._appending_grad_times) # Set device for grad_op according to forward Op device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() @@ -1155,6 +1190,11 @@ def _append_backward_ops_(block, rename_var_map[name] = new_name if name in op_grad_to_var: + # Build the mapping between the grad var name and var name (Only for auto parallel) + if distop_context is not None: + distop_context.grad_var_to_var[ + program._appending_grad_times][ + new_name] = op_grad_to_var[name] op_grad_to_var[new_name] = op_grad_to_var[name] op_grad_to_var.pop(name) @@ -1187,8 +1227,14 @@ def _append_backward_ops_(block, grad_op_descs.extend(grad_op_desc) grad_to_var.update(op_grad_to_var) + # record mapping bewteen grad var name and var name (Only for auto parallel) + grad_var_to_var = None + if distop_context is not None: + grad_var_to_var = distop_context.grad_var_to_var[ + program._appending_grad_times] # sum parameter's gradients' var given multiple var gradient - grad_op_descs = _addup_repetitive_outputs_(grad_op_descs, block.idx) + grad_op_descs = _addup_repetitive_outputs_(grad_op_descs, block.idx, + grad_var_to_var) # if all outputs of the grad op are in no_grad_set, then just remove and fill zero # if all inputs of the grad op are in no_grad_set, just remove this op diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 97a3092f11fd244e2c7330ed30470a27cb63e447..4d052f7e90cd3d11e8607bd4a60546f9eca27ae1 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -12,6 +12,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS}) py_test_modules(test_converter MODULES test_converter ENVS ${dist_ENVS}) set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) + py_test_modules(test_high_order_grad MODULES test_high_order_grad ENVS ${dist_ENVS}) + set_tests_properties(test_high_order_grad PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) py_test_modules(test_tunable_variable MODULES test_tunable_variable ENVS ${dist_ENVS}) py_test_modules(test_tunable_space MODULES test_tunable_space ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py index d7321066ed9d96400577b422c3ef1ac8f9d9de9b..b039bb76dcb03ee5eef9d0aecbd6719a6ab2dff4 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py @@ -127,9 +127,16 @@ def train(): engine.prepare(optimizer, loss) engine.fit(dataset, batch_size=batch_size, - steps_per_epoch=batch_num * batch_size) - engine.save('./mlp') - engine.load('./mlp') + steps_per_epoch=batch_num * batch_size, + sample_generator=True) + + eval_dataset = MyDataset(batch_size) + engine.prepare(optimizer, loss, mode='eval') + engine.evaluate(eval_dataset, batch_size) + + test_dataset = MyDataset(batch_size) + engine.prepare(mode='predict') + engine.predict(test_dataset, batch_size) engine.save('./mlp_inf', training=False, mode='predict') diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/engine_predict_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/engine_predict_api.py deleted file mode 100644 index 5f7c018ee4f16a58e408c6ce08415d4e3bbaaca8..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/tests/unittests/auto_parallel/engine_predict_api.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -import time -import paddle.fluid as fluid -import copy -import os -import numpy as np -import subprocess -import paddle -import paddle.nn as nn -import paddle.fluid as fluid -import paddle.static as static -import paddle.nn.functional as F -import paddle.utils as utils -from paddle.fluid import layers -from paddle.io import Dataset, IterableDataset, DataLoader -from paddle.static import InputSpec -from paddle.distributed import fleet -import paddle.distributed.auto_parallel as auto -from paddle.distributed.auto_parallel.engine import Engine - -paddle.enable_static() -global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) -batch_size = 1 -batch_num = 10 -hidden_size = 1024 -image_size = hidden_size - -paddle.seed(44) - - -class MyDataset(Dataset): - def __init__(self, num_samples): - super(MyDataset, self).__init__() - self.num_samples = num_samples - - def __getitem__(self, index): - input = np.random.uniform(size=image_size).astype("float32") - return input - - def __len__(self): - return self.num_samples - - -class MLPLayer(nn.Layer): - def __init__(self, - hidden_size=1024, - intermediate_size=4 * 1024, - dropout_ratio=0.1, - initializer_range=0.02): - super(MLPLayer, self).__init__() - d_model = hidden_size - dim_feedforward = intermediate_size - weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( - mean=0.0, std=initializer_range)) - bias_attr = None - - self.linear0 = nn.Linear( - d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) - self.linear1 = nn.Linear( - dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) - self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr) - self.norm = nn.LayerNorm(d_model, epsilon=1e-5) - self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") - - def forward(self, input): - out = self.norm(input) - out = self.linear0(input) - auto.shard_tensor( - self.linear0.weight, - dist_attr={ - "process_mesh": global_process_mesh, - "dims_mapping": [-1, 0] - }) - out = F.gelu(out, approximate=True) - out = self.linear1(out) - auto.shard_tensor( - self.linear1.weight, - dist_attr={ - "process_mesh": global_process_mesh, - "dims_mapping": [0, -1] - }) - out = self.dropout(out) - out = self.linear2(out) - return out - - -def train(): - mlp = MLPLayer( - hidden_size=hidden_size, - intermediate_size=4 * hidden_size, - dropout_ratio=0.1, - initializer_range=0.02) - - dataset = MyDataset(batch_num * batch_size) - inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') - - dist_strategy = fleet.DistributedStrategy() - # init parallel optimizer - dist_strategy.semi_auto = True - fleet.init(is_collective=True, strategy=dist_strategy) - - engine = Engine(mlp, inputs_spec=inputs_spec, strategy=dist_strategy) - engine.prepare(mode='predict') - engine.predict(dataset, batch_size=batch_size) - - -if __name__ == "__main__": - train() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py b/python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9efe7ab2dd0a48517690383779a4ace2a8107c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py @@ -0,0 +1,157 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import paddle +import unittest +import numpy as np +import paddle.distributed.auto_parallel as auto + +from paddle.static import InputSpec +from paddle.distributed import fleet +from paddle.incubate.autograd import Hessian +from paddle.distributed.auto_parallel.engine import Engine + + +class FCNet: + def __init__(self, num_ins, num_outs, num_layers, hidden_size): + self.num_ins = num_ins + self.num_outs = num_outs + self.num_layers = num_layers + self.hidden_size = hidden_size + self.activation = paddle.tanh + + self.weights = [] + self.biases = [] + for i in range(self.num_layers): + if i == 0: + lsize = self.num_ins + rsize = self.hidden_size + elif i == (self.num_layers - 1): + lsize = self.hidden_size + rsize = self.num_outs + else: + lsize = self.hidden_size + rsize = self.hidden_size + + w = paddle.static.create_parameter( + shape=[lsize, rsize], dtype="float32", is_bias=False) + b = paddle.static.create_parameter( + shape=[rsize], dtype="float32", is_bias=True) + self.weights.append(w) + self.biases.append(b) + + def nn_func(self, ins): + u = ins + for i in range(self.num_layers - 1): + u = paddle.nn.functional.linear(u, self.weights[i], self.biases[i]) + u = self.activation(u) + u = paddle.nn.functional.linear(u, self.weights[-1], self.biases[-1]) + return u + + +class LaplaceModel(paddle.nn.Layer): + def __init__(self, num_ins=2, num_outs=1, num_layers=5, hidden_size=20): + super(LaplaceModel, self).__init__() + self.net = FCNet( + num_ins=num_ins, + num_outs=num_outs, + num_layers=num_layers, + hidden_size=hidden_size) + + def forward(self, inputs, bc_index): + inputs.stop_gradient = False + outputs = self.net.nn_func(inputs) + # eq_loss + hes = Hessian(self.net.nn_func, inputs, is_batched=True) + eq_loss = paddle.norm(hes[:, 0, 0] + hes[:, 1, 1], p=2) + # bc_loss + bc_u = paddle.index_select(outputs, bc_index) + return eq_loss, bc_u + + +class LaplaceDataset: + def __init__(self, num_sample): + self.num_sample = num_sample + + def __getitem__(self, index): + x = np.linspace(0, 0.9, 10) + y = np.linspace(0, 0.9, 10) + bc_value = np.random.rand(36).reshape(36, 1).astype('float32') + + domain_space = [] + bc_index = [] + for j in range(len(y)): + for i in range(len(x)): + domain_space.append([x[i], y[j]]) + if i == 0 or i == 9 or j == 0 or j == 9: + bc_index.append(i + 10 * j) + domain_space = np.array(domain_space, dtype='float32') + bc_index = np.array(bc_index, dtype='int64') + + return domain_space, bc_index, bc_value + + def __len__(self): + return self.num_sample + + +def loss_func(eq_loss, bc_u, bc_value): + bc_diff = bc_u - bc_value + bc_loss = paddle.norm(bc_diff, p=2) + loss = eq_loss + bc_loss + return loss + + +def main(): + # dataset + train_dataset = LaplaceDataset(10) + # optimizer + optimizer = paddle.optimizer.Adam(learning_rate=0.001) + # model + laplace = LaplaceModel() + + # spec + inputs_spec = [ + InputSpec([100, 2], 'float32', 'x'), InputSpec([36], 'int64', 'bc_idx') + ] + labels_spec = InputSpec([36, 1], 'float32', 'bc_v') + + dist_strategy = fleet.DistributedStrategy() + dist_strategy.semi_auto = True + fleet.init(is_collective=True, strategy=dist_strategy) + + engine = Engine( + laplace, + inputs_spec=inputs_spec, + labels_spec=labels_spec, + strategy=dist_strategy) + paddle.seed(1234 + engine._cur_rank) + engine.prepare(optimizer=optimizer, loss=loss_func) + res = engine.fit(train_dataset, sample_generator=False) + assert np.allclose(res[-1], 2.840593) + + dist_context = engine.dist_context + block = engine.main_program.global_block() + ops = block.ops + for op in ops: + if op.type == 'p_norm': + op_dist_attr = dist_context.get_op_dist_attr_for_program(op) + assert op_dist_attr.impl_type == 'p_norm' + if 'x' in op.input_arg_names: + out_name = op.output_arg_names[0] + assert block.vars[out_name].shape[0] == 50 + + +if __name__ == "__main__": + main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_engine_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_engine_api.py index 5ca12bc1e0e177a1477f8415ccc7032dcd85d925..efcad7eb11268a9995da9e40c0750a40da7bc227 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_engine_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_engine_api.py @@ -49,28 +49,6 @@ class TestEngineAPI(unittest.TestCase): if os.path.exists('rank_mapping.csv'): os.remove('rank_mapping.csv') - def test_engine_predict(self): - file_dir = os.path.dirname(os.path.abspath(__file__)) - launch_model_path = os.path.join(file_dir, "engine_predict_api.py") - - if os.environ.get("WITH_COVERAGE", "OFF") == "ON": - coverage_args = ["-m", "coverage", "run", "--branch", "-p"] - else: - coverage_args = [] - - cmd = [sys.executable, "-u"] + coverage_args + [ - "-m", "launch", "--gpus", "0,1", launch_model_path - ] - - process = subprocess.Popen(cmd) - process.wait() - self.assertEqual(process.returncode, 0) - - # Remove unnecessary files - log_path = os.path.join(file_dir, "log") - if os.path.exists(log_path): - shutil.rmtree(log_path) - if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_high_order_grad.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_high_order_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..ab4a34cf99cbfa17567e91e63454dc2b09ef38fd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_high_order_grad.py @@ -0,0 +1,48 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import sys +import shutil +import subprocess +from paddle.distributed.fleet.launch_utils import run_with_coverage + + +class TestHighOrderGrad(unittest.TestCase): + def test_dp2(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, "high_order_grad.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + cmd = [sys.executable, "-u"] + coverage_args + [ + "-m", "launch", "--gpus", "0,1", launch_model_path + ] + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + # Remove unnecessary files + log_path = os.path.join(file_dir, "log") + if os.path.exists(log_path): + shutil.rmtree(log_path) + + +if __name__ == "__main__": + unittest.main()