未验证 提交 a3790606 编写于 作者: L lilong12 提交者: GitHub

remove distributed attributes at the last stage for auto parallel (#35605)

* update
上级 2c70b844
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import paddle import paddle
from paddle.distributed.fleet import cloud_utils from paddle.distributed.fleet import cloud_utils
import paddle.fluid.core as core
from .context import DistributedContext from .context import DistributedContext
from .context import get_default_distributed_context from .context import get_default_distributed_context
from .completion import complete_annotation from .completion import complete_annotation
...@@ -38,6 +39,16 @@ class AutoParallelizer: ...@@ -38,6 +39,16 @@ class AutoParallelizer:
# self._dist_context = DistributedContext() # self._dist_context = DistributedContext()
self._dist_context = get_default_distributed_context() self._dist_context = get_default_distributed_context()
def _remove_distributed_attrs(self, main_program):
suffix = core.kAutoParallelSuffix()
# distributed attributes for variable have been removed
# in previous process.
for block in main_program.blocks:
for op in block.ops:
for attr_name in op.attr_names:
if suffix in attr_name:
op._remove_attr(attr_name)
def parallelize(self, def parallelize(self,
loss, loss,
startup_program=None, startup_program=None,
...@@ -76,4 +87,8 @@ class AutoParallelizer: ...@@ -76,4 +87,8 @@ class AutoParallelizer:
for process_group in all_process_groups: for process_group in all_process_groups:
process_group.instantiate() process_group.instantiate()
# The last step: remove all distributed attributes to be compatiable
# with inference.
self._remove_distributed_attrs(partitioned_main_prog)
return dist_optimize_ops, dist_params_grads, partitioned_startup_prog, partitioned_main_prog return dist_optimize_ops, dist_params_grads, partitioned_startup_prog, partitioned_main_prog
...@@ -30,6 +30,7 @@ from paddle.fluid import layers ...@@ -30,6 +30,7 @@ from paddle.fluid import layers
from paddle.distributed import fleet from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr
import paddle.fluid.core as core
paddle.enable_static() paddle.enable_static()
_global_parallel_strategy = None _global_parallel_strategy = None
...@@ -83,6 +84,7 @@ def mlp_pretrain_forward(train_program, start_program): ...@@ -83,6 +84,7 @@ def mlp_pretrain_forward(train_program, start_program):
name="label", shape=[batch_size, sequence_len, 1], dtype='float32') name="label", shape=[batch_size, sequence_len, 1], dtype='float32')
auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1, -1]) auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1, -1])
auto.set_pipeline_stage(1)
mlp = MLPLayer( mlp = MLPLayer(
hidden_size=hidden_size, hidden_size=hidden_size,
...@@ -129,6 +131,11 @@ class TestMLPAutoParallelizer(unittest.TestCase): ...@@ -129,6 +131,11 @@ class TestMLPAutoParallelizer(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer) optimizer = fleet.distributed_optimizer(optimizer)
_, _, distributed_startup_program, distributed_main_program = optimizer.minimize( _, _, distributed_startup_program, distributed_main_program = optimizer.minimize(
loss, start_program) loss, start_program)
suffix = core.kAutoParallelSuffix()
for block in distributed_main_program.blocks:
for op in block.ops:
for attr_name in op.attr_names:
self.assertTrue(suffix not in attr_name)
# print_program_with_distributed_attr(distributed_main_program) # print_program_with_distributed_attr(distributed_main_program)
self.assertIsNotNone(distributed_startup_program) self.assertIsNotNone(distributed_startup_program)
self.assertIsNotNone(distributed_main_program) self.assertIsNotNone(distributed_main_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册