From 5be7a1ffc82fa805214217fa7610661c22259402 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Wed, 27 Jul 2022 17:29:09 +0800 Subject: [PATCH] retain dist op returns (#44634) --- .../paddle/distributed/auto_parallel/dist_op.py | 4 +--- .../tests/unittests/auto_parallel/engine_api.py | 4 ++-- .../tests/unittests/auto_parallel_gpt_model.py | 16 ++++++++-------- .../unittests/test_auto_parallel_reshard_mppp.py | 8 +------- 4 files changed, 12 insertions(+), 20 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/dist_op.py b/python/paddle/distributed/auto_parallel/dist_op.py index d48804b71f..b6a77b7788 100644 --- a/python/paddle/distributed/auto_parallel/dist_op.py +++ b/python/paddle/distributed/auto_parallel/dist_op.py @@ -267,6 +267,4 @@ class DistributedModule: dist_op = DistributedOperator(op, self._dist_attr) dist_op.dist_attr.mark_annotated_as(self._dist_attr) default_dist_ctx.add_dist_op_for_program(dist_op) - if isinstance(output, Variable): - output = [output] - return list(output) + return output diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py index ec757c0347..9335d7d9d2 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py @@ -89,11 +89,11 @@ class MLPLayer(nn.Layer): def forward(self, input): out = auto.shard_op(self.norm, dist_attr={"process_mesh": - PP_MESH_0})(input)[0] + PP_MESH_0})(input) out = self.linear0(out) out = F.gelu(out, approximate=True) out = auto.shard_op(self.linear1, dist_attr={"process_mesh": - PP_MESH_1})(out)[0] + PP_MESH_1})(out) out = self.dropout(out) out = self.linear2(out) self.out = out diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py b/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py index 4695f6a4a9..87c746ab5d 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py @@ -391,7 +391,7 @@ class TransformerDecoder(nn.Layer): mod, dist_attr={ "process_mesh": PP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + })(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -405,7 +405,7 @@ class TransformerDecoder(nn.Layer): mod, dist_attr={ "process_mesh": DPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + })(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -419,7 +419,7 @@ class TransformerDecoder(nn.Layer): mod, dist_attr={ "process_mesh": MPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + })(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -433,7 +433,7 @@ class TransformerDecoder(nn.Layer): mod, dist_attr={ "process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + })(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -456,7 +456,7 @@ class TransformerDecoder(nn.Layer): "process_mesh": PP_MESH_LIST[mod.mesh_idx] })(output, memory, tgt_mask, - use_cache, cache)[0] + use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -471,7 +471,7 @@ class TransformerDecoder(nn.Layer): "process_mesh": DPPP_MESH_LIST[mod.mesh_idx] })(output, memory, tgt_mask, - use_cache, cache)[0] + use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -486,7 +486,7 @@ class TransformerDecoder(nn.Layer): "process_mesh": MPPP_MESH_LIST[mod.mesh_idx] })(output, memory, tgt_mask, - use_cache, cache)[0] + use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -500,7 +500,7 @@ class TransformerDecoder(nn.Layer): mod, dist_attr={ "process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + })(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( output, dist_attr={ diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py index 0e647a3db5..dfb314796a 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py @@ -255,12 +255,6 @@ class TestMLPReshard(unittest.TestCase): "dims_mapping": [-1, -1] }) - # y = paddle.distributed.shard_op(paddle.matmul, process_mesh, { - # x.name: [-1, -1], - # w.name: [-1, -1] - # }, **{"x": x, - # "y": w})[0] - y = paddle.distributed.shard_op(paddle.matmul, dist_attr={ "process_mesh": process_mesh, @@ -270,7 +264,7 @@ class TestMLPReshard(unittest.TestCase): w: { "dims_mapping": [-1, -1] } - })(x, w)[0] + })(x, w) rank_id = 0 dist_context = DistributedContext() -- GitLab