未验证 提交 30b66f03 编写于 作者: Z zhaoyingli 提交者: GitHub

fix conflict (#44891)

上级 247002ec
......@@ -26,6 +26,7 @@ from .dist_attribute import get_op_dist_attr_field_keys
class DistributedOperator:
def __init__(self, serial_op, dist_attr=None):
self._serial_op = serial_op
self._serial_inputs = {}
......@@ -248,6 +249,7 @@ class DistributedOperator:
class DistributedModule:
def __init__(self, serial_module, dist_attr=None):
self._serial_module = serial_module
self._dist_attr = dist_attr
......@@ -265,6 +267,4 @@ class DistributedModule:
dist_op = DistributedOperator(op, self._dist_attr)
dist_op.dist_attr.mark_annotated_as(self._dist_attr)
default_dist_ctx.add_dist_op_for_program(dist_op)
if isinstance(output, Variable):
output = [output]
return list(output)
return output
......@@ -47,6 +47,7 @@ paddle.seed(44)
class MyDataset(Dataset):
def __init__(self, num_samples):
super(MyDataset, self).__init__()
self.num_samples = num_samples
......@@ -61,6 +62,7 @@ class MyDataset(Dataset):
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
......@@ -69,43 +71,45 @@ class MLPLayer(nn.Layer):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.linear0 = nn.Linear(d_model,
dim_feedforward,
weight_attr,
bias_attr=bias_attr)
self.linear1 = nn.Linear(dim_feedforward,
d_model,
weight_attr,
bias_attr=bias_attr)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
out = auto.shard_op(
self.norm, dist_attr={"process_mesh": PP_MESH_0})(input)[0]
out = self.linear0(input)
out = auto.shard_op(self.norm, dist_attr={"process_mesh":
PP_MESH_0})(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = auto.shard_op(
self.linear1, dist_attr={"process_mesh": PP_MESH_1})(out)[0]
out = auto.shard_op(self.linear1, dist_attr={"process_mesh":
PP_MESH_1})(out)
out = self.dropout(out)
out = self.linear2(out)
return out
def train():
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
dataset = MyDataset(batch_num * batch_size)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x')
......@@ -119,11 +123,10 @@ def train():
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
engine = Engine(
mlp,
inputs_spec=inputs_spec,
labels_spec=labels_spec,
strategy=dist_strategy)
engine = Engine(mlp,
inputs_spec=inputs_spec,
labels_spec=labels_spec,
strategy=dist_strategy)
engine.prepare(optimizer, loss)
engine.fit(dataset,
batch_size=batch_size,
......
......@@ -76,26 +76,27 @@ class MultiHeadAttention(nn.Layer):
if self.fuse:
assert self.kdim == embed_dim
assert self.vdim == embed_dim
self.qkv_proj = nn.Linear(
embed_dim, 3 * embed_dim, weight_attr, bias_attr=bias_attr)
self.qkv_proj = nn.Linear(embed_dim,
3 * embed_dim,
weight_attr,
bias_attr=bias_attr)
else:
self.q_proj = nn.Linear(
embed_dim,
embed_dim,
weight_attr=weight_attr,
bias_attr=bias_attr)
self.k_proj = nn.Linear(
self.kdim,
embed_dim,
weight_attr=weight_attr,
bias_attr=bias_attr)
self.v_proj = nn.Linear(
self.vdim,
embed_dim,
weight_attr=weight_attr,
bias_attr=bias_attr)
self.out_proj = nn.Linear(
embed_dim, embed_dim, weight_attr=weight_attr, bias_attr=bias_attr)
self.q_proj = nn.Linear(embed_dim,
embed_dim,
weight_attr=weight_attr,
bias_attr=bias_attr)
self.k_proj = nn.Linear(self.kdim,
embed_dim,
weight_attr=weight_attr,
bias_attr=bias_attr)
self.v_proj = nn.Linear(self.vdim,
embed_dim,
weight_attr=weight_attr,
bias_attr=bias_attr)
self.out_proj = nn.Linear(embed_dim,
embed_dim,
weight_attr=weight_attr,
bias_attr=bias_attr)
def _fuse_prepare_qkv(self, query):
mix_layer = self.qkv_proj(query)
......@@ -113,33 +114,30 @@ class MultiHeadAttention(nn.Layer):
"""
q = self.q_proj(query)
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(
self.q_proj.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(
self.q_proj.weight,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1]
})
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
if isinstance(cache, self.StaticCache):
......@@ -167,62 +165,56 @@ class MultiHeadAttention(nn.Layer):
"""
k = self.k_proj(key)
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(
self.k_proj.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(
self.k_proj.weight,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1]
})
v = self.v_proj(value)
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(
self.v_proj.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(
self.v_proj.weight,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1]
})
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
......@@ -276,17 +268,18 @@ class MultiHeadAttention(nn.Layer):
else:
q, k, v, cache = self._prepare_qkv(query, key, value, use_cache,
cache)
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
product = layers.matmul(x=q,
y=k,
transpose_y=True,
alpha=self.head_dim**-0.5)
if attn_mask is not None:
product = product + attn_mask
weights = F.softmax(product)
if self.dropout:
weights = F.dropout(
weights,
self.dropout,
training=self.training,
mode="upscale_in_train")
weights = F.dropout(weights,
self.dropout,
training=self.training,
mode="upscale_in_train")
out = tensor.matmul(weights, v)
# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
......@@ -294,33 +287,30 @@ class MultiHeadAttention(nn.Layer):
# project to output
out = self.out_proj(out)
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(
self.out_proj.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(
self.out_proj.weight,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [1, -1]
})
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [1, -1]
})
outs = [out]
if self.need_weights:
outs.append(weights)
......@@ -362,36 +352,37 @@ class TransformerDecoder(nn.Layer):
new_caches = []
self.checkpoints = []
if _global_parallel_strategy == "pp":
auto.shard_tensor(
output,
dist_attr={
"process_mesh": PP_MESH_LIST[0],
"dims_mapping": [-1 for i in range(len(output.shape))]
})
auto.shard_tensor(output,
dist_attr={
"process_mesh":
PP_MESH_LIST[0],
"dims_mapping":
[-1 for i in range(len(output.shape))]
})
if _global_parallel_strategy == "dp_pp":
auto.shard_tensor(
output,
dist_attr={
"process_mesh": DPPP_MESH_LIST[0],
"dims_mapping":
[0] + [-1 for i in range(len(output.shape) - 1)]
})
auto.shard_tensor(output,
dist_attr={
"process_mesh":
DPPP_MESH_LIST[0],
"dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)]
})
if _global_parallel_strategy == "mp_pp":
auto.shard_tensor(
output,
dist_attr={
"process_mesh": MPPP_MESH_LIST[0],
"dims_mapping":
[-1] + [-1 for i in range(len(output.shape) - 1)]
})
auto.shard_tensor(output,
dist_attr={
"process_mesh":
MPPP_MESH_LIST[0],
"dims_mapping": [-1] +
[-1 for i in range(len(output.shape) - 1)]
})
if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(
output,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[0],
"dims_mapping":
[0] + [-1 for i in range(len(output.shape) - 1)]
})
auto.shard_tensor(output,
dist_attr={
"process_mesh":
DPMPPP_MESH_LIST[0],
"dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)]
})
for i, mod in enumerate(self.layers):
if cache is None:
if use_cache:
......@@ -400,11 +391,12 @@ class TransformerDecoder(nn.Layer):
mod,
dist_attr={
"process_mesh": PP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0]
})(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
"process_mesh": PP_MESH_LIST[mod.mesh_idx],
"process_mesh":
PP_MESH_LIST[mod.mesh_idx],
"dims_mapping":
[-1 for i in range(len(output.shape))]
})
......@@ -413,11 +405,12 @@ class TransformerDecoder(nn.Layer):
mod,
dist_attr={
"process_mesh": DPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0]
})(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
"process_mesh": DPPP_MESH_LIST[mod.mesh_idx],
"process_mesh":
DPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)]
})
......@@ -426,11 +419,12 @@ class TransformerDecoder(nn.Layer):
mod,
dist_attr={
"process_mesh": MPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0]
})(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
"process_mesh": MPPP_MESH_LIST[mod.mesh_idx],
"process_mesh":
MPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [-1] +
[-1 for i in range(len(output.shape) - 1)]
})
......@@ -439,11 +433,12 @@ class TransformerDecoder(nn.Layer):
mod,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0]
})(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx],
"process_mesh":
DPMPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)]
})
......@@ -456,41 +451,47 @@ class TransformerDecoder(nn.Layer):
new_caches.append(new_cache)
else:
if _global_parallel_strategy == "pp":
output = auto.shard_op(
mod,
dist_attr={
"process_mesh": PP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0]
output = auto.shard_op(mod,
dist_attr={
"process_mesh":
PP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask,
use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
"process_mesh": PP_MESH_LIST[mod.mesh_idx],
"process_mesh":
PP_MESH_LIST[mod.mesh_idx],
"dims_mapping":
[-1 for i in range(len(output.shape))]
})
elif _global_parallel_strategy == "dp_pp":
output = auto.shard_op(
mod,
dist_attr={
"process_mesh": DPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0]
output = auto.shard_op(mod,
dist_attr={
"process_mesh":
DPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask,
use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
"process_mesh": DPPP_MESH_LIST[mod.mesh_idx],
"process_mesh":
DPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)]
})
elif _global_parallel_strategy == "mp_pp":
output = auto.shard_op(
mod,
dist_attr={
"process_mesh": MPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0]
output = auto.shard_op(mod,
dist_attr={
"process_mesh":
MPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask,
use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
"process_mesh": MPPP_MESH_LIST[mod.mesh_idx],
"process_mesh":
MPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [-1] +
[-1 for i in range(len(output.shape) - 1)]
})
......@@ -499,11 +500,12 @@ class TransformerDecoder(nn.Layer):
mod,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)[0]
})(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx],
"process_mesh":
DPMPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)]
})
......@@ -517,8 +519,9 @@ class TransformerDecoder(nn.Layer):
if _global_parallel_strategy == "pp":
output, new_cache = auto.shard_op(
mod,
dist_attr={"process_mesh": PP_MESH_LIST[mod.mesh_idx]})(
output, memory, tgt_mask, use_cache, cache)
dist_attr={"process_mesh": PP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache,
cache)
auto.shard_tensor(
output,
dist_attr={
......@@ -535,7 +538,8 @@ class TransformerDecoder(nn.Layer):
auto.shard_tensor(
output,
dist_attr={
"process_mesh": DPPP_MESH_LIST[mod.mesh_idx],
"process_mesh":
DPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping":
[0] + [-1 for i in range(len(output.shape) - 1)]
})
......@@ -548,7 +552,8 @@ class TransformerDecoder(nn.Layer):
auto.shard_tensor(
output,
dist_attr={
"process_mesh": MPPP_MESH_LIST[mod.mesh_idx],
"process_mesh":
MPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping":
[-1] + [-1 for i in range(len(output.shape) - 1)]
})
......@@ -561,7 +566,8 @@ class TransformerDecoder(nn.Layer):
auto.shard_tensor(
output,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx],
"process_mesh":
DPMPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping":
[0] + [-1 for i in range(len(output.shape) - 1)]
})
......@@ -619,17 +625,20 @@ class TransformerDecoderLayer(nn.Layer):
self.normalize_before = normalize_before
weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
bias_attrs = _convert_param_attr_to_list(bias_attr, 3)
self.self_attn = MultiHeadAttention(
d_model,
nhead,
dropout=attn_dropout,
weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0],
mesh_idx=self.mesh_idx)
self.linear1 = nn.Linear(
d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2])
self.linear2 = nn.Linear(
dim_feedforward, d_model, weight_attrs[2], bias_attr=bias_attrs[2])
self.self_attn = MultiHeadAttention(d_model,
nhead,
dropout=attn_dropout,
weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0],
mesh_idx=self.mesh_idx)
self.linear1 = nn.Linear(d_model,
dim_feedforward,
weight_attrs[2],
bias_attr=bias_attrs[2])
self.linear2 = nn.Linear(dim_feedforward,
d_model,
weight_attrs[2],
bias_attr=bias_attrs[2])
self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
......@@ -652,72 +661,65 @@ class TransformerDecoderLayer(nn.Layer):
if self.normalize_before:
tgt = self.norm2(tgt)
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 0]
})
if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1]
})
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
auto.shard_tensor(self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(
self.linear2.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.linear2.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(
self.linear2.weight,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [1, -1]
})
auto.shard_tensor(self.linear2.weight,
dist_attr={
"process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [1, -1]
})
tgt = self.dropout2(
self.linear2(F.gelu(
self.linear1(tgt), approximate=True)))
self.linear2(F.gelu(self.linear1(tgt), approximate=True)))
tgt = residual + tgt
if not self.normalize_before:
tgt = self.norm2(tgt)
return tgt if use_cache is False else (tgt, incremental_cache)
def gen_cache(self, memory):
incremental_cache = self.self_attn.gen_cache(
memory, type=self.self_attn.Cache)
incremental_cache = self.self_attn.gen_cache(memory,
type=self.self_attn.Cache)
return incremental_cache
......@@ -737,17 +739,15 @@ class GPTEmbeddings(nn.Layer):
self.word_embeddings = nn.Embedding(
vocab_size,
hidden_size,
weight_attr=paddle.ParamAttr(
name="word_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)))
weight_attr=paddle.ParamAttr(name="word_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)))
self.position_embeddings = nn.Embedding(
max_position_embeddings,
hidden_size,
weight_attr=paddle.ParamAttr(
name="pos_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)))
weight_attr=paddle.ParamAttr(name="pos_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)))
self.dropout = nn.Dropout(hidden_dropout_prob)
def forward(self, input_ids, position_ids=None):
......@@ -757,33 +757,29 @@ class GPTEmbeddings(nn.Layer):
position_ids = seq_length - ones
input_embedings = self.word_embeddings(input_ids)
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.word_embeddings.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.word_embeddings.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(
self.word_embeddings.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[0],
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[0],
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(
self.word_embeddings.weight,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[0],
"dims_mapping": [1, -1]
})
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[0],
"dims_mapping": [1, -1]
})
position_embeddings = self.position_embeddings(position_ids)
embeddings = input_embedings + position_embeddings
embeddings = self.dropout(embeddings)
......@@ -821,9 +817,10 @@ class GPTModel(nn.Layer):
self.pipline_mode = (pp_degree is not None and pp_degree > 1)
if self.pipline_mode:
self.layer_per_stage = num_hidden_layers // pp_degree
self.embeddings = GPTEmbeddings(
vocab_size, hidden_size, hidden_dropout_prob,
max_position_embeddings, type_vocab_size, self.initializer_range)
self.embeddings = GPTEmbeddings(vocab_size, hidden_size,
hidden_dropout_prob,
max_position_embeddings,
type_vocab_size, self.initializer_range)
decoder_layers = nn.LayerList()
for i in range(num_hidden_layers):
mesh_index = None
......@@ -831,25 +828,23 @@ class GPTModel(nn.Layer):
if self.layer_per_stage is not None:
mesh_index = i // self.layer_per_stage
decoder_layers.append(
DecoderLayer(
d_model=hidden_size,
nhead=num_attention_heads,
dim_feedforward=intermediate_size,
dropout=hidden_dropout_prob,
activation=hidden_act,
attn_dropout=attention_probs_dropout_prob,
act_dropout=hidden_dropout_prob,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range)),
bias_attr=None,
mesh_idx=mesh_index))
DecoderLayer(d_model=hidden_size,
nhead=num_attention_heads,
dim_feedforward=intermediate_size,
dropout=hidden_dropout_prob,
activation=hidden_act,
attn_dropout=attention_probs_dropout_prob,
act_dropout=hidden_dropout_prob,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range)),
bias_attr=None,
mesh_idx=mesh_index))
Decoder = TransformerDecoder
self.decoder = Decoder(
decoder_layers,
num_hidden_layers,
norm="LayerNorm",
hidden_size=hidden_size)
self.decoder = Decoder(decoder_layers,
num_hidden_layers,
norm="LayerNorm",
hidden_size=hidden_size)
self.checkpoints = []
def forward(self,
......@@ -863,44 +858,44 @@ class GPTModel(nn.Layer):
past_length = 0
if cache is not None:
past_length = paddle.shape(cache[0].k)[-2]
position_ids = paddle.arange(
past_length,
paddle.shape(input_ids)[-1] + past_length,
dtype='int64')
position_ids = paddle.arange(past_length,
paddle.shape(input_ids)[-1] +
past_length,
dtype='int64')
position_ids = position_ids.unsqueeze(0)
position_ids = paddle.fluid.layers.expand_as(position_ids,
input_ids)
embedding_output = self.embeddings(
input_ids=input_ids, position_ids=position_ids)
position_ids = paddle.fluid.layers.expand_as(
position_ids, input_ids)
embedding_output = self.embeddings(input_ids=input_ids,
position_ids=position_ids)
if _global_parallel_strategy == "pp":
auto.shard_tensor(
input_ids,
dist_attr={
"process_mesh": PP_MESH_LIST[0],
"dims_mapping": [-1 for i in range(len(input_ids.shape))]
})
auto.shard_tensor(input_ids,
dist_attr={
"process_mesh":
PP_MESH_LIST[0],
"dims_mapping":
[-1 for i in range(len(input_ids.shape))]
})
if _global_parallel_strategy == "dp_pp":
auto.shard_tensor(
input_ids,
dist_attr={
"process_mesh": DPPP_MESH_LIST[0],
"dims_mapping":
[0] + [-1 for i in range(len(input_ids.shape) - 1)]
})
auto.shard_tensor(input_ids,
dist_attr={
"process_mesh":
DPPP_MESH_LIST[0],
"dims_mapping": [0] +
[-1 for i in range(len(input_ids.shape) - 1)]
})
if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(
input_ids,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[0],
"dims_mapping":
[0] + [-1 for i in range(len(input_ids.shape) - 1)]
})
encoder_outputs = self.decoder(
embedding_output,
memory=None,
tgt_mask=attention_mask,
use_cache=use_cache,
cache=cache)
auto.shard_tensor(input_ids,
dist_attr={
"process_mesh":
DPMPPP_MESH_LIST[0],
"dims_mapping": [0] +
[-1 for i in range(len(input_ids.shape) - 1)]
})
encoder_outputs = self.decoder(embedding_output,
memory=None,
tgt_mask=attention_mask,
use_cache=use_cache,
cache=cache)
self.checkpoints.extend(self.decoder.checkpoints)
return encoder_outputs
......@@ -912,19 +907,19 @@ class GPTForPretraining(nn.Layer):
"""
def __init__(
self,
gpt,
vocab_size=50304,
hidden_size=768,
initializer_range=0.02, ):
self,
gpt,
vocab_size=50304,
hidden_size=768,
initializer_range=0.02,
):
super(GPTForPretraining, self).__init__()
self.output_embeddings = nn.Embedding(
vocab_size,
hidden_size,
weight_attr=paddle.ParamAttr(
name="output_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)))
weight_attr=paddle.ParamAttr(name="output_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)))
self.gpt = gpt
def forward(self,
......@@ -943,8 +938,9 @@ class GPTForPretraining(nn.Layer):
encoder_outputs, cached_kvs = outputs[:2]
else:
encoder_outputs = outputs
logits = paddle.matmul(
encoder_outputs, self.output_embeddings.weight, transpose_y=True)
logits = paddle.matmul(encoder_outputs,
self.output_embeddings.weight,
transpose_y=True)
if use_cache:
return logits, cached_kvs
else:
......
......@@ -38,6 +38,7 @@ PP_MESH_1 = auto.ProcessMesh([2, 3])
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
......@@ -45,42 +46,51 @@ class MLPLayer(nn.Layer):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
bias_attr = None
self.word_embeddings = nn.Embedding(
hidden_size,
hidden_size,
weight_attr=paddle.ParamAttr(
name="word_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)))
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.linear2 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
weight_attr=paddle.ParamAttr(name="word_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)))
self.linear0 = nn.Linear(d_model,
dim_feedforward,
weight_attr,
bias_attr=bias_attr)
self.linear1 = nn.Linear(dim_feedforward,
d_model,
weight_attr,
bias_attr=bias_attr)
self.linear2 = nn.Linear(dim_feedforward,
d_model,
weight_attr,
bias_attr=bias_attr)
def forward(self, input):
auto.shard_tensor(
self.word_embeddings.weight,
dist_attr={"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]})
auto.shard_tensor(
self.linear0.weight,
dist_attr={"process_mesh": PP_MESH_0,
"dims_mapping": [-1, 0]})
auto.shard_tensor(
self.linear1.weight,
dist_attr={"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]})
auto.shard_tensor(
self.linear2.weight,
dist_attr={"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]})
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.linear2.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]
})
w_out = self.word_embeddings(input)
out = self.linear0(w_out)
gelu_out = F.gelu(out, approximate=True)
......@@ -98,21 +108,24 @@ def mlp_forward(train_program, start_program):
hidden_size = 1024
sequence_len = 512
input = static.data(name="input", shape=[batch_size], dtype='int32')
label = static.data(
name="label", shape=[batch_size, 1], dtype='float32')
auto.shard_tensor(
input, dist_attr={"process_mesh": PP_MESH_0,
"dims_mapping": [-1]})
auto.shard_tensor(
label,
dist_attr={"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]})
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
initializer_range=0.02)
label = static.data(name="label",
shape=[batch_size, 1],
dtype='float32')
auto.shard_tensor(input,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
initializer_range=0.02)
predict = mlp(input)
error_cost = paddle.nn.functional.square_error_cost(predict, label)
......@@ -137,13 +150,12 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
complete_train_program = completer.complete_forward_annotation(
train_program)
dist_context.block_state.parse_forward_blocks(complete_train_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
params_grads = parallelizer._generate_backward(complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
# logical partition
partitioner = Partitioner(dist_context, rank_id)
......@@ -171,8 +183,7 @@ def check_send_recv_result(dist_main_prog, rank_id):
if op.type == "send_v2" and "gelu_0.tmp_0@GRAD" in op.input_arg_names[
0]:
send_result = True
if op.type == "recv_v2" and "gelu_0.tmp_0" in op.output_arg_names[
0]:
if op.type == "recv_v2" and "gelu_0.tmp_0" in op.output_arg_names[0]:
recv_result = True
return send_result and recv_result
......@@ -206,6 +217,7 @@ def check_allgather(dist_main_program):
class TestMLPReshard(unittest.TestCase):
def test_mlp_mppp(self):
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
......@@ -230,38 +242,29 @@ class TestMLPReshard(unittest.TestCase):
process_mesh = auto.ProcessMesh(mesh=[0, 3])
with static.program_guard(train_program, startup_program):
x = paddle.static.data(name="x", shape=[4, 4], dtype='float32')
x = auto.shard_tensor(
x,
dist_attr={
"process_mesh": process_mesh,
"dims_mapping": [0, -1]
})
x = auto.shard_tensor(x,
dist_attr={
"process_mesh": process_mesh,
"dims_mapping": [0, -1]
})
w = paddle.static.data(name="w", shape=[4, 4], dtype='float32')
w = auto.shard_tensor(
w,
dist_attr={
"process_mesh": process_mesh,
"dims_mapping": [-1, -1]
})
# y = paddle.distributed.shard_op(paddle.matmul, process_mesh, {
# x.name: [-1, -1],
# w.name: [-1, -1]
# }, **{"x": x,
# "y": w})[0]
y = paddle.distributed.shard_op(
paddle.matmul,
dist_attr={
"process_mesh": process_mesh,
x: {
"dims_mapping": [-1, -1]
},
w: {
"dims_mapping": [-1, -1]
}
})(x, w)[0]
w = auto.shard_tensor(w,
dist_attr={
"process_mesh": process_mesh,
"dims_mapping": [-1, -1]
})
y = paddle.distributed.shard_op(paddle.matmul,
dist_attr={
"process_mesh": process_mesh,
x: {
"dims_mapping": [-1, -1]
},
w: {
"dims_mapping": [-1, -1]
}
})(x, w)
rank_id = 0
dist_context = DistributedContext()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册