未验证 提交 a3790606 编写于 作者: L lilong12 提交者: GitHub

remove distributed attributes at the last stage for auto parallel (#35605)

* update
上级 2c70b844
......@@ -14,6 +14,7 @@
import paddle
from paddle.distributed.fleet import cloud_utils
import paddle.fluid.core as core
from .context import DistributedContext
from .context import get_default_distributed_context
from .completion import complete_annotation
......@@ -38,6 +39,16 @@ class AutoParallelizer:
# self._dist_context = DistributedContext()
self._dist_context = get_default_distributed_context()
def _remove_distributed_attrs(self, main_program):
suffix = core.kAutoParallelSuffix()
# distributed attributes for variable have been removed
# in previous process.
for block in main_program.blocks:
for op in block.ops:
for attr_name in op.attr_names:
if suffix in attr_name:
op._remove_attr(attr_name)
def parallelize(self,
loss,
startup_program=None,
......@@ -76,4 +87,8 @@ class AutoParallelizer:
for process_group in all_process_groups:
process_group.instantiate()
# The last step: remove all distributed attributes to be compatiable
# with inference.
self._remove_distributed_attrs(partitioned_main_prog)
return dist_optimize_ops, dist_params_grads, partitioned_startup_prog, partitioned_main_prog
......@@ -30,6 +30,7 @@ from paddle.fluid import layers
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr
import paddle.fluid.core as core
paddle.enable_static()
_global_parallel_strategy = None
......@@ -83,6 +84,7 @@ def mlp_pretrain_forward(train_program, start_program):
name="label", shape=[batch_size, sequence_len, 1], dtype='float32')
auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1, -1])
auto.set_pipeline_stage(1)
mlp = MLPLayer(
hidden_size=hidden_size,
......@@ -129,6 +131,11 @@ class TestMLPAutoParallelizer(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer)
_, _, distributed_startup_program, distributed_main_program = optimizer.minimize(
loss, start_program)
suffix = core.kAutoParallelSuffix()
for block in distributed_main_program.blocks:
for op in block.ops:
for attr_name in op.attr_names:
self.assertTrue(suffix not in attr_name)
# print_program_with_distributed_attr(distributed_main_program)
self.assertIsNotNone(distributed_startup_program)
self.assertIsNotNone(distributed_main_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册