未验证 提交 9edf8502 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel]Update comp cost and completion for gpt auto search (#46387)

* update comp cost and completion for gpt auto search

* add unittest
上级 6e9bb9f9
...@@ -142,6 +142,7 @@ class Completer: ...@@ -142,6 +142,7 @@ class Completer:
def __init__(self, dist_context): def __init__(self, dist_context):
assert dist_context is not None assert dist_context is not None
self._dist_context = dist_context self._dist_context = dist_context
self._has_prepared = False
def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True): def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):
changed = False changed = False
...@@ -719,6 +720,8 @@ class Completer: ...@@ -719,6 +720,8 @@ class Completer:
self._update_process_mesh_between_graphs() self._update_process_mesh_between_graphs()
def _prepare(self): def _prepare(self):
if self._has_prepared:
return
self._while_op_nodes = {} self._while_op_nodes = {}
self._array_nodes = {} self._array_nodes = {}
self._node_pairs_between_graphs = [] self._node_pairs_between_graphs = []
...@@ -732,6 +735,8 @@ class Completer: ...@@ -732,6 +735,8 @@ class Completer:
if self._array_nodes.get(array_var_name, None) is None: if self._array_nodes.get(array_var_name, None) is None:
self._array_nodes[array_var_name] = [] self._array_nodes[array_var_name] = []
self._array_nodes[array_var_name].append(node) self._array_nodes[array_var_name].append(node)
# Add the array input node
self._array_nodes[array_var_name].append(node.inputs[0])
if node.op().type() == "write_to_array": if node.op().type() == "write_to_array":
array_var_name = node.op().output("Out")[0] array_var_name = node.op().output("Out")[0]
if self._array_nodes.get(array_var_name, None) is None: if self._array_nodes.get(array_var_name, None) is None:
...@@ -752,6 +757,7 @@ class Completer: ...@@ -752,6 +757,7 @@ class Completer:
and after_node.var().name() == node.var().name(): and after_node.var().name() == node.var().name():
self._node_pairs_between_graphs.append( self._node_pairs_between_graphs.append(
(after_node, node)) (after_node, node))
self._has_prepared = True
def complete_forward_annotation(self, serial_main_program=None): def complete_forward_annotation(self, serial_main_program=None):
""" Complete annotation for the partial annotated serial_main_program. """ Complete annotation for the partial annotated serial_main_program.
...@@ -899,6 +905,72 @@ class Completer: ...@@ -899,6 +905,72 @@ class Completer:
else: else:
dist_op.dist_attr = original_op_dist_attr dist_op.dist_attr = original_op_dist_attr
def _complete_tensor_dist_attr_by_op(self, serial_main_program=None):
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context._serial_main_program = serial_main_program
self._dist_context.initialize()
self._prepare()
has_set_dist_attr = set()
all_nodes = self._dist_context.serial_ordered_nodes
for node in all_nodes:
if node.is_op():
if node.op().type() in ["while"]:
continue
dist_op = self._dist_context.get_dist_op_for_graph(node)
op_dist_attr = dist_op.dist_attr
for tensor_node in node.inputs:
if tensor_node.is_var() and tensor_node.var() is not None:
# Skip the non-leaf var node
if len(tensor_node.inputs) != 0:
continue
tensor_desc = tensor_node.var()
tensor_name = tensor_desc.name()
tensor = dist_op.get_serial_input(tensor_name)
# Use the first op to set the tensor dist attr
if tensor_name in has_set_dist_attr:
continue
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_dist_attr.process_mesh = op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_name) if tensor.is_parameter else [
-1 for i in tensor_desc.shape()
]
has_set_dist_attr.add(tensor_name)
for tensor_node in node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None:
tensor_name = tensor_node.var().name()
if tensor_name in has_set_dist_attr:
continue
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_dist_attr.process_mesh = op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_name)
has_set_dist_attr.add(tensor_name)
self._update_process_mesh_for_specials()
self._update_process_mesh_between_graphs()
self._update_dims_mapping_for_special()
self._update_dims_mapping_between_graphs()
# Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program()
# Do the validation check and amend some completion
self._dist_context.amend_dist_attr_for_program()
self._dist_context.validate_dist_attr_for_program()
def _complete_high_order_grad_annotation(self, serial_main_program=None): def _complete_high_order_grad_annotation(self, serial_main_program=None):
""" """
NOTE: NOTE:
......
...@@ -167,6 +167,25 @@ class DropoutOpCost(CompOpCost): ...@@ -167,6 +167,25 @@ class DropoutOpCost(CompOpCost):
return 0 return 0
@register_op_cost
class DropoutGradOpCost(CompOpCost):
OP_TYPE = "dropout_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(DropoutGradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class ElementwiseAddOpCost(CompOpCost): class ElementwiseAddOpCost(CompOpCost):
OP_TYPE = "elementwise_add" OP_TYPE = "elementwise_add"
...@@ -395,6 +414,42 @@ class FillConstantBatchSizeLikeOpCost(CompOpCost): ...@@ -395,6 +414,42 @@ class FillConstantBatchSizeLikeOpCost(CompOpCost):
return 0 return 0
@register_op_cost
class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost):
OP_TYPE = "fused_softmax_mask_upper_triangle"
def __init__(self, op=None, op_desc=None, cluster=None):
super(FusedSoftmaxMaskUpperTriangleOpCost,
self).__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost):
OP_TYPE = "fused_softmax_mask_upper_triangle_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(FusedSoftmaxMaskUpperTriangleGradOpCost,
self).__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class GatherOpCost(CompOpCost): class GatherOpCost(CompOpCost):
OP_TYPE = "gather" OP_TYPE = "gather"
......
...@@ -82,6 +82,9 @@ from paddle.distributed.auto_parallel.cost.comp_op_cost import Transpose2OpCost ...@@ -82,6 +82,9 @@ from paddle.distributed.auto_parallel.cost.comp_op_cost import Transpose2OpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import Transpose2GradOpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import Transpose2GradOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import Unsqueeze2OpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import Unsqueeze2OpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import WriteToArrayOpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import WriteToArrayOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import DropoutGradOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import FusedSoftmaxMaskUpperTriangleOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import FusedSoftmaxMaskUpperTriangleGradOpCost
from test_cluster import cluster_json from test_cluster import cluster_json
...@@ -417,6 +420,22 @@ class TestCompOpCost(unittest.TestCase): ...@@ -417,6 +420,22 @@ class TestCompOpCost(unittest.TestCase):
self.assertTrue(op_cost.flops >= 0) self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0) self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0) self.assertTrue(op_cost.memory >= 0)
op_cost = DropoutGradOpCost(cluster=cluster)
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)
op_cost = FusedSoftmaxMaskUpperTriangleOpCost(cluster=cluster)
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)
op_cost = FusedSoftmaxMaskUpperTriangleGradOpCost(cluster=cluster)
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)
# Remove unnecessary files # Remove unnecessary files
if os.path.exists(cluster_json_path): if os.path.exists(cluster_json_path):
os.remove(cluster_json_path) os.remove(cluster_json_path)
......
...@@ -187,6 +187,14 @@ class TestMLP(unittest.TestCase): ...@@ -187,6 +187,14 @@ class TestMLP(unittest.TestCase):
train_program) train_program)
# print_program_with_dist_attr(complete_train_program, dist_context) # print_program_with_dist_attr(complete_train_program, dist_context)
def test_completer_by_dist_op(self):
train_program, start_program, dataloader, i, loss = get_program()
dist_context = DistributedContext()
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
complete_train_program = completer._complete_tensor_dist_attr_by_op()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册