diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 5e4a8c7d040338c3e9e5a92ea01d74a7a84eaedc..3d5b91cd7faa760ec0e7230d888c474243c9748f 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -18,7 +18,6 @@ from collections import defaultdict import paddle import paddle.utils as utils -import paddle.distributed.auto_parallel as auto from paddle import fluid, static from paddle.io import Dataset @@ -72,7 +71,6 @@ class Engine: self._saver = DistributedSaver() self._logger = get_logger(logging.INFO) - self._default_strategy = None self._orig_main_prog = static.default_main_program() self._orig_startup_prog = static.default_startup_program() self._orig_dist_context = get_default_distributed_context() @@ -117,9 +115,11 @@ class Engine: self._planned_mode = None self._modes = ['train', 'eval', 'predict'] - self._build() - # Do auto parallel process + # Build program and do auto parallel process + for mode in self._modes: + # Build forward program + self._build(mode) for mode in self._modes: # Do the planning process self._plan(mode) @@ -129,56 +129,49 @@ class Engine: # Init comm and startup program self._initialize(mode) - def _build(self): - for mode in self._modes: - serial_main_prog = self._serial_main_progs.get(mode, None) - if serial_main_prog is not None: - return - - losses = [] - metrics = [] - serial_main_prog = self._orig_main_prog.clone() - serial_startup_prog = self._orig_startup_prog.clone() - with static.program_guard(serial_main_prog, serial_startup_prog), \ - utils.unique_name.guard(): - inputs_spec = self.inputs_spec - labels_spec = self.labels_spec if self.labels_spec else [] - inputs = [s._create_feed_layer() for s in inputs_spec] - labels = [s._create_feed_layer() for s in labels_spec] - outputs = to_list(self.model(*inputs)) - if mode != "predict" and self._loss: - losses = to_list(self._loss(*(outputs + labels))) - - if mode != "predict": - for metric in self._metrics: - metrics.extend( - to_list(metric.compute(*(outputs + labels)))) - - default_ctx = get_default_distributed_context() - if not default_ctx.has_annotation or self._default_strategy: - # We build the world process group because the data parallel - # needs all ranks by default. - new_process_group(list(range(self._nranks))) - default_ctx.data_parallel = True - - # self._feed_vars[mode] = {"inputs": inputs, "labels": labels} - feed_vars = {"inputs": inputs, "labels": labels} - - # self._fetch_vars[mode] = { - # "outputs": flatten(outputs), - # "loss": losses, - # "metrics": metrics - # } - fetch_vars = { - "outputs": flatten(outputs), - "loss": losses, - "metrics": metrics - } - - self._dist_contexts[mode] = DistributedContext( - serial_main_prog, serial_startup_prog, self._optimizer, losses, - feed_vars, fetch_vars, self.cluster, self.strategy) - self._dist_contexts[mode].gradient_scale = self._gradient_scale + def _build(self, mode): + + serial_main_prog = self._serial_main_progs.get(mode, None) + if serial_main_prog is not None: + return + + losses = [] + metrics = [] + serial_main_prog = self._orig_main_prog.clone() + serial_startup_prog = self._orig_startup_prog.clone() + with static.program_guard(serial_main_prog, serial_startup_prog), \ + utils.unique_name.guard(): + inputs_spec = self.inputs_spec + labels_spec = self.labels_spec if self.labels_spec else [] + inputs = [s._create_feed_layer() for s in inputs_spec] + labels = [s._create_feed_layer() for s in labels_spec] + outputs = to_list(self.model(*inputs)) + if mode != "predict" and self._loss: + losses = to_list(self._loss(*(outputs + labels))) + + if mode != "predict": + for metric in self._metrics: + metrics.extend(to_list(metric.compute(*(outputs + labels)))) + + default_ctx = get_default_distributed_context() + if not default_ctx.has_annotation: + # We build the world process group because the data parallel + # needs all ranks by default. + new_process_group(list(range(self._nranks))) + default_ctx.data_parallel = True + + feed_vars = {"inputs": inputs, "labels": labels} + + fetch_vars = { + "outputs": flatten(outputs), + "loss": losses, + "metrics": metrics + } + + self._dist_contexts[mode] = DistributedContext( + serial_main_prog, serial_startup_prog, self._optimizer, losses, + feed_vars, fetch_vars, self.cluster, self.strategy) + self._dist_contexts[mode].gradient_scale = self._gradient_scale def _plan(self, mode): if self._planned_mode is None: @@ -240,7 +233,6 @@ class Engine: continue process_group.instantiate() - # initialize self._place = _get_device() if isinstance(self._place, fluid.CUDAPlace): self._place = fluid.CUDAPlace(ParallelEnv().dev_id) @@ -273,8 +265,8 @@ class Engine: train_dataloader = self._create_dataloader(train_data, batch_size, epochs, steps_per_epoch) - usr_fetch = self._to_map_fetch(fetches) - fetch_loss = self._inner_fetch(self.fetch_vars["loss"]) + usr_fetch = self._validate_fetches(fetches) + fetch_loss = self._validate_fetches(self.fetch_vars["loss"]) fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch) for epoch in range(epochs): @@ -292,8 +284,7 @@ class Engine: user_outs = outs[len(fetch_loss):] user_fetch_list = fetch_list[len(fetch_loss):] for i, out in enumerate(user_outs): - train_logs["train_" + - fetch_map[user_fetch_list[i]]] = out[0] + train_logs["train_" + fetch_map[user_fetch_list[i]]] = out self._logger.info(train_logs) def evaluate(self, @@ -307,9 +298,9 @@ class Engine: "eval model is not ready, please call `engine.prepare()` first." eval_dataloader = self._create_dataloader(eval_data, batch_size) - usr_fetch = self._to_map_fetch(fetches) - fetch_loss = self._inner_fetch(self.fetch_vars["loss"]) - fetch_metrics = self._inner_fetch(self.fetch_vars["metrics"]) + usr_fetch = self._validate_fetches(fetches) + fetch_loss = self._validate_fetches(self.fetch_vars["loss"]) + fetch_metrics = self._validate_fetches(self.fetch_vars["metrics"]) inner_fetch = dict(fetch_loss, **fetch_metrics) fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch) @@ -321,7 +312,7 @@ class Engine: return_numpy=return_numpy) # inner fetches if fetch_loss: - eval_logs["eval_loss"] = outs[0] + eval_logs["eval_loss"] = outs[0][0] # Metric if fetch_metrics: metric_out = outs[len(fetch_loss):len(inner_fetch)] @@ -331,9 +322,9 @@ class Engine: for i, res in enumerate(to_list(results)): eval_logs["eval_" + metric.name()[i]] = res # usr fetches - usr_out = outs[len(inner_fetch):] + usr_outs = outs[len(inner_fetch):] usr_fetch_list = fetch_list[len(inner_fetch):] - for i, out in enumerate(usr_out): + for i, out in enumerate(usr_outs): eval_logs["eval_" + fetch_map[usr_fetch_list[i]]] = out # logger self._logger.info(eval_logs) @@ -349,8 +340,8 @@ class Engine: "predict model is not ready, please call `engine.prepare()` first." test_dataloader = self._create_dataloader(test_data, batch_size) - usr_fetch = self._to_map_fetch(fetches) - fetch_outputs = self._inner_fetch(self.fetch_vars["outputs"]) + usr_fetch = self._validate_fetches(fetches) + fetch_outputs = self._validate_fetches(self.fetch_vars["outputs"]) fetch_list, fetch_map = self._fetch_map(fetch_outputs, usr_fetch) outputs = [] @@ -362,42 +353,11 @@ class Engine: return_numpy=return_numpy) outputs.append(outs[:len(fetch_outputs)]) for i, out in enumerate(outs): - predict_logs["pred_" + fetch_map[fetch_list[i]]] = out[0] + predict_logs["pred_" + fetch_map[fetch_list[i]]] = out self._logger.info(predict_logs) return outputs - def _local_var(self, var): - var_name = _to_name_str(var) - return var_name in self.main_program.global_block().vars - - def _to_map_fetch(self, fetches): - if not fetches: - return {} - if isinstance(fetches, dict): - fetch_var_names = list(map(_to_name_str, fetches.values())) - usr_fetches = dict(zip(fetch_var_names, list(fetches.keys()))) - elif isinstance(fetches, list): - fetch_var_names = list(map(_to_name_str, fetches)) - usr_fetches = dict(zip(fetch_var_names, fetch_var_names)) - return dict(filter(lambda x: self._local_var(x[0]), - usr_fetches.items())) - - def _inner_fetch(self, fetch_vars): - fetch_list = list( - map(lambda x: x.name, list(filter(self._local_var, fetch_vars)))) - inner_fetches = dict(zip(fetch_list, fetch_list)) - return inner_fetches - - def _fetch_map(self, inner_fetch, usr_fetch): - # replace inner fetch name if usr set for it - for iname in inner_fetch: - if iname in usr_fetch: - inner_fetch[iname] = usr_fetch[iname] - usr_fetch.pop(iname) - fetches = dict(inner_fetch, **usr_fetch) - return list(fetches.keys()), fetches - def _create_dataloader(self, dataset, batch_size, @@ -468,26 +428,35 @@ class Engine: .format(i, spec)) return specs - def _set_data_parallel(self, var): - if self._nranks == 1: - self._default_strategy = 'serial' - auto.shard_tensor(var, - dist_attr={ - "process_mesh": [0], - "dims_mapping": - [-1 for _ in range(len(var.shape))] - }) + def _is_local_var(self, var): + var_name = _to_name_str(var) + return var_name in self.main_program.global_block().vars + + def _validate_fetches(self, fetches): + # 1. Check user-defined fetches type + # 2. Prepare fetches_dict like {user_defined_name: var_name} + if not fetches: + return {} + if isinstance(fetches, dict): + fetch_var_names = list(map(_to_name_str, fetches.values())) + fetches_dict = dict(zip(fetch_var_names, list(fetches.keys()))) + elif isinstance(fetches, list): + fetch_var_names = list(map(_to_name_str, fetches)) + fetches_dict = dict(zip(fetch_var_names, fetch_var_names)) else: - self._default_strategy = 'dp' - auto.shard_tensor(var, - dist_attr={ - "process_mesh": - list(range(self._nranks)), - "dims_mapping": - [0] + [-1 for _ in range(len(var.shape) - 1)] - }) - - return var + raise TypeError("'fetches' only support 'dict' and 'list', " + "but got '{}'".format(str(type(fetches)))) + return dict( + filter(lambda x: self._is_local_var(x[0]), fetches_dict.items())) + + def _fetch_map(self, inner_fetch, usr_fetch): + # replace inner fetch name if usr set for it + for iname in inner_fetch: + if iname in usr_fetch: + inner_fetch[iname] = usr_fetch[iname] + usr_fetch.pop(iname) + fetches = dict(inner_fetch, **usr_fetch) + return list(fetches.keys()), fetches def _get_data_parallel_info(self, var, dist_context): # get data parallel world size and current data parallel rank diff --git a/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py b/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py index 0a4bfb1213d46889885f7ba137915703f1827376..b00f1a589e31211d1a2ae3da67b737d01d58ae90 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py @@ -137,7 +137,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): attrs={ "in_dtype": inf_var.dtype, "out_dtype": inf_var_int32.dtype, - OP_ROLE_KEY: OpRole.Backward + OP_ROLE_KEY: OpRole.Optimize }) allreduce_op = main_block.append_op(type='c_allreduce_max', inputs={'X': inf_var_int32}, @@ -145,7 +145,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): attrs={ 'ring_id': group.id, 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Backward + OP_ROLE_KEY: OpRole.Optimize }) cast_op2 = main_block.append_op(type='cast', inputs={'X': inf_var_int32}, @@ -153,7 +153,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): attrs={ "in_dtype": inf_var_int32.dtype, "out_dtype": inf_var.dtype, - OP_ROLE_KEY: OpRole.Backward + OP_ROLE_KEY: OpRole.Optimize }) main_block._sync_with_cpp() diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index 2272400e60ddfe82fd65703e42cd1cb873ecf2fc..80c9b8641ba36bf7c322a217b77701ebe8ae0138 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -222,7 +222,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): 'W': [Weight_var] }, outputs={'Out': [intermediate_var_0]}, - attrs={"start_index": relative_idx}) + attrs={ + "start_index": relative_idx, + OP_ROLE_KEY: src_op.attr('op_role') + }) if intermediate_var_0.shape != ref_shape: intermediate_var_0.desc.set_shape(ref_shape) @@ -235,6 +238,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): 'ring_id': group.id, 'use_calc_stream': True, 'use_model_parallel': True, + OP_ROLE_KEY: src_op.attr('op_role') }) if Out_var.shape != ref_shape: Out_var.desc.set_shape(ref_shape) @@ -442,6 +446,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): dp_group = new_process_group(group_ranks) if need_gradient_allreduce: + added_ops = [] W_Grad_var = main_block.var(kwargs['W@GRAD'][0]) allreduce_op = main_block.append_op(type='c_allreduce_sum', inputs={'X': [W_Grad_var]}, @@ -451,19 +456,24 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): 'use_calc_stream': True, OP_ROLE_KEY: OpRole.Backward }) - scale_op = main_block.append_op(type='scale', - inputs={'X': W_Grad_var}, - outputs={'Out': W_Grad_var}, - attrs={ - 'scale': 1.0 / dp_degree, - OP_ROLE_KEY: OpRole.Backward - }) + added_ops.append(allreduce_op) + + if ctx.gradient_scale: + scale_op = main_block.append_op(type='scale', + inputs={'X': W_Grad_var}, + outputs={'Out': W_Grad_var}, + attrs={ + 'scale': 1.0 / dp_degree, + OP_ROLE_KEY: OpRole.Backward + }) + added_ops.append(scale_op) + main_block._sync_with_cpp() dims_mapping = ctx.get_tensor_dist_attr_for_program( W_Grad_var).dims_mapping process_mesh = dist_attr.process_mesh - for op in [allreduce_op, scale_op]: + for op in added_ops: op_attr = OperatorDistributedAttribute() op_attr.process_mesh = process_mesh op_attr.set_output_dims_mapping(W_Grad_var.name, dims_mapping) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index 427932a77fbcd17569a3716c08d46036864f34c9..0826148208ec0233d61c324374b87b15a41979dc 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -405,6 +405,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): dp_group = new_process_group(group_ranks) if need_gradient_allreduce and is_parameter_related(Y_var.name, main_block): + added_ops = [] Y_Grad_var = main_block.var(kwargs['Y@GRAD'][0]) allreduce_op = main_block.append_op(type='c_allreduce_sum', inputs={'X': [Y_Grad_var]}, @@ -414,19 +415,24 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): 'use_calc_stream': True, OP_ROLE_KEY: OpRole.Backward }) - scale_op = main_block.append_op(type='scale', - inputs={'X': Y_Grad_var}, - outputs={'Out': Y_Grad_var}, - attrs={ - 'scale': 1.0 / dp_degree, - OP_ROLE_KEY: OpRole.Backward - }) + added_ops.append(allreduce_op) + + if ctx.gradient_scale: + scale_op = main_block.append_op(type='scale', + inputs={'X': Y_Grad_var}, + outputs={'Out': Y_Grad_var}, + attrs={ + 'scale': 1.0 / dp_degree, + OP_ROLE_KEY: OpRole.Backward + }) + added_ops.append(scale_op) + main_block._sync_with_cpp() dims_mapping = ctx.get_tensor_dist_attr_for_program( Y_Grad_var).dims_mapping process_mesh = dist_attr.process_mesh - for op in [allreduce_op, scale_op]: + for op in added_ops: op_attr = OperatorDistributedAttribute() op_attr.process_mesh = process_mesh op_attr.set_output_dims_mapping(Y_Grad_var.name, dims_mapping) @@ -617,6 +623,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): 'ring_id': group.id, 'use_calc_stream': True, 'use_model_parallel': True, + OP_ROLE_KEY: src_op.attr('op_role') }) if intermediate_var_0.shape != ref_shape_x: intermediate_var_0.desc.set_shape(ref_shape_x) @@ -629,6 +636,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): 'transpose_X': False, 'transpose_Y': False, 'alpha': 1, + OP_ROLE_KEY: src_op('op_role') } inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} matmul_op = main_block.append_op(type='matmul', @@ -814,6 +822,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): 'transpose_X': False, 'transpose_Y': False, 'alpha': 1, + OP_ROLE_KEY: src_op.attr('op_role') } inputs = {'X': X_var, 'Y': Weight_var} @@ -853,7 +862,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): attrs={ 'ring_id': group.id, 'use_calc_stream': True, - 'use_model_parallel': True + 'use_model_parallel': True, + OP_ROLE_KEY: src_op.attr('op_role') }) if Out_var.shape != ref_shape: Out_var.desc.set_shape(ref_shape) @@ -1137,6 +1147,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): 'ring_id': group.id, 'use_calc_stream': True, 'use_model_parallel': True, + OP_ROLE_KEY: src_op.attr('op_role'), }) if intermediate_var_0.shape != ref_shape_x: intermediate_var_0.desc.set_shape(ref_shape_x) @@ -1145,7 +1156,11 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ['float16', 'float32', 'float64'], 'linear') check_dtype(intermediate_var_0.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear') - attrs = {'trans_x': False, 'trans_y': False} + attrs = { + 'trans_x': False, + 'trans_y': False, + OP_ROLE_KEY: src_op.attr('op_role') + } inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} matmul_v2_op = main_block.append_op(type='matmul_v2', inputs=inputs, @@ -1322,7 +1337,11 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): 'linear') check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear') - attrs = {'trans_x': False, 'trans_y': False} + attrs = { + 'trans_x': False, + 'trans_y': False, + OP_ROLE_KEY: src_op.attr('op_role') + } inputs = {'X': X_var, 'Y': Weight_var} # infer out var shape with op dist attr @@ -1361,7 +1380,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): attrs={ 'ring_id': group.id, 'use_calc_stream': True, - 'use_model_parallel': True + 'use_model_parallel': True, + OP_ROLE_KEY: src_op.attr('op_role') }) if Out_var.shape != ref_shape: Out_var.desc.set_shape(ref_shape) @@ -1646,6 +1666,7 @@ class DistributedMulImpl0(DistributedOperatorImpl): 'ring_id': group.id, 'use_calc_stream': True, 'use_model_parallel': True, + OP_ROLE_KEY: src_op.attr('op_role') }) if intermediate_var_0.shape != ref_shape_x: intermediate_var_0.desc.set_shape(ref_shape_x) @@ -1657,7 +1678,8 @@ class DistributedMulImpl0(DistributedOperatorImpl): # attrs = {'trans_x': False, 'trans_y': False} attrs = { "x_num_col_dims": src_op.desc.attr("x_num_col_dims"), - "y_num_col_dims": src_op.desc.attr("y_num_col_dims") + "y_num_col_dims": src_op.desc.attr("y_num_col_dims"), + OP_ROLE_KEY: src_op.attr('op_role') } inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} mul_op = main_block.append_op(type='mul', @@ -1838,7 +1860,8 @@ class DistributedMulImpl1(DistributedOperatorImpl): # attrs = {'trans_x': False, 'trans_y': False} attrs = { "x_num_col_dims": src_op.desc.attr("x_num_col_dims"), - "y_num_col_dims": src_op.desc.attr("y_num_col_dims") + "y_num_col_dims": src_op.desc.attr("y_num_col_dims"), + OP_ROLE_KEY: src_op.attr('op_role') } inputs = {'X': X_var, 'Y': Weight_var} @@ -1878,7 +1901,8 @@ class DistributedMulImpl1(DistributedOperatorImpl): attrs={ 'ring_id': group.id, 'use_calc_stream': True, - 'use_model_parallel': True + 'use_model_parallel': True, + OP_ROLE_KEY: src_op.attr('op_role') }) if Out_var.shape != ref_shape: Out_var.desc.set_shape(ref_shape) diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index 9056ab34fa71109f04317ee5be4c08e65d102e3e..97ff881ef95bfdc7e07b8da64a18aa417f493f60 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -264,10 +264,12 @@ class Partitioner(object): self._dist_context, **kinputs, **koutputs, **{"grad_var_to_var": grad_var_to_var}) elif is_optimize_op(op): + # NOTE: BACKWARD_ONLY_DIST_OPS's op_role must 2 because of 1F1B PASS kinputs, koutputs = dist_op_context.prepare_context(op) - dist_op_impl = get_distributed_operator_impl_container( - "default").get_impl(0) - dist_op_impl.backward(self._dist_context, **kinputs, **koutputs) + dist_op_opt_impl = _get_dist_op_backward_implement( + op, self._dist_context, forward_op_id2forward_op) + dist_op_opt_impl.backward(self._dist_context, **kinputs, + **koutputs) else: raise NotImplementedError( "partitioner only support forward and backward, optimize ops, but got {}" diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index e220b654e700a6e43473de4061ec86e7345b0188..c4f9ad8b6bc844d324070743511b8912741d19b0 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1065,7 +1065,7 @@ def set_grad_var_shape(program, dist_context): "softmax", "cross_entropy2", "dropout", "tanh", ["slice_grad", "c_allgather"], "assign", "matmul_v2_grad_grad", "elementwise_add_grad_grad", "shape", "sqrt", - "fused_softmax_mask_upper_triangle_grad" + "fused_softmax_mask_upper_triangle" ] if op.type in need_set_shape_list: for forward_op in block.ops: @@ -1096,11 +1096,9 @@ OpRole = core.op_proto_and_checker_maker.OpRole def is_forward_op(op): - ref_role1 = int(core.op_proto_and_checker_maker.OpRole.Forward) - ref_role2 = int(core.op_proto_and_checker_maker.OpRole.Loss) op_role = int(op.attr('op_role')) - return OP_ROLE_KEY in op.attr_names and (op_role == ref_role1 - or op_role == ref_role2) + return OP_ROLE_KEY in op.attr_names and (op_role == int(OpRole.Forward) + or op_role == int(OpRole.Loss)) def is_backward_op(op): @@ -1113,9 +1111,14 @@ def is_optimize_op(op): int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize) +def is_lr_sched_op(op): + return OP_ROLE_KEY in op.attr_names and \ + int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize.LRSched) + + def is_loss_op(op): return OP_ROLE_KEY in op.attr_names and \ - int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss)) + int(op.all_attrs()[OP_ROLE_KEY]) == (int(OpRole.Forward) | int(OpRole.Loss)) def is_prim_op(op): diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index f0d02451141aeb69f8dff87f0368352dc15aa7d6..7afba8c0f137745283af7b29a80128e9113dee28 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -452,7 +452,7 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context): inputs = {'X': grads, 'Scale': loss_scaling} outputs = {'Out': grads, 'FoundInfinite': found_inf} - attrs = {'op_role': OpRole.Backward} + attrs = {'op_role': OpRole.Optimize} new_op = main_block.append_op(type='check_finite_and_unscale', inputs=inputs, outputs=outputs, @@ -732,7 +732,7 @@ class AMPPass(PassBase): 'incr_ratio': self.get_attr("incr_ratio"), 'decr_ratio': self.get_attr("decr_ratio"), 'stop_update': self.get_attr("stop_update"), - 'op_role': OpRole.Backward + 'op_role': OpRole.Optimize } new_op = main_block.append_op(type='update_loss_scaling', diff --git a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py index 394d71706c4c494c908bd7568b99f6bf23c1be90..66cce97533efcf89d0191f543b8da6d56dea3602 100644 --- a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py +++ b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py @@ -21,20 +21,13 @@ from paddle.framework import core from paddle.fluid import layers from paddle.fluid.framework import program_guard, device_guard from .pass_base import PassBase, PassType, register_pass -from paddle.distributed.fleet.meta_optimizers.common import OpRole -from paddle.distributed.auto_parallel.utils import set_var_dist_attr +from paddle.distributed.auto_parallel.utils import set_var_dist_attr, is_optimize_op, OpRole, OP_ROLE_KEY from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping from paddle.distributed.auto_parallel.process_group import get_world_process_group world_process_group = get_world_process_group() -def _is_the_optimizer_op(op): - OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() - return OP_ROLE_KEY in op.attr_names and \ - int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize) - - def _remove_and_get_optimizer_op(main_program, dist_context): # 1 create tmp block # 2 mv optimizer op from global program to tmp block @@ -43,9 +36,8 @@ def _remove_and_get_optimizer_op(main_program, dist_context): temp_block = main_program._create_block() removed_op_idx = [] optimize_ops_desc = [] - skip_ops = ["increment", "elementwise_mod", "equal"] for idx, op in enumerate(main_block.ops): - if _is_the_optimizer_op(op) and op.type not in skip_ops: + if is_optimize_op(op): # append optimizer op to tmp block new_op_desc = temp_block.desc.append_op() new_op_desc.copy_from(op.desc) @@ -57,7 +49,8 @@ def _remove_and_get_optimizer_op(main_program, dist_context): dist_context.del_dist_op_for_program(op) for idx in removed_op_idx[::-1]: - main_block._remove_op(idx) + main_block._remove_op(idx, sync=False) + main_block._sync_with_cpp() return optimize_ops_desc @@ -109,7 +102,7 @@ def _get_gm_cond_var(main_program, k_steps, dist_context): outputs={'Out': [step_var]}, attrs={ 'step': float(1.0), - 'op_role': OpRole.Optimize + OP_ROLE_KEY: OpRole.Backward }) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( increment_op, world_process_group.ranks, [-1], dist_context) @@ -123,7 +116,8 @@ def _get_gm_cond_var(main_program, k_steps, dist_context): attrs={ 'axis': -1, 'use_mkldnn': False, - 'op_role': OpRole.Optimize + OP_ROLE_KEY: + OpRole.Backward }) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( elementwise_mod_op, world_process_group.ranks, [-1], dist_context) @@ -134,7 +128,7 @@ def _get_gm_cond_var(main_program, k_steps, dist_context): 'Y': zero_var }, outputs={'Out': cond_var}, - attrs={'op_role': OpRole.Optimize}) + attrs={OP_ROLE_KEY: OpRole.Backward}) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( equal_op, world_process_group.ranks, [-1], dist_context) @@ -143,7 +137,6 @@ def _get_gm_cond_var(main_program, k_steps, dist_context): def _append_gradient_merge_backward_op( main_program, startup_program, params_grads: List[Tuple[Any, Any]], - cond_var_name: str, dist_context) -> Tuple[List[Tuple[Any, Any]], Dict[str, Any]]: main_block = main_program.global_block() startup_block = startup_program.global_block() @@ -201,7 +194,7 @@ def _append_gradient_merge_backward_op( attrs={ 'axis': -1, 'use_mkldnn': False, - 'op_role': OpRole.Optimize + OP_ROLE_KEY: OpRole.Backward }) new_params_to_grads.append([param, gradient_merge_var]) grad_to_gradient_merge[grad.name] = gradient_merge_var.name @@ -233,8 +226,7 @@ def _create_cond_block_and_update_optimizer( 'bias': 0.0, 'bias_after_scale': False }) - new_grad.op._set_attr(op_maker.kOpRoleAttrName(), - OpRole.Optimize) + new_grad.op._set_attr(OP_ROLE_KEY, OpRole.Optimize) # append optimizer ops for op_desc in optimize_ops_desc: @@ -272,29 +264,27 @@ def _create_cond_block_and_update_optimizer( dtype=new_grad.dtype, value=0.0, out=new_grad) - new_grad.op._set_attr(op_maker.kOpRoleAttrName(), - op_maker.OpRole.Optimize) + new_grad.op._set_attr(OP_ROLE_KEY, op_maker.OpRole.Optimize) layers.cond(cond_var, true_fn=true_apply_gradient, false_fn=None) cond_op = main_program.global_block().ops[-1] - cond_op._set_attr('op_role', OpRole.Optimize) + cond_op._set_attr(OP_ROLE_KEY, OpRole.Optimize) def parse_program(main_program, startup_program, params_grads, k_steps, avg, dist_context): - # 1 create gradient_merge_cond - cond_var = _get_gm_cond_var(main_program, k_steps, dist_context) - - # 2 remove optimizer_op from main_program + # 1 remove optimizer_op from main_program optimize_ops_desc = _remove_and_get_optimizer_op(main_program, dist_context) # back to block 0 main_program._rollback() - # 3 append gradient merge backward op to main_program + # 2 append gradient merge backward op to main_program new_params_to_grads, grad_to_gradient_merge = _append_gradient_merge_backward_op( - main_program, startup_program, params_grads, cond_var.name, - dist_context) + main_program, startup_program, params_grads, dist_context) + + # 3 create gradient_merge_cond + cond_var = _get_gm_cond_var(main_program, k_steps, dist_context) # 4 create ConditionalBlock and append gradient merge optimizer ops _create_cond_block_and_update_optimizer(main_program, cond_var,