未验证 提交 e5cda6fa 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Use the new completion algorithm (#39086)

* Add the backward support for QR

* Remove unnecessary comments

* [Auto Parallel] Improve the dist op interface and compatible computation

* Remove unnecessary modification

* Recover some modifications

* Add lost files

* Fix a minor bug

* Fix the bug of the planner

* Fix the format problem

* [Auto Parallel] Update the completion algorithm

* Fix the bug of auto_searcher unittest
上级 f68ef9d2
...@@ -15,12 +15,6 @@ ...@@ -15,12 +15,6 @@
from .interface import shard_tensor # noqa: F401 from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401 from .interface import shard_op # noqa: F401
from .process_mesh import ProcessMesh from .process_mesh import ProcessMesh
# from .interface import set_shard_mask # noqa: F401
# from .interface import set_offload_device # noqa: F401
# from .interface import set_pipeline_stage # noqa: F401
# from .interface import ProcessMesh # noqa: F401
from .completion import complete_annotation # noqa: F401
from .completion import complete_backward_annotation # noqa: F401
from .reshard import reshard # noqa: F401 from .reshard import reshard # noqa: F401
from .cost_model import estimate_cost from .cost_model import estimate_cost
......
...@@ -247,23 +247,23 @@ class DistributedContext: ...@@ -247,23 +247,23 @@ class DistributedContext:
# new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr) # new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr)
# self._dist_ops_for_graph[serial_op_node_id] = new_dist_op # self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
# def get_dist_attr_for_graph(self, serial_node): def get_dist_attr_for_graph(self, serial_node):
# if serial_node.is_var() and serial_node.var() is not None: if serial_node.is_var() and serial_node.var() is not None:
# serial_tensor_node_id = serial_node.id() serial_tensor_node_id = serial_node.id()
# dist_tensor = self._dist_tensors_for_graph.get( dist_tensor = self._dist_tensors_for_graph.get(
# serial_tensor_node_id, None) serial_tensor_node_id, None)
# if dist_tensor: if dist_tensor:
# return dist_tensor.dist_attr return dist_tensor.dist_attr
# else: else:
# return None return None
# if serial_node.is_op() and serial_node.op() is not None: if serial_node.is_op() and serial_node.op() is not None:
# serial_op_node_id = serial_node.id() serial_op_node_id = serial_node.id()
# dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None) dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
# if dist_op: if dist_op:
# return dist_op.dist_attr return dist_op.dist_attr
# else: else:
# return None return None
# return None return None
def init_dist_attr_for_program(self): def init_dist_attr_for_program(self):
assert self._serial_program, \ assert self._serial_program, \
......
...@@ -32,7 +32,7 @@ from paddle.distributed.passes import new_pass, PassContext ...@@ -32,7 +32,7 @@ from paddle.distributed.passes import new_pass, PassContext
from .dist_context import DistributedContext from .dist_context import DistributedContext
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .dist_context import set_default_distributed_context from .dist_context import set_default_distributed_context
from .completion import complete_annotation, complete_backward_annotation, complete_update_annotation from .completion import Completer
from .partitioner import Partitioner from .partitioner import Partitioner
from .process_group import get_all_process_groups from .process_group import get_all_process_groups
from .process_group import get_process_group from .process_group import get_process_group
...@@ -130,8 +130,8 @@ class AutoParallelizer: ...@@ -130,8 +130,8 @@ class AutoParallelizer:
no_grad_set, no_grad_set,
callbacks, callbacks,
distop_context=self._dist_context.dist_op_context) distop_context=self._dist_context.dist_op_context)
complete_backward_annotation( self._completer = Completer(self._dist_context)
main_program, dist_context=self._dist_context) self._completer.complete_backward_annotation(main_program)
return params_grads return params_grads
...@@ -142,8 +142,8 @@ class AutoParallelizer: ...@@ -142,8 +142,8 @@ class AutoParallelizer:
params_grads) params_grads)
# update completion # update completion
complete_update_annotation( self._completer = Completer(self._dist_context)
main_program, dist_context=self._dist_context) self._completer.complete_update_annotation(main_program)
return optimize_ops return optimize_ops
...@@ -179,8 +179,9 @@ class AutoParallelizer: ...@@ -179,8 +179,9 @@ class AutoParallelizer:
# Annotation completion # Annotation completion
self._dist_context = DistributedContext() self._dist_context = DistributedContext()
_logger.info("Start annotation dist attr.") _logger.info("Start annotation dist attr.")
completed_main_program = complete_annotation(serial_main_program, self._completer = Completer(self._dist_context)
self._dist_context) completed_main_program = self._completer.complete_forward_annotation(
serial_main_program)
else: else:
completed_main_program = serial_main_program completed_main_program = serial_main_program
self._dist_context = copy.deepcopy(dist_context) self._dist_context = copy.deepcopy(dist_context)
......
...@@ -27,6 +27,7 @@ import paddle.tensor as tensor ...@@ -27,6 +27,7 @@ import paddle.tensor as tensor
from paddle.fluid import layers from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
...@@ -154,10 +155,9 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -154,10 +155,9 @@ class TestMLPAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program, train_program, start_program = mlp_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_mlp_mp(self): def test_mlp_mp(self):
...@@ -171,10 +171,9 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -171,10 +171,9 @@ class TestMLPAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program, train_program, start_program = mlp_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_mlp_dp_mp(self): def test_mlp_dp_mp(self):
...@@ -189,10 +188,9 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -189,10 +188,9 @@ class TestMLPAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program, train_program, start_program = mlp_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
# def test_mlp_misc(self): # def test_mlp_misc(self):
...@@ -212,8 +210,8 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -212,8 +210,8 @@ class TestMLPAutoCompletion(unittest.TestCase):
# train_program, start_program = mlp_pretrain_forward(train_program, # train_program, start_program = mlp_pretrain_forward(train_program,
# start_program) # start_program)
# # pdb.set_trace() # # pdb.set_trace()
# complete_train_program = auto.complete_annotation(train_program, # completer = Completer(dist_context)
# dist_context) # complete_train_program = auto.completer.complete_forward_annotation(train_program)
# # print_program_with_dist_attr(complete_train_program, # # print_program_with_dist_attr(complete_train_program,
# # dist_context) # # dist_context)
# dist_context.finalize_distributed_attr_for_program( # dist_context.finalize_distributed_attr_for_program(
...@@ -423,8 +421,9 @@ class TestAttentionAutoCompletion(unittest.TestCase): ...@@ -423,8 +421,9 @@ class TestAttentionAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program, train_program, start_program = attn_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
# print_program_with_dist_attr(complete_train_program, # print_program_with_dist_attr(complete_train_program,
# dist_context) # dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
...@@ -440,10 +439,9 @@ class TestAttentionAutoCompletion(unittest.TestCase): ...@@ -440,10 +439,9 @@ class TestAttentionAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program, train_program, start_program = attn_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_attn_dp_mp(self): def test_attn_dp_mp(self):
...@@ -458,10 +456,9 @@ class TestAttentionAutoCompletion(unittest.TestCase): ...@@ -458,10 +456,9 @@ class TestAttentionAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program, train_program, start_program = attn_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
...@@ -747,10 +744,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): ...@@ -747,10 +744,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program, train_program, start_program = decoder_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_decoder_mp(self): def test_decoder_mp(self):
...@@ -764,10 +760,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): ...@@ -764,10 +760,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program, train_program, start_program = decoder_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_decoder_dp_mp(self): def test_decoder_dp_mp(self):
...@@ -782,10 +777,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): ...@@ -782,10 +777,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program, train_program, start_program = decoder_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
......
...@@ -31,6 +31,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer ...@@ -31,6 +31,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed.fleet import fleet from paddle.distributed.fleet import fleet
import paddle.static as static import paddle.static as static
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
...@@ -817,10 +818,9 @@ class TestGPTAutoCompletion(unittest.TestCase): ...@@ -817,10 +818,9 @@ class TestGPTAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program, train_program, start_program = gpt_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_gpt_mp(self): def test_gpt_mp(self):
...@@ -834,10 +834,9 @@ class TestGPTAutoCompletion(unittest.TestCase): ...@@ -834,10 +834,9 @@ class TestGPTAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program, train_program, start_program = gpt_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_gpt_dp_mp(self): def test_gpt_dp_mp(self):
...@@ -852,10 +851,9 @@ class TestGPTAutoCompletion(unittest.TestCase): ...@@ -852,10 +851,9 @@ class TestGPTAutoCompletion(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program, train_program, start_program = gpt_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
# print_program_with_dist_attr(complete_train_program, train_program)
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program()) self.assertTrue(dist_context.validate_dist_attr_for_program())
......
...@@ -23,6 +23,7 @@ import paddle.static as static ...@@ -23,6 +23,7 @@ import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
...@@ -154,8 +155,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -154,8 +155,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context parallelizer._dist_context = dist_context
# serial forward & backward completion # serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
......
...@@ -18,6 +18,7 @@ import unittest ...@@ -18,6 +18,7 @@ import unittest
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
...@@ -42,8 +43,9 @@ def get_dist_prog(train_program, ...@@ -42,8 +43,9 @@ def get_dist_prog(train_program,
parallelizer._dist_context = dist_context parallelizer._dist_context = dist_context
# serial forward & backward completion # serial forward & backward completion
complete_train_program = auto.complete_annotation( completer = Completer(dist_context)
train_program, dist_context complete_train_program = completer.complete_forward_annotation(
train_program
) if complete_train_program is None else complete_train_program ) if complete_train_program is None else complete_train_program
# parallelizer._apply_serial_forward_pass(complete_train_program, # parallelizer._apply_serial_forward_pass(complete_train_program,
......
...@@ -36,6 +36,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer ...@@ -36,6 +36,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed import fleet from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
...@@ -433,6 +434,12 @@ class MLPLayer(nn.Layer): ...@@ -433,6 +434,12 @@ class MLPLayer(nn.Layer):
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
out = self.linear1(out) out = self.linear1(out)
auto.shard_tensor(
out,
dist_attr={
"process_mesh": _global_process_mesh[1],
"dims_mapping": [0, -1]
})
out = self.linear2(out) out = self.linear2(out)
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
out = self.linear3(out) out = self.linear3(out)
...@@ -476,8 +483,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -476,8 +483,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context parallelizer._dist_context = dist_context
# auto completion # auto completion
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
......
...@@ -28,6 +28,7 @@ import paddle.tensor as tensor ...@@ -28,6 +28,7 @@ import paddle.tensor as tensor
from paddle.fluid import layers from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
...@@ -49,8 +50,9 @@ def get_programs(annotated_func): ...@@ -49,8 +50,9 @@ def get_programs(annotated_func):
global _global_process_mesh global _global_process_mesh
dist_context.process_mesh = _global_process_mesh dist_context.process_mesh = _global_process_mesh
train_program, start_program = annotated_func(train_program, start_program) train_program, start_program = annotated_func(train_program, start_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
rank_id = 3 rank_id = 3
dist_strategy = fleet.DistributedStrategy() dist_strategy = fleet.DistributedStrategy()
......
...@@ -31,6 +31,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer ...@@ -31,6 +31,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed import fleet from paddle.distributed import fleet
import paddle.static as static import paddle.static as static
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
...@@ -881,8 +882,9 @@ class TestGPTPartitioner(unittest.TestCase): ...@@ -881,8 +882,9 @@ class TestGPTPartitioner(unittest.TestCase):
dist_context.process_mesh = _global_process_mesh dist_context.process_mesh = _global_process_mesh
train_program, startup_program, loss = gpt_pretrain_forward( train_program, startup_program, loss = gpt_pretrain_forward(
train_program, startup_program) train_program, startup_program)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
# serial backward pass # serial backward pass
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
...@@ -913,8 +915,9 @@ class TestGPTPartitioner(unittest.TestCase): ...@@ -913,8 +915,9 @@ class TestGPTPartitioner(unittest.TestCase):
"w") as fw: "w") as fw:
fw.write(str(auto_parallel_startup_prog)) fw.write(str(auto_parallel_startup_prog))
# with open("./test_auto_parallel_partitioner_main_completed.txt", "w") as fw: # with open("./test_auto_parallel_partitioner_main_completed.txt", "w") as fw:
# from paddle.distributed.auto_parallel.completion import complete_backward_annotation # from paddle.distributed.auto_parallel.completion import Completer
# complete_backward_annotation(auto_parallel_main_prog) # completer = Completer()
# completer.complete_forward_annotation(auto_parallel_main_prog)
# fw.write(str(auto_parallel_main_prog)) # fw.write(str(auto_parallel_main_prog))
nrank = 4 nrank = 4
# col parallel # col parallel
......
...@@ -22,6 +22,7 @@ import paddle.static as static ...@@ -22,6 +22,7 @@ import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
...@@ -152,8 +153,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -152,8 +153,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context parallelizer._dist_context = dist_context
# serial forward & backward completion # serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
...@@ -299,7 +301,6 @@ class TestMLPReshard(unittest.TestCase): ...@@ -299,7 +301,6 @@ class TestMLPReshard(unittest.TestCase):
for key in list(_g_process_group_map.keys()): for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key] del _g_process_group_map[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
# print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......
...@@ -22,6 +22,7 @@ import paddle.static as static ...@@ -22,6 +22,7 @@ import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
...@@ -116,8 +117,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -116,8 +117,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context parallelizer._dist_context = dist_context
# serial forward & backward completion # serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
......
...@@ -22,6 +22,7 @@ import paddle.static as static ...@@ -22,6 +22,7 @@ import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
...@@ -132,8 +133,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -132,8 +133,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context parallelizer._dist_context = dist_context
# serial forward & backward completion # serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
...@@ -263,8 +265,9 @@ class TestMLPReshard(unittest.TestCase): ...@@ -263,8 +265,9 @@ class TestMLPReshard(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
dist_strategy = fleet.DistributedStrategy() dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_context, rank_id) partitioner = Partitioner(dist_context, rank_id)
complete_train_program = auto.complete_annotation(train_program, completer = Completer(dist_context)
dist_context) complete_train_program = completer.complete_forward_annotation(
train_program)
partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition( partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition(
complete_train_program, startup_program, []) complete_train_program, startup_program, [])
reshard(partitioned_main_prog, partitioned_startup_prog, rank_id, reshard(partitioned_main_prog, partitioned_startup_prog, rank_id,
......
...@@ -154,7 +154,7 @@ class TestMLPSearcher(unittest.TestCase): ...@@ -154,7 +154,7 @@ class TestMLPSearcher(unittest.TestCase):
ops = train_program.global_block().ops ops = train_program.global_block().ops
vars = train_program.global_block().vars vars = train_program.global_block().vars
from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container
from paddle.distributed.auto_parallel.completion import is_elementwise_like_op from paddle.distributed.auto_parallel.operators.common import is_elementwise_op
from paddle.distributed.auto_parallel.dist_op import DistributedOperator from paddle.distributed.auto_parallel.dist_op import DistributedOperator
for op in ops: for op in ops:
...@@ -163,7 +163,7 @@ class TestMLPSearcher(unittest.TestCase): ...@@ -163,7 +163,7 @@ class TestMLPSearcher(unittest.TestCase):
if dist_op_impl_container is None: if dist_op_impl_container is None:
op_dist_attr = dist_context.get_op_dist_attr_for_program(op) op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
dist_op = DistributedOperator(op, op_dist_attr) dist_op = DistributedOperator(op, op_dist_attr)
if is_elementwise_like_op(op.type): if is_elementwise_op(op.type):
changed = update_op_dims_mapping_by_elementwise_like_dist_impl( changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
dist_op) dist_op)
self.assertFalse(changed) self.assertFalse(changed)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册