未验证 提交 e5cda6fa 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Use the new completion algorithm (#39086)

* Add the backward support for QR

* Remove unnecessary comments

* [Auto Parallel] Improve the dist op interface and compatible computation

* Remove unnecessary modification

* Recover some modifications

* Add lost files

* Fix a minor bug

* Fix the bug of the planner

* Fix the format problem

* [Auto Parallel] Update the completion algorithm

* Fix the bug of auto_searcher unittest
上级 f68ef9d2
......@@ -15,12 +15,6 @@
from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401
from .process_mesh import ProcessMesh
# from .interface import set_shard_mask # noqa: F401
# from .interface import set_offload_device # noqa: F401
# from .interface import set_pipeline_stage # noqa: F401
# from .interface import ProcessMesh # noqa: F401
from .completion import complete_annotation # noqa: F401
from .completion import complete_backward_annotation # noqa: F401
from .reshard import reshard # noqa: F401
from .cost_model import estimate_cost
......
......@@ -247,23 +247,23 @@ class DistributedContext:
# new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr)
# self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
# def get_dist_attr_for_graph(self, serial_node):
# if serial_node.is_var() and serial_node.var() is not None:
# serial_tensor_node_id = serial_node.id()
# dist_tensor = self._dist_tensors_for_graph.get(
# serial_tensor_node_id, None)
# if dist_tensor:
# return dist_tensor.dist_attr
# else:
# return None
# if serial_node.is_op() and serial_node.op() is not None:
# serial_op_node_id = serial_node.id()
# dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
# if dist_op:
# return dist_op.dist_attr
# else:
# return None
# return None
def get_dist_attr_for_graph(self, serial_node):
if serial_node.is_var() and serial_node.var() is not None:
serial_tensor_node_id = serial_node.id()
dist_tensor = self._dist_tensors_for_graph.get(
serial_tensor_node_id, None)
if dist_tensor:
return dist_tensor.dist_attr
else:
return None
if serial_node.is_op() and serial_node.op() is not None:
serial_op_node_id = serial_node.id()
dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
if dist_op:
return dist_op.dist_attr
else:
return None
return None
def init_dist_attr_for_program(self):
assert self._serial_program, \
......
......@@ -32,7 +32,7 @@ from paddle.distributed.passes import new_pass, PassContext
from .dist_context import DistributedContext
from .dist_context import get_default_distributed_context
from .dist_context import set_default_distributed_context
from .completion import complete_annotation, complete_backward_annotation, complete_update_annotation
from .completion import Completer
from .partitioner import Partitioner
from .process_group import get_all_process_groups
from .process_group import get_process_group
......@@ -130,8 +130,8 @@ class AutoParallelizer:
no_grad_set,
callbacks,
distop_context=self._dist_context.dist_op_context)
complete_backward_annotation(
main_program, dist_context=self._dist_context)
self._completer = Completer(self._dist_context)
self._completer.complete_backward_annotation(main_program)
return params_grads
......@@ -142,8 +142,8 @@ class AutoParallelizer:
params_grads)
# update completion
complete_update_annotation(
main_program, dist_context=self._dist_context)
self._completer = Completer(self._dist_context)
self._completer.complete_update_annotation(main_program)
return optimize_ops
......@@ -179,8 +179,9 @@ class AutoParallelizer:
# Annotation completion
self._dist_context = DistributedContext()
_logger.info("Start annotation dist attr.")
completed_main_program = complete_annotation(serial_main_program,
self._dist_context)
self._completer = Completer(self._dist_context)
completed_main_program = self._completer.complete_forward_annotation(
serial_main_program)
else:
completed_main_program = serial_main_program
self._dist_context = copy.deepcopy(dist_context)
......
......@@ -27,6 +27,7 @@ import paddle.tensor as tensor
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
......@@ -154,10 +155,9 @@ class TestMLPAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_mlp_mp(self):
......@@ -171,10 +171,9 @@ class TestMLPAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_mlp_dp_mp(self):
......@@ -189,10 +188,9 @@ class TestMLPAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
# def test_mlp_misc(self):
......@@ -212,8 +210,8 @@ class TestMLPAutoCompletion(unittest.TestCase):
# train_program, start_program = mlp_pretrain_forward(train_program,
# start_program)
# # pdb.set_trace()
# complete_train_program = auto.complete_annotation(train_program,
# dist_context)
# completer = Completer(dist_context)
# complete_train_program = auto.completer.complete_forward_annotation(train_program)
# # print_program_with_dist_attr(complete_train_program,
# # dist_context)
# dist_context.finalize_distributed_attr_for_program(
......@@ -423,8 +421,9 @@ class TestAttentionAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program())
......@@ -440,10 +439,9 @@ class TestAttentionAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_attn_dp_mp(self):
......@@ -458,10 +456,9 @@ class TestAttentionAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
......@@ -747,10 +744,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_decoder_mp(self):
......@@ -764,10 +760,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_decoder_dp_mp(self):
......@@ -782,10 +777,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
......
......@@ -31,6 +31,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed.fleet import fleet
import paddle.static as static
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.dist_context import DistributedContext
......@@ -817,10 +818,9 @@ class TestGPTAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_gpt_mp(self):
......@@ -834,10 +834,9 @@ class TestGPTAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_gpt_dp_mp(self):
......@@ -852,10 +851,9 @@ class TestGPTAutoCompletion(unittest.TestCase):
dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())
......
......@@ -23,6 +23,7 @@ import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
......@@ -154,8 +155,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context
# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
......
......@@ -18,6 +18,7 @@ import unittest
import paddle
from paddle.fluid import core
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
......@@ -42,8 +43,9 @@ def get_dist_prog(train_program,
parallelizer._dist_context = dist_context
# serial forward & backward completion
complete_train_program = auto.complete_annotation(
train_program, dist_context
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program
) if complete_train_program is None else complete_train_program
# parallelizer._apply_serial_forward_pass(complete_train_program,
......
......@@ -36,6 +36,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner
......@@ -433,6 +434,12 @@ class MLPLayer(nn.Layer):
out = F.gelu(out, approximate=True)
out = self.linear1(out)
auto.shard_tensor(
out,
dist_attr={
"process_mesh": _global_process_mesh[1],
"dims_mapping": [0, -1]
})
out = self.linear2(out)
out = F.gelu(out, approximate=True)
out = self.linear3(out)
......@@ -476,8 +483,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context
# auto completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
......
......@@ -28,6 +28,7 @@ import paddle.tensor as tensor
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
......@@ -49,8 +50,9 @@ def get_programs(annotated_func):
global _global_process_mesh
dist_context.process_mesh = _global_process_mesh
train_program, start_program = annotated_func(train_program, start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
rank_id = 3
dist_strategy = fleet.DistributedStrategy()
......
......@@ -31,6 +31,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed import fleet
import paddle.static as static
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.dist_context import DistributedContext
......@@ -881,8 +882,9 @@ class TestGPTPartitioner(unittest.TestCase):
dist_context.process_mesh = _global_process_mesh
train_program, startup_program, loss = gpt_pretrain_forward(
train_program, startup_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
# serial backward pass
params_grads = parallelizer._generate_backward(
......@@ -913,8 +915,9 @@ class TestGPTPartitioner(unittest.TestCase):
"w") as fw:
fw.write(str(auto_parallel_startup_prog))
# with open("./test_auto_parallel_partitioner_main_completed.txt", "w") as fw:
# from paddle.distributed.auto_parallel.completion import complete_backward_annotation
# complete_backward_annotation(auto_parallel_main_prog)
# from paddle.distributed.auto_parallel.completion import Completer
# completer = Completer()
# completer.complete_forward_annotation(auto_parallel_main_prog)
# fw.write(str(auto_parallel_main_prog))
nrank = 4
# col parallel
......
......@@ -22,6 +22,7 @@ import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
......@@ -152,8 +153,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context
# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
......@@ -299,7 +301,6 @@ class TestMLPReshard(unittest.TestCase):
for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
# print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......
......@@ -22,6 +22,7 @@ import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
......@@ -116,8 +117,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context
# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
......
......@@ -22,6 +22,7 @@ import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
......@@ -132,8 +133,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context
# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
......@@ -263,8 +265,9 @@ class TestMLPReshard(unittest.TestCase):
dist_context = DistributedContext()
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_context, rank_id)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition(
complete_train_program, startup_program, [])
reshard(partitioned_main_prog, partitioned_startup_prog, rank_id,
......
......@@ -154,7 +154,7 @@ class TestMLPSearcher(unittest.TestCase):
ops = train_program.global_block().ops
vars = train_program.global_block().vars
from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container
from paddle.distributed.auto_parallel.completion import is_elementwise_like_op
from paddle.distributed.auto_parallel.operators.common import is_elementwise_op
from paddle.distributed.auto_parallel.dist_op import DistributedOperator
for op in ops:
......@@ -163,7 +163,7 @@ class TestMLPSearcher(unittest.TestCase):
if dist_op_impl_container is None:
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
dist_op = DistributedOperator(op, op_dist_attr)
if is_elementwise_like_op(op.type):
if is_elementwise_op(op.type):
changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
dist_op)
self.assertFalse(changed)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册