未验证 提交 30b66f03 编写于 作者: Z zhaoyingli 提交者: GitHub

fix conflict (#44891)

上级 247002ec
...@@ -26,6 +26,7 @@ from .dist_attribute import get_op_dist_attr_field_keys ...@@ -26,6 +26,7 @@ from .dist_attribute import get_op_dist_attr_field_keys
class DistributedOperator: class DistributedOperator:
def __init__(self, serial_op, dist_attr=None): def __init__(self, serial_op, dist_attr=None):
self._serial_op = serial_op self._serial_op = serial_op
self._serial_inputs = {} self._serial_inputs = {}
...@@ -248,6 +249,7 @@ class DistributedOperator: ...@@ -248,6 +249,7 @@ class DistributedOperator:
class DistributedModule: class DistributedModule:
def __init__(self, serial_module, dist_attr=None): def __init__(self, serial_module, dist_attr=None):
self._serial_module = serial_module self._serial_module = serial_module
self._dist_attr = dist_attr self._dist_attr = dist_attr
...@@ -265,6 +267,4 @@ class DistributedModule: ...@@ -265,6 +267,4 @@ class DistributedModule:
dist_op = DistributedOperator(op, self._dist_attr) dist_op = DistributedOperator(op, self._dist_attr)
dist_op.dist_attr.mark_annotated_as(self._dist_attr) dist_op.dist_attr.mark_annotated_as(self._dist_attr)
default_dist_ctx.add_dist_op_for_program(dist_op) default_dist_ctx.add_dist_op_for_program(dist_op)
if isinstance(output, Variable): return output
output = [output]
return list(output)
...@@ -47,6 +47,7 @@ paddle.seed(44) ...@@ -47,6 +47,7 @@ paddle.seed(44)
class MyDataset(Dataset): class MyDataset(Dataset):
def __init__(self, num_samples): def __init__(self, num_samples):
super(MyDataset, self).__init__() super(MyDataset, self).__init__()
self.num_samples = num_samples self.num_samples = num_samples
...@@ -61,6 +62,7 @@ class MyDataset(Dataset): ...@@ -61,6 +62,7 @@ class MyDataset(Dataset):
class MLPLayer(nn.Layer): class MLPLayer(nn.Layer):
def __init__(self, def __init__(self,
hidden_size=1024, hidden_size=1024,
intermediate_size=4 * 1024, intermediate_size=4 * 1024,
...@@ -69,39 +71,41 @@ class MLPLayer(nn.Layer): ...@@ -69,39 +71,41 @@ class MLPLayer(nn.Layer):
super(MLPLayer, self).__init__() super(MLPLayer, self).__init__()
d_model = hidden_size d_model = hidden_size
dim_feedforward = intermediate_size dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( weight_attr = paddle.ParamAttr(
mean=0.0, std=initializer_range)) initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
bias_attr = None bias_attr = None
self.linear0 = nn.Linear( self.linear0 = nn.Linear(d_model,
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) dim_feedforward,
self.linear1 = nn.Linear( weight_attr,
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) bias_attr=bias_attr)
self.linear1 = nn.Linear(dim_feedforward,
d_model,
weight_attr,
bias_attr=bias_attr)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr) self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5) self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input): def forward(self, input):
out = auto.shard_op( out = auto.shard_op(self.norm, dist_attr={"process_mesh":
self.norm, dist_attr={"process_mesh": PP_MESH_0})(input)[0] PP_MESH_0})(input)
out = self.linear0(input) out = self.linear0(out)
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
out = auto.shard_op( out = auto.shard_op(self.linear1, dist_attr={"process_mesh":
self.linear1, dist_attr={"process_mesh": PP_MESH_1})(out)[0] PP_MESH_1})(out)
out = self.dropout(out) out = self.dropout(out)
out = self.linear2(out) out = self.linear2(out)
return out return out
def train(): def train():
mlp = MLPLayer( mlp = MLPLayer(hidden_size=hidden_size,
hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
dropout_ratio=0.1, dropout_ratio=0.1,
initializer_range=0.02) initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss() loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.fluid.optimizer.AdamOptimizer( optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001,
learning_rate=0.00001,
beta1=0.9, beta1=0.9,
beta2=0.999, beta2=0.999,
epsilon=1e-08, epsilon=1e-08,
...@@ -119,8 +123,7 @@ def train(): ...@@ -119,8 +123,7 @@ def train():
dist_strategy.semi_auto = True dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy) fleet.init(is_collective=True, strategy=dist_strategy)
engine = Engine( engine = Engine(mlp,
mlp,
inputs_spec=inputs_spec, inputs_spec=inputs_spec,
labels_spec=labels_spec, labels_spec=labels_spec,
strategy=dist_strategy) strategy=dist_strategy)
......
...@@ -76,26 +76,27 @@ class MultiHeadAttention(nn.Layer): ...@@ -76,26 +76,27 @@ class MultiHeadAttention(nn.Layer):
if self.fuse: if self.fuse:
assert self.kdim == embed_dim assert self.kdim == embed_dim
assert self.vdim == embed_dim assert self.vdim == embed_dim
self.qkv_proj = nn.Linear( self.qkv_proj = nn.Linear(embed_dim,
embed_dim, 3 * embed_dim, weight_attr, bias_attr=bias_attr) 3 * embed_dim,
weight_attr,
bias_attr=bias_attr)
else: else:
self.q_proj = nn.Linear( self.q_proj = nn.Linear(embed_dim,
embed_dim, embed_dim,
weight_attr=weight_attr,
bias_attr=bias_attr)
self.k_proj = nn.Linear(self.kdim,
embed_dim, embed_dim,
weight_attr=weight_attr, weight_attr=weight_attr,
bias_attr=bias_attr) bias_attr=bias_attr)
self.k_proj = nn.Linear( self.v_proj = nn.Linear(self.vdim,
self.kdim,
embed_dim, embed_dim,
weight_attr=weight_attr, weight_attr=weight_attr,
bias_attr=bias_attr) bias_attr=bias_attr)
self.v_proj = nn.Linear( self.out_proj = nn.Linear(embed_dim,
self.vdim,
embed_dim, embed_dim,
weight_attr=weight_attr, weight_attr=weight_attr,
bias_attr=bias_attr) bias_attr=bias_attr)
self.out_proj = nn.Linear(
embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias_attr)
def _fuse_prepare_qkv(self, query): def _fuse_prepare_qkv(self, query):
mix_layer = self.qkv_proj(query) mix_layer = self.qkv_proj(query)
...@@ -113,31 +114,28 @@ class MultiHeadAttention(nn.Layer): ...@@ -113,31 +114,28 @@ class MultiHeadAttention(nn.Layer):
""" """
q = self.q_proj(query) q = self.q_proj(query)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(self.q_proj.weight,
self.q_proj.weight,
dist_attr={ dist_attr={
"process_mesh": _global_process_mesh, "process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0] "dims_mapping": [-1, 0]
}) })
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(self.q_proj.weight,
self.q_proj.weight,
dist_attr={ dist_attr={
"process_mesh": _global_process_mesh, "process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1] "dims_mapping": [-1, 1]
}) })
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor( auto.shard_tensor(self.q_proj.weight,
self.q_proj.weight,
dist_attr={ dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx], "process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 0] "dims_mapping": [-1, 0]
}) })
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor( auto.shard_tensor(self.q_proj.weight,
self.q_proj.weight,
dist_attr={ dist_attr={
"process_mesh": DPMPPP_MESH_LIST[self.mesh_idx], "process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1] "dims_mapping": [-1, 1]
}) })
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
...@@ -167,60 +165,54 @@ class MultiHeadAttention(nn.Layer): ...@@ -167,60 +165,54 @@ class MultiHeadAttention(nn.Layer):
""" """
k = self.k_proj(key) k = self.k_proj(key)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(self.k_proj.weight,
self.k_proj.weight,
dist_attr={ dist_attr={
"process_mesh": _global_process_mesh, "process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0] "dims_mapping": [-1, 0]
}) })
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(self.k_proj.weight,
self.k_proj.weight,
dist_attr={ dist_attr={
"process_mesh": _global_process_mesh, "process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1] "dims_mapping": [-1, 1]
}) })
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor( auto.shard_tensor(self.k_proj.weight,
self.k_proj.weight,
dist_attr={ dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx], "process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 0] "dims_mapping": [-1, 0]
}) })
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor( auto.shard_tensor(self.k_proj.weight,
self.k_proj.weight,
dist_attr={ dist_attr={
"process_mesh": DPMPPP_MESH_LIST[self.mesh_idx], "process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1] "dims_mapping": [-1, 1]
}) })
v = self.v_proj(value) v = self.v_proj(value)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(self.v_proj.weight,
self.v_proj.weight,
dist_attr={ dist_attr={
"process_mesh": _global_process_mesh, "process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0] "dims_mapping": [-1, 0]
}) })
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(self.v_proj.weight,
self.v_proj.weight,
dist_attr={ dist_attr={
"process_mesh": _global_process_mesh, "process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1] "dims_mapping": [-1, 1]
}) })
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor( auto.shard_tensor(self.v_proj.weight,
self.v_proj.weight,
dist_attr={ dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx], "process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 0] "dims_mapping": [-1, 0]
}) })
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor( auto.shard_tensor(self.v_proj.weight,
self.v_proj.weight,
dist_attr={ dist_attr={
"process_mesh": DPMPPP_MESH_LIST[self.mesh_idx], "process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1] "dims_mapping": [-1, 1]
}) })
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
...@@ -276,14 +268,15 @@ class MultiHeadAttention(nn.Layer): ...@@ -276,14 +268,15 @@ class MultiHeadAttention(nn.Layer):
else: else:
q, k, v, cache = self._prepare_qkv(query, key, value, use_cache, q, k, v, cache = self._prepare_qkv(query, key, value, use_cache,
cache) cache)
product = layers.matmul( product = layers.matmul(x=q,
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) y=k,
transpose_y=True,
alpha=self.head_dim**-0.5)
if attn_mask is not None: if attn_mask is not None:
product = product + attn_mask product = product + attn_mask
weights = F.softmax(product) weights = F.softmax(product)
if self.dropout: if self.dropout:
weights = F.dropout( weights = F.dropout(weights,
weights,
self.dropout, self.dropout,
training=self.training, training=self.training,
mode="upscale_in_train") mode="upscale_in_train")
...@@ -294,31 +287,28 @@ class MultiHeadAttention(nn.Layer): ...@@ -294,31 +287,28 @@ class MultiHeadAttention(nn.Layer):
# project to output # project to output
out = self.out_proj(out) out = self.out_proj(out)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(self.out_proj.weight,
self.out_proj.weight,
dist_attr={ dist_attr={
"process_mesh": _global_process_mesh, "process_mesh": _global_process_mesh,
"dims_mapping": [0, -1] "dims_mapping": [0, -1]
}) })
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(self.out_proj.weight,
self.out_proj.weight,
dist_attr={ dist_attr={
"process_mesh": _global_process_mesh, "process_mesh": _global_process_mesh,
"dims_mapping": [1, -1] "dims_mapping": [1, -1]
}) })
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor( auto.shard_tensor(self.out_proj.weight,
self.out_proj.weight,
dist_attr={ dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx], "process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [0, -1] "dims_mapping": [0, -1]
}) })
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor( auto.shard_tensor(self.out_proj.weight,
self.out_proj.weight,
dist_attr={ dist_attr={
"process_mesh": DPMPPP_MESH_LIST[self.mesh_idx], "process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [1, -1] "dims_mapping": [1, -1]
}) })
outs = [out] outs = [out]
...@@ -362,35 +352,36 @@ class TransformerDecoder(nn.Layer): ...@@ -362,35 +352,36 @@ class TransformerDecoder(nn.Layer):
new_caches = [] new_caches = []
self.checkpoints = [] self.checkpoints = []
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
auto.shard_tensor( auto.shard_tensor(output,
output,
dist_attr={ dist_attr={
"process_mesh": PP_MESH_LIST[0], "process_mesh":
"dims_mapping": [-1 for i in range(len(output.shape))] PP_MESH_LIST[0],
"dims_mapping":
[-1 for i in range(len(output.shape))]
}) })
if _global_parallel_strategy == "dp_pp": if _global_parallel_strategy == "dp_pp":
auto.shard_tensor( auto.shard_tensor(output,
output,
dist_attr={ dist_attr={
"process_mesh": DPPP_MESH_LIST[0], "process_mesh":
"dims_mapping": DPPP_MESH_LIST[0],
[0] + [-1 for i in range(len(output.shape) - 1)] "dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)]
}) })
if _global_parallel_strategy == "mp_pp": if _global_parallel_strategy == "mp_pp":
auto.shard_tensor( auto.shard_tensor(output,
output,
dist_attr={ dist_attr={
"process_mesh": MPPP_MESH_LIST[0], "process_mesh":
"dims_mapping": MPPP_MESH_LIST[0],
[-1] + [-1 for i in range(len(output.shape) - 1)] "dims_mapping": [-1] +
[-1 for i in range(len(output.shape) - 1)]
}) })
if _global_parallel_strategy == "dp_mp_pp": if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor( auto.shard_tensor(output,
output,
dist_attr={ dist_attr={
"process_mesh": DPMPPP_MESH_LIST[0], "process_mesh":
"dims_mapping": DPMPPP_MESH_LIST[0],
[0] + [-1 for i in range(len(output.shape) - 1)] "dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)]
}) })
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
if cache is None: if cache is None:
...@@ -400,11 +391,12 @@ class TransformerDecoder(nn.Layer): ...@@ -400,11 +391,12 @@ class TransformerDecoder(nn.Layer):
mod, mod,
dist_attr={ dist_attr={
"process_mesh": PP_MESH_LIST[mod.mesh_idx] "process_mesh": PP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0] })(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": PP_MESH_LIST[mod.mesh_idx], "process_mesh":
PP_MESH_LIST[mod.mesh_idx],
"dims_mapping": "dims_mapping":
[-1 for i in range(len(output.shape))] [-1 for i in range(len(output.shape))]
}) })
...@@ -413,11 +405,12 @@ class TransformerDecoder(nn.Layer): ...@@ -413,11 +405,12 @@ class TransformerDecoder(nn.Layer):
mod, mod,
dist_attr={ dist_attr={
"process_mesh": DPPP_MESH_LIST[mod.mesh_idx] "process_mesh": DPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0] })(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": DPPP_MESH_LIST[mod.mesh_idx], "process_mesh":
DPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [0] + "dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)] [-1 for i in range(len(output.shape) - 1)]
}) })
...@@ -426,11 +419,12 @@ class TransformerDecoder(nn.Layer): ...@@ -426,11 +419,12 @@ class TransformerDecoder(nn.Layer):
mod, mod,
dist_attr={ dist_attr={
"process_mesh": MPPP_MESH_LIST[mod.mesh_idx] "process_mesh": MPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0] })(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": MPPP_MESH_LIST[mod.mesh_idx], "process_mesh":
MPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [-1] + "dims_mapping": [-1] +
[-1 for i in range(len(output.shape) - 1)] [-1 for i in range(len(output.shape) - 1)]
}) })
...@@ -439,11 +433,12 @@ class TransformerDecoder(nn.Layer): ...@@ -439,11 +433,12 @@ class TransformerDecoder(nn.Layer):
mod, mod,
dist_attr={ dist_attr={
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx] "process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0] })(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx], "process_mesh":
DPMPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [0] + "dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)] [-1 for i in range(len(output.shape) - 1)]
}) })
...@@ -456,41 +451,47 @@ class TransformerDecoder(nn.Layer): ...@@ -456,41 +451,47 @@ class TransformerDecoder(nn.Layer):
new_caches.append(new_cache) new_caches.append(new_cache)
else: else:
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
output = auto.shard_op( output = auto.shard_op(mod,
mod,
dist_attr={ dist_attr={
"process_mesh": PP_MESH_LIST[mod.mesh_idx] "process_mesh":
})(output, memory, tgt_mask, use_cache, cache)[0] PP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask,
use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": PP_MESH_LIST[mod.mesh_idx], "process_mesh":
PP_MESH_LIST[mod.mesh_idx],
"dims_mapping": "dims_mapping":
[-1 for i in range(len(output.shape))] [-1 for i in range(len(output.shape))]
}) })
elif _global_parallel_strategy == "dp_pp": elif _global_parallel_strategy == "dp_pp":
output = auto.shard_op( output = auto.shard_op(mod,
mod,
dist_attr={ dist_attr={
"process_mesh": DPPP_MESH_LIST[mod.mesh_idx] "process_mesh":
})(output, memory, tgt_mask, use_cache, cache)[0] DPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask,
use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": DPPP_MESH_LIST[mod.mesh_idx], "process_mesh":
DPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [0] + "dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)] [-1 for i in range(len(output.shape) - 1)]
}) })
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
output = auto.shard_op( output = auto.shard_op(mod,
mod,
dist_attr={ dist_attr={
"process_mesh": MPPP_MESH_LIST[mod.mesh_idx] "process_mesh":
})(output, memory, tgt_mask, use_cache, cache)[0] MPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask,
use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": MPPP_MESH_LIST[mod.mesh_idx], "process_mesh":
MPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [-1] + "dims_mapping": [-1] +
[-1 for i in range(len(output.shape) - 1)] [-1 for i in range(len(output.shape) - 1)]
}) })
...@@ -499,11 +500,12 @@ class TransformerDecoder(nn.Layer): ...@@ -499,11 +500,12 @@ class TransformerDecoder(nn.Layer):
mod, mod,
dist_attr={ dist_attr={
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx] "process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0] })(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx], "process_mesh":
DPMPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [0] + "dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)] [-1 for i in range(len(output.shape) - 1)]
}) })
...@@ -517,8 +519,9 @@ class TransformerDecoder(nn.Layer): ...@@ -517,8 +519,9 @@ class TransformerDecoder(nn.Layer):
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
output, new_cache = auto.shard_op( output, new_cache = auto.shard_op(
mod, mod,
dist_attr={"process_mesh": PP_MESH_LIST[mod.mesh_idx]})( dist_attr={"process_mesh": PP_MESH_LIST[mod.mesh_idx]
output, memory, tgt_mask, use_cache, cache) })(output, memory, tgt_mask, use_cache,
cache)
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
...@@ -535,7 +538,8 @@ class TransformerDecoder(nn.Layer): ...@@ -535,7 +538,8 @@ class TransformerDecoder(nn.Layer):
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": DPPP_MESH_LIST[mod.mesh_idx], "process_mesh":
DPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": "dims_mapping":
[0] + [-1 for i in range(len(output.shape) - 1)] [0] + [-1 for i in range(len(output.shape) - 1)]
}) })
...@@ -548,7 +552,8 @@ class TransformerDecoder(nn.Layer): ...@@ -548,7 +552,8 @@ class TransformerDecoder(nn.Layer):
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": MPPP_MESH_LIST[mod.mesh_idx], "process_mesh":
MPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": "dims_mapping":
[-1] + [-1 for i in range(len(output.shape) - 1)] [-1] + [-1 for i in range(len(output.shape) - 1)]
}) })
...@@ -561,7 +566,8 @@ class TransformerDecoder(nn.Layer): ...@@ -561,7 +566,8 @@ class TransformerDecoder(nn.Layer):
auto.shard_tensor( auto.shard_tensor(
output, output,
dist_attr={ dist_attr={
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx], "process_mesh":
DPMPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": "dims_mapping":
[0] + [-1 for i in range(len(output.shape) - 1)] [0] + [-1 for i in range(len(output.shape) - 1)]
}) })
...@@ -619,17 +625,20 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -619,17 +625,20 @@ class TransformerDecoderLayer(nn.Layer):
self.normalize_before = normalize_before self.normalize_before = normalize_before
weight_attrs = _convert_param_attr_to_list(weight_attr, 3) weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
bias_attrs = _convert_param_attr_to_list(bias_attr, 3) bias_attrs = _convert_param_attr_to_list(bias_attr, 3)
self.self_attn = MultiHeadAttention( self.self_attn = MultiHeadAttention(d_model,
d_model,
nhead, nhead,
dropout=attn_dropout, dropout=attn_dropout,
weight_attr=weight_attrs[0], weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0], bias_attr=bias_attrs[0],
mesh_idx=self.mesh_idx) mesh_idx=self.mesh_idx)
self.linear1 = nn.Linear( self.linear1 = nn.Linear(d_model,
d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2]) dim_feedforward,
self.linear2 = nn.Linear( weight_attrs[2],
dim_feedforward, d_model, weight_attrs[2], bias_attr=bias_attrs[2]) bias_attr=bias_attrs[2])
self.linear2 = nn.Linear(dim_feedforward,
d_model,
weight_attrs[2],
bias_attr=bias_attrs[2])
self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5) self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5) self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train") self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
...@@ -652,72 +661,65 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -652,72 +661,65 @@ class TransformerDecoderLayer(nn.Layer):
if self.normalize_before: if self.normalize_before:
tgt = self.norm2(tgt) tgt = self.norm2(tgt)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(self.linear1.weight,
self.linear1.weight,
dist_attr={ dist_attr={
"process_mesh": _global_process_mesh, "process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0] "dims_mapping": [-1, 0]
}) })
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(self.linear1.weight,
self.linear1.weight,
dist_attr={ dist_attr={
"process_mesh": _global_process_mesh, "process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1] "dims_mapping": [-1, 1]
}) })
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor( auto.shard_tensor(self.linear1.weight,
self.linear1.weight,
dist_attr={ dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx], "process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 0] "dims_mapping": [-1, 0]
}) })
if _global_parallel_strategy == "dp_mp_pp": if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor( auto.shard_tensor(self.linear1.weight,
self.linear1.weight,
dist_attr={ dist_attr={
"process_mesh": DPMPPP_MESH_LIST[self.mesh_idx], "process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1] "dims_mapping": [-1, 1]
}) })
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(self.linear2.weight,
self.linear2.weight,
dist_attr={ dist_attr={
"process_mesh": _global_process_mesh, "process_mesh": _global_process_mesh,
"dims_mapping": [0, -1] "dims_mapping": [0, -1]
}) })
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(self.linear2.weight,
self.linear2.weight,
dist_attr={ dist_attr={
"process_mesh": _global_process_mesh, "process_mesh": _global_process_mesh,
"dims_mapping": [1, -1] "dims_mapping": [1, -1]
}) })
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor( auto.shard_tensor(self.linear2.weight,
self.linear2.weight,
dist_attr={ dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx], "process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [0, -1] "dims_mapping": [0, -1]
}) })
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor( auto.shard_tensor(self.linear2.weight,
self.linear2.weight,
dist_attr={ dist_attr={
"process_mesh": DPMPPP_MESH_LIST[self.mesh_idx], "process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [1, -1] "dims_mapping": [1, -1]
}) })
tgt = self.dropout2( tgt = self.dropout2(
self.linear2(F.gelu( self.linear2(F.gelu(self.linear1(tgt), approximate=True)))
self.linear1(tgt), approximate=True)))
tgt = residual + tgt tgt = residual + tgt
if not self.normalize_before: if not self.normalize_before:
tgt = self.norm2(tgt) tgt = self.norm2(tgt)
return tgt if use_cache is False else (tgt, incremental_cache) return tgt if use_cache is False else (tgt, incremental_cache)
def gen_cache(self, memory): def gen_cache(self, memory):
incremental_cache = self.self_attn.gen_cache( incremental_cache = self.self_attn.gen_cache(memory,
memory, type=self.self_attn.Cache) type=self.self_attn.Cache)
return incremental_cache return incremental_cache
...@@ -737,15 +739,13 @@ class GPTEmbeddings(nn.Layer): ...@@ -737,15 +739,13 @@ class GPTEmbeddings(nn.Layer):
self.word_embeddings = nn.Embedding( self.word_embeddings = nn.Embedding(
vocab_size, vocab_size,
hidden_size, hidden_size,
weight_attr=paddle.ParamAttr( weight_attr=paddle.ParamAttr(name="word_embeddings",
name="word_embeddings",
initializer=nn.initializer.Normal( initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))) mean=0.0, std=initializer_range)))
self.position_embeddings = nn.Embedding( self.position_embeddings = nn.Embedding(
max_position_embeddings, max_position_embeddings,
hidden_size, hidden_size,
weight_attr=paddle.ParamAttr( weight_attr=paddle.ParamAttr(name="pos_embeddings",
name="pos_embeddings",
initializer=nn.initializer.Normal( initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))) mean=0.0, std=initializer_range)))
self.dropout = nn.Dropout(hidden_dropout_prob) self.dropout = nn.Dropout(hidden_dropout_prob)
...@@ -757,29 +757,25 @@ class GPTEmbeddings(nn.Layer): ...@@ -757,29 +757,25 @@ class GPTEmbeddings(nn.Layer):
position_ids = seq_length - ones position_ids = seq_length - ones
input_embedings = self.word_embeddings(input_ids) input_embedings = self.word_embeddings(input_ids)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(self.word_embeddings.weight,
self.word_embeddings.weight,
dist_attr={ dist_attr={
"process_mesh": _global_process_mesh, "process_mesh": _global_process_mesh,
"dims_mapping": [0, -1] "dims_mapping": [0, -1]
}) })
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(self.word_embeddings.weight,
self.word_embeddings.weight,
dist_attr={ dist_attr={
"process_mesh": _global_process_mesh, "process_mesh": _global_process_mesh,
"dims_mapping": [1, -1] "dims_mapping": [1, -1]
}) })
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor( auto.shard_tensor(self.word_embeddings.weight,
self.word_embeddings.weight,
dist_attr={ dist_attr={
"process_mesh": MPPP_MESH_LIST[0], "process_mesh": MPPP_MESH_LIST[0],
"dims_mapping": [0, -1] "dims_mapping": [0, -1]
}) })
elif _global_parallel_strategy == "dp_mp_pp": elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor( auto.shard_tensor(self.word_embeddings.weight,
self.word_embeddings.weight,
dist_attr={ dist_attr={
"process_mesh": DPMPPP_MESH_LIST[0], "process_mesh": DPMPPP_MESH_LIST[0],
"dims_mapping": [1, -1] "dims_mapping": [1, -1]
...@@ -821,9 +817,10 @@ class GPTModel(nn.Layer): ...@@ -821,9 +817,10 @@ class GPTModel(nn.Layer):
self.pipline_mode = (pp_degree is not None and pp_degree > 1) self.pipline_mode = (pp_degree is not None and pp_degree > 1)
if self.pipline_mode: if self.pipline_mode:
self.layer_per_stage = num_hidden_layers // pp_degree self.layer_per_stage = num_hidden_layers // pp_degree
self.embeddings = GPTEmbeddings( self.embeddings = GPTEmbeddings(vocab_size, hidden_size,
vocab_size, hidden_size, hidden_dropout_prob, hidden_dropout_prob,
max_position_embeddings, type_vocab_size, self.initializer_range) max_position_embeddings,
type_vocab_size, self.initializer_range)
decoder_layers = nn.LayerList() decoder_layers = nn.LayerList()
for i in range(num_hidden_layers): for i in range(num_hidden_layers):
mesh_index = None mesh_index = None
...@@ -831,8 +828,7 @@ class GPTModel(nn.Layer): ...@@ -831,8 +828,7 @@ class GPTModel(nn.Layer):
if self.layer_per_stage is not None: if self.layer_per_stage is not None:
mesh_index = i // self.layer_per_stage mesh_index = i // self.layer_per_stage
decoder_layers.append( decoder_layers.append(
DecoderLayer( DecoderLayer(d_model=hidden_size,
d_model=hidden_size,
nhead=num_attention_heads, nhead=num_attention_heads,
dim_feedforward=intermediate_size, dim_feedforward=intermediate_size,
dropout=hidden_dropout_prob, dropout=hidden_dropout_prob,
...@@ -845,8 +841,7 @@ class GPTModel(nn.Layer): ...@@ -845,8 +841,7 @@ class GPTModel(nn.Layer):
bias_attr=None, bias_attr=None,
mesh_idx=mesh_index)) mesh_idx=mesh_index))
Decoder = TransformerDecoder Decoder = TransformerDecoder
self.decoder = Decoder( self.decoder = Decoder(decoder_layers,
decoder_layers,
num_hidden_layers, num_hidden_layers,
norm="LayerNorm", norm="LayerNorm",
hidden_size=hidden_size) hidden_size=hidden_size)
...@@ -863,40 +858,40 @@ class GPTModel(nn.Layer): ...@@ -863,40 +858,40 @@ class GPTModel(nn.Layer):
past_length = 0 past_length = 0
if cache is not None: if cache is not None:
past_length = paddle.shape(cache[0].k)[-2] past_length = paddle.shape(cache[0].k)[-2]
position_ids = paddle.arange( position_ids = paddle.arange(past_length,
paddle.shape(input_ids)[-1] +
past_length, past_length,
paddle.shape(input_ids)[-1] + past_length,
dtype='int64') dtype='int64')
position_ids = position_ids.unsqueeze(0) position_ids = position_ids.unsqueeze(0)
position_ids = paddle.fluid.layers.expand_as(position_ids, position_ids = paddle.fluid.layers.expand_as(
input_ids) position_ids, input_ids)
embedding_output = self.embeddings( embedding_output = self.embeddings(input_ids=input_ids,
input_ids=input_ids, position_ids=position_ids) position_ids=position_ids)
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
auto.shard_tensor( auto.shard_tensor(input_ids,
input_ids,
dist_attr={ dist_attr={
"process_mesh": PP_MESH_LIST[0], "process_mesh":
"dims_mapping": [-1 for i in range(len(input_ids.shape))] PP_MESH_LIST[0],
"dims_mapping":
[-1 for i in range(len(input_ids.shape))]
}) })
if _global_parallel_strategy == "dp_pp": if _global_parallel_strategy == "dp_pp":
auto.shard_tensor( auto.shard_tensor(input_ids,
input_ids,
dist_attr={ dist_attr={
"process_mesh": DPPP_MESH_LIST[0], "process_mesh":
"dims_mapping": DPPP_MESH_LIST[0],
[0] + [-1 for i in range(len(input_ids.shape) - 1)] "dims_mapping": [0] +
[-1 for i in range(len(input_ids.shape) - 1)]
}) })
if _global_parallel_strategy == "dp_mp_pp": if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor( auto.shard_tensor(input_ids,
input_ids,
dist_attr={ dist_attr={
"process_mesh": DPMPPP_MESH_LIST[0], "process_mesh":
"dims_mapping": DPMPPP_MESH_LIST[0],
[0] + [-1 for i in range(len(input_ids.shape) - 1)] "dims_mapping": [0] +
[-1 for i in range(len(input_ids.shape) - 1)]
}) })
encoder_outputs = self.decoder( encoder_outputs = self.decoder(embedding_output,
embedding_output,
memory=None, memory=None,
tgt_mask=attention_mask, tgt_mask=attention_mask,
use_cache=use_cache, use_cache=use_cache,
...@@ -916,13 +911,13 @@ class GPTForPretraining(nn.Layer): ...@@ -916,13 +911,13 @@ class GPTForPretraining(nn.Layer):
gpt, gpt,
vocab_size=50304, vocab_size=50304,
hidden_size=768, hidden_size=768,
initializer_range=0.02, ): initializer_range=0.02,
):
super(GPTForPretraining, self).__init__() super(GPTForPretraining, self).__init__()
self.output_embeddings = nn.Embedding( self.output_embeddings = nn.Embedding(
vocab_size, vocab_size,
hidden_size, hidden_size,
weight_attr=paddle.ParamAttr( weight_attr=paddle.ParamAttr(name="output_embeddings",
name="output_embeddings",
initializer=nn.initializer.Normal( initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))) mean=0.0, std=initializer_range)))
self.gpt = gpt self.gpt = gpt
...@@ -943,8 +938,9 @@ class GPTForPretraining(nn.Layer): ...@@ -943,8 +938,9 @@ class GPTForPretraining(nn.Layer):
encoder_outputs, cached_kvs = outputs[:2] encoder_outputs, cached_kvs = outputs[:2]
else: else:
encoder_outputs = outputs encoder_outputs = outputs
logits = paddle.matmul( logits = paddle.matmul(encoder_outputs,
encoder_outputs, self.output_embeddings.weight, transpose_y=True) self.output_embeddings.weight,
transpose_y=True)
if use_cache: if use_cache:
return logits, cached_kvs return logits, cached_kvs
else: else:
......
...@@ -38,6 +38,7 @@ PP_MESH_1 = auto.ProcessMesh([2, 3]) ...@@ -38,6 +38,7 @@ PP_MESH_1 = auto.ProcessMesh([2, 3])
class MLPLayer(nn.Layer): class MLPLayer(nn.Layer):
def __init__(self, def __init__(self,
hidden_size=1024, hidden_size=1024,
intermediate_size=4 * 1024, intermediate_size=4 * 1024,
...@@ -45,42 +46,51 @@ class MLPLayer(nn.Layer): ...@@ -45,42 +46,51 @@ class MLPLayer(nn.Layer):
super(MLPLayer, self).__init__() super(MLPLayer, self).__init__()
d_model = hidden_size d_model = hidden_size
dim_feedforward = intermediate_size dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( weight_attr = paddle.ParamAttr(
mean=0.0, std=initializer_range)) initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
bias_attr = None bias_attr = None
self.word_embeddings = nn.Embedding( self.word_embeddings = nn.Embedding(
hidden_size, hidden_size,
hidden_size, hidden_size,
weight_attr=paddle.ParamAttr( weight_attr=paddle.ParamAttr(name="word_embeddings",
name="word_embeddings",
initializer=nn.initializer.Normal( initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))) mean=0.0, std=initializer_range)))
self.linear0 = nn.Linear( self.linear0 = nn.Linear(d_model,
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) dim_feedforward,
self.linear1 = nn.Linear( weight_attr,
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) bias_attr=bias_attr)
self.linear2 = nn.Linear( self.linear1 = nn.Linear(dim_feedforward,
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) d_model,
weight_attr,
bias_attr=bias_attr)
self.linear2 = nn.Linear(dim_feedforward,
d_model,
weight_attr,
bias_attr=bias_attr)
def forward(self, input): def forward(self, input):
auto.shard_tensor( auto.shard_tensor(self.word_embeddings.weight,
self.word_embeddings.weight, dist_attr={
dist_attr={"process_mesh": PP_MESH_0, "process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]}) "dims_mapping": [0, -1]
auto.shard_tensor( })
self.linear0.weight, auto.shard_tensor(self.linear0.weight,
dist_attr={"process_mesh": PP_MESH_0, dist_attr={
"dims_mapping": [-1, 0]}) "process_mesh": PP_MESH_0,
auto.shard_tensor( "dims_mapping": [-1, 0]
self.linear1.weight, })
dist_attr={"process_mesh": PP_MESH_1, auto.shard_tensor(self.linear1.weight,
"dims_mapping": [0, -1]}) dist_attr={
auto.shard_tensor( "process_mesh": PP_MESH_1,
self.linear2.weight, "dims_mapping": [0, -1]
dist_attr={"process_mesh": PP_MESH_1, })
"dims_mapping": [0, -1]}) auto.shard_tensor(self.linear2.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]
})
w_out = self.word_embeddings(input) w_out = self.word_embeddings(input)
out = self.linear0(w_out) out = self.linear0(w_out)
gelu_out = F.gelu(out, approximate=True) gelu_out = F.gelu(out, approximate=True)
...@@ -98,19 +108,22 @@ def mlp_forward(train_program, start_program): ...@@ -98,19 +108,22 @@ def mlp_forward(train_program, start_program):
hidden_size = 1024 hidden_size = 1024
sequence_len = 512 sequence_len = 512
input = static.data(name="input", shape=[batch_size], dtype='int32') input = static.data(name="input", shape=[batch_size], dtype='int32')
label = static.data( label = static.data(name="label",
name="label", shape=[batch_size, 1], dtype='float32') shape=[batch_size, 1],
dtype='float32')
auto.shard_tensor(
input, dist_attr={"process_mesh": PP_MESH_0, auto.shard_tensor(input,
"dims_mapping": [-1]}) dist_attr={
auto.shard_tensor( "process_mesh": PP_MESH_0,
label, "dims_mapping": [-1]
dist_attr={"process_mesh": PP_MESH_1, })
"dims_mapping": [-1, -1]}) auto.shard_tensor(label,
dist_attr={
mlp = MLPLayer( "process_mesh": PP_MESH_1,
hidden_size=hidden_size, "dims_mapping": [-1, -1]
})
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
initializer_range=0.02) initializer_range=0.02)
...@@ -137,8 +150,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -137,8 +150,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
complete_train_program = completer.complete_forward_annotation( complete_train_program = completer.complete_forward_annotation(
train_program) train_program)
dist_context.block_state.parse_forward_blocks(complete_train_program) dist_context.block_state.parse_forward_blocks(complete_train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(complete_train_program,
complete_train_program,
startup_program, startup_program,
loss, loss,
parameter_list=None, parameter_list=None,
...@@ -171,8 +183,7 @@ def check_send_recv_result(dist_main_prog, rank_id): ...@@ -171,8 +183,7 @@ def check_send_recv_result(dist_main_prog, rank_id):
if op.type == "send_v2" and "gelu_0.tmp_0@GRAD" in op.input_arg_names[ if op.type == "send_v2" and "gelu_0.tmp_0@GRAD" in op.input_arg_names[
0]: 0]:
send_result = True send_result = True
if op.type == "recv_v2" and "gelu_0.tmp_0" in op.output_arg_names[ if op.type == "recv_v2" and "gelu_0.tmp_0" in op.output_arg_names[0]:
0]:
recv_result = True recv_result = True
return send_result and recv_result return send_result and recv_result
...@@ -206,6 +217,7 @@ def check_allgather(dist_main_program): ...@@ -206,6 +217,7 @@ def check_allgather(dist_main_program):
class TestMLPReshard(unittest.TestCase): class TestMLPReshard(unittest.TestCase):
def test_mlp_mppp(self): def test_mlp_mppp(self):
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
...@@ -230,29 +242,20 @@ class TestMLPReshard(unittest.TestCase): ...@@ -230,29 +242,20 @@ class TestMLPReshard(unittest.TestCase):
process_mesh = auto.ProcessMesh(mesh=[0, 3]) process_mesh = auto.ProcessMesh(mesh=[0, 3])
with static.program_guard(train_program, startup_program): with static.program_guard(train_program, startup_program):
x = paddle.static.data(name="x", shape=[4, 4], dtype='float32') x = paddle.static.data(name="x", shape=[4, 4], dtype='float32')
x = auto.shard_tensor( x = auto.shard_tensor(x,
x,
dist_attr={ dist_attr={
"process_mesh": process_mesh, "process_mesh": process_mesh,
"dims_mapping": [0, -1] "dims_mapping": [0, -1]
}) })
w = paddle.static.data(name="w", shape=[4, 4], dtype='float32') w = paddle.static.data(name="w", shape=[4, 4], dtype='float32')
w = auto.shard_tensor( w = auto.shard_tensor(w,
w,
dist_attr={ dist_attr={
"process_mesh": process_mesh, "process_mesh": process_mesh,
"dims_mapping": [-1, -1] "dims_mapping": [-1, -1]
}) })
# y = paddle.distributed.shard_op(paddle.matmul, process_mesh, { y = paddle.distributed.shard_op(paddle.matmul,
# x.name: [-1, -1],
# w.name: [-1, -1]
# }, **{"x": x,
# "y": w})[0]
y = paddle.distributed.shard_op(
paddle.matmul,
dist_attr={ dist_attr={
"process_mesh": process_mesh, "process_mesh": process_mesh,
x: { x: {
...@@ -261,7 +264,7 @@ class TestMLPReshard(unittest.TestCase): ...@@ -261,7 +264,7 @@ class TestMLPReshard(unittest.TestCase):
w: { w: {
"dims_mapping": [-1, -1] "dims_mapping": [-1, -1]
} }
})(x, w)[0] })(x, w)
rank_id = 0 rank_id = 0
dist_context = DistributedContext() dist_context = DistributedContext()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册