diff --git a/python/paddle/distributed/auto_parallel/dist_op.py b/python/paddle/distributed/auto_parallel/dist_op.py index d48804b71fc3e0c146814a36e9367498efb1d484..b6a77b778885f552438fb654ab3c4a88ebd79a05 100644 --- a/python/paddle/distributed/auto_parallel/dist_op.py +++ b/python/paddle/distributed/auto_parallel/dist_op.py @@ -267,6 +267,4 @@ class DistributedModule: dist_op = DistributedOperator(op, self._dist_attr) dist_op.dist_attr.mark_annotated_as(self._dist_attr) default_dist_ctx.add_dist_op_for_program(dist_op) - if isinstance(output, Variable): - output = [output] - return list(output) + return output diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py index ec757c03478defee80b56322ee5f5f533a466934..9335d7d9d2e03f6674513dac9c341d8898486613 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py @@ -89,11 +89,11 @@ class MLPLayer(nn.Layer): def forward(self, input): out = auto.shard_op(self.norm, dist_attr={"process_mesh": - PP_MESH_0})(input)[0] + PP_MESH_0})(input) out = self.linear0(out) out = F.gelu(out, approximate=True) out = auto.shard_op(self.linear1, dist_attr={"process_mesh": - PP_MESH_1})(out)[0] + PP_MESH_1})(out) out = self.dropout(out) out = self.linear2(out) self.out = out diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py b/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py index 4695f6a4a9425aeba1aecbeabebf738410965d9e..87c746ab5d3b506ba865904d15bf04ac0310f85d 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py @@ -391,7 +391,7 @@ class TransformerDecoder(nn.Layer): mod, dist_attr={ "process_mesh": PP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + })(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -405,7 +405,7 @@ class TransformerDecoder(nn.Layer): mod, dist_attr={ "process_mesh": DPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + })(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -419,7 +419,7 @@ class TransformerDecoder(nn.Layer): mod, dist_attr={ "process_mesh": MPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + })(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -433,7 +433,7 @@ class TransformerDecoder(nn.Layer): mod, dist_attr={ "process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + })(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -456,7 +456,7 @@ class TransformerDecoder(nn.Layer): "process_mesh": PP_MESH_LIST[mod.mesh_idx] })(output, memory, tgt_mask, - use_cache, cache)[0] + use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -471,7 +471,7 @@ class TransformerDecoder(nn.Layer): "process_mesh": DPPP_MESH_LIST[mod.mesh_idx] })(output, memory, tgt_mask, - use_cache, cache)[0] + use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -486,7 +486,7 @@ class TransformerDecoder(nn.Layer): "process_mesh": MPPP_MESH_LIST[mod.mesh_idx] })(output, memory, tgt_mask, - use_cache, cache)[0] + use_cache, cache) auto.shard_tensor( output, dist_attr={ @@ -500,7 +500,7 @@ class TransformerDecoder(nn.Layer): mod, dist_attr={ "process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache)[0] + })(output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( output, dist_attr={ diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py index 0e647a3db5b64bcc649002e8f618108b781d2d13..dfb314796a9ff2166fd6970c46ebec2455280a1c 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py @@ -255,12 +255,6 @@ class TestMLPReshard(unittest.TestCase): "dims_mapping": [-1, -1] }) - # y = paddle.distributed.shard_op(paddle.matmul, process_mesh, { - # x.name: [-1, -1], - # w.name: [-1, -1] - # }, **{"x": x, - # "y": w})[0] - y = paddle.distributed.shard_op(paddle.matmul, dist_attr={ "process_mesh": process_mesh, @@ -270,7 +264,7 @@ class TestMLPReshard(unittest.TestCase): w: { "dims_mapping": [-1, -1] } - })(x, w)[0] + })(x, w) rank_id = 0 dist_context = DistributedContext()