未验证 提交 5be7a1ff 编写于 作者: Z zhaoyingli 提交者: GitHub

retain dist op returns (#44634)

上级 f49b0cb9
......@@ -267,6 +267,4 @@ class DistributedModule:
dist_op = DistributedOperator(op, self._dist_attr)
dist_op.dist_attr.mark_annotated_as(self._dist_attr)
default_dist_ctx.add_dist_op_for_program(dist_op)
if isinstance(output, Variable):
output = [output]
return list(output)
return output
......@@ -89,11 +89,11 @@ class MLPLayer(nn.Layer):
def forward(self, input):
out = auto.shard_op(self.norm, dist_attr={"process_mesh":
PP_MESH_0})(input)[0]
PP_MESH_0})(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = auto.shard_op(self.linear1, dist_attr={"process_mesh":
PP_MESH_1})(out)[0]
PP_MESH_1})(out)
out = self.dropout(out)
out = self.linear2(out)
self.out = out
......
......@@ -391,7 +391,7 @@ class TransformerDecoder(nn.Layer):
mod,
dist_attr={
"process_mesh": PP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0]
})(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
......@@ -405,7 +405,7 @@ class TransformerDecoder(nn.Layer):
mod,
dist_attr={
"process_mesh": DPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0]
})(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
......@@ -419,7 +419,7 @@ class TransformerDecoder(nn.Layer):
mod,
dist_attr={
"process_mesh": MPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0]
})(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
......@@ -433,7 +433,7 @@ class TransformerDecoder(nn.Layer):
mod,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0]
})(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
......@@ -456,7 +456,7 @@ class TransformerDecoder(nn.Layer):
"process_mesh":
PP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask,
use_cache, cache)[0]
use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
......@@ -471,7 +471,7 @@ class TransformerDecoder(nn.Layer):
"process_mesh":
DPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask,
use_cache, cache)[0]
use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
......@@ -486,7 +486,7 @@ class TransformerDecoder(nn.Layer):
"process_mesh":
MPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask,
use_cache, cache)[0]
use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
......@@ -500,7 +500,7 @@ class TransformerDecoder(nn.Layer):
mod,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0]
})(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
......
......@@ -255,12 +255,6 @@ class TestMLPReshard(unittest.TestCase):
"dims_mapping": [-1, -1]
})
# y = paddle.distributed.shard_op(paddle.matmul, process_mesh, {
# x.name: [-1, -1],
# w.name: [-1, -1]
# }, **{"x": x,
# "y": w})[0]
y = paddle.distributed.shard_op(paddle.matmul,
dist_attr={
"process_mesh": process_mesh,
......@@ -270,7 +264,7 @@ class TestMLPReshard(unittest.TestCase):
w: {
"dims_mapping": [-1, -1]
}
})(x, w)[0]
})(x, w)
rank_id = 0
dist_context = DistributedContext()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册