diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 2e36e92b3445a04d99d32c316d0fe456458ad983..a08da13a39cafaf135ccc8855537721421da0053 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -14,6 +14,7 @@ import paddle from paddle.distributed.fleet import cloud_utils +import paddle.fluid.core as core from .context import DistributedContext from .context import get_default_distributed_context from .completion import complete_annotation @@ -38,6 +39,16 @@ class AutoParallelizer: # self._dist_context = DistributedContext() self._dist_context = get_default_distributed_context() + def _remove_distributed_attrs(self, main_program): + suffix = core.kAutoParallelSuffix() + # distributed attributes for variable have been removed + # in previous process. + for block in main_program.blocks: + for op in block.ops: + for attr_name in op.attr_names: + if suffix in attr_name: + op._remove_attr(attr_name) + def parallelize(self, loss, startup_program=None, @@ -76,4 +87,8 @@ class AutoParallelizer: for process_group in all_process_groups: process_group.instantiate() + # The last step: remove all distributed attributes to be compatiable + # with inference. + self._remove_distributed_attrs(partitioned_main_prog) + return dist_optimize_ops, dist_params_grads, partitioned_startup_prog, partitioned_main_prog diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_parallelizer.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_parallelizer.py index 6db7fbf807568c7e278b236bb53426f22ee91a3a..a92e1e2f338b1008957f5cf56221b4ebe052917e 100755 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_parallelizer.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_parallelizer.py @@ -30,6 +30,7 @@ from paddle.fluid import layers from paddle.distributed import fleet import paddle.distributed.auto_parallel as auto from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr +import paddle.fluid.core as core paddle.enable_static() _global_parallel_strategy = None @@ -83,6 +84,7 @@ def mlp_pretrain_forward(train_program, start_program): name="label", shape=[batch_size, sequence_len, 1], dtype='float32') auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1, -1]) + auto.set_pipeline_stage(1) mlp = MLPLayer( hidden_size=hidden_size, @@ -129,6 +131,11 @@ class TestMLPAutoParallelizer(unittest.TestCase): optimizer = fleet.distributed_optimizer(optimizer) _, _, distributed_startup_program, distributed_main_program = optimizer.minimize( loss, start_program) + suffix = core.kAutoParallelSuffix() + for block in distributed_main_program.blocks: + for op in block.ops: + for attr_name in op.attr_names: + self.assertTrue(suffix not in attr_name) # print_program_with_distributed_attr(distributed_main_program) self.assertIsNotNone(distributed_startup_program) self.assertIsNotNone(distributed_main_program)