From 9edf8502e302f5b6b6c1e908b2a8141be36a0892 Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Sat, 8 Oct 2022 20:44:25 +0800 Subject: [PATCH] [Auto Parallel]Update comp cost and completion for gpt auto search (#46387) * update comp cost and completion for gpt auto search * add unittest --- .../distributed/auto_parallel/completion.py | 72 +++++++++++++++++++ .../auto_parallel/cost/comp_op_cost.py | 55 ++++++++++++++ .../unittests/auto_parallel/test_comp_cost.py | 19 +++++ .../auto_parallel/test_while_op_completion.py | 8 +++ 4 files changed, 154 insertions(+) diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index c8633b4a73..977e5fb9fc 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -142,6 +142,7 @@ class Completer: def __init__(self, dist_context): assert dist_context is not None self._dist_context = dist_context + self._has_prepared = False def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True): changed = False @@ -719,6 +720,8 @@ class Completer: self._update_process_mesh_between_graphs() def _prepare(self): + if self._has_prepared: + return self._while_op_nodes = {} self._array_nodes = {} self._node_pairs_between_graphs = [] @@ -732,6 +735,8 @@ class Completer: if self._array_nodes.get(array_var_name, None) is None: self._array_nodes[array_var_name] = [] self._array_nodes[array_var_name].append(node) + # Add the array input node + self._array_nodes[array_var_name].append(node.inputs[0]) if node.op().type() == "write_to_array": array_var_name = node.op().output("Out")[0] if self._array_nodes.get(array_var_name, None) is None: @@ -752,6 +757,7 @@ class Completer: and after_node.var().name() == node.var().name(): self._node_pairs_between_graphs.append( (after_node, node)) + self._has_prepared = True def complete_forward_annotation(self, serial_main_program=None): """ Complete annotation for the partial annotated serial_main_program. @@ -899,6 +905,72 @@ class Completer: else: dist_op.dist_attr = original_op_dist_attr + def _complete_tensor_dist_attr_by_op(self, serial_main_program=None): + if serial_main_program is None: + serial_main_program = self._dist_context.serial_main_program + else: + self._dist_context._serial_main_program = serial_main_program + + self._dist_context.initialize() + + self._prepare() + + has_set_dist_attr = set() + + all_nodes = self._dist_context.serial_ordered_nodes + for node in all_nodes: + if node.is_op(): + if node.op().type() in ["while"]: + continue + dist_op = self._dist_context.get_dist_op_for_graph(node) + op_dist_attr = dist_op.dist_attr + for tensor_node in node.inputs: + if tensor_node.is_var() and tensor_node.var() is not None: + # Skip the non-leaf var node + if len(tensor_node.inputs) != 0: + continue + tensor_desc = tensor_node.var() + tensor_name = tensor_desc.name() + tensor = dist_op.get_serial_input(tensor_name) + # Use the first op to set the tensor dist attr + if tensor_name in has_set_dist_attr: + continue + tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( + tensor_node) + tensor_dist_attr.process_mesh = op_dist_attr.process_mesh + tensor_dist_attr.dims_mapping = op_dist_attr.get_input_dims_mapping( + tensor_name) if tensor.is_parameter else [ + -1 for i in tensor_desc.shape() + ] + has_set_dist_attr.add(tensor_name) + for tensor_node in node.outputs: + if tensor_node.is_var() and tensor_node.var() is not None: + tensor_name = tensor_node.var().name() + if tensor_name in has_set_dist_attr: + continue + tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( + tensor_node) + tensor_dist_attr.process_mesh = op_dist_attr.process_mesh + tensor_dist_attr.dims_mapping = op_dist_attr.get_output_dims_mapping( + tensor_name) + has_set_dist_attr.add(tensor_name) + + self._update_process_mesh_for_specials() + + self._update_process_mesh_between_graphs() + + self._update_dims_mapping_for_special() + + self._update_dims_mapping_between_graphs() + + # Copy the corresponding distributed attribute from graph to serial_main_program + self._dist_context.copy_dist_attr_from_graph_to_program() + + # Do the validation check and amend some completion + self._dist_context.amend_dist_attr_for_program() + + self._dist_context.validate_dist_attr_for_program() + def _complete_high_order_grad_annotation(self, serial_main_program=None): """ NOTE: diff --git a/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py b/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py index b4ac972bcf..c5bdc85e1b 100644 --- a/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py @@ -167,6 +167,25 @@ class DropoutOpCost(CompOpCost): return 0 +@register_op_cost +class DropoutGradOpCost(CompOpCost): + OP_TYPE = "dropout_grad" + + def __init__(self, op=None, op_desc=None, cluster=None): + super(DropoutGradOpCost, self).__init__(op=op, + op_desc=op_desc, + cluster=cluster) + + # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided + def calc_flops(self): + # NOTE: The actual formula will be filled in the future + return 0 + + def calc_time(self): + # NOTE: The actual formula will be filled in the future + return 0 + + @register_op_cost class ElementwiseAddOpCost(CompOpCost): OP_TYPE = "elementwise_add" @@ -395,6 +414,42 @@ class FillConstantBatchSizeLikeOpCost(CompOpCost): return 0 +@register_op_cost +class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost): + OP_TYPE = "fused_softmax_mask_upper_triangle" + + def __init__(self, op=None, op_desc=None, cluster=None): + super(FusedSoftmaxMaskUpperTriangleOpCost, + self).__init__(op=op, op_desc=op_desc, cluster=cluster) + + # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided + def calc_flops(self): + # NOTE: The actual formula will be filled in the future + return 0 + + def calc_time(self): + # NOTE: The actual formula will be filled in the future + return 0 + + +@register_op_cost +class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost): + OP_TYPE = "fused_softmax_mask_upper_triangle_grad" + + def __init__(self, op=None, op_desc=None, cluster=None): + super(FusedSoftmaxMaskUpperTriangleGradOpCost, + self).__init__(op=op, op_desc=op_desc, cluster=cluster) + + # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided + def calc_flops(self): + # NOTE: The actual formula will be filled in the future + return 0 + + def calc_time(self): + # NOTE: The actual formula will be filled in the future + return 0 + + @register_op_cost class GatherOpCost(CompOpCost): OP_TYPE = "gather" diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_comp_cost.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_comp_cost.py index 0a3a5993ff..c0f7c87819 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_comp_cost.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_comp_cost.py @@ -82,6 +82,9 @@ from paddle.distributed.auto_parallel.cost.comp_op_cost import Transpose2OpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import Transpose2GradOpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import Unsqueeze2OpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import WriteToArrayOpCost +from paddle.distributed.auto_parallel.cost.comp_op_cost import DropoutGradOpCost +from paddle.distributed.auto_parallel.cost.comp_op_cost import FusedSoftmaxMaskUpperTriangleOpCost +from paddle.distributed.auto_parallel.cost.comp_op_cost import FusedSoftmaxMaskUpperTriangleGradOpCost from test_cluster import cluster_json @@ -417,6 +420,22 @@ class TestCompOpCost(unittest.TestCase): self.assertTrue(op_cost.flops >= 0) self.assertTrue(op_cost.time >= 0) self.assertTrue(op_cost.memory >= 0) + + op_cost = DropoutGradOpCost(cluster=cluster) + self.assertTrue(op_cost.flops >= 0) + self.assertTrue(op_cost.time >= 0) + self.assertTrue(op_cost.memory >= 0) + + op_cost = FusedSoftmaxMaskUpperTriangleOpCost(cluster=cluster) + self.assertTrue(op_cost.flops >= 0) + self.assertTrue(op_cost.time >= 0) + self.assertTrue(op_cost.memory >= 0) + + op_cost = FusedSoftmaxMaskUpperTriangleGradOpCost(cluster=cluster) + self.assertTrue(op_cost.flops >= 0) + self.assertTrue(op_cost.time >= 0) + self.assertTrue(op_cost.memory >= 0) + # Remove unnecessary files if os.path.exists(cluster_json_path): os.remove(cluster_json_path) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_completion.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_completion.py index d31b34cacc..f0edf8d6e2 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_completion.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_completion.py @@ -187,6 +187,14 @@ class TestMLP(unittest.TestCase): train_program) # print_program_with_dist_attr(complete_train_program, dist_context) + def test_completer_by_dist_op(self): + train_program, start_program, dataloader, i, loss = get_program() + dist_context = DistributedContext() + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) + complete_train_program = completer._complete_tensor_dist_attr_by_op() + if __name__ == "__main__": unittest.main() -- GitLab